ByteCode: Separate uniforms and globals

This requires new instructions, but means that uniforms don't need
to be copied/expanded into the globals array. It also removes any
limit on the number of uniforms (other than instruction encoding),
and simplifies the memory layout (no need for slot tracking).

To help with this, added a Location struct that encapsulates the
information returned by the two variants of getLocation.

Change-Id: I961be74ea5fdf933da6c7ad284be9fc345cfd909
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/245358
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLByteCode.cpp b/src/sksl/SkSLByteCode.cpp
index 1f0841f..b03e2ee 100644
--- a/src/sksl/SkSLByteCode.cpp
+++ b/src/sksl/SkSLByteCode.cpp
@@ -104,6 +104,10 @@
         case ByteCodeInstruction::kLoadGlobal2: printf("loadglobal2 %d", READ16() >> 8); break;
         case ByteCodeInstruction::kLoadGlobal3: printf("loadglobal3 %d", READ16() >> 8); break;
         case ByteCodeInstruction::kLoadGlobal4: printf("loadglobal4 %d", READ16() >> 8); break;
+        case ByteCodeInstruction::kLoadUniform: printf("loaduniform %d", READ16() >> 8); break;
+        case ByteCodeInstruction::kLoadUniform2: printf("loaduniform2 %d", READ16() >> 8); break;
+        case ByteCodeInstruction::kLoadUniform3: printf("loaduniform3 %d", READ16() >> 8); break;
+        case ByteCodeInstruction::kLoadUniform4: printf("loaduniform4 %d", READ16() >> 8); break;
         case ByteCodeInstruction::kLoadSwizzle: {
             int target = READ8();
             int count = READ8();
@@ -122,9 +126,20 @@
             }
             break;
         }
+        case ByteCodeInstruction::kLoadSwizzleUniform: {
+            int target = READ8();
+            int count = READ8();
+            printf("loadswizzleuniform %d %d", target, count);
+            for (int i = 0; i < count; ++i) {
+                printf(", %d", READ8());
+            }
+            break;
+        }
         case ByteCodeInstruction::kLoadExtended: printf("loadextended %d", READ8()); break;
         case ByteCodeInstruction::kLoadExtendedGlobal: printf("loadextendedglobal %d", READ8());
             break;
