SPV: Implement all matrix operators {+,-,*,/} for {matrix,scalar,vector}.
diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp
index 5e43286..3d589bd 100755
--- a/SPIRV/GlslangToSpv.cpp
+++ b/SPIRV/GlslangToSpv.cpp
@@ -108,6 +108,7 @@
spv::Id handleUserFunctionCall(const glslang::TIntermAggregate*);
spv::Id createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison = true);
+ spv::Id createBinaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right);
spv::Id createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy);
spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Id destTypeId, spv::Id operand);
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
@@ -2122,26 +2123,17 @@
break;
case glslang::EOpVectorTimesMatrix:
case glslang::EOpVectorTimesMatrixAssign:
- assert(builder.isVector(left));
- assert(builder.isMatrix(right));
binOp = spv::OpVectorTimesMatrix;
break;
case glslang::EOpMatrixTimesVector:
- assert(builder.isMatrix(left));
- assert(builder.isVector(right));
binOp = spv::OpMatrixTimesVector;
break;
case glslang::EOpMatrixTimesScalar:
case glslang::EOpMatrixTimesScalarAssign:
- if (builder.isMatrix(right))
- std::swap(left, right);
- assert(builder.isScalar(right));
binOp = spv::OpMatrixTimesScalar;
break;
case glslang::EOpMatrixTimesMatrix:
case glslang::EOpMatrixTimesMatrixAssign:
- assert(builder.isMatrix(left));
- assert(builder.isMatrix(right));
binOp = spv::OpMatrixTimesMatrix;
break;
case glslang::EOpOuterProduct:
@@ -2220,29 +2212,8 @@
// handle mapped binary operations (should be non-comparison)
if (binOp != spv::OpNop) {
assert(comparison == false);
- if (builder.isMatrix(left) || builder.isMatrix(right)) {
- switch (binOp) {
- case spv::OpMatrixTimesScalar:
- case spv::OpVectorTimesMatrix:
- case spv::OpMatrixTimesVector:
- case spv::OpMatrixTimesMatrix:
- break;
- case spv::OpFDiv:
- // turn it into a multiply...
- assert(builder.isMatrix(left) && builder.isScalar(right));
- right = builder.createBinOp(spv::OpFDiv, builder.getTypeId(right), builder.makeFloatConstant(1.0F), right);
- binOp = spv::OpFMul;
- break;
- default:
- spv::MissingFunctionality("binary operation on matrix");
- break;
- }
-
- spv::Id id = builder.createBinOp(binOp, typeId, left, right);
- builder.setPrecision(id, precision);
-
- return id;
- }
+ if (builder.isMatrix(left) || builder.isMatrix(right))
+ return createBinaryMatrixOperation(binOp, precision, typeId, left, right);
// No matrix involved; make both operands be the same number of components, if needed
if (needMatchingVectors)
@@ -2326,6 +2297,111 @@
return 0;
}
+//
+// Translate AST matrix operation to SPV operation, already having SPV-based operands/types.
+// These can be any of:
+//
+// matrix * scalar
+// scalar * matrix
+// matrix * matrix linear algebraic
+// matrix * vector
+// vector * matrix
+// matrix * matrix componentwise
+// matrix op matrix op in {+, -, /}
+// matrix op scalar op in {+, -, /}
+// scalar op matrix op in {+, -, /}
+//
+spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right)
+{
+ bool firstClass = true;
+
+ // First, handle first-class matrix operations (* and matrix/scalar)
+ switch (op) {
+ case spv::OpFDiv:
+ if (builder.isMatrix(left) && builder.isScalar(right)) {
+ // turn matrix / scalar into a multiply...
+ right = builder.createBinOp(spv::OpFDiv, builder.getTypeId(right), builder.makeFloatConstant(1.0F), right);
+ op = spv::OpMatrixTimesScalar;
+ } else
+ firstClass = false;
+ break;
+ case spv::OpMatrixTimesScalar:
+ if (builder.isMatrix(right))
+ std::swap(left, right);
+ assert(builder.isScalar(right));
+ break;
+ case spv::OpVectorTimesMatrix:
+ assert(builder.isVector(left));
+ assert(builder.isMatrix(right));
+ break;
+ case spv::OpMatrixTimesVector:
+ assert(builder.isMatrix(left));
+ assert(builder.isVector(right));
+ break;
+ case spv::OpMatrixTimesMatrix:
+ assert(builder.isMatrix(left));
+ assert(builder.isMatrix(right));
+ break;
+ default:
+ firstClass = false;
+ break;
+ }
+
+ if (firstClass) {
+ spv::Id id = builder.createBinOp(op, typeId, left, right);
+ builder.setPrecision(id, precision);
+
+ return id;
+ }
+
+ // Handle component-wise +, -, *, and / for all combinations of type.
+ // The result type of all of them is the same type as the (a) matrix operand.
+ // The algorithm is to:
+ // - break the matrix(es) into vectors
+ // - smear any scalar to a vector
+ // - do vector operations
+ // - make a matrix out the vector results
+ switch (op) {
+ case spv::OpFAdd:
+ case spv::OpFSub:
+ case spv::OpFDiv:
+ case spv::OpFMul:
+ {
+ // one time set up...
+ bool leftMat = builder.isMatrix(left);
+ bool rightMat = builder.isMatrix(right);
+ unsigned int numCols = leftMat ? builder.getNumColumns(left) : builder.getNumColumns(right);
+ int numRows = leftMat ? builder.getNumRows(left) : builder.getNumRows(right);
+ spv::Id scalarType = builder.getScalarTypeId(typeId);
+ spv::Id vecType = builder.makeVectorType(scalarType, numRows);
+ std::vector<spv::Id> results;
+ spv::Id smearVec = spv::NoResult;
+ if (builder.isScalar(left))
+ smearVec = builder.smearScalar(precision, left, vecType);
+ else if (builder.isScalar(right))
+ smearVec = builder.smearScalar(precision, right, vecType);
+
+ // do each vector op
+ for (unsigned int c = 0; c < numCols; ++c) {
+ std::vector<unsigned int> indexes;
+ indexes.push_back(c);
+ spv::Id leftVec = leftMat ? builder.createCompositeExtract( left, vecType, indexes) : smearVec;
+ spv::Id rightVec = rightMat ? builder.createCompositeExtract(right, vecType, indexes) : smearVec;
+ results.push_back(builder.createBinOp(op, vecType, leftVec, rightVec));
+ builder.setPrecision(results.back(), precision);
+ }
+
+ // put the pieces together
+ spv::Id id = builder.createCompositeConstruct(typeId, results);
+ builder.setPrecision(id, precision);
+ return id;
+ }
+ default:
+ assert(0);
+ return spv::NoResult;
+ }
+}
+
spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy)
{
spv::Op unaryOp = spv::OpNop;