Support larger compound types in the interpreter

Field access and array indexing are supported, including
dynamic indices. Larger types (> 4 slots) can be used as
lvalues, rvalues, etc.

Change-Id: I9bb4ed850be4259c05c8952c6c0a17b71f813772
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/214443
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index bc3aaea..0070674 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -21,8 +21,23 @@
     fIntrinsics["tan"]  = ByteCodeInstruction::kTan;
 }
 
-static int slot_count(const Type& type) {
-    return type.columns() * type.rows();
+int ByteCodeGenerator::SlotCount(const Type& type) {
+    if (type.kind() == Type::kStruct_Kind) {
+        int slots = 0;
+        for (const auto& f : type.fields()) {
+            slots += SlotCount(*f.fType);
+        }
+        SkASSERT(slots <= 255);
+        return slots;
+    } else if (type.kind() == Type::kArray_Kind) {
+        int columns = type.columns();
+        SkASSERT(columns >= 0);
+        int slots = columns * SlotCount(type.componentType());
+        SkASSERT(slots <= 255);
+        return slots;
+    } else {
+        return type.columns() * type.rows();
+    }
 }
 
 bool ByteCodeGenerator::generateCode() {
@@ -44,11 +59,11 @@
                         continue;
                     }
                     if (declVar->fModifiers.fFlags & Modifiers::kIn_Flag) {
-                        for (int i = slot_count(declVar->fType); i > 0; --i) {
+                        for (int i = SlotCount(declVar->fType); i > 0; --i) {
                             fOutput->fInputSlots.push_back(fOutput->fGlobalCount++);
                         }
                     } else {
-                        fOutput->fGlobalCount += slot_count(declVar->fType);
+                        fOutput->fGlobalCount += SlotCount(declVar->fType);
                     }
                 }
                 break;
@@ -70,7 +85,7 @@
     std::unique_ptr<ByteCodeFunction> result(new ByteCodeFunction(&f.fDeclaration));
     fParameterCount = 0;
     for (const auto& p : f.fDeclaration.fParameters) {
-        fParameterCount += p->fType.columns() * p->fType.rows();
+        fParameterCount += SlotCount(p->fType);
     }
     fCode = &result->fCode;
     this->writeStatement(*f.fBody);
@@ -80,7 +95,7 @@
     result->fLocalCount = fLocals.size();
     const Type& returnType = f.fDeclaration.fReturnType;
     if (returnType != *fContext.fVoid_Type) {
-        result->fReturnCount = returnType.columns() * returnType.rows();
+        result->fReturnCount = SlotCount(returnType);
     }
     fLocals.clear();
     fFunction = nullptr;
@@ -127,7 +142,7 @@
             }
             int result = fParameterCount + fLocals.size();
             fLocals.push_back(&var);
-            for (int i = 0; i < slot_count(var.fType) - 1; ++i) {
+            for (int i = 0; i < SlotCount(var.fType) - 1; ++i) {
                 fLocals.push_back(nullptr);
             }
             SkASSERT(result <= 255);
@@ -140,7 +155,7 @@
                     SkASSERT(offset <= 255);
                     return offset;
                 }
-                offset += slot_count(p->fType);
+                offset += SlotCount(p->fType);
             }
             SkASSERT(false);
             return 0;
@@ -159,7 +174,7 @@
                             SkASSERT(offset <= 255);
                             return offset;
                         }
-                        offset += slot_count(declVar->fType);
+                        offset += SlotCount(declVar->fType);
                     }
                 }
             }
@@ -172,6 +187,62 @@
     }
 }
 
