SPV: Implement the extension SPV_KHR_shader_ballot
diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp
index 7286588..8372dbb 100755
--- a/SPIRV/GlslangToSpv.cpp
+++ b/SPIRV/GlslangToSpv.cpp
@@ -42,9 +42,10 @@
#include "GlslangToSpv.h"
#include "SpvBuilder.h"
namespace spv {
- #include "GLSL.std.450.h"
+ #include "GLSL.std.450.h"
+ #include "GLSL.ext.KHR.h"
#ifdef AMD_EXTENSIONS
- #include "GLSL.ext.AMD.h"
+ #include "GLSL.ext.AMD.h"
#endif
}
@@ -154,7 +155,7 @@
spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id destTypeId, spv::Id operand, glslang::TBasicType typeProxy);
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
- spv::Id createInvocationsOperation(glslang::TOperator op, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy);
+ spv::Id createInvocationsOperation(glslang::TOperator op, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
#ifdef AMD_EXTENSIONS
spv::Id CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, spv::Id operand);
#endif
@@ -521,16 +522,40 @@
case glslang::EbvLocalInvocationId: return spv::BuiltInLocalInvocationId;
case glslang::EbvLocalInvocationIndex: return spv::BuiltInLocalInvocationIndex;
case glslang::EbvGlobalInvocationId: return spv::BuiltInGlobalInvocationId;
+
case glslang::EbvSubGroupSize:
+ builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+ return spv::BuiltInSubgroupSize;
+
case glslang::EbvSubGroupInvocation:
+ builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+ return spv::BuiltInSubgroupLocalInvocationId;
+
case glslang::EbvSubGroupEqMask:
+ builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+ builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+ return spv::BuiltInSubgroupEqMaskKHR;
+
case glslang::EbvSubGroupGeMask:
+ builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+ builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+ return spv::BuiltInSubgroupGeMaskKHR;
+
case glslang::EbvSubGroupGtMask:
+ builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+ builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+ return spv::BuiltInSubgroupGtMaskKHR;
+
case glslang::EbvSubGroupLeMask:
+ builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+ builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+ return spv::BuiltInSubgroupLeMaskKHR;
+
case glslang::EbvSubGroupLtMask:
- // TODO: Add SPIR-V builtin ID.
- logger->missingFunctionality("shader ballot");
- return spv::BuiltInMax;
+ builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+ builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+ return spv::BuiltInSubgroupLtMaskKHR;
+
#ifdef AMD_EXTENSIONS
case glslang::EbvBaryCoordNoPersp: return spv::BuiltInBaryCoordNoPerspAMD;
case glslang::EbvBaryCoordNoPerspCentroid: return spv::BuiltInBaryCoordNoPerspCentroidAMD;
@@ -3610,10 +3635,6 @@
case glslang::EOpBallot:
case glslang::EOpReadFirstInvocation:
- logger->missingFunctionality("shader ballot");
- libCall = spv::GLSLstd450Bad;
- break;
-
case glslang::EOpAnyInvocation:
case glslang::EOpAllInvocations:
case glslang::EOpAllInvocationsEqual:
@@ -3625,7 +3646,11 @@
case glslang::EOpMaxInvocationsNonUniform:
case glslang::EOpAddInvocationsNonUniform:
#endif
- return createInvocationsOperation(op, typeId, operand, typeProxy);
+ {
+ std::vector<spv::Id> operands;
+ operands.push_back(operand);
+ return createInvocationsOperation(op, typeId, operands, typeProxy);
+ }
#ifdef AMD_EXTENSIONS
case glslang::EOpMbcnt:
@@ -3959,113 +3984,149 @@
}
// Create group invocation operations.
-spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy)
+spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy)
{
bool isUnsigned = typeProxy == glslang::EbtUint || typeProxy == glslang::EbtUint64;
bool isFloat = typeProxy == glslang::EbtFloat || typeProxy == glslang::EbtDouble;
- builder.addCapability(spv::CapabilityGroups);
+ spv::Op opCode = spv::OpNop;
- std::vector<spv::Id> operands;
- operands.push_back(builder.makeUintConstant(spv::ScopeSubgroup));
+ std::vector<spv::Id> spvGroupOperands;
+ if (op == glslang::EOpBallot || op == glslang::EOpReadFirstInvocation) {
+ builder.addExtension(spv::E_SPV_KHR_shader_ballot);
+ builder.addCapability(spv::CapabilitySubgroupBallotKHR);
+ } else {
+ builder.addCapability(spv::CapabilityGroups);
+
+ spvGroupOperands.push_back(builder.makeUintConstant(spv::ScopeSubgroup));
#ifdef AMD_EXTENSIONS
- if (op == glslang::EOpMinInvocations || op == glslang::EOpMaxInvocations || op == glslang::EOpAddInvocations ||
- op == glslang::EOpMinInvocationsNonUniform || op == glslang::EOpMaxInvocationsNonUniform || op == glslang::EOpAddInvocationsNonUniform)
- operands.push_back(spv::GroupOperationReduce);
+ if (op == glslang::EOpMinInvocations || op == glslang::EOpMaxInvocations || op == glslang::EOpAddInvocations ||
+ op == glslang::EOpMinInvocationsNonUniform || op == glslang::EOpMaxInvocationsNonUniform || op == glslang::EOpAddInvocationsNonUniform)
+ spvGroupOperands.push_back(spv::GroupOperationReduce);
#endif
- operands.push_back(operand);
+ }
+
+ for (auto opIt = operands.begin(); opIt != operands.end(); ++opIt)
+ spvGroupOperands.push_back(*opIt);
switch (op) {
case glslang::EOpAnyInvocation:
+ opCode = spv::OpGroupAny;
+ break;
case glslang::EOpAllInvocations:
- return builder.createOp(op == glslang::EOpAnyInvocation ? spv::OpGroupAny : spv::OpGroupAll, typeId, operands);
-
+ opCode = spv::OpGroupAll;
+ break;
case glslang::EOpAllInvocationsEqual:
{
- spv::Id groupAll = builder.createOp(spv::OpGroupAll, typeId, operands);
- spv::Id groupAny = builder.createOp(spv::OpGroupAny, typeId, operands);
+ spv::Id groupAll = builder.createOp(spv::OpGroupAll, typeId, spvGroupOperands);
+ spv::Id groupAny = builder.createOp(spv::OpGroupAny, typeId, spvGroupOperands);
return builder.createBinOp(spv::OpLogicalOr, typeId, groupAll,
builder.createUnaryOp(spv::OpLogicalNot, typeId, groupAny));
}
+
+ case glslang::EOpReadInvocation:
+ opCode = spv::OpGroupBroadcast;
+ break;
+ case glslang::EOpReadFirstInvocation:
+ opCode = spv::OpSubgroupFirstInvocationKHR;
+ break;
+ case glslang::EOpBallot:
+ {
+ // NOTE: According to the spec, the result type of "OpSubgroupBallotKHR" must be a 4 component vector of 32
+ // bit integer types. The GLSL built-in function "ballotARB()" assumes the maximum number of invocations in
+ // a subgroup is 64. Thus, we have to convert uvec4.xy to uint64_t as follow:
+ //
+ // result = Bitcast(SubgroupBallotKHR(Predicate).xy)
+ //
+ spv::Id uintType = builder.makeUintType(32);
+ spv::Id uvec4Type = builder.makeVectorType(uintType, 4);
+ spv::Id result = builder.createOp(spv::OpSubgroupBallotKHR, uvec4Type, spvGroupOperands);
+
+ std::vector<spv::Id> components;
+ components.push_back(builder.createCompositeExtract(result, uintType, 0));
+ components.push_back(builder.createCompositeExtract(result, uintType, 1));
+
+ spv::Id uvec2Type = builder.makeVectorType(uintType, 2);
+ return builder.createUnaryOp(spv::OpBitcast, typeId,
+ builder.createCompositeConstruct(uvec2Type, components));
+ }
+
#ifdef AMD_EXTENSIONS
case glslang::EOpMinInvocations:
case glslang::EOpMaxInvocations:
case glslang::EOpAddInvocations:
- {
- spv::Op spvOp = spv::OpNop;
if (op == glslang::EOpMinInvocations) {
if (isFloat)
- spvOp = spv::OpGroupFMin;
+ opCode = spv::OpGroupFMin;
else {
if (isUnsigned)
- spvOp = spv::OpGroupUMin;
+ opCode = spv::OpGroupUMin;
else
- spvOp = spv::OpGroupSMin;
+ opCode = spv::OpGroupSMin;
}
} else if (op == glslang::EOpMaxInvocations) {
if (isFloat)
- spvOp = spv::OpGroupFMax;
+ opCode = spv::OpGroupFMax;
else {
if (isUnsigned)
- spvOp = spv::OpGroupUMax;
+ opCode = spv::OpGroupUMax;
else
- spvOp = spv::OpGroupSMax;
+ opCode = spv::OpGroupSMax;
}
} else {
if (isFloat)
- spvOp = spv::OpGroupFAdd;
+ opCode = spv::OpGroupFAdd;
else
- spvOp = spv::OpGroupIAdd;
+ opCode = spv::OpGroupIAdd;
}
if (builder.isVectorType(typeId))
- return CreateInvocationsVectorOperation(spvOp, typeId, operand);
- else
- return builder.createOp(spvOp, typeId, operands);
- }
+ return CreateInvocationsVectorOperation(opCode, typeId, operands[0]);
+
+ break;
case glslang::EOpMinInvocationsNonUniform:
case glslang::EOpMaxInvocationsNonUniform:
case glslang::EOpAddInvocationsNonUniform:
- {
- spv::Op spvOp = spv::OpNop;
if (op == glslang::EOpMinInvocationsNonUniform) {
if (isFloat)
- spvOp = spv::OpGroupFMinNonUniformAMD;
+ opCode = spv::OpGroupFMinNonUniformAMD;
else {
if (isUnsigned)
- spvOp = spv::OpGroupUMinNonUniformAMD;
+ opCode = spv::OpGroupUMinNonUniformAMD;
else
- spvOp = spv::OpGroupSMinNonUniformAMD;
+ opCode = spv::OpGroupSMinNonUniformAMD;
}
}
else if (op == glslang::EOpMaxInvocationsNonUniform) {
if (isFloat)
- spvOp = spv::OpGroupFMaxNonUniformAMD;
+ opCode = spv::OpGroupFMaxNonUniformAMD;
else {
if (isUnsigned)
- spvOp = spv::OpGroupUMaxNonUniformAMD;
+ opCode = spv::OpGroupUMaxNonUniformAMD;
else
- spvOp = spv::OpGroupSMaxNonUniformAMD;
+ opCode = spv::OpGroupSMaxNonUniformAMD;
}
}
else {
if (isFloat)
- spvOp = spv::OpGroupFAddNonUniformAMD;
+ opCode = spv::OpGroupFAddNonUniformAMD;
else
- spvOp = spv::OpGroupIAddNonUniformAMD;
+ opCode = spv::OpGroupIAddNonUniformAMD;
}
if (builder.isVectorType(typeId))
- return CreateInvocationsVectorOperation(spvOp, typeId, operand);
- else
- return builder.createOp(spvOp, typeId, operands);
- }
+ return CreateInvocationsVectorOperation(opCode, typeId, operands[0]);
+
+ break;
#endif
default:
logger->missingFunctionality("invocation operation");
return spv::NoResult;
}
+
+ assert(opCode != spv::OpNop);
+ return builder.createOp(opCode, typeId, spvGroupOperands);
}
#ifdef AMD_EXTENSIONS
@@ -4256,9 +4317,7 @@
break;
case glslang::EOpReadInvocation:
- logger->missingFunctionality("shader ballot");
- libCall = spv::GLSLstd450Bad;
- break;
+ return createInvocationsOperation(op, typeId, operands, typeProxy);
#ifdef AMD_EXTENSIONS
case glslang::EOpSwizzleInvocations:
@@ -4825,7 +4884,7 @@
if (extBuiltinMap.find(name) != extBuiltinMap.end())
return extBuiltinMap[name];
else {
- builder.addExtensions(name);
+ builder.addExtension(name);
spv::Id extBuiltins = builder.import(name);
extBuiltinMap[name] = extBuiltins;
return extBuiltins;