Interpreter: Compute max stack depth needed

Change-Id: I171a680aac554a0015d1854c46b35e9c9785fdf3
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/227061
Reviewed-by: Mike Klein <mtklein@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index acefb44..fd3b6f0 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -90,17 +90,22 @@
     fParameterCount = result->fParameterCount;
     fLoopCount = fMaxLoopCount = 0;
     fConditionCount = fMaxConditionCount = 0;
+    fStackCount = fMaxStackCount = 0;
     fCode = &result->fCode;
 
     this->writeStatement(*f.fBody);
-    SkASSERT(fLoopCount == 0);
-    SkASSERT(fConditionCount == 0);
-    this->write(ByteCodeInstruction::kReturn);
+    if (0 == fErrors.errorCount()) {
+        SkASSERT(fLoopCount == 0);
+        SkASSERT(fConditionCount == 0);
+        SkASSERT(fStackCount == 0);
+    }
+    this->write(ByteCodeInstruction::kReturn, 0);
     this->write8(0);
 
     result->fLocalCount     = fLocals.size();
     result->fConditionCount = fMaxConditionCount;
     result->fLoopCount      = fMaxLoopCount;
+    result->fStackCount     = fMaxStackCount;
 
     const Type& returnType = f.fDeclaration.fReturnType;
     if (returnType != *fContext.fVoid_Type) {
@@ -159,6 +164,215 @@
     return true;
 }
 
