redesigned SkSL interpreter vector instructions

Bug: skia:
Change-Id: I7737eacdb5acd6b19d95fce7ee76945f0f9d0d7e
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/214221
Reviewed-by: Brian Osman <brianosman@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index 74464c2..52df945 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -193,17 +193,22 @@
     this->write8((uint8_t) i);
 }
 
+ByteCodeInstruction vector_instruction(ByteCodeInstruction base, int count) {
+    return ((ByteCodeInstruction) ((int) base + count - 1));
+}
+
 void ByteCodeGenerator::writeTypedInstruction(const Type& type, ByteCodeInstruction s,
-                                              ByteCodeInstruction u, ByteCodeInstruction f) {
+                                              ByteCodeInstruction u, ByteCodeInstruction f,
+                                              int count) {
     switch (type_category(type)) {
         case TypeCategory::kSigned:
-            this->write(s);
+            this->write(vector_instruction(s, count));
             break;
         case TypeCategory::kUnsigned:
-            this->write(u);
+            this->write(vector_instruction(u, count));
             break;
         case TypeCategory::kFloat:
-            this->write(f);
+            this->write(vector_instruction(f, count));
             break;
         default:
             SkASSERT(false);
@@ -241,65 +246,72 @@
         }
     }
     int count = slot_count(b.fType);
-    if (count > 1) {
-        this->write(ByteCodeInstruction::kVector);
-        this->write8(count);
-    }
     switch (op) {
         case Token::Kind::EQEQ:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareIEQ,
                                         ByteCodeInstruction::kCompareIEQ,
-                                        ByteCodeInstruction::kCompareFEQ);
+                                        ByteCodeInstruction::kCompareFEQ,
+                                        count);
             break;
         case Token::Kind::GT:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSGT,
                                         ByteCodeInstruction::kCompareUGT,
-                                        ByteCodeInstruction::kCompareFGT);
+                                        ByteCodeInstruction::kCompareFGT,
+                                        count);
             break;
         case Token::Kind::GTEQ:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSGTEQ,
                                         ByteCodeInstruction::kCompareUGTEQ,
-                                        ByteCodeInstruction::kCompareFGTEQ);
+                                        ByteCodeInstruction::kCompareFGTEQ,
+                                        count);
             break;
         case Token::Kind::LT:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSLT,
                                         ByteCodeInstruction::kCompareULT,
-                                        ByteCodeInstruction::kCompareFLT);
+                                        ByteCodeInstruction::kCompareFLT,
+                                        count);
             break;
         case Token::Kind::LTEQ:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSLTEQ,
                                         ByteCodeInstruction::kCompareULTEQ,
-                                        ByteCodeInstruction::kCompareFLTEQ);
+                                        ByteCodeInstruction::kCompareFLTEQ,
+                                        count);
             break;
         case Token::Kind::MINUS:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kSubtractI,
                                         ByteCodeInstruction::kSubtractI,
-                                        ByteCodeInstruction::kSubtractF);
+                                        ByteCodeInstruction::kSubtractF,
+                                        count);
             break;
         case Token::Kind::NEQ:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareINEQ,
                                         ByteCodeInstruction::kCompareINEQ,
-                                        ByteCodeInstruction::kCompareFNEQ);
+                                        ByteCodeInstruction::kCompareFNEQ,
+                                        count);
             break;
         case Token::Kind::PERCENT:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kRemainderS,
                                         ByteCodeInstruction::kRemainderU,
-                                        ByteCodeInstruction::kRemainderF);
+                                        ByteCodeInstruction::kRemainderF,
+                                        count);
             break;
         case Token::Kind::PLUS:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kAddI,
                                         ByteCodeInstruction::kAddI,
-                                        ByteCodeInstruction::kAddF);
+                                        ByteCodeInstruction::kAddF,
+                                        count);
             break;
         case Token::Kind::SLASH:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kDivideS,
                                         ByteCodeInstruction::kDivideU,
-                                        ByteCodeInstruction::kDivideF);
+                                        ByteCodeInstruction::kDivideF,
+                                        count);
             break;
         case Token::Kind::STAR:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kMultiplyI,
                                         ByteCodeInstruction::kMultiplyI,
-                                        ByteCodeInstruction::kMultiplyF);
+                                        ByteCodeInstruction::kMultiplyF,
+                                        count);
             break;
         default:
             SkASSERT(false);
