SkSL performance improvements (plus a couple of minor warning fixes)
GOLD_TRYBOT_URL= https://gold.skia.org/search?issue=2131223002

Committed: https://skia.googlesource.com/skia/+/9fd67a1f53809f5eff1210dd107241b450c48acc
Review-Url: https://codereview.chromium.org/2131223002
diff --git a/src/gpu/vk/GrVkPipelineStateBuilder.cpp b/src/gpu/vk/GrVkPipelineStateBuilder.cpp
index 323ea66..d9d1b6c 100644
--- a/src/gpu/vk/GrVkPipelineStateBuilder.cpp
+++ b/src/gpu/vk/GrVkPipelineStateBuilder.cpp
@@ -93,8 +93,6 @@
 }
 #endif
 
-#include <fstream>
-#include <sstream>
 bool GrVkPipelineStateBuilder::CreateVkShaderModule(const GrVkGpu* gpu,
                                                     VkShaderStageFlagBits stage,
                                                     const GrGLSLShaderBuilder& builder,
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 2b4adc1..0d65b10 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -41,9 +41,10 @@
 : fErrorCount(0) {
     auto types = std::shared_ptr<SymbolTable>(new SymbolTable(*this));
     auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, *this));
-    fIRGenerator = new IRGenerator(symbols, *this);
+    fIRGenerator = new IRGenerator(&fContext, symbols, *this);
     fTypes = types;
-    #define ADD_TYPE(t) types->add(k ## t ## _Type->fName, k ## t ## _Type)
+    #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \
+                                                   fContext.f ## t ## _Type.get())
     ADD_TYPE(Void);
     ADD_TYPE(Float);
     ADD_TYPE(Vec2);
@@ -185,19 +186,21 @@
     fErrorText = "";
     fErrorCount = 0;
     fIRGenerator->pushSymbolTable();
-    std::vector<std::unique_ptr<ProgramElement>> result;
+    std::vector<std::unique_ptr<ProgramElement>> elements;
     switch (kind) {
         case Program::kVertex_Kind:
-            this->internalConvertProgram(SKSL_VERT_INCLUDE, &result);
+            this->internalConvertProgram(SKSL_VERT_INCLUDE, &elements);
             break;
         case Program::kFragment_Kind:
-            this->internalConvertProgram(SKSL_FRAG_INCLUDE, &result);
+            this->internalConvertProgram(SKSL_FRAG_INCLUDE, &elements);
             break;
     }
-    this->internalConvertProgram(text, &result);
+    this->internalConvertProgram(text, &elements);
+    auto result = std::unique_ptr<Program>(new Program(kind, std::move(elements), 
+                                                       fIRGenerator->fSymbolTable));;
     fIRGenerator->popSymbolTable();
     this->writeErrorCount();
-    return std::unique_ptr<Program>(new Program(kind, std::move(result)));;
+    return result;
 }
 
 void Compiler::error(Position position, std::string msg) {
@@ -224,7 +227,7 @@
 bool Compiler::toSPIRV(Program::Kind kind, std::string text, std::ostream& out) {
     auto program = this->convertProgram(kind, text);
     if (fErrorCount == 0) {
-        SkSL::SPIRVCodeGenerator cg;
+        SkSL::SPIRVCodeGenerator cg(&fContext);
         cg.generateCode(*program.get(), out);
         ASSERT(!out.rdstate());
     }
diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h
index 2209427..e63d5f4 100644
--- a/src/sksl/SkSLCompiler.h
+++ b/src/sksl/SkSLCompiler.h
@@ -11,6 +11,7 @@
 #include <vector>
 #include "ir/SkSLProgram.h"
 #include "ir/SkSLSymbolTable.h"
+#include "SkSLContext.h"
 #include "SkSLErrorReporter.h"
 
 namespace SkSL {
@@ -50,6 +51,7 @@
     IRGenerator* fIRGenerator;
     std::string fSkiaVertText; // FIXME store parsed version instead
 
+    Context fContext;
     int fErrorCount;
     std::string fErrorText;
 };
diff --git a/src/sksl/SkSLContext.h b/src/sksl/SkSLContext.h
new file mode 100644
index 0000000..1f124d0
--- /dev/null
+++ b/src/sksl/SkSLContext.h
@@ -0,0 +1,227 @@
+/*
+ * Copyright 2016 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+ 
+#ifndef SKSL_CONTEXT
+#define SKSL_CONTEXT
+
+#include "ir/SkSLType.h"
+
+namespace SkSL {
+
+/**
+ * Contains compiler-wide objects, which currently means the core types.
+ */
+class Context {
+public:
+    Context()
+    : fVoid_Type(new Type("void"))
+    , fDouble_Type(new Type("double", true))
+    , fDVec2_Type(new Type("dvec2", *fDouble_Type, 2))
+    , fDVec3_Type(new Type("dvec3", *fDouble_Type, 3))
+    , fDVec4_Type(new Type("dvec4", *fDouble_Type, 4))
+    , fFloat_Type(new Type("float", true, { fDouble_Type.get() }))
+    , fVec2_Type(new Type("vec2", *fFloat_Type, 2))
+    , fVec3_Type(new Type("vec3", *fFloat_Type, 3))
+    , fVec4_Type(new Type("vec4", *fFloat_Type, 4))
+    , fUInt_Type(new Type("uint", true, { fFloat_Type.get(), fDouble_Type.get() }))
+    , fUVec2_Type(new Type("uvec2", *fUInt_Type, 2))
+    , fUVec3_Type(new Type("uvec3", *fUInt_Type, 3))
+    , fUVec4_Type(new Type("uvec4", *fUInt_Type, 4))
+    , fInt_Type(new Type("int", true, { fUInt_Type.get(), fFloat_Type.get(), fDouble_Type.get() }))
+    , fIVec2_Type(new Type("ivec2", *fInt_Type, 2))
+    , fIVec3_Type(new Type("ivec3", *fInt_Type, 3))
+    , fIVec4_Type(new Type("ivec4", *fInt_Type, 4))
+    , fBool_Type(new Type("bool", false))
+    , fBVec2_Type(new Type("bvec2", *fBool_Type, 2))
+    , fBVec3_Type(new Type("bvec3", *fBool_Type, 3))
+    , fBVec4_Type(new Type("bvec4", *fBool_Type, 4))
+    , fMat2x2_Type(new Type("mat2",   *fFloat_Type, 2, 2))
+    , fMat2x3_Type(new Type("mat2x3", *fFloat_Type, 2, 3))
+    , fMat2x4_Type(new Type("mat2x4", *fFloat_Type, 2, 4))
+    , fMat3x2_Type(new Type("mat3x2", *fFloat_Type, 3, 2))
+    , fMat3x3_Type(new Type("mat3",   *fFloat_Type, 3, 3))
+    , fMat3x4_Type(new Type("mat3x4", *fFloat_Type, 3, 4))
+    , fMat4x2_Type(new Type("mat4x2", *fFloat_Type, 4, 2))
+    , fMat4x3_Type(new Type("mat4x3", *fFloat_Type, 4, 3))
+    , fMat4x4_Type(new Type("mat4",   *fFloat_Type, 4, 4))
+    , fDMat2x2_Type(new Type("dmat2",   *fFloat_Type, 2, 2))
+    , fDMat2x3_Type(new Type("dmat2x3", *fFloat_Type, 2, 3))
+    , fDMat2x4_Type(new Type("dmat2x4", *fFloat_Type, 2, 4))
+    , fDMat3x2_Type(new Type("dmat3x2", *fFloat_Type, 3, 2))
+    , fDMat3x3_Type(new Type("dmat3",   *fFloat_Type, 3, 3))
+    , fDMat3x4_Type(new Type("dmat3x4", *fFloat_Type, 3, 4))
+    , fDMat4x2_Type(new Type("dmat4x2", *fFloat_Type, 4, 2))
+    , fDMat4x3_Type(new Type("dmat4x3", *fFloat_Type, 4, 3))
+    , fDMat4x4_Type(new Type("dmat4",   *fFloat_Type, 4, 4))
+    , fSampler1D_Type(new Type("sampler1D", SpvDim1D, false, false, false, true))
+    , fSampler2D_Type(new Type("sampler2D", SpvDim2D, false, false, false, true))
+    , fSampler3D_Type(new Type("sampler3D", SpvDim3D, false, false, false, true))
+    , fSamplerCube_Type(new Type("samplerCube"))
+    , fSampler2DRect_Type(new Type("sampler2DRect"))
+    , fSampler1DArray_Type(new Type("sampler1DArray"))
+    , fSampler2DArray_Type(new Type("sampler2DArray"))
+    , fSamplerCubeArray_Type(new Type("samplerCubeArray"))
+    , fSamplerBuffer_Type(new Type("samplerBuffer"))
+    , fSampler2DMS_Type(new Type("sampler2DMS"))
+    , fSampler2DMSArray_Type(new Type("sampler2DMSArray"))
+    , fSampler1DShadow_Type(new Type("sampler1DShadow"))
+    , fSampler2DShadow_Type(new Type("sampler2DShadow"))
+    , fSamplerCubeShadow_Type(new Type("samplerCubeShadow"))
+    , fSampler2DRectShadow_Type(new Type("sampler2DRectShadow"))
+    , fSampler1DArrayShadow_Type(new Type("sampler1DArrayShadow"))
+    , fSampler2DArrayShadow_Type(new Type("sampler2DArrayShadow"))
+    , fSamplerCubeArrayShadow_Type(new Type("samplerCubeArrayShadow"))
+    // FIXME figure out what we're supposed to do with the gsampler et al. types)
+    , fGSampler1D_Type(new Type("$gsampler1D", static_type(*fSampler1D_Type)))
+    , fGSampler2D_Type(new Type("$gsampler2D", static_type(*fSampler2D_Type)))
+    , fGSampler3D_Type(new Type("$gsampler3D", static_type(*fSampler3D_Type)))
+    , fGSamplerCube_Type(new Type("$gsamplerCube", static_type(*fSamplerCube_Type)))
+    , fGSampler2DRect_Type(new Type("$gsampler2DRect", static_type(*fSampler2DRect_Type)))
+    , fGSampler1DArray_Type(new Type("$gsampler1DArray", static_type(*fSampler1DArray_Type)))
+    , fGSampler2DArray_Type(new Type("$gsampler2DArray", static_type(*fSampler2DArray_Type)))
+    , fGSamplerCubeArray_Type(new Type("$gsamplerCubeArray", static_type(*fSamplerCubeArray_Type)))
+    , fGSamplerBuffer_Type(new Type("$gsamplerBuffer", static_type(*fSamplerBuffer_Type)))
+    , fGSampler2DMS_Type(new Type("$gsampler2DMS", static_type(*fSampler2DMS_Type)))
+    , fGSampler2DMSArray_Type(new Type("$gsampler2DMSArray", static_type(*fSampler2DMSArray_Type)))
+    , fGSampler2DArrayShadow_Type(new Type("$gsampler2DArrayShadow", 
+                                           static_type(*fSampler2DArrayShadow_Type)))
+    , fGSamplerCubeArrayShadow_Type(new Type("$gsamplerCubeArrayShadow",
+                                             static_type(*fSamplerCubeArrayShadow_Type)))
+    , fGenType_Type(new Type("$genType", { fFloat_Type.get(), fVec2_Type.get(), fVec3_Type.get(), 
+                                           fVec4_Type.get() }))
+    , fGenDType_Type(new Type("$genDType", { fDouble_Type.get(), fDVec2_Type.get(), 
+                                             fDVec3_Type.get(), fDVec4_Type.get() }))
+    , fGenIType_Type(new Type("$genIType", { fInt_Type.get(), fIVec2_Type.get(), fIVec3_Type.get(), 
+                                             fIVec4_Type.get() }))
+    , fGenUType_Type(new Type("$genUType", { fUInt_Type.get(), fUVec2_Type.get(), fUVec3_Type.get(), 
+                                             fUVec4_Type.get() }))
+    , fGenBType_Type(new Type("$genBType", { fBool_Type.get(), fBVec2_Type.get(), fBVec3_Type.get(), 
+                                             fBVec4_Type.get() }))
+    , fMat_Type(new Type("$mat"))
+    , fVec_Type(new Type("$vec", { fVec2_Type.get(), fVec2_Type.get(), fVec3_Type.get(),
+                                   fVec4_Type.get() }))
+    , fGVec_Type(new Type("$gvec"))
+    , fGVec2_Type(new Type("$gvec2"))
+    , fGVec3_Type(new Type("$gvec3"))
+    , fGVec4_Type(new Type("$gvec4", static_type(*fVec4_Type)))
+    , fDVec_Type(new Type("$dvec"))
+    , fIVec_Type(new Type("$ivec"))
+    , fUVec_Type(new Type("$uvec"))
+    , fBVec_Type(new Type("$bvec", { fBVec2_Type.get(), fBVec2_Type.get(), fBVec3_Type.get(), 
+                                     fBVec4_Type.get() }))
+    , fInvalid_Type(new Type("<INVALID>")) {}
+
+    static std::vector<const Type*> static_type(const Type& t) {
+        return { &t, &t, &t, &t };   
+    }
+
+    const std::unique_ptr<Type> fVoid_Type;
+
+    const std::unique_ptr<Type> fDouble_Type;
+    const std::unique_ptr<Type> fDVec2_Type;
+    const std::unique_ptr<Type> fDVec3_Type;
+    const std::unique_ptr<Type> fDVec4_Type;
+
+    const std::unique_ptr<Type> fFloat_Type;
+    const std::unique_ptr<Type> fVec2_Type;
+    const std::unique_ptr<Type> fVec3_Type;
+    const std::unique_ptr<Type> fVec4_Type;
+
+    const std::unique_ptr<Type> fUInt_Type;
+    const std::unique_ptr<Type> fUVec2_Type;
+    const std::unique_ptr<Type> fUVec3_Type;
+    const std::unique_ptr<Type> fUVec4_Type;
+
+    const std::unique_ptr<Type> fInt_Type;
+    const std::unique_ptr<Type> fIVec2_Type;
+    const std::unique_ptr<Type> fIVec3_Type;
+    const std::unique_ptr<Type> fIVec4_Type;
+
+    const std::unique_ptr<Type> fBool_Type;
+    const std::unique_ptr<Type> fBVec2_Type;
+    const std::unique_ptr<Type> fBVec3_Type;
+    const std::unique_ptr<Type> fBVec4_Type;
+
+    const std::unique_ptr<Type> fMat2x2_Type;
+    const std::unique_ptr<Type> fMat2x3_Type;
+    const std::unique_ptr<Type> fMat2x4_Type;
+    const std::unique_ptr<Type> fMat3x2_Type;
+    const std::unique_ptr<Type> fMat3x3_Type;
+    const std::unique_ptr<Type> fMat3x4_Type;
+    const std::unique_ptr<Type> fMat4x2_Type;
+    const std::unique_ptr<Type> fMat4x3_Type;
+    const std::unique_ptr<Type> fMat4x4_Type;
+
+    const std::unique_ptr<Type> fDMat2x2_Type;
+    const std::unique_ptr<Type> fDMat2x3_Type;
+    const std::unique_ptr<Type> fDMat2x4_Type;
+    const std::unique_ptr<Type> fDMat3x2_Type;
+    const std::unique_ptr<Type> fDMat3x3_Type;
+    const std::unique_ptr<Type> fDMat3x4_Type;
+    const std::unique_ptr<Type> fDMat4x2_Type;
+    const std::unique_ptr<Type> fDMat4x3_Type;
+    const std::unique_ptr<Type> fDMat4x4_Type;
+
+    const std::unique_ptr<Type> fSampler1D_Type;
+    const std::unique_ptr<Type> fSampler2D_Type;
+    const std::unique_ptr<Type> fSampler3D_Type;
+    const std::unique_ptr<Type> fSamplerCube_Type;
+    const std::unique_ptr<Type> fSampler2DRect_Type;
+    const std::unique_ptr<Type> fSampler1DArray_Type;
+    const std::unique_ptr<Type> fSampler2DArray_Type;
+    const std::unique_ptr<Type> fSamplerCubeArray_Type;
+    const std::unique_ptr<Type> fSamplerBuffer_Type;
+    const std::unique_ptr<Type> fSampler2DMS_Type;
+    const std::unique_ptr<Type> fSampler2DMSArray_Type;
+    const std::unique_ptr<Type> fSampler1DShadow_Type;
+    const std::unique_ptr<Type> fSampler2DShadow_Type;
+    const std::unique_ptr<Type> fSamplerCubeShadow_Type;
+    const std::unique_ptr<Type> fSampler2DRectShadow_Type;
+    const std::unique_ptr<Type> fSampler1DArrayShadow_Type;
+    const std::unique_ptr<Type> fSampler2DArrayShadow_Type;
+    const std::unique_ptr<Type> fSamplerCubeArrayShadow_Type;
+
+    const std::unique_ptr<Type> fGSampler1D_Type;
+    const std::unique_ptr<Type> fGSampler2D_Type;
+    const std::unique_ptr<Type> fGSampler3D_Type;
+    const std::unique_ptr<Type> fGSamplerCube_Type;
+    const std::unique_ptr<Type> fGSampler2DRect_Type;
+    const std::unique_ptr<Type> fGSampler1DArray_Type;
+    const std::unique_ptr<Type> fGSampler2DArray_Type;
+    const std::unique_ptr<Type> fGSamplerCubeArray_Type;
+    const std::unique_ptr<Type> fGSamplerBuffer_Type;
+    const std::unique_ptr<Type> fGSampler2DMS_Type;
+    const std::unique_ptr<Type> fGSampler2DMSArray_Type;
+    const std::unique_ptr<Type> fGSampler2DArrayShadow_Type;
+    const std::unique_ptr<Type> fGSamplerCubeArrayShadow_Type;
+
+    const std::unique_ptr<Type> fGenType_Type;
+    const std::unique_ptr<Type> fGenDType_Type;
+    const std::unique_ptr<Type> fGenIType_Type;
+    const std::unique_ptr<Type> fGenUType_Type;
+    const std::unique_ptr<Type> fGenBType_Type;
+
+    const std::unique_ptr<Type> fMat_Type;
+
+    const std::unique_ptr<Type> fVec_Type;
+
+    const std::unique_ptr<Type> fGVec_Type;
+    const std::unique_ptr<Type> fGVec2_Type;
+    const std::unique_ptr<Type> fGVec3_Type;
+    const std::unique_ptr<Type> fGVec4_Type;
+    const std::unique_ptr<Type> fDVec_Type;
+    const std::unique_ptr<Type> fIVec_Type;
+    const std::unique_ptr<Type> fUVec_Type;
+
+    const std::unique_ptr<Type> fBVec_Type;
+
+    const std::unique_ptr<Type> fInvalid_Type;
+};
+
+} // namespace
+
+#endif
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 2cc7eac..f250c4b 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -66,11 +66,12 @@
     std::shared_ptr<SymbolTable> fPrevious;
 };
 
-IRGenerator::IRGenerator(std::shared_ptr<SymbolTable> symbolTable, 
+IRGenerator::IRGenerator(const Context* context, std::shared_ptr<SymbolTable> symbolTable, 
                          ErrorReporter& errorReporter)
-: fSymbolTable(std::move(symbolTable))
-, fErrors(errorReporter) {
-}
+: fContext(*context)
+, fCurrentFunction(nullptr)
+, fSymbolTable(std::move(symbolTable))
+, fErrors(errorReporter) {}
 
 void IRGenerator::pushSymbolTable() {
     fSymbolTable.reset(new SymbolTable(std::move(fSymbolTable), fErrors));
@@ -123,7 +124,7 @@
         }
         statements.push_back(std::move(statement));
     }
-    return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements)));
+    return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements), fSymbolTable));
 }
 
 std::unique_ptr<Statement> IRGenerator::convertVarDeclarationStatement(
@@ -141,22 +142,22 @@
 
 std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarDeclaration& decl,
                                                                    Variable::Storage storage) {
-    std::vector<std::shared_ptr<Variable>> variables;
+    std::vector<const Variable*> variables;
     std::vector<std::vector<std::unique_ptr<Expression>>> sizes;
     std::vector<std::unique_ptr<Expression>> values;
-    std::shared_ptr<Type> baseType = this->convertType(*decl.fType);
+    const Type* baseType = this->convertType(*decl.fType);
     if (!baseType) {
         return nullptr;
     }
     for (size_t i = 0; i < decl.fNames.size(); i++) {
         Modifiers modifiers = this->convertModifiers(decl.fModifiers);
-        std::shared_ptr<Type> type = baseType;
+        const Type* type = baseType;
         ASSERT(type->kind() != Type::kArray_Kind);
         std::vector<std::unique_ptr<Expression>> currentVarSizes;
         for (size_t j = 0; j < decl.fSizes[i].size(); j++) {
             if (decl.fSizes[i][j]) {
                 ASTExpression& rawSize = *decl.fSizes[i][j];
-                auto size = this->coerce(this->convertExpression(rawSize), kInt_Type);
+                auto size = this->coerce(this->convertExpression(rawSize), *fContext.fInt_Type);
                 if (!size) {
                     return nullptr;
                 }
@@ -172,27 +173,28 @@
                     count = -1;
                     name += "[]";
                 }
-                type = std::shared_ptr<Type>(new Type(name, Type::kArray_Kind, type, (int) count));
+                type = new Type(name, Type::kArray_Kind, *type, (int) count);
+                fSymbolTable->takeOwnership((Type*) type);
                 currentVarSizes.push_back(std::move(size));
             } else {
-                type = std::shared_ptr<Type>(new Type(type->fName + "[]", Type::kArray_Kind, type, 
-                                                      -1));
+                type = new Type(type->fName + "[]", Type::kArray_Kind, *type, -1);
+                fSymbolTable->takeOwnership((Type*) type);
                 currentVarSizes.push_back(nullptr);
             }
         }
         sizes.push_back(std::move(currentVarSizes));
-        auto var = std::make_shared<Variable>(decl.fPosition, modifiers, decl.fNames[i], type, 
-                                              storage);
-        variables.push_back(var);
+        auto var = std::unique_ptr<Variable>(new Variable(decl.fPosition, modifiers, decl.fNames[i], 
+                                                          *type, storage));
         std::unique_ptr<Expression> value;
         if (decl.fValues[i]) {
             value = this->convertExpression(*decl.fValues[i]);
             if (!value) {
                 return nullptr;
             }
-            value = this->coerce(std::move(value), type);
+            value = this->coerce(std::move(value), *type);
         }
-        fSymbolTable->add(var->fName, var);
+        variables.push_back(var.get());
+        fSymbolTable->add(decl.fNames[i], std::move(var));
         values.push_back(std::move(value));
     }
     return std::unique_ptr<VarDeclaration>(new VarDeclaration(decl.fPosition, std::move(variables), 
@@ -200,7 +202,8 @@
 }
 
 std::unique_ptr<Statement> IRGenerator::convertIf(const ASTIfStatement& s) {
-    std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest), kBool_Type);
+    std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest), 
+                                                    *fContext.fBool_Type);
     if (!test) {
         return nullptr;
     }
@@ -225,7 +228,8 @@
     if (!initializer) {
         return nullptr;
     }
-    std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*f.fTest), kBool_Type);
+    std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*f.fTest), 
+                                                    *fContext.fBool_Type);
     if (!test) {
         return nullptr;
     }