+int ByteCodeGenerator::getLocation(const Expression& expr, Variable::Storage* storage) {
+    switch (expr.fKind) {
+        case Expression::kFieldAccess_Kind: {
+            const FieldAccess& f = (const FieldAccess&)expr;
+            int baseAddr = this->getLocation(*f.fBase, storage);
+            int offset = 0;
+            for (int i = 0; i < f.fFieldIndex; ++i) {
+                offset += SlotCount(*f.fBase->fType.fields()[i].fType);
+            }
+            if (baseAddr < 0) {
+                this->write(ByteCodeInstruction::kPushImmediate);
+                this->write32(offset);
+                this->write(ByteCodeInstruction::kAddI);
+                return -1;
+            } else {
+                return baseAddr + offset;
+            }
+        }
+        case Expression::kIndex_Kind: {
+            const IndexExpression& i = (const IndexExpression&)expr;
+            int stride = SlotCount(i.fType);
+            int offset = -1;
+            if (i.fIndex->isConstant()) {
+                offset = i.fIndex->getConstantInt() * stride;
+            } else {
+                this->writeExpression(*i.fIndex);
+                this->write(ByteCodeInstruction::kPushImmediate);
+                this->write32(stride);
+                this->write(ByteCodeInstruction::kMultiplyI);
+            }
+            int baseAddr = this->getLocation(*i.fBase, storage);
+            if (baseAddr >= 0 && offset >= 0) {
+                return baseAddr + offset;
+            }
+            if (baseAddr >= 0) {
+                this->write(ByteCodeInstruction::kPushImmediate);
+                this->write32(baseAddr);
+            }
+            if (offset >= 0) {
+                this->write(ByteCodeInstruction::kPushImmediate);
+                this->write32(offset);
+            }
+            this->write(ByteCodeInstruction::kAddI);
+            return -1;
+        }
+        case Expression::kVariableReference_Kind: {
+            const Variable& var = ((const VariableReference&)expr).fVariable;
+            *storage = var.fStorage;
+            return this->getLocation(var);
+        }
+        default:
+            SkASSERT(false);
+            return 0;
+    }
+}
+
 void ByteCodeGenerator::write8(uint8_t b) {
     fCode->push_back(b);
 }
@@ -193,6 +264,7 @@
 }
 
 static ByteCodeInstruction vector_instruction(ByteCodeInstruction base, int count) {
+    SkASSERT(count >= 1 && count <= 4);
     return ((ByteCodeInstruction) ((int) base + count - 1));
 }
 
@@ -244,7 +316,7 @@
             this->write(ByteCodeInstruction::kDup);
         }
     }
-    int count = slot_count(b.fType);
+    int count = SlotCount(b.fType);
     switch (op) {
         case Token::Kind::EQEQ:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareIEQ,
@@ -364,12 +436,12 @@
     int argumentCount = 0;
     for (const auto& arg : f.fArguments) {
         this->writeExpression(*arg);
-        argumentCount += slot_count(arg->fType);
+        argumentCount += SlotCount(arg->fType);
     }
     this->write(ByteCodeInstruction::kCallExternal);
     SkASSERT(argumentCount <= 255);
     this->write8(argumentCount);
-    this->write8(slot_count(f.fType));
+    this->write8(SlotCount(f.fType));
     int index = fOutput->fExternalValues.size();
     fOutput->fExternalValues.push_back(f.fFunction);
     SkASSERT(index <= 255);
@@ -378,16 +450,32 @@
 
 void ByteCodeGenerator::writeExternalValue(const ExternalValueReference& e) {
     this->write(vector_instruction(ByteCodeInstruction::kReadExternal,
-                                   slot_count(e.fValue->type())));
+                                   SlotCount(e.fValue->type())));
     int index = fOutput->fExternalValues.size();
     fOutput->fExternalValues.push_back(e.fValue);
     SkASSERT(index <= 255);
     this->write8(index);
 }
 
-void ByteCodeGenerator::writeFieldAccess(const FieldAccess& f) {
-    // not yet implemented
-    abort();
+void ByteCodeGenerator::writeVariableExpression(const Expression& expr) {
+    Variable::Storage storage;
+    int location = this->getLocation(expr, &storage);
+    bool isGlobal = storage == Variable::kGlobal_Storage;
+    int count = SlotCount(expr.fType);
+    if (location < 0 || count > 4) {
+        if (location >= 0) {
+            this->write(ByteCodeInstruction::kPushImmediate);
+            this->write32(location);
+        }
+        this->write(isGlobal ? ByteCodeInstruction::kLoadExtendedGlobal
+                             : ByteCodeInstruction::kLoadExtended);
+        this->write8(count);
+    } else {
+        this->write(vector_instruction(isGlobal ? ByteCodeInstruction::kLoadGlobal
+                                                : ByteCodeInstruction::kLoad,
+                                       count));
+        this->write8(location);
+    }
 }
 
 void ByteCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
@@ -408,7 +496,7 @@
         case ByteCodeInstruction::kTan:
             SkASSERT(c.fArguments.size() == 1);
             this->write((ByteCodeInstruction) ((int) found->second +
-                        slot_count(c.fArguments[0]->fType) - 1));
+                        SlotCount(c.fArguments[0]->fType) - 1));
             break;
         default:
             SkASSERT(false);
