cache SkSL headers

This reduces the cost of successive shader compilations by caching the
results of compiling SkSL's headers.

Bug: skia:
Change-Id: If7fc21a9877021c4025ad99dd0981523a25855e0
Reviewed-on: https://skia-review.googlesource.com/123422
Reviewed-by: Brian Salomon <bsalomon@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLCPPCodeGenerator.cpp b/src/sksl/SkSLCPPCodeGenerator.cpp
index 7b740a7..851f8e8 100644
--- a/src/sksl/SkSLCPPCodeGenerator.cpp
+++ b/src/sksl/SkSLCPPCodeGenerator.cpp
@@ -295,10 +295,10 @@
         ASSERT(Expression::kVariableReference_Kind == c.fArguments[0]->fKind);
         int index = 0;
         bool found = false;
-        for (const auto& p : fProgram.fElements) {
-            if (ProgramElement::kVar_Kind == p->fKind) {
-                const VarDeclarations* decls = (const VarDeclarations*) p.get();
-                for (const auto& raw : decls->fVars) {
+        for (const auto& p : fProgram) {
+            if (ProgramElement::kVar_Kind == p.fKind) {
+                const VarDeclarations& decls = (const VarDeclarations&) p;
+                for (const auto& raw : decls.fVars) {
                     VarDeclaration& decl = (VarDeclaration&) *raw;
                     if (decl.fVar == &((VariableReference&) *c.fArguments[0]).fVariable) {
                         found = true;
@@ -435,10 +435,10 @@
 }
 
 void CPPCodeGenerator::writePrivateVars() {
-    for (const auto& p : fProgram.fElements) {
-        if (ProgramElement::kVar_Kind == p->fKind) {
-            const VarDeclarations* decls = (const VarDeclarations*) p.get();
-            for (const auto& raw : decls->fVars) {
+    for (const auto& p : fProgram) {
+        if (ProgramElement::kVar_Kind == p.fKind) {
+            const VarDeclarations& decls = (const VarDeclarations&) p;
+            for (const auto& raw : decls.fVars) {
                 VarDeclaration& decl = (VarDeclaration&) *raw;
                 if (is_private(*decl.fVar)) {
                     if (decl.fVar->fType == *fContext.fFragmentProcessor_Type) {
@@ -458,10 +458,10 @@
 }
 
 void CPPCodeGenerator::writePrivateVarValues() {
-    for (const auto& p : fProgram.fElements) {
-        if (ProgramElement::kVar_Kind == p->fKind) {
-            const VarDeclarations* decls = (const VarDeclarations*) p.get();
-            for (const auto& raw : decls->fVars) {
+    for (const auto& p : fProgram) {
+        if (ProgramElement::kVar_Kind == p.fKind) {
+            const VarDeclarations& decls = (const VarDeclarations&) p;
+            for (const auto& raw : decls.fVars) {
                 VarDeclaration& decl = (VarDeclaration&) *raw;
                 if (is_private(*decl.fVar) && decl.fValue) {
                     this->writef("%s = ", String(decl.fVar->fName).c_str());
@@ -524,10 +524,10 @@
     this->writef("        const %s& _outer = args.fFp.cast<%s>();\n"
                  "        (void) _outer;\n",
                  fFullName.c_str(), fFullName.c_str());
-    for (const auto& p : fProgram.fElements) {
-        if (ProgramElement::kVar_Kind == p->fKind) {
-            const VarDeclarations* decls = (const VarDeclarations*) p.get();
-            for (const auto& raw : decls->fVars) {
+    for (const auto& p : fProgram) {
+        if (ProgramElement::kVar_Kind == p.fKind) {
+            const VarDeclarations& decls = (const VarDeclarations&) p;
+            for (const auto& raw : decls.fVars) {
                 VarDeclaration& decl = (VarDeclaration&) *raw;
                 String nameString(decl.fVar->fName);
                 const char* name = nameString.c_str();
@@ -597,10 +597,10 @@
     }
     if (section) {
         int samplerIndex = 0;
-        for (const auto& p : fProgram.fElements) {
-            if (ProgramElement::kVar_Kind == p->fKind) {
-                const VarDeclarations* decls = (const VarDeclarations*) p.get();
-                for (const auto& raw : decls->fVars) {
+        for (const auto& p : fProgram) {
+            if (ProgramElement::kVar_Kind == p.fKind) {
+                const VarDeclarations& decls = (const VarDeclarations&) p;
+                for (const auto& raw : decls.fVars) {
                     VarDeclaration& decl = (VarDeclaration&) *raw;
                     String nameString(decl.fVar->fName);
                     const char* name = nameString.c_str();
@@ -750,10 +750,10 @@
 
 bool CPPCodeGenerator::generateCode() {
     std::vector<const Variable*> uniforms;
-    for (const auto& p : fProgram.fElements) {
-        if (ProgramElement::kVar_Kind == p->fKind) {
-            const VarDeclarations* decls = (const VarDeclarations*) p.get();
-            for (const auto& raw : decls->fVars) {
+    for (const auto& p : fProgram) {
+        if (ProgramElement::kVar_Kind == p.fKind) {
+            const VarDeclarations& decls = (const VarDeclarations&) p;
+            for (const auto& raw : decls.fVars) {
                 VarDeclaration& decl = (VarDeclaration&) *raw;
                 if ((decl.fVar->fModifiers.fFlags & Modifiers::kUniform_Flag) &&
                            decl.fVar->fType.kind() != Type::kSampler_Kind) {
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 7db9440..a5f5225 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -214,6 +214,28 @@
         printf("Unexpected errors: %s\n", fErrorText.c_str());
     }
     ASSERT(!fErrorCount);
+
+    Program::Settings settings;
+    fIRGenerator->start(&settings, nullptr);
+    fIRGenerator->convertProgram(Program::kFragment_Kind, SKSL_VERT_INCLUDE,
+                                 strlen(SKSL_VERT_INCLUDE), *fTypes, &fVertexInclude);
+    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
+    fVertexSymbolTable = fIRGenerator->fSymbolTable;
+    fIRGenerator->finish();
+
+    fIRGenerator->start(&settings, nullptr);
+    fIRGenerator->convertProgram(Program::kVertex_Kind, SKSL_FRAG_INCLUDE,
+                                 strlen(SKSL_FRAG_INCLUDE), *fTypes, &fFragmentInclude);
+    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
+    fFragmentSymbolTable = fIRGenerator->fSymbolTable;
+    fIRGenerator->finish();
+
+    fIRGenerator->start(&settings, nullptr);
+    fIRGenerator->convertProgram(Program::kGeometry_Kind, SKSL_GEOM_INCLUDE,
+                                 strlen(SKSL_GEOM_INCLUDE), *fTypes, &fGeometryInclude);
+    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
+    fGeometrySymbolTable = fIRGenerator->fSymbolTable;
+    fIRGenerator->finish();
 }
 
 Compiler::~Compiler() {
@@ -1186,31 +1208,39 @@
                                                   const Program::Settings& settings) {
     fErrorText = "";
     fErrorCount = 0;
-    fIRGenerator->start(&settings);
+    std::vector<std::unique_ptr<ProgramElement>>* inherited;
     std::vector<std::unique_ptr<ProgramElement>> elements;
     switch (kind) {
         case Program::kVertex_Kind:
-            fIRGenerator->convertProgram(kind, SKSL_VERT_INCLUDE, strlen(SKSL_VERT_INCLUDE),
-                                         *fTypes, &elements);
+            inherited = &fVertexInclude;
+            fIRGenerator->fSymbolTable = fVertexSymbolTable;
+            fIRGenerator->start(&settings, inherited);
             break;
         case Program::kFragment_Kind:
-            fIRGenerator->convertProgram(kind, SKSL_FRAG_INCLUDE, strlen(SKSL_FRAG_INCLUDE),
-                                         *fTypes, &elements);
+            inherited = &fFragmentInclude;
+            fIRGenerator->fSymbolTable = fFragmentSymbolTable;
+            fIRGenerator->start(&settings, inherited);
             break;
         case Program::kGeometry_Kind:
-            fIRGenerator->convertProgram(kind, SKSL_GEOM_INCLUDE, strlen(SKSL_GEOM_INCLUDE),
-                                         *fTypes, &elements);
+            inherited = &fGeometryInclude;
+            fIRGenerator->fSymbolTable = fGeometrySymbolTable;
+            fIRGenerator->start(&settings, inherited);
             break;
         case Program::kFragmentProcessor_Kind:
+            inherited = nullptr;
+            fIRGenerator->start(&settings, nullptr);
             fIRGenerator->convertProgram(kind, SKSL_FP_INCLUDE, strlen(SKSL_FP_INCLUDE), *fTypes,
                                          &elements);
+            fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
             break;
         case Program::kCPU_Kind:
+            inherited = nullptr;
+            fIRGenerator->start(&settings, nullptr);
             fIRGenerator->convertProgram(kind, SKSL_CPU_INCLUDE, strlen(SKSL_CPU_INCLUDE),
                                          *fTypes, &elements);
+            fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
             break;
     }
-    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
     for (auto& element : elements) {
         if (element->fKind == ProgramElement::kEnum_Kind) {
             ((Enum&) *element).fBuiltin = true;
@@ -1230,6 +1260,7 @@
                                                        std::move(textPtr),
                                                        settings,
                                                        fContext,
+                                                       inherited,
                                                        std::move(elements),
                                                        fIRGenerator->fSymbolTable,
                                                        fIRGenerator->fInputs));
diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h
index 7a188e0..2c8c786 100644
--- a/src/sksl/SkSLCompiler.h
+++ b/src/sksl/SkSLCompiler.h
@@ -133,9 +133,15 @@
 
     Position position(int offset);
 
+    std::vector<std::unique_ptr<ProgramElement>> fVertexInclude;
+    std::shared_ptr<SymbolTable> fVertexSymbolTable;
+    std::vector<std::unique_ptr<ProgramElement>> fFragmentInclude;
+    std::shared_ptr<SymbolTable> fFragmentSymbolTable;
+    std::vector<std::unique_ptr<ProgramElement>> fGeometryInclude;
+    std::shared_ptr<SymbolTable> fGeometrySymbolTable;
+
     std::shared_ptr<SymbolTable> fTypes;
     IRGenerator* fIRGenerator;
-    String fSkiaVertText; // FIXME store parsed version instead
     int fFlags;
 
     const String* fSource;
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index 84c16d4..9b95475 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -1246,9 +1246,9 @@
 void GLSLCodeGenerator::writeHeader() {
     this->write(fProgram.fSettings.fCaps->versionDeclString());
     this->writeLine();
-    for (const auto& e : fProgram.fElements) {
-        if (e->fKind == ProgramElement::kExtension_Kind) {
-            this->writeExtension((Extension&) *e);
+    for (const auto& e : fProgram) {
+        if (e.fKind == ProgramElement::kExtension_Kind) {
+            this->writeExtension((Extension&) e);
         }
     }
     if (!fProgram.fSettings.fCaps->canUseFragCoord()) {
@@ -1340,8 +1340,8 @@
     }
     StringStream body;
     fOut = &body;
-    for (const auto& e : fProgram.fElements) {
-        this->writeProgramElement(*e);
+    for (const auto& e : fProgram) {
+        this->writeProgramElement(e);
     }
     fOut = rawOut;
 
diff --git a/src/sksl/SkSLHCodeGenerator.cpp b/src/sksl/SkSLHCodeGenerator.cpp
index 8984bde..2c406aa 100644
--- a/src/sksl/SkSLHCodeGenerator.cpp
+++ b/src/sksl/SkSLHCodeGenerator.cpp
@@ -283,9 +283,9 @@
     this->writef("class %s : public GrFragmentProcessor {\n"
                  "public:\n",
                  fFullName.c_str());
-    for (const auto& p : fProgram.fElements) {
-        if (ProgramElement::kEnum_Kind == p->fKind && !((Enum&) *p).fBuiltin) {
-            this->writef("%s\n", p->description().c_str());
+    for (const auto& p : fProgram) {
+        if (ProgramElement::kEnum_Kind == p.fKind && !((Enum&) p).fBuiltin) {
+            this->writef("%s\n", p.description().c_str());
         }
     }
     this->writeSection(CLASS_SECTION);
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index a0de551..c081fd7 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -142,7 +142,8 @@
 #undef CAP
 }
 
-void IRGenerator::start(const Program::Settings* settings) {
+void IRGenerator::start(const Program::Settings* settings,
+                        std::vector<std::unique_ptr<ProgramElement>>* inherited) {
     fSettings = settings;
     fCapsMap.clear();
     if (settings->fCaps) {
@@ -154,6 +155,17 @@
     fSkPerVertex = nullptr;
     fRTAdjust = nullptr;
     fRTAdjustInterfaceBlock = nullptr;
+    if (inherited) {
+        for (const auto& e : *inherited) {
+            if (e->fKind == ProgramElement::kInterfaceBlock_Kind) {
+                InterfaceBlock& intf = (InterfaceBlock&) *e;
+                if (intf.fVariable.fName == Compiler::PERVERTEX_NAME) {
+                    ASSERT(!fSkPerVertex);
+                    fSkPerVertex = &intf.fVariable;
+                }
+            }
+        }
+    }
 }
 
 void IRGenerator::finish() {
@@ -861,10 +873,6 @@
                                                                        (int) i)));
         }
     }
-    if (var->fName == Compiler::PERVERTEX_NAME) {
-        ASSERT(!fSkPerVertex);
-        fSkPerVertex = var;
-    }
     return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fOffset,
                                                               var,
                                                               intf.fTypeName,
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index c78c195..2a52e04 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -86,7 +86,8 @@
      * Prepare to compile a program. Resets state, pushes a new symbol table, and installs the
      * settings.
      */
-    void start(const Program::Settings* settings);
+    void start(const Program::Settings* settings,
+               std::vector<std::unique_ptr<ProgramElement>>* inherited);
 
     /**
      * Performs cleanup after compilation is complete.
@@ -182,7 +183,7 @@
     ErrorReporter& fErrors;
     int fInvocations;
     std::vector<std::unique_ptr<ProgramElement>>* fProgramElements;
-    Variable* fSkPerVertex;
+    const Variable* fSkPerVertex = nullptr;
     Variable* fRTAdjust;
     Variable* fRTAdjustInterfaceBlock;
     int fRTAdjustFieldIndex;
diff --git a/src/sksl/SkSLInterpreter.cpp b/src/sksl/SkSLInterpreter.cpp
index c9b7ceb..45e340a 100644
--- a/src/sksl/SkSLInterpreter.cpp
+++ b/src/sksl/SkSLInterpreter.cpp
@@ -29,9 +29,9 @@
 namespace SkSL {
 
 void Interpreter::run() {
-    for (const auto& e : fProgram->fElements) {
-        if (ProgramElement::kFunction_Kind == e->fKind) {
-            const FunctionDefinition& f = (const FunctionDefinition&) *e;
+    for (const auto& e : *fProgram) {
+        if (ProgramElement::kFunction_Kind == e.fKind) {
+            const FunctionDefinition& f = (const FunctionDefinition&) e;
             if ("appendStages" == f.fDeclaration.fName) {
                 this->run(f);
                 return;
@@ -244,9 +244,9 @@
             CallbackCtx* ctx = new CallbackCtx();
             ctx->fInterpreter = this;
             ctx->fn = do_callback;
-            for (const auto& e : fProgram->fElements) {
-                if (ProgramElement::kFunction_Kind == e->fKind) {
-                    const FunctionDefinition& f = (const FunctionDefinition&) *e;
+            for (const auto& e : *fProgram) {
+                if (ProgramElement::kFunction_Kind == e.fKind) {
+                    const FunctionDefinition& f = (const FunctionDefinition&) e;
                     if (&f.fDeclaration ==
                                       ((const FunctionReference&) *a.fArguments[0]).fFunctions[0]) {
                         ctx->fFunction = &f;
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index a8301c2..99cfb5b 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -729,9 +729,9 @@
 }
 
 void MetalCodeGenerator::writeUniformStruct() {
-    for (const auto& e : fProgram.fElements) {
-        if (ProgramElement::kVar_Kind == e->fKind) {
-            VarDeclarations& decls = (VarDeclarations&) *e;
+    for (const auto& e : fProgram) {
+        if (ProgramElement::kVar_Kind == e.fKind) {
+            VarDeclarations& decls = (VarDeclarations&) e;
             if (!decls.fVars.size()) {
                 continue;
             }
@@ -770,9 +770,9 @@
     if (Program::kFragment_Kind == fProgram.fKind) {
         this->write("    float4 position [[position]];\n");
     }
-    for (const auto& e : fProgram.fElements) {
-        if (ProgramElement::kVar_Kind == e->fKind) {
-            VarDeclarations& decls = (VarDeclarations&) *e;
+    for (const auto& e : fProgram) {
+        if (ProgramElement::kVar_Kind == e.fKind) {
+            VarDeclarations& decls = (VarDeclarations&) e;
             if (!decls.fVars.size()) {
                 continue;
             }
@@ -800,9 +800,9 @@
 void MetalCodeGenerator::writeOutputStruct() {
     this->write("struct Outputs {\n");
     this->write("    float4 position [[position]];\n");
-    for (const auto& e : fProgram.fElements) {
-        if (ProgramElement::kVar_Kind == e->fKind) {
-            VarDeclarations& decls = (VarDeclarations&) *e;
+    for (const auto& e : fProgram) {
+        if (ProgramElement::kVar_Kind == e.fKind) {
+            VarDeclarations& decls = (VarDeclarations&) e;
             if (!decls.fVars.size()) {
                 continue;
             }
@@ -978,9 +978,9 @@
     }
     auto found = fRequirements.find(&f);
     if (found == fRequirements.end()) {
-        for (const auto& e : fProgram.fElements) {
-            if (ProgramElement::kFunction_Kind == e->fKind) {
-                const FunctionDefinition& def = (const FunctionDefinition&) *e;
+        for (const auto& e : fProgram) {
+            if (ProgramElement::kFunction_Kind == e.fKind) {
+                const FunctionDefinition& def = (const FunctionDefinition&) e;
                 if (&def.fDeclaration == &f) {
                     Requirements reqs = this->requirements(*def.fBody);
                     fRequirements[&f] = reqs;
@@ -1004,8 +1004,8 @@
     }
     StringStream body;
     fOut = &body;
-    for (const auto& e : fProgram.fElements) {
-        this->writeProgramElement(*e);
+    for (const auto& e : fProgram) {
+        this->writeProgramElement(e);
     }
     fOut = rawOut;
 
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 5f312af..c5a5d2c 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -2854,9 +2854,9 @@
 void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
     ASSERT(fProgram.fKind == Program::kGeometry_Kind);
     int invocations = 1;
-    for (size_t i = 0; i < fProgram.fElements.size(); i++) {
-        if (fProgram.fElements[i]->fKind == ProgramElement::kModifiers_Kind) {
-            const Modifiers& m = ((ModifiersDeclaration&) *fProgram.fElements[i]).fModifiers;
+    for (const auto& e : fProgram) {
+        if (e.fKind == ProgramElement::kModifiers_Kind) {
+            const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
             if (m.fFlags & Modifiers::kIn_Flag) {
                 if (m.fLayout.fInvocations != -1) {
                     invocations = m.fLayout.fInvocations;
@@ -2922,15 +2922,15 @@
     std::set<SpvId> interfaceVars;
     // assign IDs to functions, determine sk_in size
     int skInSize = -1;
-    for (size_t i = 0; i < program.fElements.size(); i++) {
-        switch (program.fElements[i]->fKind) {
+    for (const auto& e : program) {
+        switch (e.fKind) {
             case ProgramElement::kFunction_Kind: {
-                FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
+                FunctionDefinition& f = (FunctionDefinition&) e;
                 fFunctionMap[&f.fDeclaration] = this->nextId();
                 break;
             }
             case ProgramElement::kModifiers_Kind: {
-                Modifiers& m = ((ModifiersDeclaration&) *program.fElements[i]).fModifiers;
+                Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
                 if (m.fFlags & Modifiers::kIn_Flag) {
                     switch (m.fLayout.fPrimitive) {
                         case Layout::kPoints_Primitive: // break
@@ -2954,9 +2954,9 @@
                 break;
         }
     }
-    for (size_t i = 0; i < program.fElements.size(); i++) {
-        if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
-            InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
+    for (const auto& e : program) {
+        if (e.fKind == ProgramElement::kInterfaceBlock_Kind) {
+            InterfaceBlock& intf = (InterfaceBlock&) e;
             if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
                 ASSERT(skInSize != -1);
                 intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
@@ -2969,15 +2969,14 @@
             }
         }
     }
-    for (size_t i = 0; i < program.fElements.size(); i++) {
-        if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
-            this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
-                                  body);
+    for (const auto& e : program) {
+        if (e.fKind == ProgramElement::kVar_Kind) {
+            this->writeGlobalVars(program.fKind, ((VarDeclarations&) e), body);
         }
     }
-    for (size_t i = 0; i < program.fElements.size(); i++) {
-        if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
-            this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
+    for (const auto& e : program) {
+        if (e.fKind == ProgramElement::kFunction_Kind) {
+            this->writeFunction(((FunctionDefinition&) e), body);
         }
     }
     const FunctionDeclaration* main = nullptr;
@@ -3030,11 +3029,9 @@
                                SpvExecutionModeOriginUpperLeft,
                                out);
     }
-    for (size_t i = 0; i < program.fElements.size(); i++) {
-        if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
-            this->writeInstruction(SpvOpSourceExtension,
-                                   ((Extension&) *program.fElements[i]).fName.c_str(),
-                                   out);
+    for (const auto& e : program) {
+        if (e.fKind == ProgramElement::kExtension_Kind) {
+            this->writeInstruction(SpvOpSourceExtension, ((Extension&) e).fName.c_str(), out);
         }
     }
 
diff --git a/src/sksl/SkSLSectionAndParameterHelper.h b/src/sksl/SkSLSectionAndParameterHelper.h
index fccfff4..919cc78 100644
--- a/src/sksl/SkSLSectionAndParameterHelper.h
+++ b/src/sksl/SkSLSectionAndParameterHelper.h
@@ -39,11 +39,11 @@
 class SectionAndParameterHelper {
 public:
     SectionAndParameterHelper(const Program& program, ErrorReporter& errors) {
-        for (const auto& p : program.fElements) {
-            switch (p->fKind) {
+        for (const auto& p : program) {
+            switch (p.fKind) {
                 case ProgramElement::kVar_Kind: {
-                    const VarDeclarations* decls = (const VarDeclarations*) p.get();
-                    for (const auto& raw : decls->fVars) {
+                    const VarDeclarations& decls = (const VarDeclarations&) p;
+                    for (const auto& raw : decls.fVars) {
                         const VarDeclaration& decl = (VarDeclaration&) *raw;
                         if (IsParameter(*decl.fVar)) {
                             fParameters.push_back(decl.fVar);
@@ -52,28 +52,28 @@
                     break;
                 }
                 case ProgramElement::kSection_Kind: {
-                    const Section* s = (const Section*) p.get();
-                    if (IsSupportedSection(s->fName.c_str())) {
-                        if (SectionAcceptsArgument(s->fName.c_str())) {
-                            if (!s->fArgument.size()) {
-                                errors.error(s->fOffset,
-                                             ("section '@" + s->fName +
+                    const Section& s = (const Section&) p;
+                    if (IsSupportedSection(s.fName.c_str())) {
+                        if (SectionAcceptsArgument(s.fName.c_str())) {
+                            if (!s.fArgument.size()) {
+                                errors.error(s.fOffset,
+                                             ("section '@" + s.fName +
                                               "' requires one parameter").c_str());
                             }
-                        } else if (s->fArgument.size()) {
-                            errors.error(s->fOffset,
-                                         ("section '@" + s->fName + "' has no parameters").c_str());
+                        } else if (s.fArgument.size()) {
+                            errors.error(s.fOffset,
+                                         ("section '@" + s.fName + "' has no parameters").c_str());
                         }
                     } else {
-                        errors.error(s->fOffset,
-                                     ("unsupported section '@" + s->fName + "'").c_str());
+                        errors.error(s.fOffset,
+                                     ("unsupported section '@" + s.fName + "'").c_str());
                     }
-                    if (!SectionPermitsDuplicates(s->fName.c_str()) &&
-                            fSections.find(s->fName) != fSections.end()) {
-                        errors.error(s->fOffset,
-                                     ("duplicate section '@" + s->fName + "'").c_str());
+                    if (!SectionPermitsDuplicates(s.fName.c_str()) &&
+                            fSections.find(s.fName) != fSections.end()) {
+                        errors.error(s.fOffset,
+                                     ("duplicate section '@" + s.fName + "'").c_str());
                     }
-                    fSections[s->fName].push_back(s);
+                    fSections[s.fName].push_back(&s);
                     break;
                 }
                 default:
diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h
index 03a94fa..1ad9fff 100644
--- a/src/sksl/ir/SkSLProgram.h
+++ b/src/sksl/ir/SkSLProgram.h
@@ -101,6 +101,92 @@
         }
     };
 
+    class iterator {
+    public:
+        ProgramElement& operator*() {
+            if (fIter1 != fEnd1) {
+                return **fIter1;
+            }
+            return **fIter2;
+        }
+
+        iterator& operator++() {
+            if (fIter1 != fEnd1) {
+                ++fIter1;
+                return *this;
+            }
+            ++fIter2;
+            return *this;
+        }
+
+        bool operator==(const iterator& other) const {
+            return fIter1 == other.fIter1 && fIter2 == other.fIter2;
+        }
+
+        bool operator!=(const iterator& other) const {
+            return !(*this == other);
+        }
+
+    private:
+        using inner = std::vector<std::unique_ptr<ProgramElement>>::iterator;
+
+        iterator(inner begin1, inner end1, inner begin2, inner end2)
+        : fIter1(begin1)
+        , fEnd1(end1)
+        , fIter2(begin2)
+        , fEnd2(end2) {}
+
+        inner fIter1;
+        inner fEnd1;
+        inner fIter2;
+        inner fEnd2;
+
+        friend struct Program;
+    };
+
+    class const_iterator {
+    public:
+        const ProgramElement& operator*() {
+            if (fIter1 != fEnd1) {
+                return **fIter1;
+            }
+            return **fIter2;
+        }
+
+        const_iterator& operator++() {
+            if (fIter1 != fEnd1) {
+                ++fIter1;
+                return *this;
+            }
+            ++fIter2;
+            return *this;
+        }
+
+        bool operator==(const const_iterator& other) const {
+            return fIter1 == other.fIter1 && fIter2 == other.fIter2;
+        }
+
+        bool operator!=(const const_iterator& other) const {
+            return !(*this == other);
+        }
+
+    private:
+        using inner = std::vector<std::unique_ptr<ProgramElement>>::const_iterator;
+
+        const_iterator(inner begin1, inner end1, inner begin2, inner end2)
+        : fIter1(begin1)
+        , fEnd1(end1)
+        , fIter2(begin2)
+        , fEnd2(end2) {}
+
+        inner fIter1;
+        inner fEnd1;
+        inner fIter2;
+        inner fEnd2;
+
+        friend struct Program;
+    };
+
     enum Kind {
         kFragment_Kind,
         kVertex_Kind,
@@ -113,6 +199,7 @@
             std::unique_ptr<String> source,
             Settings settings,
             std::shared_ptr<Context> context,
+            std::vector<std::unique_ptr<ProgramElement>>* inheritedElements,
             std::vector<std::unique_ptr<ProgramElement>> elements,
             std::shared_ptr<SymbolTable> symbols,
             Inputs inputs)
@@ -121,8 +208,41 @@
     , fSettings(settings)
     , fContext(context)
     , fSymbols(symbols)
-    , fElements(std::move(elements))
-    , fInputs(inputs) {}
+    , fInputs(inputs)
+    , fInheritedElements(inheritedElements)
+    , fElements(std::move(elements)) {}
+
+    iterator begin() {
+        if (fInheritedElements) {
+            return iterator(fInheritedElements->begin(), fInheritedElements->end(),
+                            fElements.begin(), fElements.end());
+        }
+        return iterator(fElements.begin(), fElements.end(), fElements.end(), fElements.end());
+    }
+
+    iterator end() {
+        if (fInheritedElements) {
+            return iterator(fInheritedElements->end(), fInheritedElements->end(),
+                            fElements.end(), fElements.end());
+        }
+        return iterator(fElements.end(), fElements.end(), fElements.end(), fElements.end());
+    }
+
+    const_iterator begin() const {
+        if (fInheritedElements) {
+            return const_iterator(fInheritedElements->begin(), fInheritedElements->end(),
+                                  fElements.begin(), fElements.end());
+        }
+        return const_iterator(fElements.begin(), fElements.end(), fElements.end(), fElements.end());
+    }
+
+    const_iterator end() const {
+        if (fInheritedElements) {
+            return const_iterator(fInheritedElements->end(), fInheritedElements->end(),
+                                  fElements.end(), fElements.end());
+        }
+        return const_iterator(fElements.end(), fElements.end(), fElements.end(), fElements.end());
+    }
 
     Kind fKind;
     std::unique_ptr<String> fSource;
@@ -131,8 +251,11 @@
     // it's important to keep fElements defined after (and thus destroyed before) fSymbols,
     // because destroying elements can modify reference counts in symbols
     std::shared_ptr<SymbolTable> fSymbols;
-    std::vector<std::unique_ptr<ProgramElement>> fElements;
     Inputs fInputs;
+
+private:
+    std::vector<std::unique_ptr<ProgramElement>>* fInheritedElements;
+    std::vector<std::unique_ptr<ProgramElement>> fElements;
 };
 
 } // namespace