@@ -240,11 +244,12 @@
     }
     return std::unique_ptr<Statement>(new ForStatement(f.fPosition, std::move(initializer), 
                                                        std::move(test), std::move(next),
-                                                       std::move(statement)));
+                                                       std::move(statement), fSymbolTable));
 }
 
 std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTWhileStatement& w) {
-    std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest), kBool_Type);
+    std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest), 
+                                                    *fContext.fBool_Type);
     if (!test) {
         return nullptr;
     }
@@ -257,7 +262,8 @@
 }
 
 std::unique_ptr<Statement> IRGenerator::convertDo(const ASTDoStatement& d) {
-    std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest), kBool_Type);
+    std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest),
+                                                    *fContext.fBool_Type);
     if (!test) {
         return nullptr;
     }
@@ -286,7 +292,7 @@
         if (!result) {
             return nullptr;
         }
-        if (fCurrentFunction->fReturnType == kVoid_Type) {
+        if (fCurrentFunction->fReturnType == *fContext.fVoid_Type) {
             fErrors.error(result->fPosition, "may not return a value from a void function");
         } else {
             result = this->coerce(std::move(result), fCurrentFunction->fReturnType);
@@ -296,9 +302,9 @@
         }
         return std::unique_ptr<Statement>(new ReturnStatement(std::move(result)));
     } else {
-        if (fCurrentFunction->fReturnType != kVoid_Type) {
+        if (fCurrentFunction->fReturnType != *fContext.fVoid_Type) {
             fErrors.error(r.fPosition, "expected function to return '" +
-                                       fCurrentFunction->fReturnType->description() + "'");
+                                       fCurrentFunction->fReturnType.description() + "'");
         }
         return std::unique_ptr<Statement>(new ReturnStatement(r.fPosition));
     }
@@ -316,80 +322,74 @@
     return std::unique_ptr<Statement>(new DiscardStatement(d.fPosition));
 }
 
-static std::shared_ptr<Type> expand_generics(std::shared_ptr<Type> type, int i) {
-    if (type->kind() == Type::kGeneric_Kind) {
-        return type->coercibleTypes()[i];
+static const Type& expand_generics(const Type& type, int i) {
+    if (type.kind() == Type::kGeneric_Kind) {
+        return *type.coercibleTypes()[i];
     }
     return type;
 }
 
-static void expand_generics(FunctionDeclaration& decl, 
-                            SymbolTable& symbolTable) {
+static void expand_generics(const FunctionDeclaration& decl, 
+                            std::shared_ptr<SymbolTable> symbolTable) {
     for (int i = 0; i < 4; i++) {
-        std::shared_ptr<Type> returnType = expand_generics(decl.fReturnType, i);
-        std::vector<std::shared_ptr<Variable>> arguments;
+        const Type& returnType = expand_generics(decl.fReturnType, i);
+        std::vector<const Variable*> parameters;
         for (const auto& p : decl.fParameters) {
-            arguments.push_back(std::shared_ptr<Variable>(new Variable(
-                                                                    p->fPosition, 
-                                                                    Modifiers(p->fModifiers), 
-                                                                    p->fName,
-                                                                    expand_generics(p->fType, i),
-                                                                    Variable::kParameter_Storage)));
+            Variable* var = new Variable(p->fPosition, Modifiers(p->fModifiers), p->fName,
+                                         expand_generics(p->fType, i),
+                                         Variable::kParameter_Storage);
+            symbolTable->takeOwnership(var);
+            parameters.push_back(var);
         }
-        std::shared_ptr<FunctionDeclaration> expanded(new FunctionDeclaration(
-                                                                            decl.fPosition, 
-                                                                            decl.fName, 
-                                                                            std::move(arguments), 
-                                                                            std::move(returnType)));
-        symbolTable.add(expanded->fName, expanded);
+        symbolTable->add(decl.fName, std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration(
+                                                                           decl.fPosition,
+                                                                           decl.fName,
+                                                                           std::move(parameters),
+                                                                           std::move(returnType))));
     }
 }
 
 std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFunction& f) {
-    std::shared_ptr<SymbolTable> old = fSymbolTable;
-    AutoSymbolTable table(this);
     bool isGeneric;
-    std::shared_ptr<Type> returnType = this->convertType(*f.fReturnType);
+    const Type* returnType = this->convertType(*f.fReturnType);
     if (!returnType) {
         return nullptr;
     }
     isGeneric = returnType->kind() == Type::kGeneric_Kind;
-    std::vector<std::shared_ptr<Variable>> parameters;
+    std::vector<const Variable*> parameters;
     for (const auto& param : f.fParameters) {
-        std::shared_ptr<Type> type = this->convertType(*param->fType);
+        const Type* type = this->convertType(*param->fType);
         if (!type) {
             return nullptr;
         }
         for (int j = (int) param->fSizes.size() - 1; j >= 0; j--) {
             int size = param->fSizes[j];
             std::string name = type->name() + "[" + to_string(size) + "]";
-            type = std::shared_ptr<Type>(new Type(std::move(name), Type::kArray_Kind, 
-                                                  std::move(type), size));
+            Type* newType = new Type(std::move(name), Type::kArray_Kind, *type, size);
+            fSymbolTable->takeOwnership(newType);
+            type = newType;
         }
         std::string name = param->fName;
         Modifiers modifiers = this->convertModifiers(param->fModifiers);
         Position pos = param->fPosition;
-        std::shared_ptr<Variable> var = std::shared_ptr<Variable>(new Variable(
-                                                                     pos, 
-                                                                     modifiers, 
-                                                                     std::move(name), 
-                                                                     type,
-                                                                     Variable::kParameter_Storage));
-        parameters.push_back(std::move(var));
+        Variable* var = new Variable(pos, modifiers, std::move(name), *type,
+                                     Variable::kParameter_Storage);
+        fSymbolTable->takeOwnership(var);
+        parameters.push_back(var);
         isGeneric |= type->kind() == Type::kGeneric_Kind;
     }
 
     // find existing declaration
-    std::shared_ptr<FunctionDeclaration> decl;
-    auto entry = (*old)[f.fName];
+    const FunctionDeclaration* decl = nullptr;
+    auto entry = (*fSymbolTable)[f.fName];
     if (entry) {
-        std::vector<std::shared_ptr<FunctionDeclaration>> functions;
+        std::vector<const FunctionDeclaration*> functions;
         switch (entry->fKind) {
             case Symbol::kUnresolvedFunction_Kind:
-                functions = std::static_pointer_cast<UnresolvedFunction>(entry)->fFunctions;
+                functions = ((UnresolvedFunction*) entry)->fFunctions;
                 break;
             case Symbol::kFunctionDeclaration_Kind:
-                functions.push_back(std::static_pointer_cast<FunctionDeclaration>(entry));
+                functions.push_back((FunctionDeclaration*) entry);
                 break;
             default:
                 fErrors.error(f.fPosition, "symbol '" + f.fName + "' was already defined");
@@ -406,11 +406,8 @@
                     }
                 }
                 if (match) {
-                    if (returnType != other->fReturnType) {
-                        FunctionDeclaration newDecl = FunctionDeclaration(f.fPosition, 
-                                                                          f.fName, 
-                                                                          parameters, 
-                                                                          returnType);
+                    if (*returnType != other->fReturnType) {
+                        FunctionDeclaration newDecl(f.fPosition, f.fName, parameters, *returnType);
                         fErrors.error(f.fPosition, "functions '" + newDecl.description() +
                                                    "' and '" + other->description() + 
                                                    "' differ only in return type");
@@ -424,7 +421,6 @@
                                                        "declaration and definition");
                             return nullptr;
                         }
-                        fSymbolTable->add(parameters[i]->fName, decl->fParameters[i]);
                     }
                     if (other->fDefined) {
                         fErrors.error(f.fPosition, "duplicate definition of " + 
@@ -437,28 +433,36 @@
     }
     if (!decl) {
         // couldn't find an existing declaration
-        decl.reset(new FunctionDeclaration(f.fPosition, f.fName, parameters, returnType));
-        for (auto var : parameters) {
-            fSymbolTable->add(var->fName, var);
+        if (isGeneric) {
+            ASSERT(!f.fBody);
+            expand_generics(FunctionDeclaration(f.fPosition, f.fName, parameters, *returnType), 
+                            fSymbolTable);
+        } else {
+            auto newDecl = std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration(
+                                                                                     f.fPosition, 
+                                                                                     f.fName, 
+                                                                                     parameters, 
+                                                                                     *returnType));
+            decl = newDecl.get();
+            fSymbolTable->add(decl->fName, std::move(newDecl));
         }
     }
-    if (isGeneric) {
-        ASSERT(!f.fBody);
-        expand_generics(*decl, *old);
-    } else {
-        old->add(decl->fName, decl);
-        if (f.fBody) {
-            ASSERT(!fCurrentFunction);
-            fCurrentFunction = decl;
-            decl->fDefined = true;
-            std::unique_ptr<Block> body = this->convertBlock(*f.fBody);
-            fCurrentFunction = nullptr;
-            if (!body) {
-                return nullptr;
-            }
-            return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, decl, 
-                                                                              std::move(body)));
+    if (f.fBody) {
+        ASSERT(!fCurrentFunction);
+        fCurrentFunction = decl;
+        decl->fDefined = true;
+        std::shared_ptr<SymbolTable> old = fSymbolTable;
+        AutoSymbolTable table(this);
+        for (size_t i = 0; i < parameters.size(); i++) {
+            fSymbolTable->addWithoutOwnership(parameters[i]->fName, decl->fParameters[i]);
         }
+        std::unique_ptr<Block> body = this->convertBlock(*f.fBody);
+        fCurrentFunction = nullptr;
+        if (!body) {
+            return nullptr;
+        }
+        return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, *decl, 
+                                                                          std::move(body)));
     }
     return nullptr;
 }
@@ -488,28 +492,26 @@
             }
         }        
     }
-    std::shared_ptr<Type> type = std::shared_ptr<Type>(new Type(intf.fInterfaceName, fields));
+    Type* type = new Type(intf.fInterfaceName, fields);
+    fSymbolTable->takeOwnership(type);
     std::string name = intf.fValueName.length() > 0 ? intf.fValueName : intf.fInterfaceName;
-    std::shared_ptr<Variable> var = std::shared_ptr<Variable>(new Variable(intf.fPosition, mods, 
-                                                                          name, type,
-                                                                Variable::kGlobal_Storage));
+    Variable* var = new Variable(intf.fPosition, mods, name, *type, Variable::kGlobal_Storage);
+    fSymbolTable->takeOwnership(var);
     if (intf.fValueName.length()) {
-        old->add(intf.fValueName, var);
-
+        old->addWithoutOwnership(intf.fValueName, var);
     } else {
         for (size_t i = 0; i < fields.size(); i++) {
-            std::shared_ptr<Field> field = std::shared_ptr<Field>(new Field(intf.fPosition, var, 
-                                                                            (int) i));
-            old->add(fields[i].fName, field);
+            old->add(fields[i].fName, std::unique_ptr<Field>(new Field(intf.fPosition, *var, 
+                                                                       (int) i)));
         }
     }
-    return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, var));
+    return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, *var, fSymbolTable));
 }
 
