Implement GL_NV_cooperative_matrix
diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp
old mode 100755
new mode 100644
index 9bf3704..2ddece1
--- a/SPIRV/GlslangToSpv.cpp
+++ b/SPIRV/GlslangToSpv.cpp
@@ -1330,6 +1330,10 @@
}
builder.setMemoryModel(addressingModel, memoryModel);
+ if (glslangIntermediate->usingVariablePointers()) {
+ builder.addCapability(spv::CapabilityVariablePointers);
+ }
+
shaderEntry = builder.makeEntryPoint(glslangIntermediate->getEntryPointName().c_str());
entryPoint = builder.addEntryPoint(executionModel, shaderEntry, glslangIntermediate->getEntryPointName().c_str());
@@ -1870,16 +1874,31 @@
// So, this has to be block.lastMember.length().
// SPV wants "block" and member number as the operands, go get them.
- glslang::TIntermTyped* block = node->getOperand()->getAsBinaryNode()->getLeft();
- block->traverse(this);
- unsigned int member = node->getOperand()->getAsBinaryNode()->getRight()->getAsConstantUnion()->getConstArray()[0].getUConst();
- spv::Id length = builder.createArrayLength(builder.accessChainGetLValue(), member);
+ spv::Id length;
+ if (node->getOperand()->getType().isCoopMat()) {
+ spec_constant_op_mode_setter.turnOnSpecConstantOpMode();
+
+ spv::Id typeId = convertGlslangToSpvType(node->getOperand()->getType());
+ assert(builder.isCooperativeMatrixType(typeId));
+
+ length = builder.createCooperativeMatrixLength(typeId);
+ } else {
+ glslang::TIntermTyped* block = node->getOperand()->getAsBinaryNode()->getLeft();
+ block->traverse(this);
+ unsigned int member = node->getOperand()->getAsBinaryNode()->getRight()->getAsConstantUnion()->getConstArray()[0].getUConst();
+ length = builder.createArrayLength(builder.accessChainGetLValue(), member);
+ }
// GLSL semantics say the result of .length() is an int, while SPIR-V says
// signedness must be 0. So, convert from SPIR-V unsigned back to GLSL's
// AST expectation of a signed result.
- if (glslangIntermediate->getSource() == glslang::EShSourceGlsl)
- length = builder.createUnaryOp(spv::OpBitcast, builder.makeIntType(32), length);
+ if (glslangIntermediate->getSource() == glslang::EShSourceGlsl) {
+ if (builder.isInSpecConstCodeGenMode()) {
+ length = builder.createBinOp(spv::OpIAdd, builder.makeIntType(32), length, builder.makeIntConstant(0));
+ } else {
+ length = builder.createUnaryOp(spv::OpBitcast, builder.makeIntType(32), length);
+ }
+ }
builder.clearAccessChain();
builder.setAccessChainRValue(length);
@@ -2222,6 +2241,7 @@
case glslang::EOpConstructStruct:
case glslang::EOpConstructTextureSampler:
case glslang::EOpConstructReference:
+ case glslang::EOpConstructCooperativeMatrix:
{
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
std::vector<spv::Id> arguments;
@@ -2229,7 +2249,9 @@
spv::Id constructed;
if (node->getOp() == glslang::EOpConstructTextureSampler)
constructed = builder.createOp(spv::OpSampledImage, resultType(), arguments);
- else if (node->getOp() == glslang::EOpConstructStruct || node->getType().isArray()) {
+ else if (node->getOp() == glslang::EOpConstructStruct ||
+ node->getOp() == glslang::EOpConstructCooperativeMatrix ||
+ node->getType().isArray()) {
std::vector<spv::Id> constituents;
for (int c = 0; c < (int)arguments.size(); ++c)
constituents.push_back(arguments[c]);
@@ -2347,6 +2369,10 @@
noReturnValue = true;
break;
#endif
+ case glslang::EOpCooperativeMatrixLoad:
+ case glslang::EOpCooperativeMatrixStore:
+ noReturnValue = true;
+ break;
default:
break;
@@ -2389,6 +2415,7 @@
//
glslang::TIntermSequence& glslangOperands = node->getSequence();
std::vector<spv::Id> operands;
+ std::vector<spv::IdImmediate> memoryAccessOperands;
for (int arg = 0; arg < (int)glslangOperands.size(); ++arg) {
// special case l-value operands; there are just a few
bool lvalue = false;
@@ -2445,6 +2472,14 @@
if (arg >= 2)
lvalue = true;
break;
+ case glslang::EOpCooperativeMatrixLoad:
+ if (arg == 0 || arg == 1)
+ lvalue = true;
+ break;
+ case glslang::EOpCooperativeMatrixStore:
+ if (arg == 1)
+ lvalue = true;
+ break;
default:
break;
}
@@ -2453,6 +2488,50 @@
glslangOperands[0]->getAsBinaryNode()->getLeft()->traverse(this);
else
glslangOperands[arg]->traverse(this);
+
+ if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
+ node->getOp() == glslang::EOpCooperativeMatrixStore) {
+
+ if (arg == 1) {
+ // fold "element" parameter into the access chain
+ spv::Builder::AccessChain save = builder.getAccessChain();
+ builder.clearAccessChain();
+ glslangOperands[2]->traverse(this);
+
+ spv::Id elementId = accessChainLoad(glslangOperands[2]->getAsTyped()->getType());
+
+ builder.setAccessChain(save);
+
+ // Point to the first element of the array.
+ builder.accessChainPush(elementId, TranslateCoherent(glslangOperands[arg]->getAsTyped()->getType()),
+ getBufferReferenceAlignment(glslangOperands[arg]->getAsTyped()->getType()));
+
+ spv::Builder::AccessChain::CoherentFlags coherentFlags = builder.getAccessChain().coherentFlags;
+ unsigned int alignment = builder.getAccessChain().alignment;
+
+ int memoryAccess = TranslateMemoryAccess(coherentFlags);
+ if (node->getOp() == glslang::EOpCooperativeMatrixLoad)
+ memoryAccess &= ~spv::MemoryAccessMakePointerAvailableKHRMask;
+ if (node->getOp() == glslang::EOpCooperativeMatrixStore)
+ memoryAccess &= ~spv::MemoryAccessMakePointerVisibleKHRMask;
+ if (builder.getStorageClass(builder.getAccessChain().base) == spv::StorageClassPhysicalStorageBufferEXT) {
+ memoryAccess = (spv::MemoryAccessMask)(memoryAccess | spv::MemoryAccessAlignedMask);
+ }
+
+ memoryAccessOperands.push_back(spv::IdImmediate(false, memoryAccess));
+
+ if (memoryAccess & spv::MemoryAccessAlignedMask) {
+ memoryAccessOperands.push_back(spv::IdImmediate(false, alignment));
+ }
+
+ if (memoryAccess & (spv::MemoryAccessMakePointerAvailableKHRMask | spv::MemoryAccessMakePointerVisibleKHRMask)) {
+ memoryAccessOperands.push_back(spv::IdImmediate(true, builder.makeUintConstant(TranslateMemoryScope(coherentFlags))));
+ }
+ } else if (arg == 2) {
+ continue;
+ }
+ }
+
if (lvalue)
operands.push_back(builder.accessChainGetLValue());
else {
@@ -2462,7 +2541,33 @@
}
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
- if (atomic) {
+ if (node->getOp() == glslang::EOpCooperativeMatrixLoad) {
+ std::vector<spv::IdImmediate> idImmOps;
+
+ idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
+ idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
+ idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
+ idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
+ // get the pointee type
+ spv::Id typeId = builder.getContainedTypeId(builder.getTypeId(operands[0]));
+ assert(builder.isCooperativeMatrixType(typeId));
+ // do the op
+ spv::Id result = builder.createOp(spv::OpCooperativeMatrixLoadNV, typeId, idImmOps);
+ // store the result to the pointer (out param 'm')
+ builder.createStore(result, operands[0]);
+ result = 0;
+ } else if (node->getOp() == glslang::EOpCooperativeMatrixStore) {
+ std::vector<spv::IdImmediate> idImmOps;
+
+ idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
+ idImmOps.push_back(spv::IdImmediate(true, operands[0])); // object
+ idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
+ idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
+ idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
+
+ builder.createNoResultOp(spv::OpCooperativeMatrixStoreNV, idImmOps);
+ result = 0;
+ } else if (atomic) {
// Handle all atomics
result = createAtomicOperation(node->getOp(), precision, resultType(), operands, node->getBasicType());
} else {
@@ -3090,6 +3195,19 @@
spvType = builder.makeVectorType(spvType, type.getVectorSize());
}
+ if (type.isCoopMat()) {
+ builder.addCapability(spv::CapabilityCooperativeMatrixNV);
+ builder.addExtension(spv::E_SPV_NV_cooperative_matrix);
+ if (type.getBasicType() == glslang::EbtFloat16)
+ builder.addCapability(spv::CapabilityFloat16);
+
+ spv::Id scope = makeArraySizeId(*type.getTypeParameters(), 1);
+ spv::Id rows = makeArraySizeId(*type.getTypeParameters(), 2);
+ spv::Id cols = makeArraySizeId(*type.getTypeParameters(), 3);
+
+ spvType = builder.makeCooperativeMatrixType(spvType, scope, rows, cols);
+ }
+
if (type.isArray()) {
int stride = 0; // keep this 0 unless doing an explicit layout; 0 will mean no decoration, no stride
@@ -4847,7 +4965,8 @@
// handle mapped binary operations (should be non-comparison)
if (binOp != spv::OpNop) {
assert(comparison == false);
- if (builder.isMatrix(left) || builder.isMatrix(right))
+ if (builder.isMatrix(left) || builder.isMatrix(right) ||
+ builder.isCooperativeMatrix(left) || builder.isCooperativeMatrix(right))
return createBinaryMatrixOperation(binOp, decorations, typeId, left, right);
// No matrix involved; make both operands be the same number of components, if needed
@@ -4968,7 +5087,7 @@
firstClass = false;
break;
case spv::OpMatrixTimesScalar:
- if (builder.isMatrix(right))
+ if (builder.isMatrix(right) || builder.isCooperativeMatrix(right))
std::swap(left, right);
assert(builder.isScalar(right));
break;
@@ -4989,6 +5108,9 @@
break;
}
+ if (builder.isCooperativeMatrix(left) || builder.isCooperativeMatrix(right))
+ firstClass = true;
+
if (firstClass) {
spv::Id result = builder.createBinOp(op, typeId, left, right);
builder.addDecoration(result, decorations.noContraction);
@@ -7030,6 +7152,10 @@
builder.createNoResultOp(spv::OpWritePackedPrimitiveIndices4x8NV, operands);
return 0;
#endif
+ case glslang::EOpCooperativeMatrixMulAdd:
+ opCode = spv::OpCooperativeMatrixMulAddNV;
+ break;
+
default:
return 0;
}
@@ -7486,6 +7612,9 @@
glslang::TType vectorType(glslangType, 0);
for (int col = 0; col < glslangType.getMatrixCols(); ++col)
spvConsts.push_back(createSpvConstantFromConstUnionArray(vectorType, consts, nextConst, false));
+ } else if (glslangType.isCoopMat()) {
+ glslang::TType componentType(glslangType.getBasicType());
+ spvConsts.push_back(createSpvConstantFromConstUnionArray(componentType, consts, nextConst, false));
} else if (glslangType.isStruct()) {
glslang::TVector<glslang::TTypeLoc>::const_iterator iter;
for (iter = glslangType.getStruct()->begin(); iter != glslangType.getStruct()->end(); ++iter)