Interpreter: Matrix/Vector multiplication
Change-Id: I3dc5e5be1cf12c581cce3854d0db7e73db6e1fd9
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/216681
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index ff32bc6..b9aac93 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -342,6 +342,8 @@
}
const Type& lType = b.fLeft->fType;
const Type& rType = b.fRight->fType;
+ bool lVecOrMtx = (lType.kind() == Type::kVector_Kind || lType.kind() == Type::kMatrix_Kind);
+ bool rVecOrMtx = (rType.kind() == Type::kVector_Kind || rType.kind() == Type::kMatrix_Kind);
Token::Kind op;
std::unique_ptr<LValue> lvalue;
if (is_assignment(b.fOperator)) {
@@ -351,98 +353,115 @@
} else {
this->writeExpression(*b.fLeft);
op = b.fOperator;
- if (lType.kind() == Type::kScalar_Kind &&
- (rType.kind() == Type::kVector_Kind || rType.kind() == Type::kMatrix_Kind)) {
+ if (!lVecOrMtx && rVecOrMtx) {
for (int i = SlotCount(rType); i > 1; --i) {
this->write(ByteCodeInstruction::kDup);
}
}
}
this->writeExpression(*b.fRight);
- if ((lType.kind() == Type::kVector_Kind || lType.kind() == Type::kMatrix_Kind) &&
- rType.kind() == Type::kScalar_Kind) {
+ if (lVecOrMtx && !rVecOrMtx) {
for (int i = SlotCount(lType); i > 1; --i) {
this->write(ByteCodeInstruction::kDup);
}
}
- int count = std::max(SlotCount(lType), SlotCount(rType));
- switch (op) {
- case Token::Kind::EQEQ:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareIEQ,
- ByteCodeInstruction::kCompareIEQ,
- ByteCodeInstruction::kCompareFEQ,
- count);
- // Collapse to a single bool
- for (int i = count; i > 1; --i) {
- this->write(ByteCodeInstruction::kAndB);
- }
- break;
- case Token::Kind::GT:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSGT,
- ByteCodeInstruction::kCompareUGT,
- ByteCodeInstruction::kCompareFGT,
- count);
- break;
- case Token::Kind::GTEQ:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSGTEQ,
- ByteCodeInstruction::kCompareUGTEQ,
- ByteCodeInstruction::kCompareFGTEQ,
- count);
- break;
- case Token::Kind::LT:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSLT,
- ByteCodeInstruction::kCompareULT,
- ByteCodeInstruction::kCompareFLT,
- count);
- break;
- case Token::Kind::LTEQ:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareSLTEQ,
- ByteCodeInstruction::kCompareULTEQ,
- ByteCodeInstruction::kCompareFLTEQ,
- count);
- break;
- case Token::Kind::MINUS:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kSubtractI,
- ByteCodeInstruction::kSubtractI,
- ByteCodeInstruction::kSubtractF,
- count);
- break;
- case Token::Kind::NEQ:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kCompareINEQ,
- ByteCodeInstruction::kCompareINEQ,
- ByteCodeInstruction::kCompareFNEQ,
- count);
- // Collapse to a single bool
- for (int i = count; i > 1; --i) {
- this->write(ByteCodeInstruction::kOrB);
- }
- break;
- case Token::Kind::PERCENT:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kRemainderS,
- ByteCodeInstruction::kRemainderU,
- ByteCodeInstruction::kRemainderF,
- count);
- break;
- case Token::Kind::PLUS:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kAddI,
- ByteCodeInstruction::kAddI,
- ByteCodeInstruction::kAddF,
- count);
- break;
- case Token::Kind::SLASH:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kDivideS,
- ByteCodeInstruction::kDivideU,
- ByteCodeInstruction::kDivideF,
- count);
- break;
- case Token::Kind::STAR:
- this->writeTypedInstruction(b.fLeft->fType, ByteCodeInstruction::kMultiplyI,
- ByteCodeInstruction::kMultiplyI,
- ByteCodeInstruction::kMultiplyF,
- count);
- break;
- default:
- SkASSERT(false);
+ // Special case for M*V, V*M, M*M (but not V*V!)
+ if (op == Token::Kind::STAR && lVecOrMtx && rVecOrMtx &&
+ !(lType.kind() == Type::kVector_Kind && rType.kind() == Type::kVector_Kind)) {
+ this->write(ByteCodeInstruction::kMatrixMultiply);
+ int rCols = rType.columns(),
+ rRows = rType.rows(),
+ lCols = lType.columns(),
+ lRows = lType.rows();
+ // M*V treats the vector as a column
+ if (rType.kind() == Type::kVector_Kind) {
+ std::swap(rCols, rRows);
+ }
+ SkASSERT(lCols == rRows);
+ SkASSERT(SlotCount(b.fType) == lRows * rCols);
+ this->write8(lCols);
+ this->write8(lRows);
+ this->write8(rCols);
+ } else {
+ int count = std::max(SlotCount(lType), SlotCount(rType));
+ switch (op) {
+ case Token::Kind::EQEQ:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareIEQ,
+ ByteCodeInstruction::kCompareIEQ,
+ ByteCodeInstruction::kCompareFEQ,
+ count);
+ // Collapse to a single bool
+ for (int i = count; i > 1; --i) {
+ this->write(ByteCodeInstruction::kAndB);
+ }
+ break;
+ case Token::Kind::GT:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareSGT,
+ ByteCodeInstruction::kCompareUGT,
+ ByteCodeInstruction::kCompareFGT,
+ count);
+ break;
+ case Token::Kind::GTEQ:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareSGTEQ,
+ ByteCodeInstruction::kCompareUGTEQ,
+ ByteCodeInstruction::kCompareFGTEQ,
+ count);
+ break;
+ case Token::Kind::LT:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareSLT,
+ ByteCodeInstruction::kCompareULT,
+ ByteCodeInstruction::kCompareFLT,
+ count);
+ break;
+ case Token::Kind::LTEQ:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareSLTEQ,
+ ByteCodeInstruction::kCompareULTEQ,
+ ByteCodeInstruction::kCompareFLTEQ,
+ count);
+ break;
+ case Token::Kind::MINUS:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kSubtractI,
+ ByteCodeInstruction::kSubtractI,
+ ByteCodeInstruction::kSubtractF,
+ count);
+ break;
+ case Token::Kind::NEQ:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kCompareINEQ,
+ ByteCodeInstruction::kCompareINEQ,
+ ByteCodeInstruction::kCompareFNEQ,
+ count);
+ // Collapse to a single bool
+ for (int i = count; i > 1; --i) {
+ this->write(ByteCodeInstruction::kOrB);
+ }
+ break;
+ case Token::Kind::PERCENT:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kRemainderS,
+ ByteCodeInstruction::kRemainderU,
+ ByteCodeInstruction::kRemainderF,
+ count);
+ break;
+ case Token::Kind::PLUS:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kAddI,
+ ByteCodeInstruction::kAddI,
+ ByteCodeInstruction::kAddF,
+ count);
+ break;
+ case Token::Kind::SLASH:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kDivideS,
+ ByteCodeInstruction::kDivideU,
+ ByteCodeInstruction::kDivideF,
+ count);
+ break;
+ case Token::Kind::STAR:
+ this->writeTypedInstruction(lType, ByteCodeInstruction::kMultiplyI,
+ ByteCodeInstruction::kMultiplyI,
+ ByteCodeInstruction::kMultiplyF,
+ count);
+ break;
+ default:
+ SkASSERT(false);
+ }
}
if (lvalue) {
lvalue->store();