-std::shared_ptr<Type> IRGenerator::convertType(const ASTType& type) {
-    std::shared_ptr<Symbol> result = (*fSymbolTable)[type.fName];
+const Type* IRGenerator::convertType(const ASTType& type) {
+    const Symbol* result = (*fSymbolTable)[type.fName];
     if (result && result->fKind == Symbol::kType_Kind) {
-        return std::static_pointer_cast<Type>(result);
+        return (const Type*) result;
     }
     fErrors.error(type.fPosition, "unknown type '" + type.fName + "'");
     return nullptr;
@@ -520,13 +522,13 @@
         case ASTExpression::kIdentifier_Kind:
             return this->convertIdentifier((ASTIdentifier&) expr);
         case ASTExpression::kBool_Kind:
-            return std::unique_ptr<Expression>(new BoolLiteral(expr.fPosition,
+            return std::unique_ptr<Expression>(new BoolLiteral(fContext, expr.fPosition,
                                                                ((ASTBoolLiteral&) expr).fValue));
         case ASTExpression::kInt_Kind:
-            return std::unique_ptr<Expression>(new IntLiteral(expr.fPosition,
+            return std::unique_ptr<Expression>(new IntLiteral(fContext, expr.fPosition,
                                                               ((ASTIntLiteral&) expr).fValue));
         case ASTExpression::kFloat_Kind:
-            return std::unique_ptr<Expression>(new FloatLiteral(expr.fPosition,
+            return std::unique_ptr<Expression>(new FloatLiteral(fContext, expr.fPosition,
                                                                 ((ASTFloatLiteral&) expr).fValue));
         case ASTExpression::kBinary_Kind:
             return this->convertBinaryExpression((ASTBinaryExpression&) expr);
@@ -542,40 +544,42 @@
 }
 
 std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier& identifier) {
-    std::shared_ptr<Symbol> result = (*fSymbolTable)[identifier.fText];
+    const Symbol* result = (*fSymbolTable)[identifier.fText];
     if (!result) {
         fErrors.error(identifier.fPosition, "unknown identifier '" + identifier.fText + "'");
         return nullptr;
     }
     switch (result->fKind) {
         case Symbol::kFunctionDeclaration_Kind: {
-            std::vector<std::shared_ptr<FunctionDeclaration>> f = {
-                std::static_pointer_cast<FunctionDeclaration>(result)
+            std::vector<const FunctionDeclaration*> f = {
+                (const FunctionDeclaration*) result
             };
-            return std::unique_ptr<FunctionReference>(new FunctionReference(identifier.fPosition,
-                                                                            std::move(f)));
+            return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
+                                                                            identifier.fPosition,
+                                                                            f));
         }
         case Symbol::kUnresolvedFunction_Kind: {
-            auto f = std::static_pointer_cast<UnresolvedFunction>(result);
-            return std::unique_ptr<FunctionReference>(new FunctionReference(identifier.fPosition,
+            const UnresolvedFunction* f = (const UnresolvedFunction*) result;
+            return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
+                                                                            identifier.fPosition,
                                                                             f->fFunctions));
         }
         case Symbol::kVariable_Kind: {
-            std::shared_ptr<Variable> var = std::static_pointer_cast<Variable>(result);
-            this->markReadFrom(var);
+            const Variable* var = (const Variable*) result;
+            this->markReadFrom(*var);
             return std::unique_ptr<VariableReference>(new VariableReference(identifier.fPosition,
-                                                                            std::move(var)));
+                                                                            *var));
         }
         case Symbol::kField_Kind: {
-            std::shared_ptr<Field> field = std::static_pointer_cast<Field>(result);
+            const Field* field = (const Field*) result;
             VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner);
             return std::unique_ptr<Expression>(new FieldAccess(std::unique_ptr<Expression>(base),
                                                                field->fFieldIndex));
         }
         case Symbol::kType_Kind: {
-            auto t = std::static_pointer_cast<Type>(result);
-            return std::unique_ptr<TypeReference>(new TypeReference(identifier.fPosition, 
-                                                                    std::move(t)));
+            const Type* t = (const Type*) result;
+            return std::unique_ptr<TypeReference>(new TypeReference(fContext, identifier.fPosition, 
+                                                                    *t));
         }
         default:
             ABORT("unsupported symbol type %d\n", result->fKind);
@@ -584,43 +588,45 @@
 }
 
 std::unique_ptr<Expression> IRGenerator::coerce(std::unique_ptr<Expression> expr, 
-                                                std::shared_ptr<Type> type) {
+                                                const Type& type) {
     if (!expr) {
         return nullptr;
     }
-    if (*expr->fType == *type) {
+    if (expr->fType == type) {
         return expr;
     }
     this->checkValid(*expr);
-    if (*expr->fType == *kInvalid_Type) {
+    if (expr->fType == *fContext.fInvalid_Type) {
         return nullptr;
     }
-    if (!expr->fType->canCoerceTo(type)) {
-        fErrors.error(expr->fPosition, "expected '" + type->description() + "', but found '" + 
-                                        expr->fType->description() + "'");
+    if (!expr->fType.canCoerceTo(type)) {
+        fErrors.error(expr->fPosition, "expected '" + type.description() + "', but found '" + 
+                                        expr->fType.description() + "'");
         return nullptr;
     }
-    if (type->kind() == Type::kScalar_Kind) {
+    if (type.kind() == Type::kScalar_Kind) {
         std::vector<std::unique_ptr<Expression>> args;
         args.push_back(std::move(expr));
-        ASTIdentifier id(Position(), type->description());
+        ASTIdentifier id(Position(), type.description());
         std::unique_ptr<Expression> ctor = this->convertIdentifier(id);
         ASSERT(ctor);
         return this->call(Position(), std::move(ctor), std::move(args));
     }
-    ABORT("cannot coerce %s to %s", expr->fType->description().c_str(), 
-          type->description().c_str());
+    ABORT("cannot coerce %s to %s", expr->fType.description().c_str(), 
+          type.description().c_str());
 }
 
 /**
  * Determines the operand and result types of a binary expression. Returns true if the expression is
  * legal, false otherwise. If false, the values of the out parameters are undefined.
  */
-static bool determine_binary_type(Token::Kind op, std::shared_ptr<Type> left, 
-                                  std::shared_ptr<Type> right, 
-                                  std::shared_ptr<Type>* outLeftType,
-                                  std::shared_ptr<Type>* outRightType,
-                                  std::shared_ptr<Type>* outResultType,
+static bool determine_binary_type(const Context& context, 
+                                  Token::Kind op, 
+                                  const Type& left, 
+                                  const Type& right, 
+                                  const Type** outLeftType,
+                                  const Type** outRightType,
+                                  const Type** outResultType,
                                   bool tryFlipped) {
     bool isLogical;
     switch (op) {
@@ -638,24 +644,25 @@
         case Token::LOGICALOREQ: // fall through
         case Token::LOGICALANDEQ: // fall through
         case Token::LOGICALXOREQ:
-            *outLeftType = kBool_Type;
-            *outRightType = kBool_Type;
-            *outResultType = kBool_Type;
-            return left->canCoerceTo(kBool_Type) && right->canCoerceTo(kBool_Type);
+            *outLeftType = context.fBool_Type.get();
+            *outRightType = context.fBool_Type.get();
+            *outResultType = context.fBool_Type.get();
+            return left.canCoerceTo(*context.fBool_Type) && 
+                   right.canCoerceTo(*context.fBool_Type);
         case Token::STAR: // fall through
         case Token::STAREQ: 
             // FIXME need to handle non-square matrices
-            if (left->kind() == Type::kMatrix_Kind && right->kind() == Type::kVector_Kind) {
-                *outLeftType = left;
-                *outRightType = right;
-                *outResultType = right;
-                return left->rows() == right->columns();
+            if (left.kind() == Type::kMatrix_Kind && right.kind() == Type::kVector_Kind) {
+                *outLeftType = &left;
+                *outRightType = &right;
+                *outResultType = &right;
+                return left.rows() == right.columns();
             }  
-            if (left->kind() == Type::kVector_Kind && right->kind() == Type::kMatrix_Kind) {
-                *outLeftType = left;
-                *outRightType = right;
-                *outResultType = left;
-                return left->columns() == right->columns();
+            if (left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind) {
+                *outLeftType = &left;
+                *outRightType = &right;
+                *outResultType = &left;
+                return left.columns() == right.columns();
             }
             // fall through
         default:
@@ -664,41 +671,42 @@
     // FIXME: need to disallow illegal operations like vec3 > vec3. Also do not currently have
     // full support for numbers other than float.
     if (left == right) {
-        *outLeftType = left;
-        *outRightType = left;
+        *outLeftType = &left;
+        *outRightType = &left;
         if (isLogical) {
-            *outResultType = kBool_Type;
+            *outResultType = context.fBool_Type.get();
         } else {
-            *outResultType = left;
+            *outResultType = &left;
         }
         return true;
     }
     // FIXME: incorrect for shift operations
-    if (left->canCoerceTo(right)) {
-        *outLeftType = right;
-        *outRightType = right;
+    if (left.canCoerceTo(right)) {
+        *outLeftType = &right;
+        *outRightType = &right;
         if (isLogical) {
-            *outResultType = kBool_Type;
+            *outResultType = context.fBool_Type.get();
         } else {
-            *outResultType = right;
+            *outResultType = &right;
         }
         return true;
     }
-    if ((left->kind() == Type::kVector_Kind || left->kind() == Type::kMatrix_Kind) && 
-        (right->kind() == Type::kScalar_Kind)) {
-        if (determine_binary_type(op, left->componentType(), right, outLeftType, outRightType,
-                                  outResultType, false)) {
-            *outLeftType = (*outLeftType)->toCompound(left->columns(), left->rows());
+    if ((left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind) && 
+        (right.kind() == Type::kScalar_Kind)) {
+        if (determine_binary_type(context, op, left.componentType(), right, outLeftType, 
+                                  outRightType, outResultType, false)) {
+            *outLeftType = &(*outLeftType)->toCompound(context, left.columns(), left.rows());
             if (!isLogical) {
-                *outResultType = (*outResultType)->toCompound(left->columns(), left->rows());
+                *outResultType = &(*outResultType)->toCompound(context, left.columns(), 
+                                                               left.rows());
             }
             return true;
         }
         return false;
     }
     if (tryFlipped) {
-        return determine_binary_type(op, right, left, outRightType, outLeftType, outResultType, 
-                                     false);
+        return determine_binary_type(context, op, right, left, outRightType, outLeftType, 
+                                     outResultType, false);
     }
     return false;
 }
@@ -713,15 +721,15 @@
     if (!right) {
         return nullptr;
     }
-    std::shared_ptr<Type> leftType;
-    std::shared_ptr<Type> rightType;
-    std::shared_ptr<Type> resultType;
-    if (!determine_binary_type(expression.fOperator, left->fType, right->fType, &leftType,
+    const Type* leftType;
+    const Type* rightType;
+    const Type* resultType;
+    if (!determine_binary_type(fContext, expression.fOperator, left->fType, right->fType, &leftType,
                                &rightType, &resultType, true)) {
         fErrors.error(expression.fPosition, "type mismatch: '" + 
                                             Token::OperatorName(expression.fOperator) + 
-                                            "' cannot operate on '" + left->fType->fName + 
-                                            "', '" + right->fType->fName + "'");
+                                            "' cannot operate on '" + left->fType.fName + 
+                                            "', '" + right->fType.fName + "'");
         return nullptr;
     }
     switch (expression.fOperator) {
@@ -744,17 +752,18 @@
             break;
     }
     return std::unique_ptr<Expression>(new BinaryExpression(expression.fPosition, 
-                                                            this->coerce(std::move(left), leftType), 
+                                                            this->coerce(std::move(left), 
+                                                                         *leftType), 
                                                             expression.fOperator, 
                                                             this->coerce(std::move(right), 
-                                                                         rightType), 
-                                                            resultType));
+                                                                         *rightType), 
+                                                            *resultType));
 }
 
 std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(  
                                                            const ASTTernaryExpression& expression) {
     std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*expression.fTest), 
-                                                    kBool_Type);
+                                                    *fContext.fBool_Type);
     if (!test) {
         return nullptr;
     }
@@ -766,34 +775,33 @@
     if (!ifFalse) {
         return nullptr;
     }
-    std::shared_ptr<Type> trueType;
-    std::shared_ptr<Type> falseType;
-    std::shared_ptr<Type> resultType;
-    if (!determine_binary_type(Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType,
+    const Type* trueType;
+    const Type* falseType;
+    const Type* resultType;
+    if (!determine_binary_type(fContext, Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType,
                                &falseType, &resultType, true)) {
         fErrors.error(expression.fPosition, "ternary operator result mismatch: '" + 
-                                            ifTrue->fType->fName + "', '" + 
-                                            ifFalse->fType->fName + "'");
+                                            ifTrue->fType.fName + "', '" + 
+                                            ifFalse->fType.fName + "'");
         return nullptr;
     }
     ASSERT(trueType == falseType);
-    ifTrue = this->coerce(std::move(ifTrue), trueType);
-    ifFalse = this->coerce(std::move(ifFalse), falseType);
+    ifTrue = this->coerce(std::move(ifTrue), *trueType);
+    ifFalse = this->coerce(std::move(ifFalse), *falseType);
     return std::unique_ptr<Expression>(new TernaryExpression(expression.fPosition, 
                                                              std::move(test),
                                                              std::move(ifTrue), 
                                                              std::move(ifFalse)));
 }
 
-std::unique_ptr<Expression> IRGenerator::call(
-                                         Position position, 
-                                         std::shared_ptr<FunctionDeclaration> function, 
-                                         std::vector<std::unique_ptr<Expression>> arguments) {
-    if (function->fParameters.size() != arguments.size()) {
-        std::string msg = "call to '" + function->fName + "' expected " + 
-                                 to_string(function->fParameters.size()) + 
+std::unique_ptr<Expression> IRGenerator::call(Position position, 
+                                              const FunctionDeclaration& function, 
+                                              std::vector<std::unique_ptr<Expression>> arguments) {
+    if (function.fParameters.size() != arguments.size()) {
+        std::string msg = "call to '" + function.fName + "' expected " + 
+                                 to_string(function.fParameters.size()) + 
                                  " argument";
-        if (function->fParameters.size() != 1) {
+        if (function.fParameters.size() != 1) {
             msg += "s";
         }
         msg += ", but found " + to_string(arguments.size());
@@ -801,12 +809,12 @@
         return nullptr;
     }
     for (size_t i = 0; i < arguments.size(); i++) {
-        arguments[i] = this->coerce(std::move(arguments[i]), function->fParameters[i]->fType);
-        if (arguments[i] && (function->fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) {
+        arguments[i] = this->coerce(std::move(arguments[i]), function.fParameters[i]->fType);
+        if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) {
             this->markWrittenTo(*arguments[i]);
         }
     }
-    return std::unique_ptr<FunctionCall>(new FunctionCall(position, std::move(function),
+    return std::unique_ptr<FunctionCall>(new FunctionCall(position, function,
                                                           std::move(arguments)));
 }
 
@@ -815,16 +823,16 @@
  * if the cost could be computed, false if the call is not valid. Cost has no particular meaning 
  * other than "lower costs are preferred".
  */
-bool IRGenerator::determineCallCost(std::shared_ptr<FunctionDeclaration> function, 
+bool IRGenerator::determineCallCost(const FunctionDeclaration& function, 
                                     const std::vector<std::unique_ptr<Expression>>& arguments,
                                     int* outCost) {
-    if (function->fParameters.size() != arguments.size()) {
+    if (function.fParameters.size() != arguments.size()) {
         return false;
     }
     int total = 0;
     for (size_t i = 0; i < arguments.size(); i++) {
         int cost;
-        if (arguments[i]->fType->determineCoercionCost(function->fParameters[i]->fType, &cost)) {
+        if (arguments[i]->fType.determineCoercionCost(function.fParameters[i]->fType, &cost)) {
             total += cost;
         } else {
             return false;
@@ -848,97 +856,97 @@
     }
     FunctionReference* ref = (FunctionReference*) functionValue.get();
     int bestCost = INT_MAX;
-    std::shared_ptr<FunctionDeclaration> best;
+    const FunctionDeclaration* best = nullptr;
     if (ref->fFunctions.size() > 1) {
         for (const auto& f : ref->fFunctions) {
             int cost;
-            if (this->determineCallCost(f, arguments, &cost) && cost < bestCost) {
+            if (this->determineCallCost(*f, arguments, &cost) && cost < bestCost) {
                 bestCost = cost;
                 best = f;
             }
         }
         if (best) {
-            return this->call(position, std::move(best), std::move(arguments));
+            return this->call(position, *best, std::move(arguments));
         }
         std::string msg = "no match for " + ref->fFunctions[0]->fName + "(";
         std::string separator = "";
         for (size_t i = 0; i < arguments.size(); i++) {
             msg += separator;
             separator = ", ";
-            msg += arguments[i]->fType->description();
+            msg += arguments[i]->fType.description();
         }
         msg += ")";
         fErrors.error(position, msg);
         return nullptr;
     }
-    return this->call(position, ref->fFunctions[0], std::move(arguments));
+    return this->call(position, *ref->fFunctions[0], std::move(arguments));
 }
 
 std::unique_ptr<Expression> IRGenerator::convertConstructor(
                                                     Position position, 
-                                                    std::shared_ptr<Type> type, 
+                                                    const Type& type, 
                                                     std::vector<std::unique_ptr<Expression>> args) {
     // FIXME: add support for structs and arrays
-    Type::Kind kind = type->kind();
-    if (!type->isNumber() && kind != Type::kVector_Kind && kind != Type::kMatrix_Kind) {
-        fErrors.error(position, "cannot construct '" + type->description() + "'");
+    Type::Kind kind = type.kind();
+    if (!type.isNumber() && kind != Type::kVector_Kind && kind != Type::kMatrix_Kind) {
+        fErrors.error(position, "cannot construct '" + type.description() + "'");
         return nullptr;
     }
-    if (type == kFloat_Type && args.size() == 1 && 
+    if (type == *fContext.fFloat_Type && args.size() == 1 && 
         args[0]->fKind == Expression::kIntLiteral_Kind) {
         int64_t value = ((IntLiteral&) *args[0]).fValue;
-        return std::unique_ptr<Expression>(new FloatLiteral(position, (double) value));
+        return std::unique_ptr<Expression>(new FloatLiteral(fContext, position, (double) value));
     }
     if (args.size() == 1 && args[0]->fType == type) {
         // argument is already the right type, just return it
         return std::move(args[0]);
     }
-    if (type->isNumber()) {
+    if (type.isNumber()) {
         if (args.size() != 1) {
-            fErrors.error(position, "invalid arguments to '" + type->description() + 
+            fErrors.error(position, "invalid arguments to '" + type.description() + 
                                     "' constructor, (expected exactly 1 argument, but found " +
                                     to_string(args.size()) + ")");
         }
-        if (args[0]->fType == kBool_Type) {
-            std::unique_ptr<IntLiteral> zero(new IntLiteral(position, 0));
-            std::unique_ptr<IntLiteral> one(new IntLiteral(position, 1));
+        if (args[0]->fType == *fContext.fBool_Type) {
+            std::unique_ptr<IntLiteral> zero(new IntLiteral(fContext, position, 0));
+            std::unique_ptr<IntLiteral> one(new IntLiteral(fContext, position, 1));
             return std::unique_ptr<Expression>(
                                          new TernaryExpression(position, std::move(args[0]),
                                                                this->coerce(std::move(one), type),
                                                                this->coerce(std::move(zero), 
                                                                             type)));
-        } else if (!args[0]->fType->isNumber()) {
-            fErrors.error(position, "invalid argument to '" + type->description() + 
+        } else if (!args[0]->fType.isNumber()) {
+            fErrors.error(position, "invalid argument to '" + type.description() + 
                                     "' constructor (expected a number or bool, but found '" +
-                                    args[0]->fType->description() + "')");
+                                    args[0]->fType.description() + "')");
         }
     } else {
         ASSERT(kind == Type::kVector_Kind || kind == Type::kMatrix_Kind);
         int actual = 0;
         for (size_t i = 0; i < args.size(); i++) {
-            if (args[i]->fType->kind() == Type::kVector_Kind || 
-                args[i]->fType->kind() == Type::kMatrix_Kind) {
-                int columns = args[i]->fType->columns();
-                int rows = args[i]->fType->rows();
+            if (args[i]->fType.kind() == Type::kVector_Kind || 
+                args[i]->fType.kind() == Type::kMatrix_Kind) {
+                int columns = args[i]->fType.columns();
+                int rows = args[i]->fType.rows();
                 args[i] = this->coerce(std::move(args[i]), 
-                                       type->componentType()->toCompound(columns, rows));
-                actual += args[i]->fType->rows() * args[i]->fType->columns();
-            } else if (args[i]->fType->kind() == Type::kScalar_Kind) {
+                                       type.componentType().toCompound(fContext, columns, rows));
+                actual += args[i]->fType.rows() * args[i]->fType.columns();
+            } else if (args[i]->fType.kind() == Type::kScalar_Kind) {
                 actual += 1;
-                if (type->kind() != Type::kScalar_Kind) {
-                    args[i] = this->coerce(std::move(args[i]), type->componentType());
+                if (type.kind() != Type::kScalar_Kind) {
+                    args[i] = this->coerce(std::move(args[i]), type.componentType());
                 }
             } else {
-                fErrors.error(position, "'" + args[i]->fType->description() + "' is not a valid "
-                                        "parameter to '" + type->description() + "' constructor");
+                fErrors.error(position, "'" + args[i]->fType.description() + "' is not a valid "
+                                        "parameter to '" + type.description() + "' constructor");
                 return nullptr;
             }
         }
-        int min = type->rows() * type->columns();
-        int max = type->columns() > 1 ? INT_MAX : min;
+        int min = type.rows() * type.columns();
+        int max = type.columns() > 1 ? INT_MAX : min;
         if ((actual < min || actual > max) &&
             !((kind == Type::kVector_Kind || kind == Type::kMatrix_Kind) && (actual == 1))) {
-            fErrors.error(position, "invalid arguments to '" + type->description() + 
+            fErrors.error(position, "invalid arguments to '" + type.description() + 
                                     "' constructor (expected " + to_string(min) + " scalar" + 
                                     (min == 1 ? "" : "s") + ", but found " + to_string(actual) + 
                                     ")");
@@ -956,50 +964,51 @@
     }
     switch (expression.fOperator) {
         case Token::PLUS:
-            if (!base->fType->isNumber() && base->fType->kind() != Type::kVector_Kind) {
+            if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) {
                 fErrors.error(expression.fPosition, 
-                              "'+' cannot operate on '" + base->fType->description() + "'");
+                              "'+' cannot operate on '" + base->fType.description() + "'");
                 return nullptr;
             }
             return base;
         case Token::MINUS:
-            if (!base->fType->isNumber() && base->fType->kind() != Type::kVector_Kind) {
+            if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) {
                 fErrors.error(expression.fPosition, 
-                              "'-' cannot operate on '" + base->fType->description() + "'");
+                              "'-' cannot operate on '" + base->fType.description() + "'");
                 return nullptr;
             }
             if (base->fKind == Expression::kIntLiteral_Kind) {
-                return std::unique_ptr<Expression>(new IntLiteral(base->fPosition,
+                return std::unique_ptr<Expression>(new IntLiteral(fContext, base->fPosition,
                                                                   -((IntLiteral&) *base).fValue));
             }
             if (base->fKind == Expression::kFloatLiteral_Kind) {
                 double value = -((FloatLiteral&) *base).fValue;
-                return std::unique_ptr<Expression>(new FloatLiteral(base->fPosition, value));
+                return std::unique_ptr<Expression>(new FloatLiteral(fContext, base->fPosition, 
+                                                                    value));
             }
             return std::unique_ptr<Expression>(new PrefixExpression(Token::MINUS, std::move(base)));
         case Token::PLUSPLUS:
-            if (!base->fType->isNumber()) {
+            if (!base->fType.isNumber()) {
                 fErrors.error(expression.fPosition, 
                               "'" + Token::OperatorName(expression.fOperator) + 
-                              "' cannot operate on '" + base->fType->description() + "'");
+                              "' cannot operate on '" + base->fType.description() + "'");
                 return nullptr;
             }
             this->markWrittenTo(*base);
             break;
         case Token::MINUSMINUS: 
-            if (!base->fType->isNumber()) {
+            if (!base->fType.isNumber()) {
                 fErrors.error(expression.fPosition, 
                               "'" + Token::OperatorName(expression.fOperator) + 
-                              "' cannot operate on '" + base->fType->description() + "'");
+                              "' cannot operate on '" + base->fType.description() + "'");
                 return nullptr;
             }
             this->markWrittenTo(*base);
             break;
         case Token::NOT:
-            if (base->fType != kBool_Type) {
+            if (base->fType != *fContext.fBool_Type) {
                 fErrors.error(expression.fPosition, 
                               "'" + Token::OperatorName(expression.fOperator) + 
-                              "' cannot operate on '" + base->fType->description() + "'");
+                              "' cannot operate on '" + base->fType.description() + "'");
                 return nullptr;
             }
             break;
@@ -1012,8 +1021,8 @@
 
 std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression> base,
                                                       const ASTExpression& index) {
-    if (base->fType->kind() != Type::kArray_Kind && base->fType->kind() != Type::kMatrix_Kind) {
-        fErrors.error(base->fPosition, "expected array, but found '" + base->fType->description() + 
+    if (base->fType.kind() != Type::kArray_Kind && base->fType.kind() != Type::kMatrix_Kind) {
+        fErrors.error(base->fPosition, "expected array, but found '" + base->fType.description() + 
                                        "'");
         return nullptr;
     }
@@ -1021,30 +1030,31 @@
     if (!converted) {
         return nullptr;
     }
-    converted = this->coerce(std::move(converted), kInt_Type);
+    converted = this->coerce(std::move(converted), *fContext.fInt_Type);
     if (!converted) {
         return nullptr;
     }
-    return std::unique_ptr<Expression>(new IndexExpression(std::move(base), std::move(converted)));
+    return std::unique_ptr<Expression>(new IndexExpression(fContext, std::move(base), 
+                                                           std::move(converted)));
 }
 
 std::unique_ptr<Expression> IRGenerator::convertField(std::unique_ptr<Expression> base,
                                                       const std::string& field) {
-    auto fields = base->fType->fields();
+    auto fields = base->fType.fields();
     for (size_t i = 0; i < fields.size(); i++) {
         if (fields[i].fName == field) {
             return std::unique_ptr<Expression>(new FieldAccess(std::move(base), (int) i));
         }
     }
-    fErrors.error(base->fPosition, "type '" + base->fType->description() + "' does not have a "
+    fErrors.error(base->fPosition, "type '" + base->fType.description() + "' does not have a "
                                    "field named '" + field + "");
     return nullptr;
 }
 
 std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expression> base,
                                                         const std::string& fields) {
-    if (base->fType->kind() != Type::kVector_Kind) {
-        fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType->description() + "'");
+    if (base->fType.kind() != Type::kVector_Kind) {
+        fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType.description() + "'");
         return nullptr;
     }
     std::vector<int> swizzleComponents;
@@ -1058,7 +1068,7 @@
             case 'y': // fall through
             case 'g': // fall through
             case 't':
-                if (base->fType->columns() >= 2) {
+                if (base->fType.columns() >= 2) {
                     swizzleComponents.push_back(1);
                     break;
                 }
@@ -1066,7 +1076,7 @@
             case 'z': // fall through
             case 'b': // fall through
             case 'p': 
-                if (base->fType->columns() >= 3) {
+                if (base->fType.columns() >= 3) {
                     swizzleComponents.push_back(2);
                     break;
                 }
@@ -1074,7 +1084,7 @@
             case 'w': // fall through
             case 'a': // fall through
             case 'q':
-                if (base->fType->columns() >= 4) {
+                if (base->fType.columns() >= 4) {
                     swizzleComponents.push_back(3);
                     break;
                 }
@@ -1090,7 +1100,7 @@
         fErrors.error(base->fPosition, "too many components in swizzle mask '" + fields + "'");
         return nullptr;
     }
-    return std::unique_ptr<Expression>(new Swizzle(std::move(base), swizzleComponents));
+    return std::unique_ptr<Expression>(new Swizzle(fContext, std::move(base), swizzleComponents));
 }
 
 std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
@@ -1117,7 +1127,7 @@
             return this->call(expression.fPosition, std::move(base), std::move(arguments));
         }
         case ASTSuffix::kField_Kind: {
-            switch (base->fType->kind()) {
+            switch (base->fType.kind()) {
                 case Type::kVector_Kind:
                     return this->convertSwizzle(std::move(base), 
                                                 ((ASTFieldSuffix&) *expression.fSuffix).fField);
@@ -1126,23 +1136,23 @@
                                               ((ASTFieldSuffix&) *expression.fSuffix).fField);
                 default:
                     fErrors.error(base->fPosition, "cannot swizzle value of type '" + 
-                                                   base->fType->description() + "'");
+                                                   base->fType.description() + "'");
                     return nullptr;
             }
         }
         case ASTSuffix::kPostIncrement_Kind:
-            if (!base->fType->isNumber()) {
+            if (!base->fType.isNumber()) {
                 fErrors.error(expression.fPosition, 
-                              "'++' cannot operate on '" + base->fType->description() + "'");
+                              "'++' cannot operate on '" + base->fType.description() + "'");
                 return nullptr;
             }
             this->markWrittenTo(*base);
             return std::unique_ptr<Expression>(new PostfixExpression(std::move(base), 
                                                                      Token::PLUSPLUS));
         case ASTSuffix::kPostDecrement_Kind:
-            if (!base->fType->isNumber()) {
+            if (!base->fType.isNumber()) {
                 fErrors.error(expression.fPosition, 
-                              "'--' cannot operate on '" + base->fType->description() + "'");
+                              "'--' cannot operate on '" + base->fType.description() + "'");
                 return nullptr;
             }
             this->markWrittenTo(*base);
@@ -1162,13 +1172,13 @@
             fErrors.error(expr.fPosition, "expected '(' to begin constructor invocation");
             break;
         default:
-            ASSERT(expr.fType != kInvalid_Type);
+            ASSERT(expr.fType != *fContext.fInvalid_Type);
             break;
     }
 }
 
-void IRGenerator::markReadFrom(std::shared_ptr<Variable> var) {
-    var->fIsReadFrom = true;
+void IRGenerator::markReadFrom(const Variable& var) {
+    var.fIsReadFrom = true;
 }
 
 static bool has_duplicates(const Swizzle& swizzle) {
@@ -1187,7 +1197,7 @@
 void IRGenerator::markWrittenTo(const Expression& expr) {
     switch (expr.fKind) {
         case Expression::kVariableReference_Kind: {
-            const Variable& var = *((VariableReference&) expr).fVariable;
+            const Variable& var = ((VariableReference&) expr).fVariable;
             if (var.fModifiers.fFlags & (Modifiers::kConst_Flag | Modifiers::kUniform_Flag)) {
                 fErrors.error(expr.fPosition, 
                               "cannot modify immutable variable '" + var.fName + "'");
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index d23e5a1..2384b2d 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -53,7 +53,8 @@
  */
 class IRGenerator {
 public:
-    IRGenerator(std::shared_ptr<SymbolTable> root, ErrorReporter& errorReporter);
+    IRGenerator(const Context* context, std::shared_ptr<SymbolTable> root, 
+                ErrorReporter& errorReporter);
 
     std::unique_ptr<VarDeclaration> convertVarDeclaration(const ASTVarDeclaration& decl, 
                                                           Variable::Storage storage);
@@ -65,21 +66,20 @@
     void pushSymbolTable();
     void popSymbolTable();
 
-    std::shared_ptr<Type> convertType(const ASTType& type);
+    const Type* convertType(const ASTType& type);
     std::unique_ptr<Expression> call(Position position, 
-                                     std::shared_ptr<FunctionDeclaration> function, 
+                                     const FunctionDeclaration& function, 
                                      std::vector<std::unique_ptr<Expression>> arguments);
-    bool determineCallCost(std::shared_ptr<FunctionDeclaration> function, 
+    bool determineCallCost(const FunctionDeclaration& function, 
                            const std::vector<std::unique_ptr<Expression>>& arguments,
                            int* outCost);
     std::unique_ptr<Expression> call(Position position, std::unique_ptr<Expression> function, 
                                      std::vector<std::unique_ptr<Expression>> arguments);
-    std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, 
-                                       std::shared_ptr<Type> type);
+    std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type);
     std::unique_ptr<Block> convertBlock(const ASTBlock& block);
     std::unique_ptr<Statement> convertBreak(const ASTBreakStatement& b);
     std::unique_ptr<Expression> convertConstructor(Position position, 
-                                                   std::shared_ptr<Type> type, 
+                                                   const Type& type, 
                                                    std::vector<std::unique_ptr<Expression>> params);
     std::unique_ptr<Statement> convertContinue(const ASTContinueStatement& c);
     std::unique_ptr<Statement> convertDiscard(const ASTDiscardStatement& d);
@@ -106,10 +106,11 @@
     std::unique_ptr<Statement> convertWhile(const ASTWhileStatement& w);
 
     void checkValid(const Expression& expr);
-    void markReadFrom(std::shared_ptr<Variable> var);
+    void markReadFrom(const Variable& var);
     void markWrittenTo(const Expression& expr);
 
-    std::shared_ptr<FunctionDeclaration> fCurrentFunction;
+    const Context& fContext;
+    const FunctionDeclaration* fCurrentFunction;
     std::shared_ptr<SymbolTable> fSymbolTable;
     ErrorReporter& fErrors;
 
diff --git a/src/sksl/SkSLParser.cpp b/src/sksl/SkSLParser.cpp
index fa302af..edff0c6 100644
--- a/src/sksl/SkSLParser.cpp
+++ b/src/sksl/SkSLParser.cpp
@@ -52,6 +52,7 @@
 #include "ast/SkSLASTVarDeclarationStatement.h"
 #include "ast/SkSLASTWhileStatement.h"
 #include "ir/SkSLSymbolTable.h"
+#include "ir/SkSLType.h"
 
 namespace SkSL {
 
@@ -290,17 +291,17 @@
             return nullptr;
         }
         for (size_t i = 0; i < decl->fNames.size(); i++) {
-            auto type = std::static_pointer_cast<Type>(fTypes[decl->fType->fName]);
+            auto type = (const Type*) fTypes[decl->fType->fName];
             for (int j = (int) decl->fSizes[i].size() - 1; j >= 0; j--) {
-                if (decl->fSizes[i][j]->fKind == ASTExpression::kInt_Kind) {
+                if (decl->fSizes[i][j]->fKind != ASTExpression::kInt_Kind) {
                     this->error(decl->fPosition, "array size in struct field must be a constant");
                 }
                 uint64_t columns = ((ASTIntLiteral&) *decl->fSizes[i][j]).fValue;
                 std::string name = type->name() + "[" + to_string(columns) + "]";
-                type = std::shared_ptr<Type>(new Type(name, Type::kArray_Kind, std::move(type), 
-                                                      (int) columns));
+                type = new Type(name, Type::kArray_Kind, *type, (int) columns);
+                fTypes.takeOwnership((Type*) type);
             }
-            fields.push_back(Type::Field(decl->fModifiers, decl->fNames[i], std::move(type)));
+            fields.push_back(Type::Field(decl->fModifiers, decl->fNames[i], *type));
             if (decl->fValues[i]) {
                 this->error(decl->fPosition, "initializers are not permitted on struct fields");
             }
@@ -309,9 +310,8 @@
     if (!this->expect(Token::RBRACE, "'}'")) {
         return nullptr;
     }
-    std::shared_ptr<Type> type(new Type(name.fText, fields));
-    fTypes.add(type->fName, type);
-    return std::unique_ptr<ASTType>(new ASTType(name.fPosition, type->fName, 
+    fTypes.add(name.fText, std::unique_ptr<Type>(new Type(name.fText, fields)));
+    return std::unique_ptr<ASTType>(new ASTType(name.fPosition, name.fText, 
                                                 ASTType::kStruct_Kind));
 }
 
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 0a2dab3..2771e02 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -141,36 +141,36 @@
 #endif
 }
 
-static bool is_float(const Type& type) {
+static bool is_float(const Context& context, const Type& type) {
     if (type.kind() == Type::kVector_Kind) {
-        return is_float(*type.componentType());
+        return is_float(context, type.componentType());
     }
-    return type == *kFloat_Type || type == *kDouble_Type;
+    return type == *context.fFloat_Type || type == *context.fDouble_Type;
 }
 
-static bool is_signed(const Type& type) {
+static bool is_signed(const Context& context, const Type& type) {
     if (type.kind() == Type::kVector_Kind) {
-        return is_signed(*type.componentType());
+        return is_signed(context, type.componentType());
     }
-    return type == *kInt_Type;
+    return type == *context.fInt_Type;
 }
 
-static bool is_unsigned(const Type& type) {
+static bool is_unsigned(const Context& context, const Type& type) {
     if (type.kind() == Type::kVector_Kind) {
-        return is_unsigned(*type.componentType());
+        return is_unsigned(context, type.componentType());
     }
-    return type == *kUInt_Type;
+    return type == *context.fUInt_Type;
 }
 
-static bool is_bool(const Type& type) {
+static bool is_bool(const Context& context, const Type& type) {
     if (type.kind() == Type::kVector_Kind) {
-        return is_bool(*type.componentType());
+        return is_bool(context, type.componentType());
     }
-    return type == *kBool_Type;
+    return type == *context.fBool_Type;
 }
 
-static bool is_out(std::shared_ptr<Variable> var) {
-    return (var->fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
+static bool is_out(const Variable& var) {
+    return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
 }
 
 #if SPIRV_DEBUG
@@ -973,7 +973,7 @@
     // in the middle of writing the struct instruction
     std::vector<SpvId> types;
     for (const auto& f : type.fields()) {
-        types.push_back(this->getType(*f.fType));
+        types.push_back(this->getType(f.fType));
     }
     this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
     this->writeWord(resultId, fConstantBuffer);
@@ -982,8 +982,8 @@
     }
     size_t offset = 0;
     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
-        size_t size = type.fields()[i].fType->size();
-        size_t alignment = type.fields()[i].fType->alignment();
+        size_t size = type.fields()[i].fType.size();
+        size_t alignment = type.fields()[i].fType.alignment();
         size_t mod = offset % alignment;
         if (mod != 0) {
             offset += alignment - mod;
@@ -995,14 +995,14 @@
             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset, 
                                    (SpvId) offset, fDecorationBuffer);
         }
-        if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
+        if (type.fields()[i].fType.kind() == Type::kMatrix_Kind) {
             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor, 
                                    fDecorationBuffer);
             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride, 
-                                   (SpvId) type.fields()[i].fType->stride(), fDecorationBuffer);
+                                   (SpvId) type.fields()[i].fType.stride(), fDecorationBuffer);
         }
         offset += size;
-        Type::Kind kind = type.fields()[i].fType->kind();
+        Type::Kind kind = type.fields()[i].fType.kind();
         if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
             offset += alignment - offset % alignment;
         }
@@ -1016,15 +1016,15 @@
         SpvId result = this->nextId();
         switch (type.kind()) {
             case Type::kScalar_Kind:
-                if (type == *kBool_Type) {
+                if (type == *fContext.fBool_Type) {
                     this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
-                } else if (type == *kInt_Type) {
+                } else if (type == *fContext.fInt_Type) {
                     this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
-                } else if (type == *kUInt_Type) {
+                } else if (type == *fContext.fUInt_Type) {
                     this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
-                } else if (type == *kFloat_Type) {
+                } else if (type == *fContext.fFloat_Type) {
                     this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
-                } else if (type == *kDouble_Type) {
+                } else if (type == *fContext.fDouble_Type) {
                     this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
                 } else {
                     ASSERT(false);
@@ -1032,11 +1032,12 @@
                 break;
             case Type::kVector_Kind:
                 this->writeInstruction(SpvOpTypeVector, result, 
-                                       this->getType(*type.componentType()),
+                                       this->getType(type.componentType()),
                                        type.columns(), fConstantBuffer);
                 break;
             case Type::kMatrix_Kind:
-                this->writeInstruction(SpvOpTypeMatrix, result, this->getType(*index_type(type)), 
+                this->writeInstruction(SpvOpTypeMatrix, result, 
+                                       this->getType(index_type(fContext, type)), 
                                        type.columns(), fConstantBuffer);
                 break;
             case Type::kStruct_Kind:
@@ -1044,22 +1045,22 @@
                 break;
             case Type::kArray_Kind: {
                 if (type.columns() > 0) {
-                    IntLiteral count(Position(), type.columns());
+                    IntLiteral count(fContext, Position(), type.columns());
                     this->writeInstruction(SpvOpTypeArray, result, 
-                                           this->getType(*type.componentType()), 
+                                           this->getType(type.componentType()), 
                                            this->writeIntLiteral(count), fConstantBuffer);
                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, 
                                            (int32_t) type.stride(), fDecorationBuffer);
                 } else {
                     ABORT("runtime-sized arrays are not yet supported");
                     this->writeInstruction(SpvOpTypeRuntimeArray, result, 
-                                           this->getType(*type.componentType()), fConstantBuffer);
+                                           this->getType(type.componentType()), fConstantBuffer);
                 }
                 break;
             }
             case Type::kSampler_Kind: {
                 SpvId image = this->nextId();
-                this->writeInstruction(SpvOpTypeImage, image, this->getType(*kFloat_Type), 
+                this->writeInstruction(SpvOpTypeImage, image, this->getType(*fContext.fFloat_Type), 
                                        type.dimensions(), type.isDepth(), type.isArrayed(),
                                        type.isMultisampled(), type.isSampled(), 
                                        SpvImageFormatUnknown, fConstantBuffer);
@@ -1067,7 +1068,7 @@
                 break;
             }
             default:
-                if (type == *kVoid_Type) {
+                if (type == *fContext.fVoid_Type) {
                     this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
                 } else {
                     ABORT("invalid type: %s", type.description().c_str());
@@ -1079,22 +1080,22 @@
     return entry->second;
 }
 
-SpvId SPIRVCodeGenerator::getFunctionType(std::shared_ptr<FunctionDeclaration> function) {
-    std::string key = function->fReturnType->description() + "(";
+SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
+    std::string key = function.fReturnType.description() + "(";
     std::string separator = "";
-    for (size_t i = 0; i < function->fParameters.size(); i++) {
+    for (size_t i = 0; i < function.fParameters.size(); i++) {
         key += separator;
         separator = ", ";
-        key += function->fParameters[i]->fType->description();
+        key += function.fParameters[i]->fType.description();
     }
     key += ")";
     auto entry = fTypeMap.find(key);
     if (entry == fTypeMap.end()) {
         SpvId result = this->nextId();
-        int32_t length = 3 + (int32_t) function->fParameters.size();
-        SpvId returnType = this->getType(*function->fReturnType);
+        int32_t length = 3 + (int32_t) function.fParameters.size();
+        SpvId returnType = this->getType(function.fReturnType);
         std::vector<SpvId> parameterTypes;
-        for (size_t i = 0; i < function->fParameters.size(); i++) {
+        for (size_t i = 0; i < function.fParameters.size(); i++) {
             // glslang seems to treat all function arguments as pointers whether they need to be or 
             // not. I  was initially puzzled by this until I ran bizarre failures with certain 
             // patterns of function calls and control constructs, as exemplified by this minimal 
@@ -1118,10 +1119,10 @@
             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
             // the spec makes this make sense.
 //            if (is_out(function->fParameters[i])) {
-                parameterTypes.push_back(this->getPointerType(function->fParameters[i]->fType,
+                parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
                                                               SpvStorageClassFunction));
 //            } else {
-//                parameterTypes.push_back(this->getType(*function->fParameters[i]->fType));
+//                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
 //            }
         }
         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
@@ -1136,14 +1137,14 @@
     return entry->second;
 }
 
-SpvId SPIRVCodeGenerator::getPointerType(std::shared_ptr<Type> type, 
+SpvId SPIRVCodeGenerator::getPointerType(const Type& type, 
                                          SpvStorageClass_ storageClass) {
-    std::string key = type->description() + "*" + to_string(storageClass);
+    std::string key = type.description() + "*" + to_string(storageClass);
     auto entry = fTypeMap.find(key);
     if (entry == fTypeMap.end()) {
         SpvId result = this->nextId();
         this->writeInstruction(SpvOpTypePointer, result, storageClass, 
-                               this->getType(*type), fConstantBuffer);
+                               this->getType(type), fConstantBuffer);
         fTypeMap[key] = result;
         return result;
     }
@@ -1185,21 +1186,21 @@
 }
 
 SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out) {
-    auto intrinsic = fIntrinsicMap.find(c.fFunction->fName);
+    auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
     ASSERT(intrinsic != fIntrinsicMap.end());
-    std::shared_ptr<Type> type = c.fArguments[0]->fType;
+    const Type& type = c.fArguments[0]->fType;
     int32_t intrinsicId;
-    if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(*type)) {
+    if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
         intrinsicId = std::get<1>(intrinsic->second);
-    } else if (is_signed(*type)) {
+    } else if (is_signed(fContext, type)) {
         intrinsicId = std::get<2>(intrinsic->second);
-    } else if (is_unsigned(*type)) {
+    } else if (is_unsigned(fContext, type)) {
         intrinsicId = std::get<3>(intrinsic->second);
-    } else if (is_bool(*type)) {
+    } else if (is_bool(fContext, type)) {
         intrinsicId = std::get<4>(intrinsic->second);
     } else {
         ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(),
-              type->description().c_str());
+              type.description().c_str());
     }
     switch (std::get<0>(intrinsic->second)) {
         case kGLSL_STD_450_IntrinsicKind: {
@@ -1209,7 +1210,7 @@
                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
             }
             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
-            this->writeWord(this->getType(*c.fType), out);
+            this->writeWord(this->getType(c.fType), out);
             this->writeWord(result, out);
             this->writeWord(fGLSLExtendedInstructions, out);
             this->writeWord(intrinsicId, out);
@@ -1225,7 +1226,7 @@
                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
             }
             this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
-            this->writeWord(this->getType(*c.fType), out);
+            this->writeWord(this->getType(c.fType), out);
             this->writeWord(result, out);
             for (SpvId id : arguments) {
                 this->writeWord(id, out);
@@ -1249,7 +1250,7 @@
                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
             }
             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
-            this->writeWord(this->getType(*c.fType), out);
+            this->writeWord(this->getType(c.fType), out);
             this->writeWord(result, out);
             this->writeWord(fGLSLExtendedInstructions, out);
             this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
@@ -1259,7 +1260,7 @@
             return result;            
         }
         case kTexture_SpecialIntrinsic: {
-            SpvId type = this->getType(*c.fType);
+            SpvId type = this->getType(c.fType);
             SpvId sampler = this->writeExpression(*c.fArguments[0], out);
             SpvId uv = this->writeExpression(*c.fArguments[1], out);
             if (c.fArguments.size() == 3) {
@@ -1274,7 +1275,7 @@
             break;
         }
         case kTextureProj_SpecialIntrinsic: {
-            SpvId type = this->getType(*c.fType);
+            SpvId type = this->getType(c.fType);
             SpvId sampler = this->writeExpression(*c.fArguments[0], out);
             SpvId uv = this->writeExpression(*c.fArguments[1], out);
             if (c.fArguments.size() == 3) {
@@ -1293,7 +1294,7 @@
             SpvId img = this->writeExpression(*c.fArguments[0], out);
             SpvId coords = this->writeExpression(*c.fArguments[1], out);
             this->writeInstruction(SpvOpImageSampleImplicitLod,
-                                   this->getType(*c.fType),
+                                   this->getType(c.fType),
                                    result, 
                                    img,
                                    coords,
@@ -1305,7 +1306,7 @@
 }
 
 SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) {
-    const auto& entry = fFunctionMap.find(c.fFunction);
+    const auto& entry = fFunctionMap.find(&c.fFunction);
     if (entry == fFunctionMap.end()) {
         return this->writeIntrinsicCall(c, out);
     }
@@ -1318,7 +1319,7 @@
         SpvId tmpVar;
         // if we need a temporary var to store this argument, this is the value to store in the var
         SpvId tmpValueId;
-        if (is_out(c.fFunction->fParameters[i])) {
+        if (is_out(*c.fFunction.fParameters[i])) {
             std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
             SpvId ptr = lv->getPointer();
             if (ptr) {
@@ -1330,7 +1331,7 @@
                 // update the lvalue.
                 tmpValueId = lv->load(out);
                 tmpVar = this->nextId();
-                lvalues.push_back(std::make_tuple(tmpVar, this->getType(*c.fArguments[i]->fType),
+                lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
                                   std::move(lv)));
             }
         } else {
@@ -1343,13 +1344,13 @@
                                                     SpvStorageClassFunction),
                                tmpVar, 
                                SpvStorageClassFunction,
-                               out);
+                               fVariableBuffer);
         this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
         arguments.push_back(tmpVar);
     }
     SpvId result = this->nextId();
     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
-    this->writeWord(this->getType(*c.fType), out);
+    this->writeWord(this->getType(c.fType), out);
     this->writeWord(result, out);
     this->writeWord(entry->second, out);
     for (SpvId id : arguments) {
@@ -1366,19 +1367,19 @@
 }
 
 SpvId SPIRVCodeGenerator::writeConstantVector(Constructor& c) {
-    ASSERT(c.fType->kind() == Type::kVector_Kind && c.isConstant());
+    ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
     SpvId result = this->nextId();
     std::vector<SpvId> arguments;
     for (size_t i = 0; i < c.fArguments.size(); i++) {
         arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
     }
-    SpvId type = this->getType(*c.fType);
+    SpvId type = this->getType(c.fType);
     if (c.fArguments.size() == 1) {
         // with a single argument, a vector will have all of its entries equal to the argument
-        this->writeOpCode(SpvOpConstantComposite, 3 + c.fType->columns(), fConstantBuffer);
+        this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
         this->writeWord(type, fConstantBuffer);
         this->writeWord(result, fConstantBuffer);
-        for (int i = 0; i < c.fType->columns(); i++) {
+        for (int i = 0; i < c.fType.columns(); i++) {
             this->writeWord(arguments[0], fConstantBuffer);
         }
     } else {
@@ -1394,43 +1395,43 @@
 }
 
 SpvId SPIRVCodeGenerator::writeFloatConstructor(Constructor& c, std::ostream& out) {
-    ASSERT(c.fType == kFloat_Type);
+    ASSERT(c.fType == *fContext.fFloat_Type);
     ASSERT(c.fArguments.size() == 1);
-    ASSERT(c.fArguments[0]->fType->isNumber());
+    ASSERT(c.fArguments[0]->fType.isNumber());
     SpvId result = this->nextId();
     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
-    if (c.fArguments[0]->fType == kInt_Type) {
-        this->writeInstruction(SpvOpConvertSToF, this->getType(*c.fType), result, parameter, 
+    if (c.fArguments[0]->fType == *fContext.fInt_Type) {
+        this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter, 
                                out);
-    } else if (c.fArguments[0]->fType == kUInt_Type) {
-        this->writeInstruction(SpvOpConvertUToF, this->getType(*c.fType), result, parameter, 
+    } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
+        this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter, 
                                out);
-    } else if (c.fArguments[0]->fType == kFloat_Type) {
+    } else if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
         return parameter;
     }
     return result;
 }
 
 SpvId SPIRVCodeGenerator::writeIntConstructor(Constructor& c, std::ostream& out) {
-    ASSERT(c.fType == kInt_Type);
+    ASSERT(c.fType == *fContext.fInt_Type);
     ASSERT(c.fArguments.size() == 1);
-    ASSERT(c.fArguments[0]->fType->isNumber());
+    ASSERT(c.fArguments[0]->fType.isNumber());
     SpvId result = this->nextId();
     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
-    if (c.fArguments[0]->fType == kFloat_Type) {
-        this->writeInstruction(SpvOpConvertFToS, this->getType(*c.fType), result, parameter, 
+    if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
+        this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter, 
                                out);
-    } else if (c.fArguments[0]->fType == kUInt_Type) {
-        this->writeInstruction(SpvOpSatConvertUToS, this->getType(*c.fType), result, parameter, 
+    } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
+        this->writeInstruction(SpvOpSatConvertUToS, this->getType(c.fType), result, parameter, 
                                out);
-    } else if (c.fArguments[0]->fType == kInt_Type) {
+    } else if (c.fArguments[0]->fType == *fContext.fInt_Type) {
         return parameter;
     }
     return result;
 }
 
 SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& out) {
-    ASSERT(c.fType->kind() == Type::kMatrix_Kind);
+    ASSERT(c.fType.kind() == Type::kMatrix_Kind);
     // go ahead and write the arguments so we don't try to write new instructions in the middle of
     // an instruction
     std::vector<SpvId> arguments;
@@ -1438,30 +1439,31 @@
         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
     }
     SpvId result = this->nextId();
-    int rows = c.fType->rows();
-    int columns = c.fType->columns();
+    int rows = c.fType.rows();
+    int columns = c.fType.columns();
     // FIXME this won't work to create a matrix from another matrix
     if (arguments.size() == 1) {
         // with a single argument, a matrix will have all of its diagonal entries equal to the 
         // argument and its other values equal to zero
         // FIXME this won't work for int matrices
-        FloatLiteral zero(Position(), 0);
+        FloatLiteral zero(fContext, Position(), 0);
         SpvId zeroId = this->writeFloatLiteral(zero);
         std::vector<SpvId> columnIds;
         for (int column = 0; column < columns; column++) {
-            this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(), 
+            this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), 
                               out);
-            this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)), out);
+            this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows, 1)), 
+                            out);
             SpvId columnId = this->nextId();
             this->writeWord(columnId, out);
             columnIds.push_back(columnId);
-            for (int row = 0; row < c.fType->columns(); row++) {
+            for (int row = 0; row < c.fType.columns(); row++) {
                 this->writeWord(row == column ? arguments[0] : zeroId, out);
             }
         }
         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, 
                           out);
-        this->writeWord(this->getType(*c.fType), out);
+        this->writeWord(this->getType(c.fType), out);
         this->writeWord(result, out);
         for (SpvId id : columnIds) {
             this->writeWord(id, out);
@@ -1470,15 +1472,16 @@
         std::vector<SpvId> columnIds;
         int currentCount = 0;
         for (size_t i = 0; i < arguments.size(); i++) {
-            if (c.fArguments[i]->fType->kind() == Type::kVector_Kind) {
+            if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
                 ASSERT(currentCount == 0);
                 columnIds.push_back(arguments[i]);
                 currentCount = 0;
             } else {
-                ASSERT(c.fArguments[i]->fType->kind() == Type::kScalar_Kind);
+                ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind);
                 if (currentCount == 0) {
-                    this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(), out);
-                    this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)), 
+                    this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out);
+                    this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows, 
+                                                                                     1)), 
                                     out);
                     SpvId id = this->nextId();
                     this->writeWord(id, out);
@@ -1490,7 +1493,7 @@
         }
         ASSERT(columnIds.size() == (size_t) columns);
         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
-        this->writeWord(this->getType(*c.fType), out);
+        this->writeWord(this->getType(c.fType), out);
         this->writeWord(result, out);
         for (SpvId id : columnIds) {
             this->writeWord(id, out);
@@ -1500,7 +1503,7 @@
 }
 
 SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& out) {
-    ASSERT(c.fType->kind() == Type::kVector_Kind);
+    ASSERT(c.fType.kind() == Type::kVector_Kind);
     if (c.isConstant()) {
         return this->writeConstantVector(c);
     }
@@ -1511,16 +1514,16 @@
         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
     }
     SpvId result = this->nextId();
-    if (arguments.size() == 1 && c.fArguments[0]->fType->kind() == Type::kScalar_Kind) {
-        this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->columns(), out);
-        this->writeWord(this->getType(*c.fType), out);
+    if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
+        this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
+        this->writeWord(this->getType(c.fType), out);
         this->writeWord(result, out);
-        for (int i = 0; i < c.fType->columns(); i++) {
+        for (int i = 0; i < c.fType.columns(); i++) {
             this->writeWord(arguments[0], out);
         }
     } else {
         this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
-        this->writeWord(this->getType(*c.fType), out);
+        this->writeWord(this->getType(c.fType), out);
         this->writeWord(result, out);
         for (SpvId id : arguments) {
             this->writeWord(id, out);
@@ -1530,12 +1533,12 @@
 }
 
 SpvId SPIRVCodeGenerator::writeConstructor(Constructor& c, std::ostream& out) {
-    if (c.fType == kFloat_Type) {
+    if (c.fType == *fContext.fFloat_Type) {
         return this->writeFloatConstructor(c, out);
-    } else if (c.fType == kInt_Type) {
+    } else if (c.fType == *fContext.fInt_Type) {
         return this->writeIntConstructor(c, out);
     }
-    switch (c.fType->kind()) {
+    switch (c.fType.kind()) {
         case Type::kVector_Kind:
             return this->writeVectorConstructor(c, out);
         case Type::kMatrix_Kind:
@@ -1560,7 +1563,7 @@
 SpvStorageClass_ get_storage_class(Expression& expr) {
     switch (expr.fKind) {
         case Expression::kVariableReference_Kind:
-            return get_storage_class(((VariableReference&) expr).fVariable->fModifiers);
+            return get_storage_class(((VariableReference&) expr).fVariable.fModifiers);
         case Expression::kFieldAccess_Kind:
             return get_storage_class(*((FieldAccess&) expr).fBase);
         case Expression::kIndex_Kind:
@@ -1582,7 +1585,7 @@
         case Expression::kFieldAccess_Kind: {
             FieldAccess& fieldExpr = (FieldAccess&) expr;
             chain = this->getAccessChain(*fieldExpr.fBase, out);
-            IntLiteral index(Position(), fieldExpr.fFieldIndex);
+            IntLiteral index(fContext, Position(), fieldExpr.fFieldIndex);
             chain.push_back(this->writeIntLiteral(index));
             break;
         }
@@ -1698,13 +1701,13 @@
                                                                           std::ostream& out) {
     switch (expr.fKind) {
         case Expression::kVariableReference_Kind: {
-            std::shared_ptr<Variable> var = ((VariableReference&) expr).fVariable;
-            auto entry = fVariableMap.find(var);
+            const Variable& var = ((VariableReference&) expr).fVariable;
+            auto entry = fVariableMap.find(&var);
             ASSERT(entry != fVariableMap.end());
             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
                                                                        *this,
                                                                        entry->second, 
-                                                                       this->getType(*expr.fType)));
+                                                                       this->getType(expr.fType)));
         }
         case Expression::kIndex_Kind: // fall through
         case Expression::kFieldAccess_Kind: {
@@ -1719,7 +1722,7 @@
             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
                                                                        *this,
                                                                        member, 
-                                                                       this->getType(*expr.fType)));
+                                                                       this->getType(expr.fType)));
         }
 
         case Expression::kSwizzle_Kind: {
@@ -1728,7 +1731,7 @@
             SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
             ASSERT(base);
             if (count == 1) {
-                IntLiteral index(Position(), swizzle.fComponents[0]);
+                IntLiteral index(fContext, Position(), swizzle.fComponents[0]);
                 SpvId member = this->nextId();
                 this->writeInstruction(SpvOpAccessChain,
                                        this->getPointerType(swizzle.fType, 
@@ -1740,14 +1743,14 @@
                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
                                                                        *this,
                                                                        member, 
-                                                                       this->getType(*expr.fType)));
+                                                                       this->getType(expr.fType)));
             } else {
                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
                                                                               *this, 
                                                                               base, 
                                                                               swizzle.fComponents, 
-                                                                              *swizzle.fBase->fType,
-                                                                              *expr.fType));
+                                                                              swizzle.fBase->fType,
+                                                                              expr.fType));
             }
         }
 
@@ -1758,21 +1761,22 @@
             // caught by IRGenerator
             SpvId result = this->nextId();
             SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
-            this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, out);
+            this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
+                                   fVariableBuffer);
             this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
                                                                        *this,
                                                                        result, 
-                                                                       this->getType(*expr.fType)));
+                                                                       this->getType(expr.fType)));
     }
 }
 
 SpvId SPIRVCodeGenerator::writeVariableReference(VariableReference& ref, std::ostream& out) {
-    auto entry = fVariableMap.find(ref.fVariable);
+    auto entry = fVariableMap.find(&ref.fVariable);
     ASSERT(entry != fVariableMap.end());
     SpvId var = entry->second;
     SpvId result = this->nextId();
-    this->writeInstruction(SpvOpLoad, this->getType(*ref.fVariable->fType), result, var, out);
+    this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
     return result;
 }
 
@@ -1789,11 +1793,11 @@
     SpvId result = this->nextId();
     size_t count = swizzle.fComponents.size();
     if (count == 1) {
-        this->writeInstruction(SpvOpCompositeExtract, this->getType(*swizzle.fType), result, base, 
+        this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base, 
                                swizzle.fComponents[0], out); 
     } else {
         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
-        this->writeWord(this->getType(*swizzle.fType), out);
+        this->writeWord(this->getType(swizzle.fType), out);
         this->writeWord(result, out);
         this->writeWord(base, out);
         this->writeWord(base, out);
@@ -1809,13 +1813,13 @@
                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, 
                                                SpvOp_ ifUInt, SpvOp_ ifBool, std::ostream& out) {
     SpvId result = this->nextId();
-    if (is_float(operandType)) {
+    if (is_float(fContext, operandType)) {
         this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
-    } else if (is_signed(operandType)) {
+    } else if (is_signed(fContext, operandType)) {
         this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
-    } else if (is_unsigned(operandType)) {
+    } else if (is_unsigned(fContext, operandType)) {
         this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
-    } else if (operandType == *kBool_Type) {
+    } else if (operandType == *fContext.fBool_Type) {
         this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
     } else {
         ABORT("invalid operandType: %s", operandType.description().c_str());
@@ -1862,7 +1866,7 @@
     }
 
     // "normal" operators
-    const Type& resultType = *b.fType;
+    const Type& resultType = b.fType;
     std::unique_ptr<LValue> lvalue;
     SpvId lhs;
     if (is_assignment(b.fOperator)) {
@@ -1878,23 +1882,23 @@
     // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling
     // in SPIR-V
     if (b.fLeft->fType != b.fRight->fType) {
-        if (b.fLeft->fType->kind() == Type::kVector_Kind && 
-            b.fRight->fType->isNumber()) {
+        if (b.fLeft->fType.kind() == Type::kVector_Kind && 
+            b.fRight->fType.isNumber()) {
             // promote number to vector
             SpvId vec = this->nextId();
-            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out);
+            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
             this->writeWord(this->getType(resultType), out);
             this->writeWord(vec, out);
             for (int i = 0; i < resultType.columns(); i++) {
                 this->writeWord(rhs, out);
             }
             rhs = vec;
-            operandType = b.fRight->fType.get();
-        } else if (b.fRight->fType->kind() == Type::kVector_Kind && 
-                   b.fLeft->fType->isNumber()) {
+            operandType = &b.fRight->fType;
+        } else if (b.fRight->fType.kind() == Type::kVector_Kind && 
+                   b.fLeft->fType.isNumber()) {
             // promote number to vector
             SpvId vec = this->nextId();
-            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out);
+            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
             this->writeWord(this->getType(resultType), out);
             this->writeWord(vec, out);
             for (int i = 0; i < resultType.columns(); i++) {
@@ -1902,33 +1906,33 @@
             }
             lhs = vec;
             ASSERT(!lvalue);
-            operandType = b.fLeft->fType.get();
-        } else if (b.fLeft->fType->kind() == Type::kMatrix_Kind) {
+            operandType = &b.fLeft->fType;
+        } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
             SpvOp_ op;
-            if (b.fRight->fType->kind() == Type::kMatrix_Kind) {
+            if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
                 op = SpvOpMatrixTimesMatrix;
-            } else if (b.fRight->fType->kind() == Type::kVector_Kind) {
+            } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
                 op = SpvOpMatrixTimesVector;
             } else {
-                ASSERT(b.fRight->fType->kind() == Type::kScalar_Kind);
+                ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
                 op = SpvOpMatrixTimesScalar;
             }
             SpvId result = this->nextId();
-            this->writeInstruction(op, this->getType(*b.fType), result, lhs, rhs, out);
+            this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
             if (b.fOperator == Token::STAREQ) {
                 lvalue->store(result, out);
             } else {
                 ASSERT(b.fOperator == Token::STAR);
             }
             return result;
-        } else if (b.fRight->fType->kind() == Type::kMatrix_Kind) {
+        } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
             SpvId result = this->nextId();
-            if (b.fLeft->fType->kind() == Type::kVector_Kind) {
-                this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(*b.fType), result, 
+            if (b.fLeft->fType.kind() == Type::kVector_Kind) {
+                this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result, 
                                        lhs, rhs, out);
             } else {
-                ASSERT(b.fLeft->fType->kind() == Type::kScalar_Kind);
-                this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(*b.fType), result, rhs, 
+                ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
+                this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs, 
                                        lhs, out);
             }
             if (b.fOperator == Token::STAREQ) {
@@ -1941,35 +1945,35 @@
             ABORT("unsupported binary expression: %s", b.description().c_str());
         }
     } else {
-        operandType = b.fLeft->fType.get();
-        ASSERT(*operandType == *b.fRight->fType);
+        operandType = &b.fLeft->fType;
+        ASSERT(*operandType == b.fRight->fType);
     }
     switch (b.fOperator) {
         case Token::EQEQ:
-            ASSERT(resultType == *kBool_Type);
+            ASSERT(resultType == *fContext.fBool_Type);
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdEqual, 
                                               SpvOpIEqual, SpvOpIEqual, SpvOpLogicalEqual, out);
         case Token::NEQ:
-            ASSERT(resultType == *kBool_Type);
+            ASSERT(resultType == *fContext.fBool_Type);
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdNotEqual, 
                                               SpvOpINotEqual, SpvOpINotEqual, SpvOpLogicalNotEqual, 
                                               out);
         case Token::GT:
-            ASSERT(resultType == *kBool_Type);
+            ASSERT(resultType == *fContext.fBool_Type);
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 
                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan, 
                                               SpvOpUGreaterThan, SpvOpUndef, out);
         case Token::LT:
-            ASSERT(resultType == *kBool_Type);
+            ASSERT(resultType == *fContext.fBool_Type);
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan, 
                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
         case Token::GTEQ:
-            ASSERT(resultType == *kBool_Type);
+            ASSERT(resultType == *fContext.fBool_Type);
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 
                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual, 
                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
         case Token::LTEQ:
-            ASSERT(resultType == *kBool_Type);
+            ASSERT(resultType == *fContext.fBool_Type);
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 
                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual, 
                                               SpvOpULessThanEqual, SpvOpUndef, out);
@@ -1980,8 +1984,8 @@
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, 
                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
         case Token::STAR:
-            if (b.fLeft->fType->kind() == Type::kMatrix_Kind && 
-                b.fRight->fType->kind() == Type::kMatrix_Kind) {
+            if (b.fLeft->fType.kind() == Type::kMatrix_Kind && 
+                b.fRight->fType.kind() == Type::kMatrix_Kind) {
                 // matrix multiply
                 SpvId result = this->nextId();
                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
@@ -2008,8 +2012,8 @@
             return result;
         }
         case Token::STAREQ: {
-            if (b.fLeft->fType->kind() == Type::kMatrix_Kind && 
-                b.fRight->fType->kind() == Type::kMatrix_Kind) {
+            if (b.fLeft->fType.kind() == Type::kMatrix_Kind && 
+                b.fRight->fType.kind() == Type::kMatrix_Kind) {
                 // matrix multiply
                 SpvId result = this->nextId();
                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
@@ -2039,7 +2043,7 @@
 
 SpvId SPIRVCodeGenerator::writeLogicalAnd(BinaryExpression& a, std::ostream& out) {
     ASSERT(a.fOperator == Token::LOGICALAND);
-    BoolLiteral falseLiteral(Position(), false);
+    BoolLiteral falseLiteral(fContext, Position(), false);
     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
     SpvId lhs = this->writeExpression(*a.fLeft, out);
     SpvId rhsLabel = this->nextId();
@@ -2053,14 +2057,14 @@
     this->writeInstruction(SpvOpBranch, end, out);
     this->writeLabel(end, out);
     SpvId result = this->nextId();
-    this->writeInstruction(SpvOpPhi, this->getType(*kBool_Type), result, falseConstant, lhsBlock,
-                           rhs, rhsBlock, out);
+    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant, 
+                           lhsBlock, rhs, rhsBlock, out);
     return result;
 }
 
 SpvId SPIRVCodeGenerator::writeLogicalOr(BinaryExpression& o, std::ostream& out) {
     ASSERT(o.fOperator == Token::LOGICALOR);
-    BoolLiteral trueLiteral(Position(), true);
+    BoolLiteral trueLiteral(fContext, Position(), true);
     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
     SpvId lhs = this->writeExpression(*o.fLeft, out);
     SpvId rhsLabel = this->nextId();
@@ -2074,8 +2078,8 @@
     this->writeInstruction(SpvOpBranch, end, out);
     this->writeLabel(end, out);
     SpvId result = this->nextId();
-    this->writeInstruction(SpvOpPhi, this->getType(*kBool_Type), result, trueConstant, lhsBlock,
-                           rhs, rhsBlock, out);
+    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant, 
+                           lhsBlock, rhs, rhsBlock, out);
     return result;
 }
 
@@ -2086,7 +2090,7 @@
         SpvId result = this->nextId();
         SpvId trueId = this->writeExpression(*t.fIfTrue, out);
         SpvId falseId = this->writeExpression(*t.fIfFalse, out);
-        this->writeInstruction(SpvOpSelect, this->getType(*t.fType), result, test, trueId, falseId, 
+        this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId, 
                                out);
         return result;
     }
@@ -2094,7 +2098,7 @@
     // Adreno. Switched to storing the result in a temp variable as glslang does.
     SpvId var = this->nextId();
     this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction), 
-                           var, SpvStorageClassFunction, out);
+                           var, SpvStorageClassFunction, fVariableBuffer);
     SpvId trueLabel = this->nextId();
     SpvId falseLabel = this->nextId();
     SpvId end = this->nextId();
@@ -2108,18 +2112,16 @@
     this->writeInstruction(SpvOpBranch, end, out);
     this->writeLabel(end, out);
     SpvId result = this->nextId();
-    this->writeInstruction(SpvOpLoad, this->getType(*t.fType), result, var, out);
+    this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
     return result;
 }
 
-Expression* literal_1(const Type& type) {
-    static IntLiteral int1(Position(), 1);
-    static FloatLiteral float1(Position(), 1.0);
-    if (type == *kInt_Type) {
-        return &int1;
+std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
+    if (type == *context.fInt_Type) {
+        return std::unique_ptr<Expression>(new IntLiteral(context, Position(), 1));
     }
-    else if (type == *kFloat_Type) {
-        return &float1;
+    else if (type == *context.fFloat_Type) {
+        return std::unique_ptr<Expression>(new FloatLiteral(context, Position(), 1.0));
     } else {
         ABORT("math is unsupported on type '%s'")
     }
@@ -2128,11 +2130,11 @@
 SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostream& out) {
     if (p.fOperator == Token::MINUS) {
         SpvId result = this->nextId();
-        SpvId typeId = this->getType(*p.fType);
+        SpvId typeId = this->getType(p.fType);
         SpvId expr = this->writeExpression(*p.fOperand, out);
-        if (is_float(*p.fType)) {
+        if (is_float(fContext, p.fType)) {
             this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
-        } else if (is_signed(*p.fType)) {
+        } else if (is_signed(fContext, p.fType)) {
             this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
         } else {
             ABORT("unsupported prefix expression %s", p.description().c_str());
@@ -2144,8 +2146,8 @@
             return this->writeExpression(*p.fOperand, out);
         case Token::PLUSPLUS: {
             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
-            SpvId one = this->writeExpression(*literal_1(*p.fType), out);
-            SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one, 
+            SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
+            SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, 
                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, 
                                                       out);
             lv->store(result, out);
@@ -2153,17 +2155,17 @@
         }
         case Token::MINUSMINUS: {
             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
-            SpvId one = this->writeExpression(*literal_1(*p.fType), out);
-            SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one, 
+            SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
+            SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, 
                                                       SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, 
                                                       out);
             lv->store(result, out);
             return result;
         }
         case Token::NOT: {
-            ASSERT(p.fOperand->fType == kBool_Type);
+            ASSERT(p.fOperand->fType == *fContext.fBool_Type);
             SpvId result = this->nextId();
-            this->writeInstruction(SpvOpLogicalNot, this->getType(*p.fOperand->fType), result,
+            this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
                                    this->writeExpression(*p.fOperand, out), out);
             return result;
         }
@@ -2175,16 +2177,16 @@
 SpvId SPIRVCodeGenerator::writePostfixExpression(PostfixExpression& p, std::ostream& out) {
     std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
     SpvId result = lv->load(out);
-    SpvId one = this->writeExpression(*literal_1(*p.fType), out);
+    SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
     switch (p.fOperator) {
         case Token::PLUSPLUS: {
-            SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFAdd, 
+            SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd, 
                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
             lv->store(temp, out);
             return result;
         }
         case Token::MINUSMINUS: {
-            SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFSub, 
+            SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub, 
                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
             lv->store(temp, out);
             return result;
@@ -2198,14 +2200,14 @@
     if (b.fValue) {
         if (fBoolTrue == 0) {
             fBoolTrue = this->nextId();
-            this->writeInstruction(SpvOpConstantTrue, this->getType(*b.fType), fBoolTrue, 
+            this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue, 
                                    fConstantBuffer);
         }
         return fBoolTrue;
     } else {
         if (fBoolFalse == 0) {
             fBoolFalse = this->nextId();
-            this->writeInstruction(SpvOpConstantFalse, this->getType(*b.fType), fBoolFalse, 
+            this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse, 
                                    fConstantBuffer);
         }
         return fBoolFalse;
@@ -2213,22 +2215,22 @@
 }
 
 SpvId SPIRVCodeGenerator::writeIntLiteral(IntLiteral& i) {
-    if (i.fType == kInt_Type) {
+    if (i.fType == *fContext.fInt_Type) {
         auto entry = fIntConstants.find(i.fValue);
         if (entry == fIntConstants.end()) {
             SpvId result = this->nextId();
-            this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue, 
+            this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, 
                                    fConstantBuffer);
             fIntConstants[i.fValue] = result;
             return result;
         }
         return entry->second;
     } else {
-        ASSERT(i.fType == kUInt_Type);
+        ASSERT(i.fType == *fContext.fUInt_Type);
         auto entry = fUIntConstants.find(i.fValue);
         if (entry == fUIntConstants.end()) {
             SpvId result = this->nextId();
-            this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue, 
+            this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, 
                                    fConstantBuffer);
             fUIntConstants[i.fValue] = result;
             return result;
@@ -2238,7 +2240,7 @@
 }
 
 SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) {
-    if (f.fType == kFloat_Type) {
+    if (f.fType == *fContext.fFloat_Type) {
         float value = (float) f.fValue;
         auto entry = fFloatConstants.find(value);
         if (entry == fFloatConstants.end()) {
@@ -2246,21 +2248,21 @@
             uint32_t bits;
             ASSERT(sizeof(bits) == sizeof(value));
             memcpy(&bits, &value, sizeof(bits));
-            this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result, bits, 
+            this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits, 
                                    fConstantBuffer);
             fFloatConstants[value] = result;
             return result;
         }
         return entry->second;
     } else {
-        ASSERT(f.fType == kDouble_Type);
+        ASSERT(f.fType == *fContext.fDouble_Type);
         auto entry = fDoubleConstants.find(f.fValue);
         if (entry == fDoubleConstants.end()) {
             SpvId result = this->nextId();
             uint64_t bits;
             ASSERT(sizeof(bits) == sizeof(f.fValue));
             memcpy(&bits, &f.fValue, sizeof(bits));
-            this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result, 
+            this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, 
                                    bits & 0xffffffff, bits >> 32, fConstantBuffer);
             fDoubleConstants[f.fValue] = result;
             return result;
@@ -2269,26 +2271,25 @@
     }
 }
 
-SpvId SPIRVCodeGenerator::writeFunctionStart(std::shared_ptr<FunctionDeclaration> f, 
-                                             std::ostream& out) {
-    SpvId result = fFunctionMap[f];
-    this->writeInstruction(SpvOpFunction, this->getType(*f->fReturnType), result, 
+SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, std::ostream& out) {
+    SpvId result = fFunctionMap[&f];
+    this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result, 
                            SpvFunctionControlMaskNone, this->getFunctionType(f), out);
-    this->writeInstruction(SpvOpName, result, f->fName.c_str(), fNameBuffer);
-    for (size_t i = 0; i < f->fParameters.size(); i++) {
+    this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer);
+    for (size_t i = 0; i < f.fParameters.size(); i++) {
         SpvId id = this->nextId();
-        fVariableMap[f->fParameters[i]] = id;
+        fVariableMap[f.fParameters[i]] = id;
         SpvId type;
-        type = this->getPointerType(f->fParameters[i]->fType, SpvStorageClassFunction);
+        type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
         this->writeInstruction(SpvOpFunctionParameter, type, id, out);
     }
     return result;
 }
 
-SpvId SPIRVCodeGenerator::writeFunction(FunctionDefinition& f, std::ostream& out) {
+SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, std::ostream& out) {
     SpvId result = this->writeFunctionStart(f.fDeclaration, out);
     this->writeLabel(this->nextId(), out);
-    if (f.fDeclaration->fName == "main") {
+    if (f.fDeclaration.fName == "main") {
         out << fGlobalInitializersBuffer.str();
     }
     std::stringstream bodyBuffer;
@@ -2350,21 +2351,26 @@
 }
 
 SpvId SPIRVCodeGenerator::writeInterfaceBlock(InterfaceBlock& intf) {
-    SpvId type = this->getType(*intf.fVariable->fType);
+    SpvId type = this->getType(intf.fVariable.fType);
     SpvId result = this->nextId();
     this->writeInstruction(SpvOpDecorate, type, SpvDecorationBlock, fDecorationBuffer);
-    SpvStorageClass_ storageClass = get_storage_class(intf.fVariable->fModifiers);
+    SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
     SpvId ptrType = this->nextId();
     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, type, fConstantBuffer);
     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
-    this->writeLayout(intf.fVariable->fModifiers.fLayout, result);
-    fVariableMap[intf.fVariable] = result;
+    this->writeLayout(intf.fVariable.fModifiers.fLayout, result);
+    fVariableMap[&intf.fVariable] = result;
     return result;
 }
 
 void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out) {
     for (size_t i = 0; i < decl.fVars.size(); i++) {
-        if (!decl.fVars[i]->fIsReadFrom && !decl.fVars[i]->fIsWrittenTo) {
+        if (!decl.fVars[i]->fIsReadFrom && !decl.fVars[i]->fIsWrittenTo &&
+                !(decl.fVars[i]->fModifiers.fFlags & (Modifiers::kIn_Flag |
+                                                      Modifiers::kOut_Flag |
+                                                      Modifiers::kUniform_Flag))) {
+            // variable is dead and not an input / output var (the Vulkan debug layers complain if
+            // we elide an interface var, even if it's dead)
             continue;
         }
         SpvStorageClass_ storageClass;
@@ -2373,7 +2379,7 @@
         } else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
             storageClass = SpvStorageClassOutput;
         } else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kUniform_Flag) {
-            if (decl.fVars[i]->fType->kind() == Type::kSampler_Kind) {
+            if (decl.fVars[i]->fType.kind() == Type::kSampler_Kind) {
                 storageClass = SpvStorageClassUniformConstant;
             } else {
                 storageClass = SpvStorageClassUniform;
@@ -2386,11 +2392,11 @@
         SpvId type = this->getPointerType(decl.fVars[i]->fType, storageClass);
         this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
         this->writeInstruction(SpvOpName, id, decl.fVars[i]->fName.c_str(), fNameBuffer);
-        if (decl.fVars[i]->fType->kind() == Type::kMatrix_Kind) {
+        if (decl.fVars[i]->fType.kind() == Type::kMatrix_Kind) {
             this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor, 
                                    fDecorationBuffer);
             this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride, 
-                                   (SpvId) decl.fVars[i]->fType->stride(), fDecorationBuffer);
+                                   (SpvId) decl.fVars[i]->fType.stride(), fDecorationBuffer);
         }
         if (decl.fValues[i]) {
 			ASSERT(!fCurrentBlock);
@@ -2538,15 +2544,15 @@
     for (size_t i = 0; i < program.fElements.size(); i++) {
         if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
             FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
-            fFunctionMap[f.fDeclaration] = this->nextId();
+            fFunctionMap[&f.fDeclaration] = this->nextId();
         }
     }
     for (size_t i = 0; i < program.fElements.size(); i++) {
         if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
             InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
             SpvId id = this->writeInterfaceBlock(intf);
-            if ((intf.fVariable->fModifiers.fFlags & Modifiers::kIn_Flag) ||
-                (intf.fVariable->fModifiers.fFlags & Modifiers::kOut_Flag)) {
+            if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
+                (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
                 interfaceVars.push_back(id);
             }
         }
@@ -2561,7 +2567,7 @@
             this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
         }
     }
-    std::shared_ptr<FunctionDeclaration> main = nullptr;
+    const FunctionDeclaration* main = nullptr;
     for (auto entry : fFunctionMap) {
 		if (entry.first->fName == "main") {
             main = entry.first;
@@ -2569,7 +2575,7 @@
     }
     ASSERT(main);
     for (auto entry : fVariableMap) {
-        std::shared_ptr<Variable> var = entry.first;
+        const Variable* var = entry.first;
         if (var->fStorage == Variable::kGlobal_Storage && 
                 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
                  (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 885c6b8..a20ad9f 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -61,8 +61,9 @@
         virtual void store(SpvId value, std::ostream& out) = 0;
     };
 
-    SPIRVCodeGenerator()
-    : fCapabilities(1 << SpvCapabilityShader)
+    SPIRVCodeGenerator(const Context* context)
+    : fContext(*context)
+    , fCapabilities(1 << SpvCapabilityShader)
     , fIdCount(1)
     , fBoolTrue(0)
     , fBoolFalse(0)
@@ -92,9 +93,9 @@
 
     SpvId getType(const Type& type);
 
-    SpvId getFunctionType(std::shared_ptr<FunctionDeclaration> function);
+    SpvId getFunctionType(const FunctionDeclaration& function);
 
-    SpvId getPointerType(std::shared_ptr<Type> type, SpvStorageClass_ storageClass);
+    SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass);
 
     std::vector<SpvId> getAccessChain(Expression& expr, std::ostream& out);
 
@@ -108,11 +109,11 @@
 
     SpvId writeInterfaceBlock(InterfaceBlock& intf);
 
-    SpvId writeFunctionStart(std::shared_ptr<FunctionDeclaration> f, std::ostream& out);
+    SpvId writeFunctionStart(const FunctionDeclaration& f, std::ostream& out);
     
-    SpvId writeFunctionDeclaration(std::shared_ptr<FunctionDeclaration> f, std::ostream& out);
+    SpvId writeFunctionDeclaration(const FunctionDeclaration& f, std::ostream& out);
 
-    SpvId writeFunction(FunctionDefinition& f, std::ostream& out);
+    SpvId writeFunction(const FunctionDefinition& f, std::ostream& out);
 
     void writeGlobalVars(VarDeclaration& v, std::ostream& out);
 
@@ -227,14 +228,16 @@
                           int32_t word5, int32_t word6, int32_t word7, int32_t word8, 
                           std::ostream& out);
 
+    const Context& fContext;
+
     uint64_t fCapabilities;
     SpvId fIdCount;
     SpvId fGLSLExtendedInstructions;
     typedef std::tuple<IntrinsicKind, int32_t, int32_t, int32_t, int32_t> Intrinsic;
     std::unordered_map<std::string, Intrinsic> fIntrinsicMap;
-    std::unordered_map<std::shared_ptr<FunctionDeclaration>, SpvId> fFunctionMap;
-    std::unordered_map<std::shared_ptr<Variable>, SpvId> fVariableMap;
-    std::unordered_map<std::shared_ptr<Variable>, int32_t> fInterfaceBlockMap;
+    std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap;
+    std::unordered_map<const Variable*, SpvId> fVariableMap;
+    std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap;
     std::unordered_map<std::string, SpvId> fTypeMap;
     std::stringstream fCapabilitiesBuffer;
     std::stringstream fGlobalInitializersBuffer;
diff --git a/src/sksl/ir/SkSLBinaryExpression.h b/src/sksl/ir/SkSLBinaryExpression.h
index bd89d6c..9ecdbc7 100644
--- a/src/sksl/ir/SkSLBinaryExpression.h
+++ b/src/sksl/ir/SkSLBinaryExpression.h
@@ -18,7 +18,7 @@
  */
 struct BinaryExpression : public Expression {
     BinaryExpression(Position position, std::unique_ptr<Expression> left, Token::Kind op,
-                     std::unique_ptr<Expression> right, std::shared_ptr<Type> type)
+                     std::unique_ptr<Expression> right, const Type& type)
     : INHERITED(position, kBinary_Kind, type)
     , fLeft(std::move(left))
     , fOperator(op)
diff --git a/src/sksl/ir/SkSLBlock.h b/src/sksl/ir/SkSLBlock.h
index 56ed77a..a53d13d 100644
--- a/src/sksl/ir/SkSLBlock.h
+++ b/src/sksl/ir/SkSLBlock.h
@@ -9,6 +9,7 @@
 #define SKSL_BLOCK
 
 #include "SkSLStatement.h"
+#include "SkSLSymbolTable.h"
 
 namespace SkSL {
 
@@ -16,9 +17,11 @@
  * A block of multiple statements functioning as a single statement.
  */
 struct Block : public Statement {
-    Block(Position position, std::vector<std::unique_ptr<Statement>> statements)
+    Block(Position position, std::vector<std::unique_ptr<Statement>> statements,
+          const std::shared_ptr<SymbolTable> symbols)
     : INHERITED(position, kBlock_Kind)
-    , fStatements(std::move(statements)) {}
+    , fStatements(std::move(statements))
+    , fSymbols(std::move(symbols)) {}
 
     std::string description() const override {
         std::string result = "{";
@@ -31,6 +34,7 @@
     }
 
     const std::vector<std::unique_ptr<Statement>> fStatements;
+    const std::shared_ptr<SymbolTable> fSymbols;
 
     typedef Statement INHERITED;
 };
diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h
index 3c40e59..8f55a69 100644
--- a/src/sksl/ir/SkSLBoolLiteral.h
+++ b/src/sksl/ir/SkSLBoolLiteral.h
@@ -8,6 +8,7 @@
 #ifndef SKSL_BOOLLITERAL
 #define SKSL_BOOLLITERAL
 
+#include "SkSLContext.h"
 #include "SkSLExpression.h"
 
 namespace SkSL {
@@ -16,8 +17,8 @@
  * Represents 'true' or 'false'.
  */
 struct BoolLiteral : public Expression {
-    BoolLiteral(Position position, bool value)
-    : INHERITED(position, kBoolLiteral_Kind, kBool_Type)
+    BoolLiteral(const Context& context, Position position, bool value)
+    : INHERITED(position, kBoolLiteral_Kind, *context.fBool_Type)
     , fValue(value) {}
 
     std::string description() const override {
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index c58da7e..0501b65 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -16,13 +16,13 @@
  * Represents the construction of a compound type, such as "vec2(x, y)".
  */
 struct Constructor : public Expression {
-    Constructor(Position position, std::shared_ptr<Type> type, 
+    Constructor(Position position, const Type& type, 
                 std::vector<std::unique_ptr<Expression>> arguments)
-    : INHERITED(position, kConstructor_Kind, std::move(type))
+    : INHERITED(position, kConstructor_Kind, type)
     , fArguments(std::move(arguments)) {}
 
     std::string description() const override {
-        std::string result = fType->description() + "(";
+        std::string result = fType.description() + "(";
         std::string separator = "";
         for (size_t i = 0; i < fArguments.size(); i++) {
             result += separator;
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 1e42c7a..92cb37d 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -35,7 +35,7 @@
         kTypeReference_Kind,
     };
 
-    Expression(Position position, Kind kind, std::shared_ptr<Type> type)
+    Expression(Position position, Kind kind, const Type& type)
     : INHERITED(position)
     , fKind(kind)
     , fType(std::move(type)) {}
@@ -45,7 +45,7 @@
     }
 
     const Kind fKind;
-    const std::shared_ptr<Type> fType;
+    const Type& fType;
 
     typedef IRNode INHERITED;
 };
diff --git a/src/sksl/ir/SkSLField.h b/src/sksl/ir/SkSLField.h
index f2b68bc..a01df29 100644
--- a/src/sksl/ir/SkSLField.h
+++ b/src/sksl/ir/SkSLField.h
@@ -21,16 +21,16 @@
  * result of declaring anonymous interface blocks.
  */
 struct Field : public Symbol {
-    Field(Position position, std::shared_ptr<Variable> owner, int fieldIndex)
-    : INHERITED(position, kField_Kind, owner->fType->fields()[fieldIndex].fName)
+    Field(Position position, const Variable& owner, int fieldIndex)
+    : INHERITED(position, kField_Kind, owner.fType.fields()[fieldIndex].fName)
     , fOwner(owner)
     , fFieldIndex(fieldIndex) {}
 
     virtual std::string description() const override {
-        return fOwner->description() + "." + fOwner->fType->fields()[fFieldIndex].fName;
+        return fOwner.description() + "." + fOwner.fType.fields()[fFieldIndex].fName;
     }
 
-    const std::shared_ptr<Variable> fOwner;
+    const Variable& fOwner;
     const int fFieldIndex;
 
     typedef Symbol INHERITED;
diff --git a/src/sksl/ir/SkSLFieldAccess.h b/src/sksl/ir/SkSLFieldAccess.h
index 053498e..f09c3a3 100644
--- a/src/sksl/ir/SkSLFieldAccess.h
+++ b/src/sksl/ir/SkSLFieldAccess.h
@@ -18,12 +18,12 @@
  */
 struct FieldAccess : public Expression {
     FieldAccess(std::unique_ptr<Expression> base, int fieldIndex)
-    : INHERITED(base->fPosition, kFieldAccess_Kind, base->fType->fields()[fieldIndex].fType)
+    : INHERITED(base->fPosition, kFieldAccess_Kind, base->fType.fields()[fieldIndex].fType)
     , fBase(std::move(base))
     , fFieldIndex(fieldIndex) {}
 
     virtual std::string description() const override {
-        return fBase->description() + "." + fBase->fType->fields()[fFieldIndex].fName;
+        return fBase->description() + "." + fBase->fType.fields()[fFieldIndex].fName;
     }
 
     const std::unique_ptr<Expression> fBase;
diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h
index deb5b27..d9c8b65 100644
--- a/src/sksl/ir/SkSLFloatLiteral.h
+++ b/src/sksl/ir/SkSLFloatLiteral.h
@@ -8,6 +8,7 @@
 #ifndef SKSL_FLOATLITERAL
 #define SKSL_FLOATLITERAL
 
+#include "SkSLContext.h"
 #include "SkSLExpression.h"
 
 namespace SkSL {
@@ -16,8 +17,8 @@
  * A literal floating point number.
  */
 struct FloatLiteral : public Expression {
-    FloatLiteral(Position position, double value)
-    : INHERITED(position, kFloatLiteral_Kind, kFloat_Type)
+    FloatLiteral(const Context& context, Position position, double value)
+    : INHERITED(position, kFloatLiteral_Kind, *context.fFloat_Type)
     , fValue(value) {}
 
     virtual std::string description() const override {
diff --git a/src/sksl/ir/SkSLForStatement.h b/src/sksl/ir/SkSLForStatement.h
index 70bb401..642d151 100644
--- a/src/sksl/ir/SkSLForStatement.h
+++ b/src/sksl/ir/SkSLForStatement.h
@@ -10,6 +10,7 @@
 
 #include "SkSLExpression.h"
 #include "SkSLStatement.h"
+#include "SkSLSymbolTable.h"
 
 namespace SkSL {
 
@@ -19,12 +20,13 @@
 struct ForStatement : public Statement {
     ForStatement(Position position, std::unique_ptr<Statement> initializer, 
                  std::unique_ptr<Expression> test, std::unique_ptr<Expression> next, 
-                 std::unique_ptr<Statement> statement)
+                 std::unique_ptr<Statement> statement, std::shared_ptr<SymbolTable> symbols)
     : INHERITED(position, kFor_Kind)
     , fInitializer(std::move(initializer))
     , fTest(std::move(test))
     , fNext(std::move(next))
-    , fStatement(std::move(statement)) {}
+    , fStatement(std::move(statement))
+    , fSymbols(symbols) {}
 
     std::string description() const override {
         std::string result = "for (";
@@ -47,6 +49,7 @@
     const std::unique_ptr<Expression> fTest;
     const std::unique_ptr<Expression> fNext;
     const std::unique_ptr<Statement> fStatement;
+    const std::shared_ptr<SymbolTable> fSymbols;
 
     typedef Statement INHERITED;
 };
diff --git a/src/sksl/ir/SkSLFunctionCall.h b/src/sksl/ir/SkSLFunctionCall.h
index 78d2566..85dba40 100644
--- a/src/sksl/ir/SkSLFunctionCall.h
+++ b/src/sksl/ir/SkSLFunctionCall.h
@@ -17,14 +17,14 @@
  * A function invocation.
  */
 struct FunctionCall : public Expression {
-    FunctionCall(Position position, std::shared_ptr<FunctionDeclaration> function,
+    FunctionCall(Position position, const FunctionDeclaration& function,
                  std::vector<std::unique_ptr<Expression>> arguments)
-    : INHERITED(position, kFunctionCall_Kind, function->fReturnType)
+    : INHERITED(position, kFunctionCall_Kind, function.fReturnType)
     , fFunction(std::move(function))
     , fArguments(std::move(arguments)) {}
 
     std::string description() const override {
-        std::string result = fFunction->fName + "(";
+        std::string result = fFunction.fName + "(";
         std::string separator = "";
         for (size_t i = 0; i < fArguments.size(); i++) {
             result += separator;
@@ -35,7 +35,7 @@
         return result;
     }
 
-    const std::shared_ptr<FunctionDeclaration> fFunction;
+    const FunctionDeclaration& fFunction;
     const std::vector<std::unique_ptr<Expression>> fArguments;
 
     typedef Expression INHERITED;
diff --git a/src/sksl/ir/SkSLFunctionDeclaration.h b/src/sksl/ir/SkSLFunctionDeclaration.h
index 32c23f5..16a184a 100644
--- a/src/sksl/ir/SkSLFunctionDeclaration.h
+++ b/src/sksl/ir/SkSLFunctionDeclaration.h
@@ -10,6 +10,7 @@
 
 #include "SkSLModifiers.h"
 #include "SkSLSymbol.h"
+#include "SkSLSymbolTable.h"
 #include "SkSLType.h"
 #include "SkSLVariable.h"
 
@@ -20,15 +21,14 @@
  */
 struct FunctionDeclaration : public Symbol {
     FunctionDeclaration(Position position, std::string name, 
-                        std::vector<std::shared_ptr<Variable>> parameters, 
-                        std::shared_ptr<Type> returnType)
+                        std::vector<const Variable*> parameters, const Type& returnType)
     : INHERITED(position, kFunctionDeclaration_Kind, std::move(name))
     , fDefined(false)
-    , fParameters(parameters)
+    , fParameters(std::move(parameters))
     , fReturnType(returnType) {}
 
     std::string description() const override {
-        std::string result = fReturnType->description() + " " + fName + "(";
+        std::string result = fReturnType.description() + " " + fName + "(";
         std::string separator = "";
         for (auto p : fParameters) {
             result += separator;
@@ -39,13 +39,24 @@
         return result;
     }
 
-    bool matches(FunctionDeclaration& f) {
-        return fName == f.fName && fParameters == f.fParameters;
+    bool matches(const FunctionDeclaration& f) const {
+        if (fName != f.fName) {
+            return false;
+        }
+        if (fParameters.size() != f.fParameters.size()) {
+            return false;
+        }
+        for (size_t i = 0; i < fParameters.size(); i++) {
+            if (fParameters[i]->fType != f.fParameters[i]->fType) {
+                return false;
+            }
+        }
+        return true;
     }
 
     mutable bool fDefined;
-    const std::vector<std::shared_ptr<Variable>> fParameters;
-    const std::shared_ptr<Type> fReturnType;
+    const std::vector<const Variable*> fParameters;
+    const Type& fReturnType;
 
     typedef Symbol INHERITED;
 };
diff --git a/src/sksl/ir/SkSLFunctionDefinition.h b/src/sksl/ir/SkSLFunctionDefinition.h
index fceb547..ace27a3 100644
--- a/src/sksl/ir/SkSLFunctionDefinition.h
+++ b/src/sksl/ir/SkSLFunctionDefinition.h
@@ -18,17 +18,17 @@
  * A function definition (a declaration plus an associated block of code).
  */
 struct FunctionDefinition : public ProgramElement {
-    FunctionDefinition(Position position, std::shared_ptr<FunctionDeclaration> declaration,
+    FunctionDefinition(Position position, const FunctionDeclaration& declaration, 
                        std::unique_ptr<Block> body)
     : INHERITED(position, kFunction_Kind)
-    , fDeclaration(std::move(declaration))
+    , fDeclaration(declaration)
     , fBody(std::move(body)) {}
 
     std::string description() const override {
-        return fDeclaration->description() + " " + fBody->description();
+        return fDeclaration.description() + " " + fBody->description();
     }
 
-    const std::shared_ptr<FunctionDeclaration> fDeclaration;
+    const FunctionDeclaration& fDeclaration;
     const std::unique_ptr<Block> fBody;
 
     typedef ProgramElement INHERITED;
diff --git a/src/sksl/ir/SkSLFunctionReference.h b/src/sksl/ir/SkSLFunctionReference.h
index d5cc444..5d97a58 100644
--- a/src/sksl/ir/SkSLFunctionReference.h
+++ b/src/sksl/ir/SkSLFunctionReference.h
@@ -8,6 +8,7 @@
 #ifndef SKSL_FUNCTIONREFERENCE
 #define SKSL_FUNCTIONREFERENCE
 
+#include "SkSLContext.h"
 #include "SkSLExpression.h"
 
 namespace SkSL {
@@ -17,8 +18,9 @@
  * always eventually replaced by FunctionCalls in valid programs.
  */
 struct FunctionReference : public Expression {
-    FunctionReference(Position position, std::vector<std::shared_ptr<FunctionDeclaration>> function)
-    : INHERITED(position, kFunctionReference_Kind, kInvalid_Type)
+    FunctionReference(const Context& context, Position position, 
+                      std::vector<const FunctionDeclaration*> function)
+    : INHERITED(position, kFunctionReference_Kind, *context.fInvalid_Type)
     , fFunctions(function) {}
 
     virtual std::string description() const override {
@@ -26,7 +28,7 @@
     	return "<function>";
     }
 
-    const std::vector<std::shared_ptr<FunctionDeclaration>> fFunctions;
+    const std::vector<const FunctionDeclaration*> fFunctions;
 
     typedef Expression INHERITED;
 };
diff --git a/src/sksl/ir/SkSLIndexExpression.h b/src/sksl/ir/SkSLIndexExpression.h
index 538c656..f5b0d09 100644
--- a/src/sksl/ir/SkSLIndexExpression.h
+++ b/src/sksl/ir/SkSLIndexExpression.h
@@ -16,21 +16,21 @@
 /**
  * Given a type, returns the type that will result from extracting an array value from it.
  */
-static std::shared_ptr<Type> index_type(const Type& type) {
+static const Type& index_type(const Context& context, const Type& type) {
     if (type.kind() == Type::kMatrix_Kind) {
-        if (type.componentType() == kFloat_Type) {
+        if (type.componentType() == *context.fFloat_Type) {
             switch (type.columns()) {
-                case 2: return kVec2_Type;
-                case 3: return kVec3_Type;
-                case 4: return kVec4_Type;
+                case 2: return *context.fVec2_Type;
+                case 3: return *context.fVec3_Type;
+                case 4: return *context.fVec4_Type;
                 default: ASSERT(false);
             }
         } else {
-            ASSERT(type.componentType() == kDouble_Type);
+            ASSERT(type.componentType() == *context.fDouble_Type);
             switch (type.columns()) {
-                case 2: return kDVec2_Type;
-                case 3: return kDVec3_Type;
-                case 4: return kDVec4_Type;
+                case 2: return *context.fDVec2_Type;
+                case 3: return *context.fDVec3_Type;
+                case 4: return *context.fDVec4_Type;
                 default: ASSERT(false);
             }
         }
@@ -42,11 +42,12 @@
  * An expression which extracts a value from an array or matrix, as in 'm[2]'.
  */
 struct IndexExpression : public Expression {
-    IndexExpression(std::unique_ptr<Expression> base, std::unique_ptr<Expression> index)
-    : INHERITED(base->fPosition, kIndex_Kind, index_type(*base->fType))
+    IndexExpression(const Context& context, std::unique_ptr<Expression> base, 
+                    std::unique_ptr<Expression> index)
+    : INHERITED(base->fPosition, kIndex_Kind, index_type(context, base->fType))
     , fBase(std::move(base))
     , fIndex(std::move(index)) {
-        ASSERT(fIndex->fType == kInt_Type);
+        ASSERT(fIndex->fType == *context.fInt_Type);
     }
 
     std::string description() const override {
diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h
index 80b30d7..f2bf40b 100644
--- a/src/sksl/ir/SkSLIntLiteral.h
+++ b/src/sksl/ir/SkSLIntLiteral.h
@@ -18,8 +18,8 @@
 struct IntLiteral : public Expression {
     // FIXME: we will need to revisit this if/when we add full support for both signed and unsigned
     // 64-bit integers, but for right now an int64_t will hold every value we care about
-    IntLiteral(Position position, int64_t value)
-    : INHERITED(position, kIntLiteral_Kind, kInt_Type)
+    IntLiteral(const Context& context, Position position, int64_t value)
+    : INHERITED(position, kIntLiteral_Kind, *context.fInt_Type)
     , fValue(value) {}
 
     virtual std::string description() const override {
diff --git a/src/sksl/ir/SkSLInterfaceBlock.h b/src/sksl/ir/SkSLInterfaceBlock.h
index baedb58..f1121ed 100644
--- a/src/sksl/ir/SkSLInterfaceBlock.h
+++ b/src/sksl/ir/SkSLInterfaceBlock.h
@@ -24,22 +24,24 @@
  * At the IR level, this is represented by a single variable of struct type.
  */
 struct InterfaceBlock : public ProgramElement {
-    InterfaceBlock(Position position, std::shared_ptr<Variable> var)
+    InterfaceBlock(Position position, const Variable& var, std::shared_ptr<SymbolTable> typeOwner)
     : INHERITED(position, kInterfaceBlock_Kind) 
-    , fVariable(std::move(var)) {
-        ASSERT(fVariable->fType->kind() == Type::kStruct_Kind);
+    , fVariable(std::move(var))
+    , fTypeOwner(typeOwner) {
+        ASSERT(fVariable.fType.kind() == Type::kStruct_Kind);
     }
 
     std::string description() const override {
-        std::string result = fVariable->fModifiers.description() + fVariable->fName + " {\n";
-        for (size_t i = 0; i < fVariable->fType->fields().size(); i++) {
-            result += fVariable->fType->fields()[i].description() + "\n";
+        std::string result = fVariable.fModifiers.description() + fVariable.fName + " {\n";
+        for (size_t i = 0; i < fVariable.fType.fields().size(); i++) {
+            result += fVariable.fType.fields()[i].description() + "\n";
         }
         result += "};";
         return result;
     }
 
-    const std::shared_ptr<Variable> fVariable;
+    const Variable& fVariable;
+    const std::shared_ptr<SymbolTable> fTypeOwner;
 
     typedef ProgramElement INHERITED;
 };
diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h
index 5edcfde..205db6e 100644
--- a/src/sksl/ir/SkSLProgram.h
+++ b/src/sksl/ir/SkSLProgram.h
@@ -12,6 +12,7 @@
 #include <memory>
 
 #include "SkSLProgramElement.h"
+#include "SkSLSymbolTable.h"
 
 namespace SkSL {
 
@@ -24,13 +25,16 @@
         kVertex_Kind
     };
 
-    Program(Kind kind, std::vector<std::unique_ptr<ProgramElement>> elements)
+    Program(Kind kind, std::vector<std::unique_ptr<ProgramElement>> elements, 
+            std::shared_ptr<SymbolTable> symbols)
     : fKind(kind) 
-    , fElements(std::move(elements)) {}
+    , fElements(std::move(elements))
+    , fSymbols(symbols) {}
 
     Kind fKind;
 
     std::vector<std::unique_ptr<ProgramElement>> fElements;
+    std::shared_ptr<SymbolTable> fSymbols;
 };
 
 } // namespace
diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h
index ce360d1..0eb4a00 100644
--- a/src/sksl/ir/SkSLSwizzle.h
+++ b/src/sksl/ir/SkSLSwizzle.h
@@ -18,41 +18,40 @@
  * instance, swizzling a vec3 with two components will result in a vec2. It is possible to swizzle
  * with more components than the source vector, as in 'vec2(1).xxxx'.
  */
-static std::shared_ptr<Type> get_type(Expression& value, 
-                                      size_t count) {
-    std::shared_ptr<Type> base = value.fType->componentType();
+static const Type& get_type(const Context& context, Expression& value, size_t count) {
+    const Type& base = value.fType.componentType();
     if (count == 1) {
         return base;
     }
-    if (base == kFloat_Type) {
+    if (base == *context.fFloat_Type) {
         switch (count) {
-            case 2: return kVec2_Type;
-            case 3: return kVec3_Type;
-            case 4: return kVec4_Type;
+            case 2: return *context.fVec2_Type;
+            case 3: return *context.fVec3_Type;
+            case 4: return *context.fVec4_Type;
         }
-    } else if (base == kDouble_Type) {
+    } else if (base == *context.fDouble_Type) {
         switch (count) {
-            case 2: return kDVec2_Type;
-            case 3: return kDVec3_Type;
-            case 4: return kDVec4_Type;
+            case 2: return *context.fDVec2_Type;
+            case 3: return *context.fDVec3_Type;
+            case 4: return *context.fDVec4_Type;
         }
-    } else if (base == kInt_Type) {
+    } else if (base == *context.fInt_Type) {
         switch (count) {
-            case 2: return kIVec2_Type;
-            case 3: return kIVec3_Type;
-            case 4: return kIVec4_Type;
+            case 2: return *context.fIVec2_Type;
+            case 3: return *context.fIVec3_Type;
+            case 4: return *context.fIVec4_Type;
         }
-    } else if (base == kUInt_Type) {
+    } else if (base == *context.fUInt_Type) {
         switch (count) {
-            case 2: return kUVec2_Type;
-            case 3: return kUVec3_Type;
-            case 4: return kUVec4_Type;
+            case 2: return *context.fUVec2_Type;
+            case 3: return *context.fUVec3_Type;
+            case 4: return *context.fUVec4_Type;
         }
-    } else if (base == kBool_Type) {
+    } else if (base == *context.fBool_Type) {
         switch (count) {
-            case 2: return kBVec2_Type;
-            case 3: return kBVec3_Type;
-            case 4: return kBVec4_Type;
+            case 2: return *context.fBVec2_Type;
+            case 3: return *context.fBVec3_Type;
+            case 4: return *context.fBVec4_Type;
         }
     }
     ABORT("cannot swizzle %s\n", value.description().c_str());
@@ -62,8 +61,8 @@
  * Represents a vector swizzle operation such as 'vec2(1, 2, 3).zyx'.
  */
 struct Swizzle : public Expression {
-    Swizzle(std::unique_ptr<Expression> base, std::vector<int> components)
-    : INHERITED(base->fPosition, kSwizzle_Kind, get_type(*base, components.size()))
+    Swizzle(const Context& context, std::unique_ptr<Expression> base, std::vector<int> components)
+    : INHERITED(base->fPosition, kSwizzle_Kind, get_type(context, *base, components.size()))
     , fBase(std::move(base))
     , fComponents(std::move(components)) {
         ASSERT(fComponents.size() >= 1 && fComponents.size() <= 4);
diff --git a/src/sksl/ir/SkSLSymbolTable.cpp b/src/sksl/ir/SkSLSymbolTable.cpp
index af83f7a..80e22da 100644
--- a/src/sksl/ir/SkSLSymbolTable.cpp
+++ b/src/sksl/ir/SkSLSymbolTable.cpp
@@ -5,23 +5,23 @@
  * found in the LICENSE file.
  */
 
- #include "SkSLSymbolTable.h"
+#include "SkSLSymbolTable.h"
+#include "SkSLUnresolvedFunction.h"
 
 namespace SkSL {
 
-std::vector<std::shared_ptr<FunctionDeclaration>> SymbolTable::GetFunctions(
-                                                                 const std::shared_ptr<Symbol>& s) {
-    switch (s->fKind) {
+std::vector<const FunctionDeclaration*> SymbolTable::GetFunctions(const Symbol& s) {
+    switch (s.fKind) {
         case Symbol::kFunctionDeclaration_Kind:
-            return { std::static_pointer_cast<FunctionDeclaration>(s) };
+            return { &((FunctionDeclaration&) s) };
         case Symbol::kUnresolvedFunction_Kind:
-            return ((UnresolvedFunction&) *s).fFunctions;
+            return ((UnresolvedFunction&) s).fFunctions;
         default:
             return { };
     }
 }
 
-std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) {
+const Symbol* SymbolTable::operator[](const std::string& name) {
     const auto& entry = fSymbols.find(name);
     if (entry == fSymbols.end()) {
         if (fParent) {
@@ -30,15 +30,15 @@
         return nullptr;
     }
     if (fParent) {
-        auto functions = GetFunctions(entry->second);
+        auto functions = GetFunctions(*entry->second);
         if (functions.size() > 0) {
             bool modified = false;
-            std::shared_ptr<Symbol> previous = (*fParent)[name];
+            const Symbol* previous = (*fParent)[name];
             if (previous) {
-                auto previousFunctions = GetFunctions(previous);
-                for (const std::shared_ptr<FunctionDeclaration>& prev : previousFunctions) {
+                auto previousFunctions = GetFunctions(*previous);
+                for (const FunctionDeclaration* prev : previousFunctions) {
                     bool found = false;
-                    for (const std::shared_ptr<FunctionDeclaration>& current : functions) {
+                    for (const FunctionDeclaration* current : functions) {
                         if (current->matches(*prev)) {
                             found = true;
                             break;
@@ -51,7 +51,7 @@
                 }
                 if (modified) {
                     ASSERT(functions.size() > 1);
-                    return std::shared_ptr<Symbol>(new UnresolvedFunction(functions));
+                    return this->takeOwnership(new UnresolvedFunction(functions));
                 }
             }
         }
@@ -59,27 +59,42 @@
     return entry->second;
 }
 
-void SymbolTable::add(const std::string& name, std::shared_ptr<Symbol> symbol) {
-        const auto& existing = fSymbols.find(name);
-        if (existing == fSymbols.end()) {
-            fSymbols[name] = symbol;
-        } else if (symbol->fKind == Symbol::kFunctionDeclaration_Kind) {
-            const std::shared_ptr<Symbol>& oldSymbol = existing->second;
-            if (oldSymbol->fKind == Symbol::kFunctionDeclaration_Kind) {
-                std::vector<std::shared_ptr<FunctionDeclaration>> functions;
-                functions.push_back(std::static_pointer_cast<FunctionDeclaration>(oldSymbol));
-                functions.push_back(std::static_pointer_cast<FunctionDeclaration>(symbol));
-                fSymbols[name].reset(new UnresolvedFunction(std::move(functions)));
-            } else if (oldSymbol->fKind == Symbol::kUnresolvedFunction_Kind) {
-                std::vector<std::shared_ptr<FunctionDeclaration>> functions;
-                for (const auto& f : ((UnresolvedFunction&) *oldSymbol).fFunctions) {
-                    functions.push_back(f);
-                }
-                functions.push_back(std::static_pointer_cast<FunctionDeclaration>(symbol));
-                fSymbols[name].reset(new UnresolvedFunction(std::move(functions)));
+Symbol* SymbolTable::takeOwnership(Symbol* s) {
+    fOwnedPointers.push_back(std::unique_ptr<Symbol>(s));
+    return s;
+}
+
+void SymbolTable::add(const std::string& name, std::unique_ptr<Symbol> symbol) {
+    this->addWithoutOwnership(name, symbol.get());
+    fOwnedPointers.push_back(std::move(symbol));
+}
+
+void SymbolTable::addWithoutOwnership(const std::string& name, const Symbol* symbol) {
+    const auto& existing = fSymbols.find(name);
+    if (existing == fSymbols.end()) {
+        fSymbols[name] = symbol;
+    } else if (symbol->fKind == Symbol::kFunctionDeclaration_Kind) {
+        const Symbol* oldSymbol = existing->second;
+        if (oldSymbol->fKind == Symbol::kFunctionDeclaration_Kind) {
+            std::vector<const FunctionDeclaration*> functions;
+            functions.push_back((const FunctionDeclaration*) oldSymbol);
+            functions.push_back((const FunctionDeclaration*) symbol);
+            UnresolvedFunction* u = new UnresolvedFunction(std::move(functions));
+            fSymbols[name] = u;
+            this->takeOwnership(u);
+        } else if (oldSymbol->fKind == Symbol::kUnresolvedFunction_Kind) {
+            std::vector<const FunctionDeclaration*> functions;
+            for (const auto* f : ((UnresolvedFunction&) *oldSymbol).fFunctions) {
+                functions.push_back(f);
             }
-        } else {
-            fErrorReporter.error(symbol->fPosition, "symbol '" + name + "' was already defined");
+            functions.push_back((const FunctionDeclaration*) symbol);
+            UnresolvedFunction* u = new UnresolvedFunction(std::move(functions));
+            fSymbols[name] = u;
+            this->takeOwnership(u);
         }
+    } else {
+        fErrorReporter.error(symbol->fPosition, "symbol '" + name + "' was already defined");
     }
+}
+
 } // namespace
diff --git a/src/sksl/ir/SkSLSymbolTable.h b/src/sksl/ir/SkSLSymbolTable.h
index 151475d..d732023 100644
--- a/src/sksl/ir/SkSLSymbolTable.h
+++ b/src/sksl/ir/SkSLSymbolTable.h
@@ -10,12 +10,14 @@
 
 #include <memory>
 #include <unordered_map>
+#include <vector>
 #include "SkSLErrorReporter.h"
 #include "SkSLSymbol.h"
-#include "SkSLUnresolvedFunction.h"
 
 namespace SkSL {
 
+struct FunctionDeclaration;
+
 /**
  * Maps identifiers to symbols. Functions, in particular, are mapped to either FunctionDeclaration
  * or UnresolvedFunction depending on whether they are overloaded or not.
@@ -29,17 +31,22 @@
     : fParent(parent)
     , fErrorReporter(errorReporter) {}
 
-    std::shared_ptr<Symbol> operator[](const std::string& name);
+    const Symbol* operator[](const std::string& name);
 
-    void add(const std::string& name, std::shared_ptr<Symbol> symbol);
+    void add(const std::string& name, std::unique_ptr<Symbol> symbol);
+
+    void addWithoutOwnership(const std::string& name, const Symbol* symbol);
+
+    Symbol* takeOwnership(Symbol* s);
 
     const std::shared_ptr<SymbolTable> fParent;
 
 private:
-    static std::vector<std::shared_ptr<FunctionDeclaration>> GetFunctions(
-                                                                  const std::shared_ptr<Symbol>& s);
+    static std::vector<const FunctionDeclaration*> GetFunctions(const Symbol& s);
 
-    std::unordered_map<std::string, std::shared_ptr<Symbol>> fSymbols;
+    std::vector<std::unique_ptr<Symbol>> fOwnedPointers;
+
+    std::unordered_map<std::string, const Symbol*> fSymbols;
 
     ErrorReporter& fErrorReporter;
 };
diff --git a/src/sksl/ir/SkSLType.cpp b/src/sksl/ir/SkSLType.cpp
index 27cbd39..d28c4f0 100644
--- a/src/sksl/ir/SkSLType.cpp
+++ b/src/sksl/ir/SkSLType.cpp
@@ -6,29 +6,30 @@
  */
  
 #include "SkSLType.h"
+#include "SkSLContext.h"
 
 namespace SkSL {
 
-bool Type::determineCoercionCost(std::shared_ptr<Type> other, int* outCost) const {
-    if (this == other.get()) {
+bool Type::determineCoercionCost(const Type& other, int* outCost) const {
+    if (*this == other) {
         *outCost = 0;
         return true;
     }
-    if (this->kind() == kVector_Kind && other->kind() == kVector_Kind) {
-        if (this->columns() == other->columns()) {
-            return this->componentType()->determineCoercionCost(other->componentType(), outCost);
+    if (this->kind() == kVector_Kind && other.kind() == kVector_Kind) {
+        if (this->columns() == other.columns()) {
+            return this->componentType().determineCoercionCost(other.componentType(), outCost);
         }
         return false;
     }
     if (this->kind() == kMatrix_Kind) {
-        if (this->columns() == other->columns() && 
-            this->rows() == other->rows()) {
-            return this->componentType()->determineCoercionCost(other->componentType(), outCost);
+        if (this->columns() == other.columns() && 
+            this->rows() == other.rows()) {
+            return this->componentType().determineCoercionCost(other.componentType(), outCost);
         }
         return false;
     }
     for (size_t i = 0; i < fCoercibleTypes.size(); i++) {
-        if (fCoercibleTypes[i] == other) {
+        if (*fCoercibleTypes[i] == other) {
             *outCost = (int) i + 1;
             return true;
         }
@@ -36,93 +37,93 @@
     return false;
 }
 
-std::shared_ptr<Type> Type::toCompound(int columns, int rows) {
+const Type& Type::toCompound(const Context& context, int columns, int rows) const {
     ASSERT(this->kind() == Type::kScalar_Kind);
     if (columns == 1 && rows == 1) {
-        return std::shared_ptr<Type>(this);
+        return *this;
     }
-    if (*this == *kFloat_Type) {
+    if (*this == *context.fFloat_Type) {
         switch (rows) {
             case 1:
                 switch (columns) {
-                    case 2: return kVec2_Type;
-                    case 3: return kVec3_Type;
-                    case 4: return kVec4_Type;
+                    case 2: return *context.fVec2_Type;
+                    case 3: return *context.fVec3_Type;
+                    case 4: return *context.fVec4_Type;
                     default: ABORT("unsupported vector column count (%d)", columns);
                 }
             case 2:
                 switch (columns) {
-                    case 2: return kMat2x2_Type;
-                    case 3: return kMat3x2_Type;
-                    case 4: return kMat4x2_Type;
+                    case 2: return *context.fMat2x2_Type;
+                    case 3: return *context.fMat3x2_Type;
+                    case 4: return *context.fMat4x2_Type;
                     default: ABORT("unsupported matrix column count (%d)", columns);
                 }
             case 3:
                 switch (columns) {
-                    case 2: return kMat2x3_Type;
-                    case 3: return kMat3x3_Type;
-                    case 4: return kMat4x3_Type;
+                    case 2: return *context.fMat2x3_Type;
+                    case 3: return *context.fMat3x3_Type;
+                    case 4: return *context.fMat4x3_Type;
                     default: ABORT("unsupported matrix column count (%d)", columns);
                 }
             case 4:
                 switch (columns) {
-                    case 2: return kMat2x4_Type;
-                    case 3: return kMat3x4_Type;
-                    case 4: return kMat4x4_Type;
+                    case 2: return *context.fMat2x4_Type;
+                    case 3: return *context.fMat3x4_Type;
+                    case 4: return *context.fMat4x4_Type;
                     default: ABORT("unsupported matrix column count (%d)", columns);
                 }
             default: ABORT("unsupported row count (%d)", rows);
         }
-    } else if (*this == *kDouble_Type) {
+    } else if (*this == *context.fDouble_Type) {
         switch (rows) {
             case 1:
                 switch (columns) {
-                    case 2: return kDVec2_Type;
-                    case 3: return kDVec3_Type;
-                    case 4: return kDVec4_Type;
+                    case 2: return *context.fDVec2_Type;
+                    case 3: return *context.fDVec3_Type;
+                    case 4: return *context.fDVec4_Type;
                     default: ABORT("unsupported vector column count (%d)", columns);
                 }
             case 2:
                 switch (columns) {
-                    case 2: return kDMat2x2_Type;
-                    case 3: return kDMat3x2_Type;
-                    case 4: return kDMat4x2_Type;
+                    case 2: return *context.fDMat2x2_Type;
+                    case 3: return *context.fDMat3x2_Type;
+                    case 4: return *context.fDMat4x2_Type;
                     default: ABORT("unsupported matrix column count (%d)", columns);
                 }
             case 3:
                 switch (columns) {
-                    case 2: return kDMat2x3_Type;
-                    case 3: return kDMat3x3_Type;
-                    case 4: return kDMat4x3_Type;
+                    case 2: return *context.fDMat2x3_Type;
+                    case 3: return *context.fDMat3x3_Type;
+                    case 4: return *context.fDMat4x3_Type;
                     default: ABORT("unsupported matrix column count (%d)", columns);
                 }
             case 4:
                 switch (columns) {
-                    case 2: return kDMat2x4_Type;
-                    case 3: return kDMat3x4_Type;
-                    case 4: return kDMat4x4_Type;
+                    case 2: return *context.fDMat2x4_Type;
+                    case 3: return *context.fDMat3x4_Type;
+                    case 4: return *context.fDMat4x4_Type;
                     default: ABORT("unsupported matrix column count (%d)", columns);
                 }
             default: ABORT("unsupported row count (%d)", rows);
         }
-    } else if (*this == *kInt_Type) {
+    } else if (*this == *context.fInt_Type) {
         switch (rows) {
             case 1:
                 switch (columns) {
-                    case 2: return kIVec2_Type;
-                    case 3: return kIVec3_Type;
-                    case 4: return kIVec4_Type;
+                    case 2: return *context.fIVec2_Type;
+                    case 3: return *context.fIVec3_Type;
+                    case 4: return *context.fIVec4_Type;
                     default: ABORT("unsupported vector column count (%d)", columns);
                 }
             default: ABORT("unsupported row count (%d)", rows);
         }
-    } else if (*this == *kUInt_Type) {
+    } else if (*this == *context.fUInt_Type) {
         switch (rows) {
             case 1:
                 switch (columns) {
-                    case 2: return kUVec2_Type;
-                    case 3: return kUVec3_Type;
-                    case 4: return kUVec4_Type;
+                    case 2: return *context.fUVec2_Type;
+                    case 3: return *context.fUVec3_Type;
+                    case 4: return *context.fUVec4_Type;
                     default: ABORT("unsupported vector column count (%d)", columns);
                 }
             default: ABORT("unsupported row count (%d)", rows);
@@ -131,128 +132,4 @@
     ABORT("unsupported scalar_to_compound type %s", this->description().c_str());
 }
 
-const std::shared_ptr<Type> kVoid_Type(new Type("void"));
-
-const std::shared_ptr<Type> kDouble_Type(new Type("double", true));
-const std::shared_ptr<Type> kDVec2_Type(new Type("dvec2", kDouble_Type, 2));
-const std::shared_ptr<Type> kDVec3_Type(new Type("dvec3", kDouble_Type, 3));
-const std::shared_ptr<Type> kDVec4_Type(new Type("dvec4", kDouble_Type, 4));
-
-const std::shared_ptr<Type> kFloat_Type(new Type("float", true, { kDouble_Type }));
-const std::shared_ptr<Type> kVec2_Type(new Type("vec2", kFloat_Type, 2));
-const std::shared_ptr<Type> kVec3_Type(new Type("vec3", kFloat_Type, 3));
-const std::shared_ptr<Type> kVec4_Type(new Type("vec4", kFloat_Type, 4));
-
-const std::shared_ptr<Type> kUInt_Type(new Type("uint", true, { kFloat_Type, kDouble_Type }));
-const std::shared_ptr<Type> kUVec2_Type(new Type("uvec2", kUInt_Type, 2));
-const std::shared_ptr<Type> kUVec3_Type(new Type("uvec3", kUInt_Type, 3));
-const std::shared_ptr<Type> kUVec4_Type(new Type("uvec4", kUInt_Type, 4));
-
-const std::shared_ptr<Type> kInt_Type(new Type("int", true, { kUInt_Type, kFloat_Type, 
-                                                              kDouble_Type }));
-const std::shared_ptr<Type> kIVec2_Type(new Type("ivec2", kInt_Type, 2));
-const std::shared_ptr<Type> kIVec3_Type(new Type("ivec3", kInt_Type, 3));
-const std::shared_ptr<Type> kIVec4_Type(new Type("ivec4", kInt_Type, 4));
-
-const std::shared_ptr<Type> kBool_Type(new Type("bool", false));
-const std::shared_ptr<Type> kBVec2_Type(new Type("bvec2", kBool_Type, 2));
-const std::shared_ptr<Type> kBVec3_Type(new Type("bvec3", kBool_Type, 3));
-const std::shared_ptr<Type> kBVec4_Type(new Type("bvec4", kBool_Type, 4));
-
-const std::shared_ptr<Type> kMat2x2_Type(new Type("mat2",   kFloat_Type, 2, 2));
-const std::shared_ptr<Type> kMat2x3_Type(new Type("mat2x3", kFloat_Type, 2, 3));
-const std::shared_ptr<Type> kMat2x4_Type(new Type("mat2x4", kFloat_Type, 2, 4));
-const std::shared_ptr<Type> kMat3x2_Type(new Type("mat3x2", kFloat_Type, 3, 2));
-const std::shared_ptr<Type> kMat3x3_Type(new Type("mat3",   kFloat_Type, 3, 3));
-const std::shared_ptr<Type> kMat3x4_Type(new Type("mat3x4", kFloat_Type, 3, 4));
-const std::shared_ptr<Type> kMat4x2_Type(new Type("mat4x2", kFloat_Type, 4, 2));
-const std::shared_ptr<Type> kMat4x3_Type(new Type("mat4x3", kFloat_Type, 4, 3));
-const std::shared_ptr<Type> kMat4x4_Type(new Type("mat4",   kFloat_Type, 4, 4));
-
-const std::shared_ptr<Type> kDMat2x2_Type(new Type("dmat2",   kFloat_Type, 2, 2));
-const std::shared_ptr<Type> kDMat2x3_Type(new Type("dmat2x3", kFloat_Type, 2, 3));
-const std::shared_ptr<Type> kDMat2x4_Type(new Type("dmat2x4", kFloat_Type, 2, 4));
-const std::shared_ptr<Type> kDMat3x2_Type(new Type("dmat3x2", kFloat_Type, 3, 2));
-const std::shared_ptr<Type> kDMat3x3_Type(new Type("dmat3",   kFloat_Type, 3, 3));
-const std::shared_ptr<Type> kDMat3x4_Type(new Type("dmat3x4", kFloat_Type, 3, 4));
-const std::shared_ptr<Type> kDMat4x2_Type(new Type("dmat4x2", kFloat_Type, 4, 2));
-const std::shared_ptr<Type> kDMat4x3_Type(new Type("dmat4x3", kFloat_Type, 4, 3));
-const std::shared_ptr<Type> kDMat4x4_Type(new Type("dmat4",   kFloat_Type, 4, 4));
-
-const std::shared_ptr<Type> kSampler1D_Type(new Type("sampler1D", SpvDim1D, false, false, false, true));
-const std::shared_ptr<Type> kSampler2D_Type(new Type("sampler2D", SpvDim2D, false, false, false, true));
-const std::shared_ptr<Type> kSampler3D_Type(new Type("sampler3D", SpvDim3D, false, false, false, true));
-const std::shared_ptr<Type> kSamplerCube_Type(new Type("samplerCube"));
-const std::shared_ptr<Type> kSampler2DRect_Type(new Type("sampler2DRect"));
-const std::shared_ptr<Type> kSampler1DArray_Type(new Type("sampler1DArray"));
-const std::shared_ptr<Type> kSampler2DArray_Type(new Type("sampler2DArray"));
-const std::shared_ptr<Type> kSamplerCubeArray_Type(new Type("samplerCubeArray"));
-const std::shared_ptr<Type> kSamplerBuffer_Type(new Type("samplerBuffer"));
-const std::shared_ptr<Type> kSampler2DMS_Type(new Type("sampler2DMS"));
-const std::shared_ptr<Type> kSampler2DMSArray_Type(new Type("sampler2DMSArray"));
-const std::shared_ptr<Type> kSampler1DShadow_Type(new Type("sampler1DShadow"));
-const std::shared_ptr<Type> kSampler2DShadow_Type(new Type("sampler2DShadow"));
-const std::shared_ptr<Type> kSamplerCubeShadow_Type(new Type("samplerCubeShadow"));
-const std::shared_ptr<Type> kSampler2DRectShadow_Type(new Type("sampler2DRectShadow"));
-const std::shared_ptr<Type> kSampler1DArrayShadow_Type(new Type("sampler1DArrayShadow"));
-const std::shared_ptr<Type> kSampler2DArrayShadow_Type(new Type("sampler2DArrayShadow"));
-const std::shared_ptr<Type> kSamplerCubeArrayShadow_Type(new Type("samplerCubeArrayShadow"));
-
-static std::vector<std::shared_ptr<Type>> type(std::shared_ptr<Type> t) {
-    return { t, t, t, t };   
-}
-
-// FIXME figure out what we're supposed to do with the gsampler et al. types
-const std::shared_ptr<Type> kGSampler1D_Type(new Type("$gsampler1D", type(kSampler1D_Type)));
-const std::shared_ptr<Type> kGSampler2D_Type(new Type("$gsampler2D", type(kSampler2D_Type)));
-const std::shared_ptr<Type> kGSampler3D_Type(new Type("$gsampler3D", type(kSampler3D_Type)));
-const std::shared_ptr<Type> kGSamplerCube_Type(new Type("$gsamplerCube", type(kSamplerCube_Type)));
-const std::shared_ptr<Type> kGSampler2DRect_Type(new Type("$gsampler2DRect", 
-                                                 type(kSampler2DRect_Type)));
-const std::shared_ptr<Type> kGSampler1DArray_Type(new Type("$gsampler1DArray", 
-                                                  type(kSampler1DArray_Type)));
-const std::shared_ptr<Type> kGSampler2DArray_Type(new Type("$gsampler2DArray", 
-                                                  type(kSampler2DArray_Type)));
-const std::shared_ptr<Type> kGSamplerCubeArray_Type(new Type("$gsamplerCubeArray", 
-                                                    type(kSamplerCubeArray_Type)));
-const std::shared_ptr<Type> kGSamplerBuffer_Type(new Type("$gsamplerBuffer", 
-                                                 type(kSamplerBuffer_Type)));
-const std::shared_ptr<Type> kGSampler2DMS_Type(new Type("$gsampler2DMS", 
-                                               type(kSampler2DMS_Type)));
-const std::shared_ptr<Type> kGSampler2DMSArray_Type(new Type("$gsampler2DMSArray", 
-                                                    type(kSampler2DMSArray_Type)));
-const std::shared_ptr<Type> kGSampler2DArrayShadow_Type(new Type("$gsampler2DArrayShadow", 
-                                                        type(kSampler2DArrayShadow_Type)));
-const std::shared_ptr<Type> kGSamplerCubeArrayShadow_Type(new Type("$gsamplerCubeArrayShadow", 
-                                                          type(kSamplerCubeArrayShadow_Type)));
-
-const std::shared_ptr<Type> kGenType_Type(new Type("$genType", { kFloat_Type, kVec2_Type,
-                                                                 kVec3_Type, kVec4_Type }));
-const std::shared_ptr<Type> kGenDType_Type(new Type("$genDType", { kDouble_Type, kDVec2_Type,
-                                                                   kDVec3_Type, kDVec4_Type }));
-const std::shared_ptr<Type> kGenIType_Type(new Type("$genIType", { kInt_Type, kIVec2_Type,
-                                                                   kIVec3_Type, kIVec4_Type }));
-const std::shared_ptr<Type> kGenUType_Type(new Type("$genUType", { kUInt_Type, kUVec2_Type,
-                                                                   kUVec3_Type, kUVec4_Type }));
-const std::shared_ptr<Type> kGenBType_Type(new Type("$genBType", { kBool_Type, kBVec2_Type,
-                                                                   kBVec3_Type, kBVec4_Type }));
-
-const std::shared_ptr<Type> kMat_Type(new Type("$mat"));
-
-const std::shared_ptr<Type> kVec_Type(new Type("$vec", { kVec2_Type, kVec2_Type, kVec3_Type, 
-                                                         kVec4_Type }));
-
-const std::shared_ptr<Type> kGVec_Type(new Type("$gvec"));
-const std::shared_ptr<Type> kGVec2_Type(new Type("$gvec2"));
-const std::shared_ptr<Type> kGVec3_Type(new Type("$gvec3"));
-const std::shared_ptr<Type> kGVec4_Type(new Type("$gvec4", type(kVec4_Type)));
-const std::shared_ptr<Type> kDVec_Type(new Type("$dvec"));
-const std::shared_ptr<Type> kIVec_Type(new Type("$ivec"));
-const std::shared_ptr<Type> kUVec_Type(new Type("$uvec"));
-
-const std::shared_ptr<Type> kBVec_Type(new Type("$bvec", { kBVec2_Type, kBVec2_Type,
-                                                           kBVec3_Type, kBVec4_Type }));
-
-const std::shared_ptr<Type> kInvalid_Type(new Type("<INVALID>"));
-
 } // namespace
diff --git a/src/sksl/ir/SkSLType.h b/src/sksl/ir/SkSLType.h
index e17bae6..a929c8e 100644
--- a/src/sksl/ir/SkSLType.h
+++ b/src/sksl/ir/SkSLType.h
@@ -18,24 +18,26 @@
 
 namespace SkSL {
 
+class Context;
+
 /**
  * Represents a type, such as int or vec4.
  */
 class Type : public Symbol {
 public:
     struct Field {
-        Field(Modifiers modifiers, std::string name, std::shared_ptr<Type> type)
+        Field(Modifiers modifiers, std::string name, const Type& type)
         : fModifiers(modifiers)
         , fName(std::move(name))
         , fType(std::move(type)) {}
 
-        const std::string description() {
-            return fType->description() + " " + fName + ";";
+        const std::string description() const {
+            return fType.description() + " " + fName + ";";
         }
 
         const Modifiers fModifiers;
         const std::string fName;
-        const std::shared_ptr<Type> fType;
+        const Type& fType;
     };
 
     enum Kind {
@@ -56,7 +58,7 @@
     , fTypeKind(kOther_Kind) {}
 
     // Create a generic type which maps to the listed types.
-    Type(std::string name, std::vector<std::shared_ptr<Type>> types)
+    Type(std::string name, std::vector<const Type*> types)
     : INHERITED(Position(), kType_Kind, std::move(name))
     , fTypeKind(kGeneric_Kind)
     , fCoercibleTypes(std::move(types)) {
@@ -78,7 +80,7 @@
     , fRows(1) {}
 
     // Create a scalar type which can be coerced to the listed types.
-    Type(std::string name, bool isNumber, std::vector<std::shared_ptr<Type>> coercibleTypes)
+    Type(std::string name, bool isNumber, std::vector<const Type*> coercibleTypes)
     : INHERITED(Position(), kType_Kind, std::move(name))
     , fTypeKind(kScalar_Kind)
     , fIsNumber(isNumber)
@@ -87,23 +89,23 @@
     , fRows(1) {}
 
     // Create a vector type.
-    Type(std::string name, std::shared_ptr<Type> componentType, int columns)
+    Type(std::string name, const Type& componentType, int columns)
     : Type(name, kVector_Kind, componentType, columns) {}
 
     // Create a vector or array type.
-    Type(std::string name, Kind kind, std::shared_ptr<Type> componentType, int columns)
+    Type(std::string name, Kind kind, const Type& componentType, int columns)
     : INHERITED(Position(), kType_Kind, std::move(name))
     , fTypeKind(kind)
-    , fComponentType(std::move(componentType))
+    , fComponentType(&componentType)
     , fColumns(columns)
     , fRows(1)    
     , fDimensions(SpvDim1D) {}
 
     // Create a matrix type.
-    Type(std::string name, std::shared_ptr<Type> componentType, int columns, int rows)
+    Type(std::string name, const Type& componentType, int columns, int rows)
     : INHERITED(Position(), kType_Kind, std::move(name))
     , fTypeKind(kMatrix_Kind)
-    , fComponentType(std::move(componentType))
+    , fComponentType(&componentType)
     , fColumns(columns)
     , fRows(rows)    
     , fDimensions(SpvDim1D) {}
@@ -153,7 +155,7 @@
      * Returns true if an instance of this type can be freely coerced (implicitly converted) to 
      * another type.
      */
-    bool canCoerceTo(std::shared_ptr<Type> other) const {
+    bool canCoerceTo(const Type& other) const {
         int cost;
         return determineCoercionCost(other, &cost);
     }
@@ -164,15 +166,15 @@
      * costs. Returns true if a conversion is possible, false otherwise. The value of the out 
      * parameter is undefined if false is returned.
      */
-    bool determineCoercionCost(std::shared_ptr<Type> other, int* outCost) const;
+    bool determineCoercionCost(const Type& other, int* outCost) const;
 
     /**
      * For matrices and vectors, returns the type of individual cells (e.g. mat2 has a component
      * type of kFloat_Type). For all other types, causes an assertion failure.
      */
-    std::shared_ptr<Type> componentType() const {
+    const Type& componentType() const {
         ASSERT(fComponentType);
-        return fComponentType;
+        return *fComponentType;
     }
 
     /**
@@ -195,7 +197,7 @@
         return fRows;
     }
 
-    std::vector<Field> fields() const {
+    const std::vector<Field>& fields() const {
         ASSERT(fTypeKind == kStruct_Kind);
         return fFields;
     }
@@ -204,7 +206,7 @@
      * For generic types, returns the types that this generic type can substitute for. For other
      * types, returns a list of other types that this type can be coerced into.
      */
-    std::vector<std::shared_ptr<Type>> coercibleTypes() const {
+    const std::vector<const Type*>& coercibleTypes() const {
         ASSERT(fCoercibleTypes.size() > 0);
         return fCoercibleTypes;
     }
@@ -257,7 +259,7 @@
             case kStruct_Kind: {
                 size_t result = 16;
                 for (size_t i = 0; i < fFields.size(); i++) {
-                    size_t alignment = fFields[i].fType->alignment();
+                    size_t alignment = fFields[i].fType.alignment();
                     if (alignment > result) {
                         result = alignment;
                     }
@@ -300,13 +302,13 @@
             case kStruct_Kind: {
                 size_t total = 0;
                 for (size_t i = 0; i < fFields.size(); i++) {
-                    size_t alignment = fFields[i].fType->alignment();
+                    size_t alignment = fFields[i].fType.alignment();
                     if (total % alignment != 0) {
                         total += alignment - total % alignment;
                     }
                     ASSERT(false);
                     ASSERT(total % alignment == 0);
-                    total += fFields[i].fType->size();
+                    total += fFields[i].fType.size();
                 }
                 return total;
             }
@@ -319,15 +321,15 @@
      * Returns the corresponding vector or matrix type with the specified number of columns and 
      * rows.
      */
-    std::shared_ptr<Type> toCompound(int columns, int rows);
+    const Type& toCompound(const Context& context, int columns, int rows) const;
 
 private:
     typedef Symbol INHERITED;
 
     const Kind fTypeKind;
     const bool fIsNumber = false;
-    const std::shared_ptr<Type> fComponentType = nullptr;
-    const std::vector<std::shared_ptr<Type>> fCoercibleTypes = { };
+    const Type* fComponentType = nullptr;
+    const std::vector<const Type*> fCoercibleTypes = { };
     const int fColumns = -1;
     const int fRows = -1;
     const std::vector<Field> fFields = { };
@@ -338,101 +340,6 @@
     const bool fIsSampled = false;
 };
 
-extern const std::shared_ptr<Type> kVoid_Type;
-
-extern const std::shared_ptr<Type> kFloat_Type;
-extern const std::shared_ptr<Type> kVec2_Type;
-extern const std::shared_ptr<Type> kVec3_Type;
-extern const std::shared_ptr<Type> kVec4_Type;
-extern const std::shared_ptr<Type> kDouble_Type;
-extern const std::shared_ptr<Type> kDVec2_Type;
-extern const std::shared_ptr<Type> kDVec3_Type;
-extern const std::shared_ptr<Type> kDVec4_Type;
-extern const std::shared_ptr<Type> kInt_Type;
-extern const std::shared_ptr<Type> kIVec2_Type;
-extern const std::shared_ptr<Type> kIVec3_Type;
-extern const std::shared_ptr<Type> kIVec4_Type;
-extern const std::shared_ptr<Type> kUInt_Type;
-extern const std::shared_ptr<Type> kUVec2_Type;
-extern const std::shared_ptr<Type> kUVec3_Type;
-extern const std::shared_ptr<Type> kUVec4_Type;
-extern const std::shared_ptr<Type> kBool_Type;
-extern const std::shared_ptr<Type> kBVec2_Type;
-extern const std::shared_ptr<Type> kBVec3_Type;
-extern const std::shared_ptr<Type> kBVec4_Type;
-
-extern const std::shared_ptr<Type> kMat2x2_Type;
-extern const std::shared_ptr<Type> kMat2x3_Type;
-extern const std::shared_ptr<Type> kMat2x4_Type;
-extern const std::shared_ptr<Type> kMat3x2_Type;
-extern const std::shared_ptr<Type> kMat3x3_Type;
-extern const std::shared_ptr<Type> kMat3x4_Type;
-extern const std::shared_ptr<Type> kMat4x2_Type;
-extern const std::shared_ptr<Type> kMat4x3_Type;
-extern const std::shared_ptr<Type> kMat4x4_Type;
-
-extern const std::shared_ptr<Type> kDMat2x2_Type;
-extern const std::shared_ptr<Type> kDMat2x3_Type;
-extern const std::shared_ptr<Type> kDMat2x4_Type;
-extern const std::shared_ptr<Type> kDMat3x2_Type;
-extern const std::shared_ptr<Type> kDMat3x3_Type;
-extern const std::shared_ptr<Type> kDMat3x4_Type;
-extern const std::shared_ptr<Type> kDMat4x2_Type;
-extern const std::shared_ptr<Type> kDMat4x3_Type;
-extern const std::shared_ptr<Type> kDMat4x4_Type;
-
-extern const std::shared_ptr<Type> kSampler1D_Type;
-extern const std::shared_ptr<Type> kSampler2D_Type;
-extern const std::shared_ptr<Type> kSampler3D_Type;
-extern const std::shared_ptr<Type> kSamplerCube_Type;
-extern const std::shared_ptr<Type> kSampler2DRect_Type;
-extern const std::shared_ptr<Type> kSampler1DArray_Type;
-extern const std::shared_ptr<Type> kSampler2DArray_Type;
-extern const std::shared_ptr<Type> kSamplerCubeArray_Type;
-extern const std::shared_ptr<Type> kSamplerBuffer_Type;
-extern const std::shared_ptr<Type> kSampler2DMS_Type;
-extern const std::shared_ptr<Type> kSampler2DMSArray_Type;
-
-extern const std::shared_ptr<Type> kGSampler1D_Type;
-extern const std::shared_ptr<Type> kGSampler2D_Type;
-extern const std::shared_ptr<Type> kGSampler3D_Type;
-extern const std::shared_ptr<Type> kGSamplerCube_Type;
-extern const std::shared_ptr<Type> kGSampler2DRect_Type;
-extern const std::shared_ptr<Type> kGSampler1DArray_Type;
-extern const std::shared_ptr<Type> kGSampler2DArray_Type;
-extern const std::shared_ptr<Type> kGSamplerCubeArray_Type;
-extern const std::shared_ptr<Type> kGSamplerBuffer_Type;
-extern const std::shared_ptr<Type> kGSampler2DMS_Type;
-extern const std::shared_ptr<Type> kGSampler2DMSArray_Type;
-
-extern const std::shared_ptr<Type> kSampler1DShadow_Type;
-extern const std::shared_ptr<Type> kSampler2DShadow_Type;
-extern const std::shared_ptr<Type> kSamplerCubeShadow_Type;
-extern const std::shared_ptr<Type> kSampler2DRectShadow_Type;
-extern const std::shared_ptr<Type> kSampler1DArrayShadow_Type;
-extern const std::shared_ptr<Type> kSampler2DArrayShadow_Type;
-extern const std::shared_ptr<Type> kSamplerCubeArrayShadow_Type;
-extern const std::shared_ptr<Type> kGSampler2DArrayShadow_Type;
-extern const std::shared_ptr<Type> kGSamplerCubeArrayShadow_Type;
-
-extern const std::shared_ptr<Type> kGenType_Type;
-extern const std::shared_ptr<Type> kGenDType_Type;
-extern const std::shared_ptr<Type> kGenIType_Type;
-extern const std::shared_ptr<Type> kGenUType_Type;
-extern const std::shared_ptr<Type> kGenBType_Type;
-extern const std::shared_ptr<Type> kMat_Type;
-extern const std::shared_ptr<Type> kVec_Type;
-extern const std::shared_ptr<Type> kGVec_Type;
-extern const std::shared_ptr<Type> kGVec2_Type;
-extern const std::shared_ptr<Type> kGVec3_Type;
-extern const std::shared_ptr<Type> kGVec4_Type;
-extern const std::shared_ptr<Type> kDVec_Type;
-extern const std::shared_ptr<Type> kIVec_Type;
-extern const std::shared_ptr<Type> kUVec_Type;
-extern const std::shared_ptr<Type> kBVec_Type;
-
-extern const std::shared_ptr<Type> kInvalid_Type;
-
 } // namespace
 
 #endif
diff --git a/src/sksl/ir/SkSLTypeReference.h b/src/sksl/ir/SkSLTypeReference.h
index 5f4990f..76923aa 100644
--- a/src/sksl/ir/SkSLTypeReference.h
+++ b/src/sksl/ir/SkSLTypeReference.h
@@ -8,6 +8,7 @@
 #ifndef SKSL_TYPEREFERENCE
 #define SKSL_TYPEREFERENCE
 
+#include "SkSLContext.h"
 #include "SkSLExpression.h"
 
 namespace SkSL {
@@ -17,16 +18,16 @@
  * always eventually replaced by Constructors in valid programs.
  */
 struct TypeReference : public Expression {
-    TypeReference(Position position, std::shared_ptr<Type> type)
-    : INHERITED(position, kTypeReference_Kind, kInvalid_Type)
-    , fValue(std::move(type)) {}
+    TypeReference(const Context& context, Position position, const Type& type)
+    : INHERITED(position, kTypeReference_Kind, *context.fInvalid_Type)
+    , fValue(type) {}
 
     std::string description() const override {
     	ASSERT(false);
     	return "<type>";
     }
 
-    const std::shared_ptr<Type> fValue;
+    const Type& fValue;
 
     typedef Expression INHERITED;
 };
diff --git a/src/sksl/ir/SkSLUnresolvedFunction.h b/src/sksl/ir/SkSLUnresolvedFunction.h
index a6cee0d..3a368ad 100644
--- a/src/sksl/ir/SkSLUnresolvedFunction.h
+++ b/src/sksl/ir/SkSLUnresolvedFunction.h
@@ -16,19 +16,21 @@
  * A symbol representing multiple functions with the same name.
  */
 struct UnresolvedFunction : public Symbol {
-    UnresolvedFunction(std::vector<std::shared_ptr<FunctionDeclaration>> funcs)
+    UnresolvedFunction(std::vector<const FunctionDeclaration*> funcs)
     : INHERITED(Position(), kUnresolvedFunction_Kind, funcs[0]->fName)
     , fFunctions(std::move(funcs)) {
+#ifdef DEBUG
     	for (auto func : funcs) {
     		ASSERT(func->fName == fName);
     	}
+#endif
     }
 
     virtual std::string description() const override {
         return fName;
     }
 
-    const std::vector<std::shared_ptr<FunctionDeclaration>> fFunctions;
+    const std::vector<const FunctionDeclaration*> fFunctions;
 
     typedef Symbol INHERITED;
 };
diff --git a/src/sksl/ir/SkSLVarDeclaration.h b/src/sksl/ir/SkSLVarDeclaration.h
index 400f430..b234231 100644
--- a/src/sksl/ir/SkSLVarDeclaration.h
+++ b/src/sksl/ir/SkSLVarDeclaration.h
@@ -20,7 +20,7 @@
  * names ['x', 'y', 'z'], sizes of [[], [], [4, 2]], and values of [null, 1, null].
  */
 struct VarDeclaration : public ProgramElement {
-    VarDeclaration(Position position, std::vector<std::shared_ptr<Variable>> vars,
+    VarDeclaration(Position position, std::vector<const Variable*> vars,
                    std::vector<std::vector<std::unique_ptr<Expression>>> sizes,
                    std::vector<std::unique_ptr<Expression>> values)
     : INHERITED(position, kVar_Kind) 
@@ -30,9 +30,9 @@
 
     std::string description() const override {
         std::string result = fVars[0]->fModifiers.description();
-        std::shared_ptr<Type> baseType = fVars[0]->fType;
+        const Type* baseType = &fVars[0]->fType;
         while (baseType->kind() == Type::kArray_Kind) {
-            baseType = baseType->componentType();
+            baseType = &baseType->componentType();
         }
         result += baseType->description();
         std::string separator = " ";
@@ -55,7 +55,7 @@
         return result;
     }
 
-    const std::vector<std::shared_ptr<Variable>> fVars;
+    const std::vector<const Variable*> fVars;
     const std::vector<std::vector<std::unique_ptr<Expression>>> fSizes;
     const std::vector<std::unique_ptr<Expression>> fValues;
 
diff --git a/src/sksl/ir/SkSLVariable.h b/src/sksl/ir/SkSLVariable.h
index d4ea2c4..39af309 100644
--- a/src/sksl/ir/SkSLVariable.h
+++ b/src/sksl/ir/SkSLVariable.h
@@ -27,7 +27,7 @@
         kParameter_Storage
     };
 
-    Variable(Position position, Modifiers modifiers, std::string name, std::shared_ptr<Type> type,
+    Variable(Position position, Modifiers modifiers, std::string name, const Type& type,
              Storage storage)
     : INHERITED(position, kVariable_Kind, std::move(name))
     , fModifiers(modifiers)
@@ -37,12 +37,11 @@
     , fIsWrittenTo(false) {}
 
     virtual std::string description() const override {
-        return fModifiers.description() + fType->fName + " " + fName;
+        return fModifiers.description() + fType.fName + " " + fName;
     }
 
     const Modifiers fModifiers;
-    const std::string fValue;
-    const std::shared_ptr<Type> fType;
+    const Type& fType;
     const Storage fStorage;
 
     mutable bool fIsReadFrom;
@@ -53,14 +52,4 @@
 
 } // namespace SkSL
 
-namespace std {
-    template <>
-        struct hash<SkSL::Variable> {
-        public :
-        size_t operator()(const SkSL::Variable &var) const{
-            return hash<std::string>()(var.fName) ^ hash<std::string>()(var.fType->description());
-        }
-    };
-} // namespace std
-
 #endif
diff --git a/src/sksl/ir/SkSLVariableReference.h b/src/sksl/ir/SkSLVariableReference.h
index 8499511..b443da1 100644
--- a/src/sksl/ir/SkSLVariableReference.h
+++ b/src/sksl/ir/SkSLVariableReference.h
@@ -20,15 +20,15 @@
  * there is only one Variable 'x', but two VariableReferences to it.
  */
 struct VariableReference : public Expression {
-    VariableReference(Position position, std::shared_ptr<Variable> variable)
-    : INHERITED(position, kVariableReference_Kind, variable->fType)
-    , fVariable(std::move(variable)) {}
+    VariableReference(Position position, const Variable& variable)
+    : INHERITED(position, kVariableReference_Kind, variable.fType)
+    , fVariable(variable) {}
 
     std::string description() const override {
-        return fVariable->fName;
+        return fVariable.fName;
     }
 
-    const std::shared_ptr<Variable> fVariable;
+    const Variable& fVariable;
 
     typedef Expression INHERITED;
 };