SPV: Add unary-matrix operations, operating at vector level.
diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp
index 9d42e3c..c174375 100755
--- a/SPIRV/GlslangToSpv.cpp
+++ b/SPIRV/GlslangToSpv.cpp
@@ -110,6 +110,7 @@
spv::Id createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison = true);
spv::Id createBinaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right);
spv::Id createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy);
+ spv::Id createUnaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy);
spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Id destTypeId, spv::Id operand);
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
@@ -2601,9 +2602,11 @@
switch (op) {
case glslang::EOpNegative:
- if (isFloat)
+ if (isFloat) {
unaryOp = spv::OpFNegate;
- else
+ if (builder.isMatrixType(typeId))
+ return createUnaryMatrixOperation(unaryOp, precision, typeId, operand, typeProxy);
+ } else
unaryOp = spv::OpSNegate;
break;
@@ -2862,6 +2865,39 @@
return id;
}
+// Create a unary operation on a matrix
+spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Id typeId, spv::Id operand, glslang::TBasicType /* typeProxy */)
+{
+ // Handle unary operations vector by vector.
+ // The result type is the same type as the original type.
+ // The algorithm is to:
+ // - break the matrix into vectors
+ // - apply the operation to each vector
+ // - make a matrix out the vector results
+
+ // get the types sorted out
+ int numCols = builder.getNumColumns(operand);
+ int numRows = builder.getNumRows(operand);
+ spv::Id scalarType = builder.getScalarTypeId(typeId);
+ spv::Id vecType = builder.makeVectorType(scalarType, numRows);
+ std::vector<spv::Id> results;
+
+ // do each vector op
+ for (int c = 0; c < numCols; ++c) {
+ std::vector<unsigned int> indexes;
+ indexes.push_back(c);
+ spv::Id vec = builder.createCompositeExtract(operand, vecType, indexes);
+ results.push_back(builder.createUnaryOp(op, vecType, vec));
+ builder.setPrecision(results.back(), precision);
+ }
+
+ // put the pieces together
+ spv::Id id = builder.createCompositeConstruct(typeId, results);
+ builder.setPrecision(id, precision);
+
+ return id;
+}
+
spv::Id TGlslangToSpvTraverser::createConversion(glslang::TOperator op, spv::Decoration precision, spv::Id destType, spv::Id operand)
{
spv::Op convOp = spv::OpNop;