@@ -329,21 +341,19 @@
         TypeCategory inCategory = type_category(c.fArguments[0]->fType);
         TypeCategory outCategory = type_category(c.fType);
         if (inCategory != outCategory) {
-            int count = c.fType.columns();
-            if (count > 1) {
-                this->write(ByteCodeInstruction::kVector);
-                this->write8(count);
-            }
             if (inCategory == TypeCategory::kFloat) {
                 SkASSERT(outCategory == TypeCategory::kSigned ||
                          outCategory == TypeCategory::kUnsigned);
-                this->write(ByteCodeInstruction::kFloatToInt);
+                this->write(vector_instruction(ByteCodeInstruction::kFloatToInt,
+                                               c.fType.columns()));
             } else if (outCategory == TypeCategory::kFloat) {
                 if (inCategory == TypeCategory::kSigned) {
-                    this->write(ByteCodeInstruction::kSignedToFloat);
+                    this->write(vector_instruction(ByteCodeInstruction::kSignedToFloat,
+                                                   c.fType.columns()));
                 } else {
                     SkASSERT(inCategory == TypeCategory::kUnsigned);
-                    this->write(ByteCodeInstruction::kUnsignedToFloat);
+                    this->write(vector_instruction(ByteCodeInstruction::kUnsignedToFloat,
+                                                   c.fType.columns()));
                 }
             } else {
                 SkASSERT(false);
@@ -353,12 +363,8 @@
 }
 
 void ByteCodeGenerator::writeExternalValue(const ExternalValueReference& e) {
-    int count = slot_count(e.fValue->type());
-    if (count > 1) {
-        this->write(ByteCodeInstruction::kVector);
-        this->write8(count);
-    }
-    this->write(ByteCodeInstruction::kReadExternal);
+    this->write(vector_instruction(ByteCodeInstruction::kReadExternal,
+                                   slot_count(e.fValue->type())));
     int index = fOutput->fExternalValues.size();
     fOutput->fExternalValues.push_back(e.fValue);
     SkASSERT(index <= 255);
@@ -410,31 +416,30 @@
             this->align(4, 3);
             this->write(ByteCodeInstruction::kPushImmediate);
             this->write32(1);
+            SkASSERT(slot_count(p.fOperand->fType) == 1);
             if (p.fOperator == Token::Kind::PLUSPLUS) {
                 this->writeTypedInstruction(p.fType,
                                             ByteCodeInstruction::kAddI,
                                             ByteCodeInstruction::kAddI,
-                                            ByteCodeInstruction::kAddF);
+                                            ByteCodeInstruction::kAddF,
+                                            1);
             } else {
                 this->writeTypedInstruction(p.fType,
                                             ByteCodeInstruction::kSubtractI,
                                             ByteCodeInstruction::kSubtractI,
-                                            ByteCodeInstruction::kSubtractF);
+                                            ByteCodeInstruction::kSubtractF,
+                                            1);
             }
             lvalue->store();
             break;
         }
         case Token::Kind::MINUS: {
             this->writeExpression(*p.fOperand);
-            int count = slot_count(p.fOperand->fType);
-            if (count > 1) {
-                this->write(ByteCodeInstruction::kVector);
-                this->write8(count);
-            }
             this->writeTypedInstruction(p.fType,
                                         ByteCodeInstruction::kNegateS,
                                         ByteCodeInstruction::kInvalid,
-                                        ByteCodeInstruction::kNegateF);
+                                        ByteCodeInstruction::kNegateF,
+                                        slot_count(p.fOperand->fType));
             break;
         }
         default:
@@ -473,14 +478,10 @@
 }
 
 void ByteCodeGenerator::writeVariableReference(const VariableReference& v) {
-    int count = slot_count(v.fType);
-    if (count > 1) {
-        this->write(ByteCodeInstruction::kVector);
-        this->write8(count);
-    }
-    this->write(v.fVariable.fStorage == Variable::kGlobal_Storage
-                    ? ByteCodeInstruction::kLoadGlobal
-                    : ByteCodeInstruction::kLoad);
+    this->write(vector_instruction(v.fVariable.fStorage == Variable::kGlobal_Storage
+                                                                  ? ByteCodeInstruction::kLoadGlobal
+                                                                  : ByteCodeInstruction::kLoad,
+                                   slot_count(v.fType)));
     this->write8(this->getLocation(v.fVariable));
 }
 
@@ -550,25 +551,13 @@
         , fIndex(index) {}
 
     void load() override {
-        if (fCount > 1) {
-            fGenerator.write(ByteCodeInstruction::kVector);
-            fGenerator.write8(fCount);
-        }
-        fGenerator.write(ByteCodeInstruction::kReadExternal);
+        fGenerator.write(vector_instruction(ByteCodeInstruction::kReadExternal, fCount));
         fGenerator.write8(fIndex);
     }
 
     void store() override {
-        if (fCount > 1) {
-            fGenerator.write(ByteCodeInstruction::kVector);
-            fGenerator.write8(fCount);
-        }
-        fGenerator.write(ByteCodeInstruction::kDup);
-        if (fCount > 1) {
-            fGenerator.write(ByteCodeInstruction::kVector);
-            fGenerator.write8(fCount);
-        }
-        fGenerator.write(ByteCodeInstruction::kWriteExternal);
+        fGenerator.write(vector_instruction(ByteCodeInstruction::kDup, fCount));
+        fGenerator.write(vector_instruction(ByteCodeInstruction::kWriteExternal, fCount));
         fGenerator.write8(fIndex);
     }
 
