Merge pull request #1474 from KhronosGroup/pure-8-16-bit-capability
SPV: only declare the pure 8/16-bit capabilities when needed.
diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp
index d8369d2..315546d 100755
--- a/SPIRV/GlslangToSpv.cpp
+++ b/SPIRV/GlslangToSpv.cpp
@@ -2612,7 +2612,6 @@
spvType = builder.makeFloatType(64);
break;
case glslang::EbtFloat16:
- builder.addCapability(spv::CapabilityFloat16);
#if AMD_EXTENSIONS
if (builder.getSpvVersion() < glslang::EShTargetSpv_1_3)
builder.addExtension(spv::E_SPV_AMD_gpu_shader_half_float);
@@ -2627,16 +2626,13 @@
else
spvType = builder.makeBoolType();
break;
- case glslang::EbtInt8:
- builder.addCapability(spv::CapabilityInt8);
+ case glslang::EbtInt8:
spvType = builder.makeIntType(8);
break;
case glslang::EbtUint8:
- builder.addCapability(spv::CapabilityInt8);
spvType = builder.makeUintType(8);
break;
- case glslang::EbtInt16:
- builder.addCapability(spv::CapabilityInt16);
+ case glslang::EbtInt16:
#ifdef AMD_EXTENSIONS
if (builder.getSpvVersion() < glslang::EShTargetSpv_1_3)
builder.addExtension(spv::E_SPV_AMD_gpu_shader_int16);
@@ -2644,7 +2640,6 @@
spvType = builder.makeIntType(16);
break;
case glslang::EbtUint16:
- builder.addCapability(spv::CapabilityInt16);
#ifdef AMD_EXTENSIONS
if (builder.getSpvVersion() < glslang::EShTargetSpv_1_3)
builder.addExtension(spv::E_SPV_AMD_gpu_shader_int16);
diff --git a/SPIRV/SpvBuilder.cpp b/SPIRV/SpvBuilder.cpp
index 5a7bb5c..093a740 100755
--- a/SPIRV/SpvBuilder.cpp
+++ b/SPIRV/SpvBuilder.cpp
@@ -194,10 +194,8 @@
// deal with capabilities
switch (width) {
case 8:
- addCapability(CapabilityInt8);
- break;
case 16:
- addCapability(CapabilityInt16);
+ // these are currently handled by storage-type declarations and post processing
break;
case 64:
addCapability(CapabilityInt64);
@@ -229,7 +227,7 @@
// deal with capabilities
switch (width) {
case 16:
- addCapability(CapabilityFloat16);
+ // currently handled by storage-type declarations and post processing
break;
case 64:
addCapability(CapabilityFloat64);
@@ -520,12 +518,6 @@
Op typeClass = instr->getOpCode();
switch (typeClass)
{
- case OpTypeVoid:
- case OpTypeBool:
- case OpTypeInt:
- case OpTypeFloat:
- case OpTypeStruct:
- return typeClass;
case OpTypeVector:
case OpTypeMatrix:
case OpTypeArray:
@@ -534,8 +526,7 @@
case OpTypePointer:
return getMostBasicTypeClass(instr->getIdOperand(1));
default:
- assert(0);
- return OpTypeFloat;
+ return typeClass;
}
}
@@ -622,6 +613,36 @@
return getContainedTypeId(typeId, 0);
}
+// Returns true if 'typeId' is or contains a scalar type declared with 'typeOp'
+// of width 'width'. The 'width' is only consumed for int and float types.
+// Returns false otherwise.
+bool Builder::containsType(Id typeId, spv::Op typeOp, int width) const
+{
+ const Instruction& instr = *module.getInstruction(typeId);
+
+ Op typeClass = instr.getOpCode();
+ switch (typeClass)
+ {
+ case OpTypeInt:
+ case OpTypeFloat:
+ return typeClass == typeOp && instr.getImmediateOperand(0) == width;
+ case OpTypeStruct:
+ for (int m = 0; m < instr.getNumOperands(); ++m) {
+ if (containsType(instr.getIdOperand(m), typeOp, width))
+ return true;
+ }
+ return false;
+ case OpTypeVector:
+ case OpTypeMatrix:
+ case OpTypeArray:
+ case OpTypeRuntimeArray:
+ case OpTypePointer:
+ return containsType(getContainedTypeId(typeId), typeOp, width);
+ default:
+ return typeClass == typeOp;
+ }
+}
+
// See if a scalar constant of this type has already been created, so it
// can be reused rather than duplicated. (Required by the specification).
Id Builder::findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value)
diff --git a/SPIRV/SpvBuilder.h b/SPIRV/SpvBuilder.h
index 01698b3..b09ba04 100755
--- a/SPIRV/SpvBuilder.h
+++ b/SPIRV/SpvBuilder.h
@@ -167,6 +167,7 @@
bool isImageType(Id typeId) const { return getTypeClass(typeId) == OpTypeImage; }
bool isSamplerType(Id typeId) const { return getTypeClass(typeId) == OpTypeSampler; }
bool isSampledImageType(Id typeId) const { return getTypeClass(typeId) == OpTypeSampledImage; }
+ bool containsType(Id typeId, Op typeOp, int width) const;
bool isConstantOpCode(Op opcode) const;
bool isSpecConstantOpCode(Op opcode) const;
@@ -569,9 +570,11 @@
void postProcess();
// Hook to visit each instruction in a block in a function
- void postProcess(Instruction& inst);
+ void postProcess(const Instruction&);
// Hook to visit each instruction in a reachable block in a function.
- void postProcessReachable(Instruction& inst);
+ void postProcessReachable(const Instruction&);
+ // Hook to visit each non-32-bit sized float/int operation in a block.
+ void postProcessType(const Instruction&, spv::Id typeId);
void dump(std::vector<unsigned int>&) const;
diff --git a/SPIRV/SpvPostProcess.cpp b/SPIRV/SpvPostProcess.cpp
index df6ba81..d9abd91 100755
--- a/SPIRV/SpvPostProcess.cpp
+++ b/SPIRV/SpvPostProcess.cpp
@@ -61,8 +61,78 @@
namespace spv {
+// Hook to visit each operand type and result type of an instruction.
+// Will be called multiple times for one instruction, once for each typed
+// operand and the result.
+void Builder::postProcessType(const Instruction& inst, Id typeId)
+{
+ // Characterize the type being questioned
+ Id basicTypeOp = getMostBasicTypeClass(typeId);
+ int width = 0;
+ if (basicTypeOp == OpTypeFloat || basicTypeOp == OpTypeInt)
+ width = getScalarTypeWidth(typeId);
+
+ // Do opcode-specific checks
+ switch (inst.getOpCode()) {
+ case OpLoad:
+ case OpStore:
+ if (basicTypeOp == OpTypeStruct) {
+ if (containsType(typeId, OpTypeInt, 8))
+ addCapability(CapabilityInt8);
+ if (containsType(typeId, OpTypeInt, 16))
+ addCapability(CapabilityInt16);
+ if (containsType(typeId, OpTypeFloat, 16))
+ addCapability(CapabilityFloat16);
+ } else {
+ StorageClass storageClass = getStorageClass(inst.getIdOperand(0));
+ if (width == 8) {
+ switch (storageClass) {
+ case StorageClassUniform:
+ case StorageClassStorageBuffer:
+ case StorageClassPushConstant:
+ break;
+ default:
+ addCapability(CapabilityInt8);
+ break;
+ }
+ } else if (width == 16) {
+ switch (storageClass) {
+ case StorageClassUniform:
+ case StorageClassStorageBuffer:
+ case StorageClassPushConstant:
+ case StorageClassInput:
+ case StorageClassOutput:
+ break;
+ default:
+ if (basicTypeOp == OpTypeInt)
+ addCapability(CapabilityInt16);
+ if (basicTypeOp == OpTypeFloat)
+ addCapability(CapabilityFloat16);
+ break;
+ }
+ }
+ }
+ break;
+ case OpAccessChain:
+ case OpPtrAccessChain:
+ case OpCopyObject:
+ case OpFConvert:
+ case OpSConvert:
+ case OpUConvert:
+ break;
+ default:
+ if (basicTypeOp == OpTypeFloat && width == 16)
+ addCapability(CapabilityFloat16);
+ if (basicTypeOp == OpTypeInt && width == 16)
+ addCapability(CapabilityInt16);
+ if (basicTypeOp == OpTypeInt && width == 8)
+ addCapability(CapabilityInt8);
+ break;
+ }
+}
+
// Called for each instruction that resides in a block.
-void Builder::postProcess(Instruction& inst)
+void Builder::postProcess(const Instruction& inst)
{
// Add capabilities based simply on the opcode.
switch (inst.getOpCode()) {
@@ -104,10 +174,22 @@
default:
break;
}
+
+ // Checks based on type
+ if (inst.getTypeId() != NoType)
+ postProcessType(inst, inst.getTypeId());
+ for (int op = 0; op < inst.getNumOperands(); ++op) {
+ if (inst.isIdOperand(op)) {
+ // In blocks, these are always result ids, but we are relying on
+ // getTypeId() to return NoType for things like OpLabel.
+ if (getTypeId(inst.getIdOperand(op)) != NoType)
+ postProcessType(inst, getTypeId(inst.getIdOperand(op)));
+ }
+ }
}
// Called for each instruction in a reachable block.
-void Builder::postProcessReachable(Instruction& inst)
+void Builder::postProcessReachable(const Instruction& inst)
{
// did have code here, but questionable to do so without deleting the instructions
}
diff --git a/SPIRV/spvIR.h b/SPIRV/spvIR.h
index 14d997d..2532b17 100755
--- a/SPIRV/spvIR.h
+++ b/SPIRV/spvIR.h
@@ -127,7 +127,7 @@
addImmediateOperand(word);
}
}
- bool isIdOperand(int op) { return idOperand[op]; }
+ bool isIdOperand(int op) const { return idOperand[op]; }
void setBlock(Block* b) { block = b; }
Block* getBlock() const { return block; }
Op getOpCode() const { return opCode; }
diff --git a/Test/baseResults/spv.16bitstorage-int.frag.out b/Test/baseResults/spv.16bitstorage-int.frag.out
index 9de223c..dd7d1b1 100755
--- a/Test/baseResults/spv.16bitstorage-int.frag.out
+++ b/Test/baseResults/spv.16bitstorage-int.frag.out
@@ -4,7 +4,6 @@
// Id's are bound by 171
Capability Shader
- Capability Int16
Capability StorageUniformBufferBlock16
Capability StorageUniform16
Extension "SPV_AMD_gpu_shader_int16"
diff --git a/Test/baseResults/spv.16bitstorage-uint.frag.out b/Test/baseResults/spv.16bitstorage-uint.frag.out
index def7c57..3a13826 100755
--- a/Test/baseResults/spv.16bitstorage-uint.frag.out
+++ b/Test/baseResults/spv.16bitstorage-uint.frag.out
@@ -4,7 +4,6 @@
// Id's are bound by 173
Capability Shader
- Capability Int16
Capability StorageUniformBufferBlock16
Capability StorageUniform16
Extension "SPV_AMD_gpu_shader_int16"
diff --git a/Test/baseResults/spv.16bitstorage.frag.out b/Test/baseResults/spv.16bitstorage.frag.out
index ebf48e9..cf536e8 100755
--- a/Test/baseResults/spv.16bitstorage.frag.out
+++ b/Test/baseResults/spv.16bitstorage.frag.out
@@ -4,7 +4,6 @@
// Id's are bound by 173
Capability Shader
- Capability Float16
Capability StorageUniformBufferBlock16
Capability StorageUniform16
Extension "SPV_AMD_gpu_shader_half_float"
diff --git a/Test/baseResults/spv.8bitstorage-int.frag.out b/Test/baseResults/spv.8bitstorage-int.frag.out
index 94e7ab5..55a8b3b 100755
--- a/Test/baseResults/spv.8bitstorage-int.frag.out
+++ b/Test/baseResults/spv.8bitstorage-int.frag.out
@@ -4,7 +4,6 @@
// Id's are bound by 171
Capability Shader
- Capability Int8
Capability CapabilityStorageBuffer8BitAccess
Capability CapabilityUniformAndStorageBuffer8BitAccess
Extension "SPV_KHR_8bit_storage"
diff --git a/Test/baseResults/spv.8bitstorage-uint.frag.out b/Test/baseResults/spv.8bitstorage-uint.frag.out
index f4e7b6d..461cec4 100755
--- a/Test/baseResults/spv.8bitstorage-uint.frag.out
+++ b/Test/baseResults/spv.8bitstorage-uint.frag.out
@@ -4,7 +4,6 @@
// Id's are bound by 173
Capability Shader
- Capability Int8
Capability CapabilityStorageBuffer8BitAccess
Capability CapabilityUniformAndStorageBuffer8BitAccess
Extension "SPV_KHR_8bit_storage"