Interpreter: Fix vector/matrix equality and inequality

Need to compare all elements, then fold the result to a single bool.

Change-Id: I0ebfaa9d518f29a782701246ada247cb55c01c2e
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/216607
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index 9958466..70052b6 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -334,6 +334,8 @@
         lvalue->store();
         return;
     }
+    const Type& lType = b.fLeft->fType;
+    const Type& rType = b.fRight->fType;
     Token::Kind op;
     std::unique_ptr<LValue> lvalue;
     if (is_assignment(b.fOperator)) {
@@ -343,27 +345,31 @@
     } else {
         this->writeExpression(*b.fLeft);
         op = b.fOperator;
-        if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
-            b.fRight->fType.kind() == Type::kVector_Kind) {
-            for (int i = b.fRight->fType.columns(); i > 1; --i) {
+        if (lType.kind() == Type::kScalar_Kind &&
+            (rType.kind() == Type::kVector_Kind || rType.kind() == Type::kMatrix_Kind)) {
+            for (int i = SlotCount(rType); i > 1; --i) {
                 this->write(ByteCodeInstruction::kDup);
             }
         }
     }
     this->writeExpression(*b.fRight);
-    if (b.fLeft->fType.kind() == Type::kVector_Kind &&
-        b.fRight->fType.kind() == Type::kScalar_Kind) {
-        for (int i = b.fLeft->fType.columns(); i > 1; --i) {
+    if ((lType.kind() == Type::kVector_Kind || lType.kind() == Type::kMatrix_Kind) &&
+        rType.kind() == Type::kScalar_Kind) {
+        for (int i = SlotCount(lType); i > 1; --i) {
             this->write(ByteCodeInstruction::kDup);
         }
     }
-    int count = SlotCount(b.fType);
+    int count = SkTMax(SlotCount(lType), SlotCount(rType));
     switch (op) {
         case Token::Kind::EQEQ:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareIEQ,
                                         ByteCodeInstruction::kCompareIEQ,
                                         ByteCodeInstruction::kCompareFEQ,
                                         count);
+            // Collapse to a single bool
+            for (int i = count; i > 1; --i) {
+                this->write(ByteCodeInstruction::kAndB);
+            }
             break;
         case Token::Kind::GT:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSGT,
@@ -400,6 +406,10 @@
                                         ByteCodeInstruction::kCompareINEQ,
                                         ByteCodeInstruction::kCompareFNEQ,
                                         count);
+            // Collapse to a single bool
+            for (int i = count; i > 1; --i) {
+                this->write(ByteCodeInstruction::kOrB);
+            }
             break;
         case Token::Kind::PERCENT:
             this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kRemainderS,
diff --git a/src/sksl/SkSLInterpreter.cpp b/src/sksl/SkSLInterpreter.cpp
index a2ddb99..15d4f6e 100644
--- a/src/sksl/SkSLInterpreter.cpp
+++ b/src/sksl/SkSLInterpreter.cpp
@@ -96,8 +96,8 @@
     switch ((ByteCodeInstruction) READ16()) {
         VECTOR_DISASSEMBLE(kAddF, "addf")
         VECTOR_DISASSEMBLE(kAddI, "addi")
-        case ByteCodeInstruction::kAndB: printf("andb"); break;
-        case ByteCodeInstruction::kAndI: printf("andi"); break;
+        VECTOR_DISASSEMBLE(kAndB, "andb")
+        VECTOR_DISASSEMBLE(kAndI, "andb")
         case ByteCodeInstruction::kBranch: printf("branch %d", READ16()); break;
         case ByteCodeInstruction::kCall: printf("call %d", READ8()); break;
         case ByteCodeInstruction::kCallExternal: {
@@ -361,6 +361,7 @@
         switch (inst) {
             VECTOR_BINARY_OP(kAddI, fSigned, +)
             VECTOR_BINARY_OP(kAddF, fFloat, +)
+            VECTOR_BINARY_OP(kAndB, fBool, &&)
 
             case ByteCodeInstruction::kBranch:
                 ip = code + READ16();
@@ -597,6 +598,8 @@
             case ByteCodeInstruction::kNegateI : sp[ 0] = -sp [0].fSigned;
                                                  break;
 
+            VECTOR_BINARY_OP(kOrB, fBool, ||)
+
             case ByteCodeInstruction::kPop4: POP();
             case ByteCodeInstruction::kPop3: POP();
             case ByteCodeInstruction::kPop2: POP();
diff --git a/tests/SkSLInterpreterTest.cpp b/tests/SkSLInterpreterTest.cpp
index 18a2d61..82b72fb 100644
--- a/tests/SkSLInterpreterTest.cpp
+++ b/tests/SkSLInterpreterTest.cpp
@@ -284,6 +284,17 @@
          "color.a = 2; }", 2, -2, 0, 0, 2, -2, 0, 2);
 }
 
+DEF_TEST(SkSLInterpreterIfVector, r) {
+    test(r, "void main(inout half4 color) { if (color.rg == color.ba) color.a = 1; }",
+         1, 2, 1, 2, 1, 2, 1, 1);
+    test(r, "void main(inout half4 color) { if (color.rg == color.ba) color.a = 1; }",
+         1, 2, 3, 2, 1, 2, 3, 2);
+    test(r, "void main(inout half4 color) { if (color.rg != color.ba) color.a = 1; }",
+         1, 2, 1, 2, 1, 2, 1, 2);
+    test(r, "void main(inout half4 color) { if (color.rg != color.ba) color.a = 1; }",
+         1, 2, 3, 2, 1, 2, 3, 1);
+}
+
 DEF_TEST(SkSLInterpreterWhile, r) {
     test(r, "void main(inout half4 color) { while (color.r < 1) color.r += 0.25; }", 0, 0, 0, 0, 1,
          0, 0, 0);