Implement new revision of extension GL_AMD_shader_ballot
- Add support for invocation functions with "InclusiveScan" and
"ExclusiveScan" modes.
- Add support for invocation functions taking int64/uint64/doube/float16
as inout data types.
diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp
index 5e3dc52..af08e4b 100755
--- a/SPIRV/GlslangToSpv.cpp
+++ b/SPIRV/GlslangToSpv.cpp
@@ -161,7 +161,7 @@
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
spv::Id createInvocationsOperation(glslang::TOperator op, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
- spv::Id CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, std::vector<spv::Id>& operands);
+ spv::Id CreateInvocationsVectorOperation(spv::Op op, spv::GroupOperation groupOperation, spv::Id typeId, std::vector<spv::Id>& operands);
spv::Id createMiscOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
spv::Id createNoArgOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId);
spv::Id getSymbolId(const glslang::TIntermSymbol* node);
@@ -2015,7 +2015,6 @@
#ifdef AMD_EXTENSIONS
case glslang::EbtFloat16:
builder.addExtension(spv::E_SPV_AMD_gpu_shader_half_float);
- builder.addCapability(spv::CapabilityFloat16);
spvType = builder.makeFloatType(16);
break;
#endif
@@ -3743,6 +3742,18 @@
case glslang::EOpMinInvocationsNonUniform:
case glslang::EOpMaxInvocationsNonUniform:
case glslang::EOpAddInvocationsNonUniform:
+ case glslang::EOpMinInvocationsInclusiveScan:
+ case glslang::EOpMaxInvocationsInclusiveScan:
+ case glslang::EOpAddInvocationsInclusiveScan:
+ case glslang::EOpMinInvocationsInclusiveScanNonUniform:
+ case glslang::EOpMaxInvocationsInclusiveScanNonUniform:
+ case glslang::EOpAddInvocationsInclusiveScanNonUniform:
+ case glslang::EOpMinInvocationsExclusiveScan:
+ case glslang::EOpMaxInvocationsExclusiveScan:
+ case glslang::EOpAddInvocationsExclusiveScan:
+ case glslang::EOpMinInvocationsExclusiveScanNonUniform:
+ case glslang::EOpMaxInvocationsExclusiveScanNonUniform:
+ case glslang::EOpAddInvocationsExclusiveScanNonUniform:
#endif
{
std::vector<spv::Id> operands;
@@ -4130,8 +4141,9 @@
#endif
spv::Op opCode = spv::OpNop;
-
std::vector<spv::Id> spvGroupOperands;
+ spv::GroupOperation groupOperation = spv::GroupOperationMax;
+
if (op == glslang::EOpBallot || op == glslang::EOpReadFirstInvocation ||
op == glslang::EOpReadInvocation) {
builder.addExtension(spv::E_SPV_KHR_shader_ballot);
@@ -4141,15 +4153,47 @@
#ifdef AMD_EXTENSIONS
if (op == glslang::EOpMinInvocationsNonUniform ||
op == glslang::EOpMaxInvocationsNonUniform ||
- op == glslang::EOpAddInvocationsNonUniform)
+ op == glslang::EOpAddInvocationsNonUniform ||
+ op == glslang::EOpMinInvocationsInclusiveScanNonUniform ||
+ op == glslang::EOpMaxInvocationsInclusiveScanNonUniform ||
+ op == glslang::EOpAddInvocationsInclusiveScanNonUniform ||
+ op == glslang::EOpMinInvocationsExclusiveScanNonUniform ||
+ op == glslang::EOpMaxInvocationsExclusiveScanNonUniform ||
+ op == glslang::EOpAddInvocationsExclusiveScanNonUniform)
builder.addExtension(spv::E_SPV_AMD_shader_ballot);
#endif
spvGroupOperands.push_back(builder.makeUintConstant(spv::ScopeSubgroup));
#ifdef AMD_EXTENSIONS
- if (op == glslang::EOpMinInvocations || op == glslang::EOpMaxInvocations || op == glslang::EOpAddInvocations ||
- op == glslang::EOpMinInvocationsNonUniform || op == glslang::EOpMaxInvocationsNonUniform || op == glslang::EOpAddInvocationsNonUniform)
- spvGroupOperands.push_back(spv::GroupOperationReduce);
+ switch (op) {
+ case glslang::EOpMinInvocations:
+ case glslang::EOpMaxInvocations:
+ case glslang::EOpAddInvocations:
+ case glslang::EOpMinInvocationsNonUniform:
+ case glslang::EOpMaxInvocationsNonUniform:
+ case glslang::EOpAddInvocationsNonUniform:
+ groupOperation = spv::GroupOperationReduce;
+ spvGroupOperands.push_back(groupOperation);
+ break;
+ case glslang::EOpMinInvocationsInclusiveScan:
+ case glslang::EOpMaxInvocationsInclusiveScan:
+ case glslang::EOpAddInvocationsInclusiveScan:
+ case glslang::EOpMinInvocationsInclusiveScanNonUniform:
+ case glslang::EOpMaxInvocationsInclusiveScanNonUniform:
+ case glslang::EOpAddInvocationsInclusiveScanNonUniform:
+ groupOperation = spv::GroupOperationInclusiveScan;
+ spvGroupOperands.push_back(groupOperation);
+ break;
+ case glslang::EOpMinInvocationsExclusiveScan:
+ case glslang::EOpMaxInvocationsExclusiveScan:
+ case glslang::EOpAddInvocationsExclusiveScan:
+ case glslang::EOpMinInvocationsExclusiveScanNonUniform:
+ case glslang::EOpMaxInvocationsExclusiveScanNonUniform:
+ case glslang::EOpAddInvocationsExclusiveScanNonUniform:
+ groupOperation = spv::GroupOperationExclusiveScan;
+ spvGroupOperands.push_back(groupOperation);
+ break;
+ }
#endif
}
@@ -4175,7 +4219,7 @@
case glslang::EOpReadInvocation:
opCode = spv::OpSubgroupReadInvocationKHR;
if (builder.isVectorType(typeId))
- return CreateInvocationsVectorOperation(opCode, typeId, operands);
+ return CreateInvocationsVectorOperation(opCode, groupOperation, typeId, operands);
break;
case glslang::EOpReadFirstInvocation:
opCode = spv::OpSubgroupFirstInvocationKHR;
@@ -4205,7 +4249,15 @@
case glslang::EOpMinInvocations:
case glslang::EOpMaxInvocations:
case glslang::EOpAddInvocations:
- if (op == glslang::EOpMinInvocations) {
+ case glslang::EOpMinInvocationsInclusiveScan:
+ case glslang::EOpMaxInvocationsInclusiveScan:
+ case glslang::EOpAddInvocationsInclusiveScan:
+ case glslang::EOpMinInvocationsExclusiveScan:
+ case glslang::EOpMaxInvocationsExclusiveScan:
+ case glslang::EOpAddInvocationsExclusiveScan:
+ if (op == glslang::EOpMinInvocations ||
+ op == glslang::EOpMinInvocationsInclusiveScan ||
+ op == glslang::EOpMinInvocationsExclusiveScan) {
if (isFloat)
opCode = spv::OpGroupFMin;
else {
@@ -4214,7 +4266,9 @@
else
opCode = spv::OpGroupSMin;
}
- } else if (op == glslang::EOpMaxInvocations) {
+ } else if (op == glslang::EOpMaxInvocations ||
+ op == glslang::EOpMaxInvocationsInclusiveScan ||
+ op == glslang::EOpMaxInvocationsExclusiveScan) {
if (isFloat)
opCode = spv::OpGroupFMax;
else {
@@ -4231,13 +4285,21 @@
}
if (builder.isVectorType(typeId))
- return CreateInvocationsVectorOperation(opCode, typeId, operands);
+ return CreateInvocationsVectorOperation(opCode, groupOperation, typeId, operands);
break;
case glslang::EOpMinInvocationsNonUniform:
case glslang::EOpMaxInvocationsNonUniform:
case glslang::EOpAddInvocationsNonUniform:
- if (op == glslang::EOpMinInvocationsNonUniform) {
+ case glslang::EOpMinInvocationsInclusiveScanNonUniform:
+ case glslang::EOpMaxInvocationsInclusiveScanNonUniform:
+ case glslang::EOpAddInvocationsInclusiveScanNonUniform:
+ case glslang::EOpMinInvocationsExclusiveScanNonUniform:
+ case glslang::EOpMaxInvocationsExclusiveScanNonUniform:
+ case glslang::EOpAddInvocationsExclusiveScanNonUniform:
+ if (op == glslang::EOpMinInvocationsNonUniform ||
+ op == glslang::EOpMinInvocationsInclusiveScanNonUniform ||
+ op == glslang::EOpMinInvocationsExclusiveScanNonUniform) {
if (isFloat)
opCode = spv::OpGroupFMinNonUniformAMD;
else {
@@ -4247,7 +4309,9 @@
opCode = spv::OpGroupSMinNonUniformAMD;
}
}
- else if (op == glslang::EOpMaxInvocationsNonUniform) {
+ else if (op == glslang::EOpMaxInvocationsNonUniform ||
+ op == glslang::EOpMaxInvocationsInclusiveScanNonUniform ||
+ op == glslang::EOpMaxInvocationsExclusiveScanNonUniform) {
if (isFloat)
opCode = spv::OpGroupFMaxNonUniformAMD;
else {
@@ -4265,7 +4329,7 @@
}
if (builder.isVectorType(typeId))
- return CreateInvocationsVectorOperation(opCode, typeId, operands);
+ return CreateInvocationsVectorOperation(opCode, groupOperation, typeId, operands);
break;
#endif
@@ -4279,7 +4343,7 @@
}
// Create group invocation operations on a vector
-spv::Id TGlslangToSpvTraverser::CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, std::vector<spv::Id>& operands)
+spv::Id TGlslangToSpvTraverser::CreateInvocationsVectorOperation(spv::Op op, spv::GroupOperation groupOperation, spv::Id typeId, std::vector<spv::Id>& operands)
{
#ifdef AMD_EXTENSIONS
assert(op == spv::OpGroupFMin || op == spv::OpGroupUMin || op == spv::OpGroupSMin ||
@@ -4323,7 +4387,7 @@
spvGroupOperands.push_back(operands[1]);
} else {
spvGroupOperands.push_back(builder.makeUintConstant(spv::ScopeSubgroup));
- spvGroupOperands.push_back(spv::GroupOperationReduce);
+ spvGroupOperands.push_back(groupOperation);
spvGroupOperands.push_back(scalar);
}