+        case ByteCodeInstruction::kLoadExtendedUniform: printf("loadextendeduniform %d", READ8());
+            break;
         case ByteCodeInstruction::kMatrixToMatrix: {
             int srcCols = READ8();
             int srcRows = READ8();
@@ -537,8 +552,8 @@
 }
 
 static bool InnerRun(const ByteCode* byteCode, const ByteCodeFunction* f, VValue* stack,
-                     float* outReturn[], VValue globals[], bool stripedOutput, int N,
-                     int baseIndex) {
+                     float* outReturn[], VValue globals[], const float uniforms[],
+                     bool stripedOutput, int N, int baseIndex) {
 #ifdef SKSLC_THREADED_CODE
     static const void* labels[] = {
         // If you aren't familiar with it, the &&label syntax is the GCC / Clang "labels as values"
@@ -580,10 +595,13 @@
         &&kInverse4x4,
         VECTOR_LABELS(kLoad),
         VECTOR_LABELS(kLoadGlobal),
+        VECTOR_LABELS(kLoadUniform),
         &&kLoadSwizzle,
         &&kLoadSwizzleGlobal,
+        &&kLoadSwizzleUniform,
         &&kLoadExtended,
         &&kLoadExtendedGlobal,
+        &&kLoadExtendedUniform,
         &&kMatrixToMatrix,
         &&kMatrixMultiply,
         VECTOR_MATRIX_LABELS(kNegateF),
@@ -671,10 +689,13 @@
     CHECK_LABEL(kInverse4x4);
     CHECK_VECTOR_LABELS(kLoad);
     CHECK_VECTOR_LABELS(kLoadGlobal);
+    CHECK_VECTOR_LABELS(kLoadUniform);
     CHECK_LABEL(kLoadSwizzle);
     CHECK_LABEL(kLoadSwizzleGlobal);
+    CHECK_LABEL(kLoadSwizzleUniform);
     CHECK_LABEL(kLoadExtended);
     CHECK_LABEL(kLoadExtendedGlobal);
+    CHECK_LABEL(kLoadExtendedUniform);
     CHECK_LABEL(kMatrixToMatrix);
     CHECK_LABEL(kMatrixMultiply);
     CHECK_VECTOR_MATRIX_LABELS(kNegateF);
@@ -914,6 +935,14 @@
                         ip += 2;
                         NEXT();
 
+    LABEL(kLoadUniform4) sp[4].fFloat = uniforms[ip[1] + 3];
+    LABEL(kLoadUniform3) sp[3].fFloat = uniforms[ip[1] + 2];
+    LABEL(kLoadUniform2) sp[2].fFloat = uniforms[ip[1] + 1];
+    LABEL(kLoadUniform)  sp[1].fFloat = uniforms[ip[1] + 0];
+                        sp += ip[0];
+                        ip += 2;
+                        NEXT();
+
     LABEL(kLoadExtended) {
         int count = READ8();
         I32 src = POP().fSigned;
@@ -944,6 +973,21 @@
         NEXT();
     }
 
+    LABEL(kLoadExtendedUniform) {
+        int count = READ8();
+        I32 src = POP().fSigned;
+        I32 m = mask();
+        for (int i = 0; i < count; ++i) {
+            for (int j = 0; j < VecWidth; ++j) {
+                if (m[j]) {
+                    sp[i + 1].fFloat[j] = uniforms[src[j] + i];
+                }
+            }
+        }
+        sp += count;
+        NEXT();
+    }
+
     LABEL(kLoadSwizzle) {
         int src = READ8();
         int count = READ8();
@@ -964,6 +1008,16 @@
         NEXT();
     }
 
+    LABEL(kLoadSwizzleUniform) {
+        int src = READ8();
+        int count = READ8();
+        for (int i = 0; i < count; ++i) {
+            PUSH(F32(uniforms[src + *(ip + i)]));
+        }
+        ip += count;
+        NEXT();
+    }
+
     LABEL(kMatrixToMatrix) {
         int srcCols = READ8();
         int srcRows = READ8();
@@ -1447,10 +1501,15 @@
             case ByteCodeInstruction::kLoadGlobal:
             case ByteCodeInstruction::kLoadGlobal2:
             case ByteCodeInstruction::kLoadGlobal3:
-            case ByteCodeInstruction::kLoadGlobal4: READ16(); break;
+            case ByteCodeInstruction::kLoadGlobal4:
+            case ByteCodeInstruction::kLoadUniform:
+            case ByteCodeInstruction::kLoadUniform2:
+            case ByteCodeInstruction::kLoadUniform3:
+            case ByteCodeInstruction::kLoadUniform4: READ16(); break;
 
             case ByteCodeInstruction::kLoadSwizzle:
-            case ByteCodeInstruction::kLoadSwizzleGlobal: {
+            case ByteCodeInstruction::kLoadSwizzleGlobal:
+            case ByteCodeInstruction::kLoadSwizzleUniform: {
                 READ8();
                 int count = READ8();
                 ip += count;
@@ -1459,6 +1518,7 @@
 
             case ByteCodeInstruction::kLoadExtended:
             case ByteCodeInstruction::kLoadExtendedGlobal:
+            case ByteCodeInstruction::kLoadExtendedUniform:
                 READ8();
                 break;
 
@@ -1577,17 +1637,14 @@
 
     if (argCount != f->fParameterCount ||
         returnCount != f->fReturnCount ||
-        uniformCount != (int)fUniformSlots.size()) {
+        uniformCount != fUniformSlotCount) {
         return false;
     }
 
     Interpreter::VValue globals[32];
-    if (fGlobalCount > (int)SK_ARRAY_COUNT(globals)) {
+    if (fGlobalSlotCount > (int)SK_ARRAY_COUNT(globals)) {
         return false;
     }
-    for (uint8_t slot : fUniformSlots) {
-        globals[slot].fFloat = *uniforms++;
-    }
 
     // Transpose args into stack
     {
@@ -1601,7 +1658,7 @@
 
     bool stripedOutput = false;
     float** outArray = outReturn ? &outReturn : nullptr;
-    if (!Interpreter::InnerRun(this, f, stack, outArray, globals, stripedOutput, 1, 0)) {
+    if (!Interpreter::InnerRun(this, f, stack, outArray, globals, uniforms, stripedOutput, 1, 0)) {
         return false;
     }
 
@@ -1642,17 +1699,14 @@
 
     if (argCount != f->fParameterCount ||
         returnCount != f->fReturnCount ||
-        uniformCount != (int)fUniformSlots.size()) {
+        uniformCount != fUniformSlotCount) {
         return false;
     }
 
     Interpreter::VValue globals[32];
-    if (fGlobalCount > (int)SK_ARRAY_COUNT(globals)) {
+    if (fGlobalSlotCount > (int)SK_ARRAY_COUNT(globals)) {
         return false;
     }
-    for (uint8_t slot : fUniformSlots) {
-        globals[slot].fFloat = *uniforms++;
-    }
 
     // innerRun just takes outArgs, so clear it if the count is zero
     if (returnCount == 0) {
@@ -1670,7 +1724,7 @@
         }
 
         bool stripedOutput = true;
-        if (!Interpreter::InnerRun(this, f, stack, outReturn, globals, stripedOutput, w,
+        if (!Interpreter::InnerRun(this, f, stack, outReturn, globals, uniforms, stripedOutput, w,
                                    baseIndex)) {
             return false;
         }
diff --git a/src/sksl/SkSLByteCode.h b/src/sksl/SkSLByteCode.h
index 51f80e7..bb7604f 100644
--- a/src/sksl/SkSLByteCode.h
+++ b/src/sksl/SkSLByteCode.h
@@ -82,13 +82,16 @@
     // local/global slot to load
     VECTOR(kLoad),
     VECTOR(kLoadGlobal),
+    VECTOR(kLoadUniform),
     // As kLoad/kLoadGlobal, then a count byte (1-4), and then one byte per swizzle component (0-3).
     kLoadSwizzle,
     kLoadSwizzleGlobal,
+    kLoadSwizzleUniform,
     // kLoadExtended* are fallback load ops when we lack a specialization. They are followed by a
     // count byte, and get the slot to load from the top of the stack.
     kLoadExtended,
     kLoadExtendedGlobal,
+    kLoadExtendedUniform,
     // Followed by four bytes: srcCols, srcRows, dstCols, dstRows. Consumes the src matrix from the
     // stack, and replaces it with the dst matrix. Per GLSL rules, there are no restrictions on
     // dimensions. Any overlapping values are copied, and any other values are filled in with the
@@ -261,8 +264,8 @@
     friend class ByteCodeGenerator;
     friend struct Interpreter;
 
-    int fGlobalCount = 0;
-    std::vector<uint8_t> fUniformSlots;
+    int fGlobalSlotCount = 0;
+    int fUniformSlotCount = 0;
     std::vector<std::unique_ptr<ByteCodeFunction>> fFunctions;
     std::vector<ExternalValue*> fExternalValues;
 };
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index acfbac5..6c8dab4 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -47,6 +47,10 @@
     }
 }
 
+static inline bool is_uniform(const SkSL::Variable& var) {
+    return var.fModifiers.fFlags & Modifiers::kUniform_Flag;
+}
+
 bool ByteCodeGenerator::generateCode() {
     for (const auto& e : fProgram) {
         switch (e.fKind) {
@@ -71,12 +75,10 @@
                     // final values of the 'in' variables, or not use 'in' variables (maybe you
                     // meant to use 'uniform' instead?).
 //                    SkASSERT(!(declVar->fModifiers.fFlags & Modifiers::kIn_Flag));
-                    if (declVar->fModifiers.fFlags & Modifiers::kUniform_Flag) {
-                        for (int i = SlotCount(declVar->fType); i > 0; --i) {
-                            fOutput->fUniformSlots.push_back(fOutput->fGlobalCount++);
-                        }
+                    if (is_uniform(*declVar)) {
+                        fOutput->fUniformSlotCount += SlotCount(declVar->fType);
                     } else {
-                        fOutput->fGlobalCount += SlotCount(declVar->fType);
+                        fOutput->fGlobalSlotCount += SlotCount(declVar->fType);
                     }
                 }
                 break;
@@ -273,6 +275,7 @@
         case ByteCodeInstruction::kDup:
         case ByteCodeInstruction::kLoad:
         case ByteCodeInstruction::kLoadGlobal:
+        case ByteCodeInstruction::kLoadUniform:
         case ByteCodeInstruction::kReadExternal:
         case ByteCodeInstruction::kPushImmediate:
             return 1;
@@ -280,29 +283,34 @@
         case ByteCodeInstruction::kDup2:
         case ByteCodeInstruction::kLoad2:
         case ByteCodeInstruction::kLoadGlobal2:
+        case ByteCodeInstruction::kLoadUniform2:
         case ByteCodeInstruction::kReadExternal2:
             return 2;
 
         case ByteCodeInstruction::kDup3:
         case ByteCodeInstruction::kLoad3:
         case ByteCodeInstruction::kLoadGlobal3:
+        case ByteCodeInstruction::kLoadUniform3:
         case ByteCodeInstruction::kReadExternal3:
             return 3;
 
         case ByteCodeInstruction::kDup4:
         case ByteCodeInstruction::kLoad4:
         case ByteCodeInstruction::kLoadGlobal4:
+        case ByteCodeInstruction::kLoadUniform4:
         case ByteCodeInstruction::kReadExternal4:
             return 4;
 
         case ByteCodeInstruction::kDupN:
         case ByteCodeInstruction::kLoadSwizzle:
         case ByteCodeInstruction::kLoadSwizzleGlobal:
+        case ByteCodeInstruction::kLoadSwizzleUniform:
             return count;
 
         // Pushes 'count' values, minus one for the 'address' that's consumed first
         case ByteCodeInstruction::kLoadExtended:
         case ByteCodeInstruction::kLoadExtendedGlobal:
+        case ByteCodeInstruction::kLoadExtendedUniform:
             return count - 1;
 
         // Ops that pop or store data to shrink the stack:
@@ -377,7 +385,7 @@
     }
 }
 
-int ByteCodeGenerator::getLocation(const Variable& var) {
+ByteCodeGenerator::Location ByteCodeGenerator::getLocation(const Variable& var) {
     // given that we seldom have more than a couple of variables, linear search is probably the most
     // efficient way to handle lookups
     switch (var.fStorage) {
@@ -385,7 +393,7 @@
             for (int i = fLocals.size() - 1; i >= 0; --i) {
                 if (fLocals[i] == &var) {
                     SkASSERT(fParameterCount + i <= 255);
-                    return fParameterCount + i;
+                    return { fParameterCount + i, Storage::kLocal };
                 }
             }
             int result = fParameterCount + fLocals.size();
@@ -394,22 +402,23 @@
                 fLocals.push_back(nullptr);
             }
             SkASSERT(result <= 255);
-            return result;
+            return { result, Storage::kLocal };
         }
         case Variable::kParameter_Storage: {
             int offset = 0;
             for (const auto& p : fFunction->fDeclaration.fParameters) {
                 if (p == &var) {
                     SkASSERT(offset <= 255);
-                    return offset;
+                    return { offset, Storage::kLocal };
                 }
                 offset += SlotCount(p->fType);
             }
             SkASSERT(false);
-            return 0;
+            return Location::MakeInvalid();
         }
         case Variable::kGlobal_Storage: {
             int offset = 0;
+            bool isUniform = is_uniform(var);
             for (const auto& e : fProgram) {
                 if (e.fKind == ProgramElement::kVar_Kind) {
                     VarDeclarations& decl = (VarDeclarations&) e;
@@ -418,42 +427,45 @@
                         if (declVar->fModifiers.fLayout.fBuiltin >= 0) {
                             continue;
                         }
+                        if (isUniform != is_uniform(*declVar)) {
+                            continue;
+                        }
                         if (declVar == &var) {
                             SkASSERT(offset <= 255);
-                            return offset;
+                            return  { offset, isUniform ? Storage::kUniform : Storage::kGlobal };
                         }
                         offset += SlotCount(declVar->fType);
                     }
                 }
             }
             SkASSERT(false);
-            return 0;
+            return Location::MakeInvalid();
         }
         default:
             SkASSERT(false);
-            return 0;
+            return Location::MakeInvalid();
     }
 }
 
-int ByteCodeGenerator::getLocation(const Expression& expr, Variable::Storage* storage) {
+ByteCodeGenerator::Location ByteCodeGenerator::getLocation(const Expression& expr) {
     switch (expr.fKind) {
         case Expression::kFieldAccess_Kind: {
             const FieldAccess& f = (const FieldAccess&)expr;
-            int baseAddr = this->getLocation(*f.fBase, storage);
+            Location baseLoc = this->getLocation(*f.fBase);
             int offset = 0;
             for (int i = 0; i < f.fFieldIndex; ++i) {
                 offset += SlotCount(*f.fBase->fType.fields()[i].fType);
             }
-            if (baseAddr < 0) {
+            if (baseLoc.isOnStack()) {
                 if (offset != 0) {
                     this->write(ByteCodeInstruction::kPushImmediate);
                     this->write32(offset);
                     this->write(ByteCodeInstruction::kAddI);
                     this->write8(1);
                 }
-                return -1;
+                return baseLoc;
             } else {
-                return baseAddr + offset;
+                return baseLoc + offset;
             }
         }
         case Expression::kIndex_Kind: {
@@ -466,7 +478,7 @@
                 int64_t index = i.fIndex->getConstantInt();
                 if (index < 0 || index >= length) {
                     fErrors.error(i.fIndex->fOffset, "Array index out of bounds.");
-                    return 0;
+                    return Location::MakeInvalid();
                 }
                 offset = index * stride;
             } else {
@@ -475,7 +487,7 @@
                     // but with lvalues we have to evaluate the indexer twice, so make it an error.
                     fErrors.error(i.fIndex->fOffset,
                             "Index expressions with side-effects not supported in byte code.");
-                    return 0;
+                    return Location::MakeInvalid();
                 }
                 this->writeExpression(*i.fIndex);
                 this->write(ByteCodeInstruction::kClampIndex);
@@ -487,24 +499,24 @@
                     this->write8(1);
                 }
             }
-            int baseAddr = this->getLocation(*i.fBase, storage);
+            Location baseLoc = this->getLocation(*i.fBase);
 
             // Are both components known statically?
-            if (baseAddr >= 0 && offset >= 0) {
-                return baseAddr + offset;
+            if (!baseLoc.isOnStack() && offset >= 0) {
+                return baseLoc + offset;
             }
 
             // At least one component is dynamic (and on the stack).
 
             // If the other component is zero, we're done
-            if (baseAddr == 0 || offset == 0) {
-                return -1;
+            if (baseLoc.fSlot == 0 || offset == 0) {
+                return baseLoc.makeOnStack();
             }
 
             // Push the non-dynamic component (if any) to the stack, then add the two
-            if (baseAddr >= 0) {
+            if (!baseLoc.isOnStack()) {
                 this->write(ByteCodeInstruction::kPushImmediate);
-                this->write32(baseAddr);
+                this->write32(baseLoc.fSlot);
             }
             if (offset >= 0) {
                 this->write(ByteCodeInstruction::kPushImmediate);
@@ -512,33 +524,32 @@
             }
             this->write(ByteCodeInstruction::kAddI);
             this->write8(1);
-            return -1;
+            return baseLoc.makeOnStack();
         }
         case Expression::kSwizzle_Kind: {
             const Swizzle& s = (const Swizzle&)expr;
             SkASSERT(swizzle_is_simple(s));
-            int baseAddr = this->getLocation(*s.fBase, storage);
+            Location baseLoc = this->getLocation(*s.fBase);
             int offset = s.fComponents[0];
-            if (baseAddr < 0) {
+            if (baseLoc.isOnStack()) {
                 if (offset != 0) {
                     this->write(ByteCodeInstruction::kPushImmediate);
                     this->write32(offset);
                     this->write(ByteCodeInstruction::kAddI);
                     this->write8(1);
                 }
-                return -1;
+                return baseLoc;
             } else {
-                return baseAddr + offset;
+                return baseLoc + offset;
             }
         }
         case Expression::kVariableReference_Kind: {
             const Variable& var = ((const VariableReference&)expr).fVariable;
-            *storage = var.fStorage;
             return this->getLocation(var);
         }
         default:
             SkASSERT(false);
-            return 0;
+            return Location::MakeInvalid();
     }
 }
 
@@ -913,25 +924,25 @@
 }
 
 void ByteCodeGenerator::writeVariableExpression(const Expression& expr) {
-    Variable::Storage storage = Variable::kLocal_Storage;
-    int location = this->getLocation(expr, &storage);
-    bool isGlobal = storage == Variable::kGlobal_Storage;
+    Location location = this->getLocation(expr);
     int count = SlotCount(expr.fType);
-    if (location < 0 || count > 4) {
-        if (location >= 0) {
+    if (location.isOnStack() || count > 4) {
+        if (!location.isOnStack()) {
             this->write(ByteCodeInstruction::kPushImmediate);
-            this->write32(location);
+            this->write32(location.fSlot);
         }
-        this->write(isGlobal ? ByteCodeInstruction::kLoadExtendedGlobal
-                             : ByteCodeInstruction::kLoadExtended,
+        this->write(location.selectLoad(ByteCodeInstruction::kLoadExtended,
+                                        ByteCodeInstruction::kLoadExtendedGlobal,
+                                        ByteCodeInstruction::kLoadExtendedUniform),
                     count);
         this->write8(count);
     } else {
-        this->write(vector_instruction(isGlobal ? ByteCodeInstruction::kLoadGlobal
-                                                : ByteCodeInstruction::kLoad,
+        this->write(vector_instruction(location.selectLoad(ByteCodeInstruction::kLoad,
+                                                           ByteCodeInstruction::kLoadGlobal,
+                                                           ByteCodeInstruction::kLoadUniform),
                                        count));
         this->write8(count);
-        this->write8(location);
+        this->write8(location.fSlot);
     }
 }
 
@@ -1194,12 +1205,12 @@
 
     switch (s.fBase->fKind) {
         case Expression::kVariableReference_Kind: {
-            const Variable& var = ((VariableReference&) *s.fBase).fVariable;
-            this->write(var.fStorage == Variable::kGlobal_Storage
-                            ? ByteCodeInstruction::kLoadSwizzleGlobal
-                            : ByteCodeInstruction::kLoadSwizzle,
+            Location location = this->getLocation(*s.fBase);
+            this->write(location.selectLoad(ByteCodeInstruction::kLoadSwizzle,
+                                            ByteCodeInstruction::kLoadSwizzleGlobal,
+                                            ByteCodeInstruction::kLoadSwizzleUniform),
                         s.fComponents.size());
-            this->write8(this->getLocation(var));
+            this->write8(location.fSlot);
             this->write8(s.fComponents.size());
             for (int c : s.fComponents) {
                 this->write8(c);
@@ -1341,18 +1352,16 @@
             fGenerator.write(vector_instruction(ByteCodeInstruction::kDup, count));
             fGenerator.write8(count);
         }
-        Variable::Storage storage = Variable::kLocal_Storage;
-        int location = fGenerator.getLocation(*fSwizzle.fBase, &storage);
-        bool isGlobal = storage == Variable::kGlobal_Storage;
-        if (location < 0) {
-            fGenerator.write(isGlobal ? ByteCodeInstruction::kStoreSwizzleIndirectGlobal
-                                      : ByteCodeInstruction::kStoreSwizzleIndirect,
+        ByteCodeGenerator::Location location = fGenerator.getLocation(*fSwizzle.fBase);
+        if (location.isOnStack()) {
+            fGenerator.write(location.selectStore(ByteCodeInstruction::kStoreSwizzleIndirect,
+                                                  ByteCodeInstruction::kStoreSwizzleIndirectGlobal),
                              count);
         } else {
-            fGenerator.write(isGlobal ? ByteCodeInstruction::kStoreSwizzleGlobal
-                                      : ByteCodeInstruction::kStoreSwizzle,
+            fGenerator.write(location.selectStore(ByteCodeInstruction::kStoreSwizzle,
+                                                  ByteCodeInstruction::kStoreSwizzleGlobal),
                              count);
-            fGenerator.write8(location);
+            fGenerator.write8(location.fSlot);
         }
         fGenerator.write8(count);
         for (int c : fSwizzle.fComponents) {
@@ -1387,23 +1396,22 @@
                 fGenerator.write8(count);
             }
         }
-        Variable::Storage storage = Variable::kLocal_Storage;
-        int location = fGenerator.getLocation(fExpression, &storage);
-        bool isGlobal = storage == Variable::kGlobal_Storage;
-        if (location < 0 || count > 4) {
-            if (location >= 0) {
+        ByteCodeGenerator::Location location = fGenerator.getLocation(fExpression);
+        if (location.isOnStack() || count > 4) {
+            if (!location.isOnStack()) {
                 fGenerator.write(ByteCodeInstruction::kPushImmediate);
-                fGenerator.write32(location);
+                fGenerator.write32(location.fSlot);
             }
-            fGenerator.write(isGlobal ? ByteCodeInstruction::kStoreExtendedGlobal
-                                      : ByteCodeInstruction::kStoreExtended,
+            fGenerator.write(location.selectStore(ByteCodeInstruction::kStoreExtended,
+                                                  ByteCodeInstruction::kStoreExtendedGlobal),
                              count);
             fGenerator.write8(count);
         } else {
-            fGenerator.write(vector_instruction(isGlobal ? ByteCodeInstruction::kStoreGlobal
-                                                         : ByteCodeInstruction::kStore,
-                                                count));
-            fGenerator.write8(location);
+            fGenerator.write(
+                    vector_instruction(location.selectStore(ByteCodeInstruction::kStore,
+                                                            ByteCodeInstruction::kStoreGlobal),
+                                       count));
+            fGenerator.write8(location.fSlot);
         }
     }
 
@@ -1554,20 +1562,19 @@
 void ByteCodeGenerator::writeVarDeclarations(const VarDeclarations& v) {
     for (const auto& declStatement : v.fVars) {
         const VarDeclaration& decl = (VarDeclaration&) *declStatement;
-        // we need to grab the location even if we don't use it, to ensure it
-        // has been allocated
-        int location = getLocation(*decl.fVar);
+        // we need to grab the location even if we don't use it, to ensure it has been allocated
+        Location location = this->getLocation(*decl.fVar);
         if (decl.fValue) {
             this->writeExpression(*decl.fValue);
             int count = SlotCount(decl.fValue->fType);
             if (count > 4) {
                 this->write(ByteCodeInstruction::kPushImmediate);
-                this->write32(location);
+                this->write32(location.fSlot);
                 this->write(ByteCodeInstruction::kStoreExtended, count);
                 this->write8(count);
             } else {
                 this->write(vector_instruction(ByteCodeInstruction::kStore, count));
-                this->write8(location);
+                this->write8(location.fSlot);
             }
         }
     }
diff --git a/src/sksl/SkSLByteCodeGenerator.h b/src/sksl/SkSLByteCodeGenerator.h
index e562970..1546f6e 100644
--- a/src/sksl/SkSLByteCodeGenerator.h
+++ b/src/sksl/SkSLByteCodeGenerator.h
@@ -164,19 +164,65 @@
         } fValue;
     };
 
+
+    // Similar to Variable::Storage, but locals and parameters are grouped together, and globals
+    // are further subidivided into uniforms and other (writable) globals.
+    enum class Storage {
+        kLocal,    // include parameters
+        kGlobal,   // non-uniform globals
+        kUniform,  // uniform globals
+    };
+
+    struct Location {
+        int     fSlot;
+        Storage fStorage;
+
+        // Not really invalid, but a "safe" placeholder to be more explicit at call-sites
+        static Location MakeInvalid() { return { 0, Storage::kLocal }; }
+
+        Location makeOnStack() { return { -1, fStorage }; }
+        bool isOnStack() const { return fSlot < 0; }
+
+        Location operator+(int offset) {
+            SkASSERT(fSlot >= 0);
+            return { fSlot + offset, fStorage };
+        }
+
+        ByteCodeInstruction selectLoad(ByteCodeInstruction local,
+                                       ByteCodeInstruction global,
+                                       ByteCodeInstruction uniform) const {
+            switch (fStorage) {
+                case Storage::kLocal:   return local;
+                case Storage::kGlobal:  return global;
+                case Storage::kUniform: return uniform;
+            }
+            SkUNREACHABLE;
+        }
+
+        ByteCodeInstruction selectStore(ByteCodeInstruction local,
+                                        ByteCodeInstruction global) const {
+            switch (fStorage) {
+                case Storage::kLocal:   return local;
+                case Storage::kGlobal:  return global;
+                case Storage::kUniform: SK_ABORT("Trying to store to a uniform"); break;
+            }
+            return local;
+        }
+    };
+
     /**
      * Returns the local slot into which var should be stored, allocating a new slot if it has not
      * already been assigned one. Compound variables (e.g. vectors) will consume more than one local
      * slot, with the getLocation return value indicating where the first element should be stored.
      */
-    int getLocation(const Variable& var);
+    Location getLocation(const Variable& var);
 
     /**
      * As above, but computes the (possibly dynamic) address of an expression involving indexing &
      * field access. If the address is known, it's returned. If not, -1 is returned, and the
      * location will be left on the top of the stack.
      */
-    int getLocation(const Expression& expr, Variable::Storage* storage);
+    Location getLocation(const Expression& expr);
 
     std::unique_ptr<ByteCodeFunction> writeFunction(const FunctionDefinition& f);