+int ByteCodeGenerator::StackUsage(ByteCodeInstruction inst, int count_) {
+    // Ensures that we use count iff we're passed a non-default value. Most instructions have an
+    // implicit count, so the caller shouldn't need to worry about it (or count makes no sense).
+    // The asserts avoids callers thinking they're supplying useful information in that scenario,
+    // or failing to supply necessary information for the ops that need a count.
+    struct CountValue {
+        operator int() {
+            SkASSERT(val != ByteCodeGenerator::kUnusedStackCount);
+            SkDEBUGCODE(used = true);
+            return val;
+        }
+        ~CountValue() {
+            SkASSERT(used || val == ByteCodeGenerator::kUnusedStackCount);
+        }
+        int val;
+        SkDEBUGCODE(bool used = false;)
+    } count = { count_ };
+
+    switch (inst) {
+        // Unary functions/operators that don't change stack depth at all:
+#define VECTOR_UNARY_OP(base)                \
+        case ByteCodeInstruction::base:      \
+        case ByteCodeInstruction::base ## 2: \
+        case ByteCodeInstruction::base ## 3: \
+        case ByteCodeInstruction::base ## 4: \
+            return 0;
+
+        VECTOR_UNARY_OP(kConvertFtoI)
+        VECTOR_UNARY_OP(kConvertStoF)
+        VECTOR_UNARY_OP(kConvertUtoF)
+
+        VECTOR_UNARY_OP(kCos)
+        VECTOR_UNARY_OP(kSin)
+        VECTOR_UNARY_OP(kSqrt)
+        VECTOR_UNARY_OP(kTan)
+
+        VECTOR_UNARY_OP(kNegateF)
+        VECTOR_UNARY_OP(kNegateI)
+
+        case ByteCodeInstruction::kNotB: return 0;
+        case ByteCodeInstruction::kNegateFN: return 0;
+
+#undef VECTOR_UNARY_OP
+
+        // Binary functions/operators that do a 2 -> 1 reduction (possibly N times)
+#define VECTOR_BINARY_OP(base)                          \
+        case ByteCodeInstruction::base:      return -1; \
+        case ByteCodeInstruction::base ## 2: return -2; \
+        case ByteCodeInstruction::base ## 3: return -3; \
+        case ByteCodeInstruction::base ## 4: return -4;
+
+#define VECTOR_MATRIX_BINARY_OP(base)                   \
+        VECTOR_BINARY_OP(base)                          \
+        case ByteCodeInstruction::base ## N: return -count;
+
+        case ByteCodeInstruction::kAndB: return -1;
+        case ByteCodeInstruction::kOrB:  return -1;
+        case ByteCodeInstruction::kXorB: return -1;
+
+        VECTOR_BINARY_OP(kAddI)
+        VECTOR_MATRIX_BINARY_OP(kAddF)
+
+        VECTOR_BINARY_OP(kCompareIEQ)
+        VECTOR_MATRIX_BINARY_OP(kCompareFEQ)
+        VECTOR_BINARY_OP(kCompareINEQ)
+        VECTOR_MATRIX_BINARY_OP(kCompareFNEQ)
+        VECTOR_BINARY_OP(kCompareSGT)
+        VECTOR_BINARY_OP(kCompareUGT)
+        VECTOR_BINARY_OP(kCompareFGT)
+        VECTOR_BINARY_OP(kCompareSGTEQ)
+        VECTOR_BINARY_OP(kCompareUGTEQ)
+        VECTOR_BINARY_OP(kCompareFGTEQ)
+        VECTOR_BINARY_OP(kCompareSLT)
+        VECTOR_BINARY_OP(kCompareULT)
+        VECTOR_BINARY_OP(kCompareFLT)
+        VECTOR_BINARY_OP(kCompareSLTEQ)
+        VECTOR_BINARY_OP(kCompareULTEQ)
+        VECTOR_BINARY_OP(kCompareFLTEQ)
+
+        VECTOR_BINARY_OP(kDivideS)
+        VECTOR_BINARY_OP(kDivideU)
+        VECTOR_MATRIX_BINARY_OP(kDivideF)
+        VECTOR_BINARY_OP(kMultiplyI)
+        VECTOR_MATRIX_BINARY_OP(kMultiplyF)
+        VECTOR_BINARY_OP(kRemainderF)
+        VECTOR_BINARY_OP(kRemainderS)
+        VECTOR_BINARY_OP(kRemainderU)
+        VECTOR_BINARY_OP(kSubtractI)
+        VECTOR_MATRIX_BINARY_OP(kSubtractF)
+
+#undef VECTOR_BINARY_OP
+#undef VECTOR_MATRIX_BINARY_OP
+
+        // Strange math operations with other behavior:
+        case ByteCodeInstruction::kCross: return -3;
+        // Binary, but also consumes T:
+        case ByteCodeInstruction::kMix:  return -2;
+        case ByteCodeInstruction::kMix2: return -3;
+        case ByteCodeInstruction::kMix3: return -4;
+        case ByteCodeInstruction::kMix4: return -5;
+
+        // Ops that push or load data to grow the stack:
+        case ByteCodeInstruction::kDup:
+        case ByteCodeInstruction::kLoad:
+        case ByteCodeInstruction::kLoadGlobal:
+        case ByteCodeInstruction::kReadExternal:
+        case ByteCodeInstruction::kPushImmediate:
+            return 1;
+
+        case ByteCodeInstruction::kDup2:
+        case ByteCodeInstruction::kLoad2:
+        case ByteCodeInstruction::kLoadGlobal2:
+        case ByteCodeInstruction::kReadExternal2:
+            return 2;
+
+        case ByteCodeInstruction::kDup3:
+        case ByteCodeInstruction::kLoad3:
+        case ByteCodeInstruction::kLoadGlobal3:
+        case ByteCodeInstruction::kReadExternal3:
+            return 3;
+
+        case ByteCodeInstruction::kDup4:
+        case ByteCodeInstruction::kLoad4:
+        case ByteCodeInstruction::kLoadGlobal4:
+        case ByteCodeInstruction::kReadExternal4:
+            return 4;
+
+        case ByteCodeInstruction::kDupN:
+        case ByteCodeInstruction::kLoadSwizzle:
+        case ByteCodeInstruction::kLoadSwizzleGlobal:
+            return count;
+
+        // Pushes 'count' values, minus one for the 'address' that's consumed first
+        case ByteCodeInstruction::kLoadExtended:
+        case ByteCodeInstruction::kLoadExtendedGlobal:
+            return count - 1;
+
+        // Ops that pop or store data to shrink the stack:
+        case ByteCodeInstruction::kPop:
+        case ByteCodeInstruction::kStore:
+        case ByteCodeInstruction::kStoreGlobal:
+        case ByteCodeInstruction::kWriteExternal:
+            return -1;
+
+        case ByteCodeInstruction::kPop2:
+        case ByteCodeInstruction::kStore2:
+        case ByteCodeInstruction::kStoreGlobal2:
+        case ByteCodeInstruction::kWriteExternal2:
+            return -2;
+
+        case ByteCodeInstruction::kPop3:
+        case ByteCodeInstruction::kStore3:
+        case ByteCodeInstruction::kStoreGlobal3:
+        case ByteCodeInstruction::kWriteExternal3:
+            return -3;
+
+        case ByteCodeInstruction::kPop4:
+        case ByteCodeInstruction::kStore4:
+        case ByteCodeInstruction::kStoreGlobal4:
+        case ByteCodeInstruction::kWriteExternal4:
+            return -4;
+
+        case ByteCodeInstruction::kPopN:
+        case ByteCodeInstruction::kStoreSwizzle:
+        case ByteCodeInstruction::kStoreSwizzleGlobal:
+            return -count;
+
+        // Consumes 'count' values, plus one for the 'address'
+        case ByteCodeInstruction::kStoreExtended:
+        case ByteCodeInstruction::kStoreExtendedGlobal:
+        case ByteCodeInstruction::kStoreSwizzleIndirect:
+        case ByteCodeInstruction::kStoreSwizzleIndirectGlobal:
+            return -count - 1;
+
+        // Strange ops where the caller computes the delta for us:
+        case ByteCodeInstruction::kCallExternal:
+        case ByteCodeInstruction::kMatrixToMatrix:
+        case ByteCodeInstruction::kMatrixMultiply:
+        case ByteCodeInstruction::kReserve:
+        case ByteCodeInstruction::kReturn:
+        case ByteCodeInstruction::kScalarToMatrix:
+        case ByteCodeInstruction::kSwizzle:
+            return count;
+
+        // Miscellaneous
+
+        // kCall is net-zero. Max stack depth is adjusted in writeFunctionCall.
+        case ByteCodeInstruction::kCall:             return 0;
+        case ByteCodeInstruction::kBranch:           return 0;
+        case ByteCodeInstruction::kBranchIfAllFalse: return 0;
+
+        case ByteCodeInstruction::kMaskPush:         return -1;
+        case ByteCodeInstruction::kMaskPop:          return 0;
+        case ByteCodeInstruction::kMaskNegate:       return 0;
+        case ByteCodeInstruction::kMaskBlend:        return -count;
+
+        case ByteCodeInstruction::kLoopBegin:        return 0;
+        case ByteCodeInstruction::kLoopNext:         return 0;
+        case ByteCodeInstruction::kLoopMask:         return -1;
+        case ByteCodeInstruction::kLoopEnd:          return 0;
+        case ByteCodeInstruction::kLoopBreak:        return 0;
+        case ByteCodeInstruction::kLoopContinue:     return 0;
+
+        default:
+            SkDEBUGFAILF("unsupported instruction %d\n", (int)inst);
+            return 0;
+    }
+}
+
 int ByteCodeGenerator::getLocation(const Variable& var) {
     // given that we seldom have more than a couple of variables, linear search is probably the most
     // efficient way to handle lookups
@@ -327,7 +541,7 @@
     memcpy(fCode->data() + n, &i, 4);
 }
 
-void ByteCodeGenerator::write(ByteCodeInstruction i) {
+void ByteCodeGenerator::write(ByteCodeInstruction i, int count) {
     switch (i) {
         case ByteCodeInstruction::kLoopBegin: this->enterLoop();      break;
         case ByteCodeInstruction::kLoopEnd:   this->exitLoop();       break;
@@ -338,6 +552,8 @@
         default: /* Do nothing */ break;
     }
     this->write16((uint16_t)i);
+    fStackCount += StackUsage(i, count);
+    fMaxStackCount = std::max(fMaxStackCount, fStackCount);
 }
 
 static ByteCodeInstruction vector_instruction(ByteCodeInstruction base, int count) {
@@ -357,7 +573,7 @@
             break;
         case TypeCategory::kFloat: {
             if (count > 4) {
-                this->write((ByteCodeInstruction)((int)f + 4));
+                this->write((ByteCodeInstruction)((int)f + 4), count);
                 this->write8(count);
             } else {
                 this->write(vector_instruction(f, count));
@@ -405,7 +621,8 @@
     // Special case for M*V, V*M, M*M (but not V*V!)
     if (op == Token::Kind::STAR && lVecOrMtx && rVecOrMtx &&
         !(lType.kind() == Type::kVector_Kind && rType.kind() == Type::kVector_Kind)) {
-        this->write(ByteCodeInstruction::kMatrixMultiply);
+        this->write(ByteCodeInstruction::kMatrixMultiply,
+                    SlotCount(b.fType) - (SlotCount(lType) + SlotCount(rType)));
         int rCols = rType.columns(),
             rRows = rType.rows(),
             lCols = lType.columns(),
@@ -559,7 +776,8 @@
             }
         }
         if (inType.kind() == Type::kMatrix_Kind && outType.kind() == Type::kMatrix_Kind) {
-            this->write(ByteCodeInstruction::kMatrixToMatrix);
+            this->write(ByteCodeInstruction::kMatrixToMatrix,
+                        SlotCount(outType) - SlotCount(inType));
             this->write8(inType.columns());
             this->write8(inType.rows());
             this->write8(outType.columns());
@@ -567,7 +785,7 @@
         } else if (inCount != outCount) {
             SkASSERT(inCount == 1);
             if (outType.kind() == Type::kMatrix_Kind) {
-                this->write(ByteCodeInstruction::kScalarToMatrix);
+                this->write(ByteCodeInstruction::kScalarToMatrix, SlotCount(outType) - 1);
                 this->write8(outType.columns());
                 this->write8(outType.rows());
             } else {
@@ -586,7 +804,7 @@
         this->writeExpression(*arg);
         argumentCount += SlotCount(arg->fType);
     }
-    this->write(ByteCodeInstruction::kCallExternal);
+    this->write(ByteCodeInstruction::kCallExternal, SlotCount(f.fType) - argumentCount);
     SkASSERT(argumentCount <= 255);
     this->write8(argumentCount);
     this->write8(SlotCount(f.fType));
@@ -616,7 +834,8 @@
             this->write32(location);
         }
         this->write(isGlobal ? ByteCodeInstruction::kLoadExtendedGlobal
-                             : ByteCodeInstruction::kLoadExtended);
+                             : ByteCodeInstruction::kLoadExtended,
+                    count);
         this->write8(count);
     } else {
         this->write(vector_instruction(isGlobal ? ByteCodeInstruction::kLoadGlobal
@@ -643,13 +862,13 @@
         fErrors.error(c.fOffset, "unsupported intrinsic function");
         return;
     }
+    int count = SlotCount(c.fArguments[0]->fType);
     if (found->second.fIsSpecial) {
         SkASSERT(found->second.fValue.fSpecial == SpecialIntrinsic::kDot);
         SkASSERT(c.fArguments.size() == 2);
-        SkASSERT(SlotCount(c.fArguments[0]->fType) == SlotCount(c.fArguments[1]->fType));
-        this->write((ByteCodeInstruction) ((int) ByteCodeInstruction::kMultiplyF +
-                    SlotCount(c.fArguments[0]->fType) - 1));
-        for (int i = SlotCount(c.fArguments[0]->fType); i > 1; --i) {
+        SkASSERT(count == SlotCount(c.fArguments[1]->fType));
+        this->write((ByteCodeInstruction)((int)ByteCodeInstruction::kMultiplyF + count - 1));
+        for (int i = count; i > 1; --i) {
             this->write(ByteCodeInstruction::kAddF);
         }
     } else {
@@ -661,7 +880,7 @@
             case ByteCodeInstruction::kTan:
                 SkASSERT(c.fArguments.size() > 0);
                 this->write((ByteCodeInstruction) ((int) found->second.fValue.fInstruction +
-                            SlotCount(c.fArguments[0]->fType) - 1));
+                            count - 1));
                 break;
             case ByteCodeInstruction::kCross:
                 this->write(found->second.fValue.fInstruction);
@@ -700,7 +919,7 @@
 
     // We may need to deal with out parameters, so the sequence is tricky
     if (int returnCount = SlotCount(f.fType)) {
-        this->write(ByteCodeInstruction::kReserve);
+        this->write(ByteCodeInstruction::kReserve, returnCount);
         this->write8(returnCount);
     }
 
@@ -717,12 +936,16 @@
         }
     }
 
+    // The space used by the call is based on the callee, but it also unwinds all of that before
+    // we continue execution. We adjust our max stack depths below.
     this->write(ByteCodeInstruction::kCall);
     this->write8(idx);
 
     const ByteCodeFunction* callee = fOutput->fFunctions[idx].get();
     fMaxLoopCount      = std::max(fMaxLoopCount,      fLoopCount      + callee->fLoopCount);
     fMaxConditionCount = std::max(fMaxConditionCount, fConditionCount + callee->fConditionCount);
+    fMaxStackCount     = std::max(fMaxStackCount,     fStackCount     + callee->fLocalCount
+                                                                      + callee->fStackCount);
 
     // After the called function returns, the stack will still contain our arguments. We have to
     // pop them (storing any out parameters back to their lvalues as we go). We glob together slot
@@ -730,7 +953,7 @@
     int popCount = 0;
     auto pop = [&]() {
         if (popCount > 4) {
-            this->write(ByteCodeInstruction::kPopN);
+            this->write(ByteCodeInstruction::kPopN, popCount);
             this->write8(popCount);
         } else if (popCount > 0) {
             this->write(vector_instruction(ByteCodeInstruction::kPop, popCount));
@@ -850,7 +1073,8 @@
             const Variable& var = ((VariableReference&) *s.fBase).fVariable;
             this->write(var.fStorage == Variable::kGlobal_Storage
                             ? ByteCodeInstruction::kLoadSwizzleGlobal
-                            : ByteCodeInstruction::kLoadSwizzle);
+                            : ByteCodeInstruction::kLoadSwizzle,
+                        s.fComponents.size());
             this->write8(this->getLocation(var));
             this->write8(s.fComponents.size());
             for (int c : s.fComponents) {
@@ -860,7 +1084,8 @@
         }
         default:
             this->writeExpression(*s.fBase);
-            this->write(ByteCodeInstruction::kSwizzle);
+            this->write(ByteCodeInstruction::kSwizzle,
+                        s.fComponents.size() - s.fBase->fType.columns());
             this->write8(s.fBase->fType.columns());
             this->write8(s.fComponents.size());
             for (int c : s.fComponents) {
@@ -870,13 +1095,17 @@
 }
 
 void ByteCodeGenerator::writeTernaryExpression(const TernaryExpression& t) {
+    int count = SlotCount(t.fType);
+    SkASSERT(count == SlotCount(t.fIfTrue->fType));
+    SkASSERT(count == SlotCount(t.fIfFalse->fType));
+
     this->writeExpression(*t.fTest);
     this->write(ByteCodeInstruction::kMaskPush);
     this->writeExpression(*t.fIfTrue);
     this->write(ByteCodeInstruction::kMaskNegate);
     this->writeExpression(*t.fIfFalse);
-    this->write(ByteCodeInstruction::kMaskBlend);
-    this->write8(SlotCount(t.fType));
+    this->write(ByteCodeInstruction::kMaskBlend, count);
+    this->write8(count);
 }
 
 void ByteCodeGenerator::writeExpression(const Expression& e, bool discard) {
@@ -932,7 +1161,7 @@
     if (discard) {
         int count = SlotCount(e.fType);
         if (count > 4) {
-            this->write(ByteCodeInstruction::kPopN);
+            this->write(ByteCodeInstruction::kPopN, count);
             this->write8(count);
         } else if (count != 0) {
             this->write(vector_instruction(ByteCodeInstruction::kPop, count));
@@ -980,22 +1209,24 @@
     }
 
     void store(bool discard) override {
+        int count = fSwizzle.fComponents.size();
         if (!discard) {
-            fGenerator.write(vector_instruction(ByteCodeInstruction::kDup,
-                                                fSwizzle.fComponents.size()));
+            fGenerator.write(vector_instruction(ByteCodeInstruction::kDup, count));
         }
         Variable::Storage storage;
         int location = fGenerator.getLocation(*fSwizzle.fBase, &storage);
         bool isGlobal = storage == Variable::kGlobal_Storage;
         if (location < 0) {
             fGenerator.write(isGlobal ? ByteCodeInstruction::kStoreSwizzleIndirectGlobal
-                                      : ByteCodeInstruction::kStoreSwizzleIndirect);
+                                      : ByteCodeInstruction::kStoreSwizzleIndirect,
+                             count);
         } else {
             fGenerator.write(isGlobal ? ByteCodeInstruction::kStoreSwizzleGlobal
-                                      : ByteCodeInstruction::kStoreSwizzle);
+                                      : ByteCodeInstruction::kStoreSwizzle,
+                             count);
             fGenerator.write8(location);
         }
-        fGenerator.write8(fSwizzle.fComponents.size());
+        fGenerator.write8(count);
         for (int c : fSwizzle.fComponents) {
             fGenerator.write8(c);
         }
@@ -1021,7 +1252,7 @@
         int count = ByteCodeGenerator::SlotCount(fExpression.fType);
         if (!discard) {
             if (count > 4) {
-                fGenerator.write(ByteCodeInstruction::kDupN);
+                fGenerator.write(ByteCodeInstruction::kDupN, count);
                 fGenerator.write8(count);
             } else {
                 fGenerator.write(vector_instruction(ByteCodeInstruction::kDup, count));
@@ -1036,7 +1267,8 @@
                 fGenerator.write32(location);
             }
             fGenerator.write(isGlobal ? ByteCodeInstruction::kStoreExtendedGlobal
-                                      : ByteCodeInstruction::kStoreExtended);
+                                      : ByteCodeInstruction::kStoreExtended,
+                             count);
             fGenerator.write8(count);
         } else {
             fGenerator.write(vector_instruction(isGlobal ? ByteCodeInstruction::kStoreGlobal
@@ -1173,9 +1405,16 @@
         fErrors.error(r.fOffset, "return not allowed inside conditional or loop");
         return;
     }
+    int count = SlotCount(r.fExpression->fType);
     this->writeExpression(*r.fExpression);
-    this->write(ByteCodeInstruction::kReturn);
-    this->write8(SlotCount(r.fExpression->fType));
+
+    // Technically, the kReturn also pops fOutput->fLocalCount values from the stack, too, but we
+    // haven't counted pushing those (they're outside the scope of our stack tracking). Instead,
+    // we account for those in writeFunction().
+
+    // This is all fine because we don't allow conditional returns, so we only return once anyway.
+    this->write(ByteCodeInstruction::kReturn, -count);
+    this->write8(count);
 }
 
 void ByteCodeGenerator::writeSwitchStatement(const SwitchStatement& r) {
@@ -1195,7 +1434,7 @@
             if (count > 4) {
                 this->write(ByteCodeInstruction::kPushImmediate);
                 this->write32(location);
-                this->write(ByteCodeInstruction::kStoreExtended);
+                this->write(ByteCodeInstruction::kStoreExtended, count);
                 this->write8(count);
             } else {
                 this->write(vector_instruction(ByteCodeInstruction::kStore, count));