@@ -594,11 +583,8 @@
 
     void store() override {
         const Variable& var = ((VariableReference&)*fSwizzle.fBase).fVariable;
-        if (fSwizzle.fComponents.size() > 1) {
-            fGenerator.write(ByteCodeInstruction::kVector);
-            fGenerator.write8(fSwizzle.fComponents.size());
-        }
-        fGenerator.write(ByteCodeInstruction::kDup);
+        fGenerator.write(vector_instruction(ByteCodeInstruction::kDup,
+                                            fSwizzle.fComponents.size()));
         fGenerator.write(var.fStorage == Variable::kGlobal_Storage
                             ? ByteCodeInstruction::kStoreSwizzleGlobal
                             : ByteCodeInstruction::kStoreSwizzle);
@@ -625,27 +611,17 @@
     }
 
     void load() override {
-        if (fCount > 1) {
-            fGenerator.write(ByteCodeInstruction::kVector);
-            fGenerator.write8(fCount);
-        }
-        fGenerator.write(fIsGlobal ? ByteCodeInstruction::kLoadGlobal
-                                   : ByteCodeInstruction::kLoad);
+        fGenerator.write(vector_instruction(fIsGlobal ? ByteCodeInstruction::kLoadGlobal
+                                                      : ByteCodeInstruction::kLoad,
+                                            fCount));
         fGenerator.write8(fLocation);
     }
 
     void store() override {
-        if (fCount > 1) {
-            fGenerator.write(ByteCodeInstruction::kVector);
-            fGenerator.write8(fCount);
-        }
-        fGenerator.write(ByteCodeInstruction::kDup);
-        if (fCount > 1) {
-            fGenerator.write(ByteCodeInstruction::kVector);
-            fGenerator.write8(fCount);
-        }
-        fGenerator.write(fIsGlobal ? ByteCodeInstruction::kStoreGlobal
-                                   : ByteCodeInstruction::kStore);
+        fGenerator.write(vector_instruction(ByteCodeInstruction::kDup, fCount));
+        fGenerator.write(vector_instruction(fIsGlobal ? ByteCodeInstruction::kStoreGlobal
+                                                      : ByteCodeInstruction::kStore,
+                                            fCount));
         fGenerator.write8(fLocation);
     }
 
@@ -745,8 +721,7 @@
         this->setContinueTargets();
         if (f.fNext) {
             this->writeExpression(*f.fNext);
-            this->write(ByteCodeInstruction::kPop);
-            this->write8(slot_count(f.fNext->fType));
+            this->write(vector_instruction(ByteCodeInstruction::kPop, slot_count(f.fNext->fType)));
         }
         this->align(2, 1);
         this->write(ByteCodeInstruction::kBranch);
@@ -757,8 +732,7 @@
         this->setContinueTargets();
         if (f.fNext) {
             this->writeExpression(*f.fNext);
-            this->write(ByteCodeInstruction::kPop);
-            this->write8(slot_count(f.fNext->fType));
+            this->write(vector_instruction(ByteCodeInstruction::kPop, slot_count(f.fNext->fType)));
         }
         this->align(2, 1);
         this->write(ByteCodeInstruction::kBranch);
@@ -803,12 +777,8 @@
         int location = getLocation(*decl.fVar);
         if (decl.fValue) {
             this->writeExpression(*decl.fValue);
-            int count = slot_count(decl.fValue->fType);
-            if (count > 1) {
-                this->write(ByteCodeInstruction::kVector);
-                this->write8(count);
-            }
-            this->write(ByteCodeInstruction::kStore);
+            this->write(vector_instruction(ByteCodeInstruction::kStore,
+                                           slot_count(decl.fValue->fType)));
             this->write8(location);
         }
     }
@@ -852,8 +822,7 @@
         case Statement::kExpression_Kind: {
             const Expression& expr = *((ExpressionStatement&) s).fExpression;
             this->writeExpression(expr);
-            this->write(ByteCodeInstruction::kPop);
-            this->write8(slot_count(expr.fType));
+            this->write(vector_instruction(ByteCodeInstruction::kPop, slot_count(expr.fType)));
             break;
         }
         case Statement::kFor_Kind: