Interpreter: Matrix/Vector multiplication

Change-Id: I3dc5e5be1cf12c581cce3854d0db7e73db6e1fd9
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/216681
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index ff32bc6..b9aac93 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -342,6 +342,8 @@
     }
     const Type& lType = b.fLeft->fType;
     const Type& rType = b.fRight->fType;
+    bool lVecOrMtx = (lType.kind() == Type::kVector_Kind || lType.kind() == Type::kMatrix_Kind);
+    bool rVecOrMtx = (rType.kind() == Type::kVector_Kind || rType.kind() == Type::kMatrix_Kind);
     Token::Kind op;
     std::unique_ptr<LValue> lvalue;
     if (is_assignment(b.fOperator)) {
@@ -351,98 +353,115 @@
     } else {
         this->writeExpression(*b.fLeft);
         op = b.fOperator;
-        if (lType.kind() == Type::kScalar_Kind &&
-            (rType.kind() == Type::kVector_Kind || rType.kind() == Type::kMatrix_Kind)) {
+        if (!lVecOrMtx && rVecOrMtx) {
             for (int i = SlotCount(rType); i > 1; --i) {
                 this->write(ByteCodeInstruction::kDup);
             }
         }
     }
     this->writeExpression(*b.fRight);
-    if ((lType.kind() == Type::kVector_Kind || lType.kind() == Type::kMatrix_Kind) &&
-        rType.kind() == Type::kScalar_Kind) {
+    if (lVecOrMtx && !rVecOrMtx) {
         for (int i = SlotCount(lType); i > 1; --i) {
             this->write(ByteCodeInstruction::kDup);
         }
     }
-    int count = std::max(SlotCount(lType), SlotCount(rType));
-    switch (op) {
-        case Token::Kind::EQEQ:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareIEQ,
-                                        ByteCodeInstruction::kCompareIEQ,
-                                        ByteCodeInstruction::kCompareFEQ,
-                                        count);
-            // Collapse to a single bool
-            for (int i = count; i > 1; --i) {
-                this->write(ByteCodeInstruction::kAndB);
-            }
-            break;
-        case Token::Kind::GT:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSGT,
-                                        ByteCodeInstruction::kCompareUGT,
-                                        ByteCodeInstruction::kCompareFGT,
-                                        count);
-            break;
-        case Token::Kind::GTEQ:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSGTEQ,
-                                        ByteCodeInstruction::kCompareUGTEQ,
-                                        ByteCodeInstruction::kCompareFGTEQ,
-                                        count);
-            break;
-        case Token::Kind::LT:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSLT,
-                                        ByteCodeInstruction::kCompareULT,
-                                        ByteCodeInstruction::kCompareFLT,
-                                        count);
-            break;
-        case Token::Kind::LTEQ:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSLTEQ,
-                                        ByteCodeInstruction::kCompareULTEQ,
-                                        ByteCodeInstruction::kCompareFLTEQ,
-                                        count);
-            break;
-        case Token::Kind::MINUS:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kSubtractI,
-                                        ByteCodeInstruction::kSubtractI,
-                                        ByteCodeInstruction::kSubtractF,
-                                        count);
-            break;
-        case Token::Kind::NEQ:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareINEQ,
-                                        ByteCodeInstruction::kCompareINEQ,
-                                        ByteCodeInstruction::kCompareFNEQ,
-                                        count);
-            // Collapse to a single bool
-            for (int i = count; i > 1; --i) {
-                this->write(ByteCodeInstruction::kOrB);
-            }
-            break;
-        case Token::Kind::PERCENT:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kRemainderS,
-                                        ByteCodeInstruction::kRemainderU,
-                                        ByteCodeInstruction::kRemainderF,
-                                        count);
-            break;
-        case Token::Kind::PLUS:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kAddI,
-                                        ByteCodeInstruction::kAddI,
-                                        ByteCodeInstruction::kAddF,
-                                        count);
-            break;
-        case Token::Kind::SLASH:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kDivideS,
-                                        ByteCodeInstruction::kDivideU,
-                                        ByteCodeInstruction::kDivideF,
-                                        count);
-            break;
-        case Token::Kind::STAR:
-            this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kMultiplyI,
-                                        ByteCodeInstruction::kMultiplyI,
-                                        ByteCodeInstruction::kMultiplyF,
-                                        count);
-            break;
-        default:
-            SkASSERT(false);
+    // Special case for M*V, V*M, M*M (but not V*V!)
+    if (op == Token::Kind::STAR && lVecOrMtx && rVecOrMtx &&
+        !(lType.kind() == Type::kVector_Kind && rType.kind() == Type::kVector_Kind)) {
+        this->write(ByteCodeInstruction::kMatrixMultiply);
+        int rCols = rType.columns(),
+            rRows = rType.rows(),
+            lCols = lType.columns(),
+            lRows = lType.rows();
+        // M*V treats the vector as a column
+        if (rType.kind() == Type::kVector_Kind) {
+            std::swap(rCols, rRows);
+        }
+        SkASSERT(lCols == rRows);
+        SkASSERT(SlotCount(b.fType) == lRows * rCols);
+        this->write8(lCols);
+        this->write8(lRows);
+        this->write8(rCols);
+    } else {
+        int count = std::max(SlotCount(lType), SlotCount(rType));
+        switch (op) {
+            case Token::Kind::EQEQ:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareIEQ,
+                                            ByteCodeInstruction::kCompareIEQ,
+                                            ByteCodeInstruction::kCompareFEQ,
+                                            count);
+                // Collapse to a single bool
+                for (int i = count; i > 1; --i) {
+                    this->write(ByteCodeInstruction::kAndB);
+                }
+                break;
+            case Token::Kind::GT:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareSGT,
+                                            ByteCodeInstruction::kCompareUGT,
+                                            ByteCodeInstruction::kCompareFGT,
+                                            count);
+                break;
+            case Token::Kind::GTEQ:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareSGTEQ,
+                                            ByteCodeInstruction::kCompareUGTEQ,
+                                            ByteCodeInstruction::kCompareFGTEQ,
+                                            count);
+                break;
+            case Token::Kind::LT:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareSLT,
+                                            ByteCodeInstruction::kCompareULT,
+                                            ByteCodeInstruction::kCompareFLT,
+                                            count);
+                break;
+            case Token::Kind::LTEQ:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareSLTEQ,
+                                            ByteCodeInstruction::kCompareULTEQ,
+                                            ByteCodeInstruction::kCompareFLTEQ,
+                                            count);
+                break;
+            case Token::Kind::MINUS:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kSubtractI,
+                                            ByteCodeInstruction::kSubtractI,
+                                            ByteCodeInstruction::kSubtractF,
+                                            count);
+                break;
+            case Token::Kind::NEQ:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareINEQ,
+                                            ByteCodeInstruction::kCompareINEQ,
+                                            ByteCodeInstruction::kCompareFNEQ,
+                                            count);
+                // Collapse to a single bool
+                for (int i = count; i > 1; --i) {
+                    this->write(ByteCodeInstruction::kOrB);
+                }
+                break;
+            case Token::Kind::PERCENT:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kRemainderS,
+                                            ByteCodeInstruction::kRemainderU,
+                                            ByteCodeInstruction::kRemainderF,
+                                            count);
+                break;
+            case Token::Kind::PLUS:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kAddI,
+                                            ByteCodeInstruction::kAddI,
+                                            ByteCodeInstruction::kAddF,
+                                            count);
+                break;
+            case Token::Kind::SLASH:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kDivideS,
+                                            ByteCodeInstruction::kDivideU,
+                                            ByteCodeInstruction::kDivideF,
+                                            count);
+                break;
+            case Token::Kind::STAR:
+                this->writeTypedInstruction(lType, ByteCodeInstruction::kMultiplyI,
+                                            ByteCodeInstruction::kMultiplyI,
+                                            ByteCodeInstruction::kMultiplyF,
+                                            count);
+                break;
+            default:
+                SkASSERT(false);
+        }
     }
     if (lvalue) {
         lvalue->store();