@@ -427,11 +515,6 @@
     fCallTargets.emplace_back(this, f.fFunction);
 }
 
-void ByteCodeGenerator::writeIndexExpression(const IndexExpression& i) {
-    // not yet implemented
-    abort();
-}
-
 void ByteCodeGenerator::writeIntLiteral(const IntLiteral& i) {
     this->write(ByteCodeInstruction::kPushImmediate);
     this->write32(i.fValue);
@@ -446,7 +529,7 @@
     switch (p.fOperator) {
         case Token::Kind::PLUSPLUS: // fall through
         case Token::Kind::MINUSMINUS: {
-            SkASSERT(slot_count(p.fOperand->fType) == 1);
+            SkASSERT(SlotCount(p.fOperand->fType) == 1);
             std::unique_ptr<LValue> lvalue = this->getLValue(*p.fOperand);
             lvalue->load();
             this->write(ByteCodeInstruction::kPushImmediate);
@@ -474,7 +557,7 @@
                                         ByteCodeInstruction::kNegateI,
                                         ByteCodeInstruction::kNegateI,
                                         ByteCodeInstruction::kNegateF,
-                                        slot_count(p.fOperand->fType));
+                                        SlotCount(p.fOperand->fType));
             break;
         }
         default:
@@ -486,7 +569,7 @@
     switch (p.fOperator) {
         case Token::Kind::PLUSPLUS: // fall through
         case Token::Kind::MINUSMINUS: {
-            SkASSERT(slot_count(p.fOperand->fType) == 1);
+            SkASSERT(SlotCount(p.fOperand->fType) == 1);
             std::unique_ptr<LValue> lvalue = this->getLValue(*p.fOperand);
             lvalue->load();
             this->write(ByteCodeInstruction::kDup);
@@ -540,14 +623,6 @@
     }
 }
 
-void ByteCodeGenerator::writeVariableReference(const VariableReference& v) {
-    this->write(vector_instruction(v.fVariable.fStorage == Variable::kGlobal_Storage
-                                                                  ? ByteCodeInstruction::kLoadGlobal
-                                                                  : ByteCodeInstruction::kLoad,
-                                   slot_count(v.fType)));
-    this->write8(this->getLocation(v.fVariable));
-}
-
 void ByteCodeGenerator::writeTernaryExpression(const TernaryExpression& t) {
     this->writeExpression(*t.fTest);
     this->write(ByteCodeInstruction::kConditionalBranch);
@@ -578,7 +653,9 @@
             this->writeExternalValue((ExternalValueReference&) e);
             break;
         case Expression::kFieldAccess_Kind:
-            this->writeFieldAccess((FieldAccess&) e);
+        case Expression::kIndex_Kind:
+        case Expression::kVariableReference_Kind:
+            this->writeVariableExpression(e);
             break;
         case Expression::kFloatLiteral_Kind:
             this->writeFloatLiteral((FloatLiteral&) e);
@@ -586,9 +663,6 @@
         case Expression::kFunctionCall_Kind:
             this->writeFunctionCall((FunctionCall&) e);
             break;
-        case Expression::kIndex_Kind:
-            this->writeIndexExpression((IndexExpression&) e);
-            break;
         case Expression::kIntLiteral_Kind:
             this->writeIntLiteral((IntLiteral&) e);
             break;
@@ -604,9 +678,6 @@
         case Expression::kSwizzle_Kind:
             this->writeSwizzle((Swizzle&) e);
             break;
-        case Expression::kVariableReference_Kind:
-            this->writeVariableReference((VariableReference&) e);
-            break;
         case Expression::kTernary_Kind:
             this->writeTernaryExpression((TernaryExpression&) e);
             break;
@@ -620,7 +691,7 @@
 public:
     ByteCodeExternalValueLValue(ByteCodeGenerator* generator, ExternalValue& value, int index)
         : INHERITED(*generator)
-        , fCount(slot_count(value.type()))
+        , fCount(ByteCodeGenerator::SlotCount(value.type()))
         , fIndex(index) {}
 
     void load() override {
@@ -646,22 +717,26 @@
 public:
     ByteCodeSwizzleLValue(ByteCodeGenerator* generator, const Swizzle& swizzle)
         : INHERITED(*generator)
-        , fSwizzle(swizzle) {
-        SkASSERT(fSwizzle.fBase->fKind == Expression::kVariableReference_Kind);
-    }
+        , fSwizzle(swizzle) {}
 
     void load() override {
         fGenerator.writeSwizzle(fSwizzle);
     }
 
     void store() override {
-        const Variable& var = ((VariableReference&)*fSwizzle.fBase).fVariable;
         fGenerator.write(vector_instruction(ByteCodeInstruction::kDup,
                                             fSwizzle.fComponents.size()));
-        fGenerator.write(var.fStorage == Variable::kGlobal_Storage
-                            ? ByteCodeInstruction::kStoreSwizzleGlobal
-                            : ByteCodeInstruction::kStoreSwizzle);
-        fGenerator.write8(fGenerator.getLocation(var));
+        Variable::Storage storage;
+        int location = fGenerator.getLocation(*fSwizzle.fBase, &storage);
+        bool isGlobal = storage == Variable::kGlobal_Storage;
+        if (location < 0) {
+            fGenerator.write(isGlobal ? ByteCodeInstruction::kStoreSwizzleIndirectGlobal
+                                      : ByteCodeInstruction::kStoreSwizzleIndirect);
+        } else {
+            fGenerator.write(isGlobal ? ByteCodeInstruction::kStoreSwizzleGlobal
+                                      : ByteCodeInstruction::kStoreSwizzle);
+            fGenerator.write8(location);
+        }
         fGenerator.write8(fSwizzle.fComponents.size());
         for (int c : fSwizzle.fComponents) {
             fGenerator.write8(c);
@@ -674,36 +749,47 @@
     typedef LValue INHERITED;
 };
 
-class ByteCodeVariableLValue : public ByteCodeGenerator::LValue {
+class ByteCodeExpressionLValue : public ByteCodeGenerator::LValue {
 public:
-    ByteCodeVariableLValue(ByteCodeGenerator* generator, const Variable& var)
+    ByteCodeExpressionLValue(ByteCodeGenerator* generator, const Expression& expr)
         : INHERITED(*generator)
-        , fCount(slot_count(var.fType))
-        , fLocation(generator->getLocation(var))
-        , fIsGlobal(var.fStorage == Variable::kGlobal_Storage) {
-    }
+        , fExpression(expr) {}
 
     void load() override {
-        fGenerator.write(vector_instruction(fIsGlobal ? ByteCodeInstruction::kLoadGlobal
-                                                      : ByteCodeInstruction::kLoad,
-                                            fCount));
-        fGenerator.write8(fLocation);
+        fGenerator.writeVariableExpression(fExpression);
     }
 
     void store() override {
-        fGenerator.write(vector_instruction(ByteCodeInstruction::kDup, fCount));
-        fGenerator.write(vector_instruction(fIsGlobal ? ByteCodeInstruction::kStoreGlobal
-                                                      : ByteCodeInstruction::kStore,
-                                            fCount));
-        fGenerator.write8(fLocation);
+        int count = ByteCodeGenerator::SlotCount(fExpression.fType);
+        if (count > 4) {
+            fGenerator.write(ByteCodeInstruction::kDupN);
+            fGenerator.write8(count);
+        } else {
+            fGenerator.write(vector_instruction(ByteCodeInstruction::kDup, count));
+        }
+        Variable::Storage storage;
+        int location = fGenerator.getLocation(fExpression, &storage);
+        bool isGlobal = storage == Variable::kGlobal_Storage;
+        if (location < 0 || count > 4) {
+            if (location >= 0) {
+                fGenerator.write(ByteCodeInstruction::kPushImmediate);
+                fGenerator.write32(location);
+            }
+            fGenerator.write(isGlobal ? ByteCodeInstruction::kStoreExtendedGlobal
+                                      : ByteCodeInstruction::kStoreExtended);
+            fGenerator.write8(count);
+        } else {
+            fGenerator.write(vector_instruction(isGlobal ? ByteCodeInstruction::kStoreGlobal
+                                                         : ByteCodeInstruction::kStore,
+                                                count));
+            fGenerator.write8(location);
+        }
     }
 
 private:
     typedef LValue INHERITED;
 
-    int fCount;
-    int fLocation;
-    bool fIsGlobal;
+    const Expression& fExpression;
 };
 
 std::unique_ptr<ByteCodeGenerator::LValue> ByteCodeGenerator::getLValue(const Expression& e) {
@@ -715,12 +801,10 @@
             SkASSERT(index <= 255);
             return std::unique_ptr<LValue>(new ByteCodeExternalValueLValue(this, *value, index));
         }
+        case Expression::kFieldAccess_Kind:
         case Expression::kIndex_Kind:
-            // not yet implemented
-            abort();
         case Expression::kVariableReference_Kind:
-            return std::unique_ptr<LValue>(new ByteCodeVariableLValue(this,
-                                                               ((VariableReference&) e).fVariable));
+            return std::unique_ptr<LValue>(new ByteCodeExpressionLValue(this, e));
         case Expression::kSwizzle_Kind:
             return std::unique_ptr<LValue>(new ByteCodeSwizzleLValue(this, (Swizzle&) e));
         case Expression::kTernary_Kind:
@@ -790,7 +874,7 @@
         this->setContinueTargets();
         if (f.fNext) {
             this->writeExpression(*f.fNext);
-            this->write(vector_instruction(ByteCodeInstruction::kPop, slot_count(f.fNext->fType)));
+            this->write(vector_instruction(ByteCodeInstruction::kPop, SlotCount(f.fNext->fType)));
         }
         this->write(ByteCodeInstruction::kBranch);
         this->write16(start);
@@ -800,7 +884,7 @@
         this->setContinueTargets();
         if (f.fNext) {
             this->writeExpression(*f.fNext);
-            this->write(vector_instruction(ByteCodeInstruction::kPop, slot_count(f.fNext->fType)));
+            this->write(vector_instruction(ByteCodeInstruction::kPop, SlotCount(f.fNext->fType)));
         }
         this->write(ByteCodeInstruction::kBranch);
         this->write16(start);
@@ -834,7 +918,7 @@
 void ByteCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
     this->writeExpression(*r.fExpression);
     this->write(ByteCodeInstruction::kReturn);
-    this->write8(r.fExpression->fType.columns() * r.fExpression->fType.rows());
+    this->write8(SlotCount(r.fExpression->fType));
 }
 
 void ByteCodeGenerator::writeSwitchStatement(const SwitchStatement& r) {
@@ -850,9 +934,16 @@
         int location = getLocation(*decl.fVar);
         if (decl.fValue) {
             this->writeExpression(*decl.fValue);
-            this->write(vector_instruction(ByteCodeInstruction::kStore,
-                                           slot_count(decl.fValue->fType)));
-            this->write8(location);
+            int count = SlotCount(decl.fValue->fType);
+            if (count > 4) {
+                this->write(ByteCodeInstruction::kPushImmediate);
+                this->write32(location);
+                this->write(ByteCodeInstruction::kStoreExtended);
+                this->write8(count);
+            } else {
+                this->write(vector_instruction(ByteCodeInstruction::kStore, count));
+                this->write8(location);
+            }
         }
     }
 }
@@ -893,7 +984,13 @@
         case Statement::kExpression_Kind: {
             const Expression& expr = *((ExpressionStatement&) s).fExpression;
             this->writeExpression(expr);
-            this->write(vector_instruction(ByteCodeInstruction::kPop, slot_count(expr.fType)));
+            int count = SlotCount(expr.fType);
+            if (count > 4) {
+                this->write(ByteCodeInstruction::kPopN);
+                this->write8(count);
+            } else {
+                this->write(vector_instruction(ByteCodeInstruction::kPop, count));
+            }
             break;
         }
         case Statement::kFor_Kind: