SpirvShader: Implement OpSwitch
Tests: dEQP-VK.spirv_assembly.instruction.compute.*
Tests: dEQP-VK.spirv_assembly.instruction.graphics.*
Bug: b/128527271
Change-Id: I7ba31ca504a582a4d36d25ef2747fb1c1607bade
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/27775
Presubmit-Ready: Ben Clayton <bclayton@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index c226f13..cc81c9f 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -1179,6 +1179,8 @@
case Block::Simple:
case Block::StructuredBranchConditional:
case Block::UnstructuredBranchConditional:
+ case Block::StructuredSwitch:
+ case Block::UnstructuredSwitch:
if (id != mainBlockId)
{
// Emit all preceeding blocks and set the activeLaneMask.
@@ -1404,6 +1406,9 @@
case spv::OpBranchConditional:
return EmitBranchConditional(insn, state);
+ case spv::OpSwitch:
+ return EmitSwitch(insn, state);
+
case spv::OpUnreachable:
return EmitUnreachable(insn, state);
@@ -2638,6 +2643,39 @@
return EmitResult::Terminator;
}
+ SpirvShader::EmitResult SpirvShader::EmitSwitch(InsnIterator insn, EmitState *state) const
+ {
+ auto block = getBlock(state->currentBlock);
+ ASSERT(block.branchInstruction == insn);
+
+ auto selId = Object::ID(block.branchInstruction.word(1));
+
+ auto sel = GenericValue(this, state->routine, selId);
+ ASSERT_MSG(getType(getObject(selId).type).sizeInComponents == 1, "Selector must be a scalar");
+
+ auto numCases = (block.branchInstruction.wordCount() - 3) / 2;
+
+ // TODO: Optimize for case where all lanes take same path.
+
+ SIMD::Int defaultLaneMask = state->activeLaneMask();
+
+ // Gather up the case label matches and calculate defaultLaneMask.
+ std::vector<RValue<SIMD::Int>> caseLabelMatches;
+ caseLabelMatches.reserve(numCases);
+ for (uint32_t i = 0; i < numCases; i++)
+ {
+ auto label = block.branchInstruction.word(i * 2 + 3);
+ auto caseBlockId = Block::ID(block.branchInstruction.word(i * 2 + 4));
+ auto caseLabelMatch = CmpEQ(sel.Int(0), SIMD::Int(label));
+ state->addOutputActiveLaneMaskEdge(caseBlockId, caseLabelMatch);
+ defaultLaneMask &= ~caseLabelMatch;
+ }
+
+ auto defaultBlockId = Block::ID(block.branchInstruction.word(2));
+ state->addOutputActiveLaneMaskEdge(defaultBlockId, defaultLaneMask);
+
+ return EmitResult::Terminator;
+ }
SpirvShader::EmitResult SpirvShader::EmitUnreachable(InsnIterator insn, EmitState *state) const
{
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 8268c71..6d72745 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -602,6 +602,7 @@
EmitResult EmitAll(InsnIterator insn, EmitState *state) const;
EmitResult EmitBranch(InsnIterator insn, EmitState *state) const;
EmitResult EmitBranchConditional(InsnIterator insn, EmitState *state) const;
+ EmitResult EmitSwitch(InsnIterator insn, EmitState *state) const;
EmitResult EmitUnreachable(InsnIterator insn, EmitState *state) const;
EmitResult EmitReturn(InsnIterator insn, EmitState *state) const;
EmitResult EmitPhi(InsnIterator insn, EmitState *state) const;
diff --git a/tests/VulkanUnitTests/unittests.cpp b/tests/VulkanUnitTests/unittests.cpp
index 2dca716..d2bfcc5 100644
--- a/tests/VulkanUnitTests/unittests.cpp
+++ b/tests/VulkanUnitTests/unittests.cpp
@@ -917,3 +917,455 @@
test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });
}
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchEmptyCases)
+{
+ std::stringstream src;
+ src <<
+ "OpCapability Shader\n"
+ "OpMemoryModel Logical GLSL450\n"
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"
+ "OpExecutionMode %1 LocalSize " <<
+ GetParam().localSizeX << " " <<
+ GetParam().localSizeY << " " <<
+ GetParam().localSizeZ << "\n" <<
+ "OpDecorate %3 ArrayStride 4\n"
+ "OpMemberDecorate %4 0 Offset 0\n"
+ "OpDecorate %4 BufferBlock\n"
+ "OpDecorate %5 DescriptorSet 0\n"
+ "OpDecorate %5 Binding 1\n"
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+ "OpDecorate %6 DescriptorSet 0\n"
+ "OpDecorate %6 Binding 0\n"
+ "%7 = OpTypeVoid\n"
+ "%8 = OpTypeFunction %7\n" // void()
+ "%9 = OpTypeInt 32 1\n" // int32
+ "%10 = OpTypeInt 32 0\n" // uint32
+ "%11 = OpTypeBool\n"
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }
+ "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }*
+ "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in
+ "%13 = OpConstant %9 0\n" // int32(0)
+ "%14 = OpConstant %9 2\n" // int32(2)
+ "%15 = OpConstant %10 0\n" // uint32(0)
+ "%16 = OpTypeVector %10 3\n" // vec4<int32>
+ "%17 = OpTypePointer Input %16\n" // vec4<int32>*
+ "%2 = OpVariable %17 Input\n" // gl_GlobalInvocationId
+ "%18 = OpTypePointer Input %10\n" // uint32*
+ "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out
+ "%19 = OpTypePointer Uniform %9\n" // int32*
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --
+ "%20 = OpLabel\n"
+ "%21 = OpAccessChain %18 %2 %15\n" // &gl_GlobalInvocationId.x
+ "%22 = OpLoad %10 %21\n" // gl_GlobalInvocationId.x
+ "%23 = OpAccessChain %19 %6 %13 %22\n" // &in.arr[gl_GlobalInvocationId.x]
+ "%24 = OpLoad %9 %23\n" // in.arr[gl_GlobalInvocationId.x]
+ "%25 = OpAccessChain %19 %5 %13 %22\n" // &out.arr[gl_GlobalInvocationId.x]
+ // Start of branch logic
+ // %24 = in value
+ "%26 = OpSMod %9 %24 %14\n" // in % 2
+ "OpSelectionMerge %27 None\n"
+ "OpSwitch %26 %27 0 %28 1 %29\n"
+ "%28 = OpLabel\n" // (in % 2) == 0
+ "OpBranch %27\n"
+ "%29 = OpLabel\n" // (in % 2) == 1
+ "OpBranch %27\n"
+ "%27 = OpLabel\n"
+ // %26 = out value
+ // End of branch logic
+ "OpStore %25 %26\n" // use SSA value from previous block
+ "OpReturn\n"
+ "OpFunctionEnd\n";
+
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i%2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchStore)
+{
+ std::stringstream src;
+ src <<
+ "OpCapability Shader\n"
+ "OpMemoryModel Logical GLSL450\n"
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"
+ "OpExecutionMode %1 LocalSize " <<
+ GetParam().localSizeX << " " <<
+ GetParam().localSizeY << " " <<
+ GetParam().localSizeZ << "\n" <<
+ "OpDecorate %3 ArrayStride 4\n"
+ "OpMemberDecorate %4 0 Offset 0\n"
+ "OpDecorate %4 BufferBlock\n"
+ "OpDecorate %5 DescriptorSet 0\n"
+ "OpDecorate %5 Binding 1\n"
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+ "OpDecorate %6 DescriptorSet 0\n"
+ "OpDecorate %6 Binding 0\n"
+ "%7 = OpTypeVoid\n"
+ "%8 = OpTypeFunction %7\n" // void()
+ "%9 = OpTypeInt 32 1\n" // int32
+ "%10 = OpTypeInt 32 0\n" // uint32
+ "%11 = OpTypeBool\n"
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }
+ "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }*
+ "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in
+ "%13 = OpConstant %9 0\n" // int32(0)
+ "%14 = OpConstant %9 1\n" // int32(1)
+ "%15 = OpConstant %9 2\n" // int32(2)
+ "%16 = OpConstant %10 0\n" // uint32(0)
+ "%17 = OpTypeVector %10 3\n" // vec4<int32>
+ "%18 = OpTypePointer Input %17\n" // vec4<int32>*
+ "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId
+ "%19 = OpTypePointer Input %10\n" // uint32*
+ "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out
+ "%20 = OpTypePointer Uniform %9\n" // int32*
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --
+ "%21 = OpLabel\n"
+ "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x
+ "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x
+ "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x]
+ "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x]
+ "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x]
+ // Start of branch logic
+ // %25 = in value
+ "%27 = OpSMod %9 %25 %15\n" // in % 2
+ "OpSelectionMerge %28 None\n"
+ "OpSwitch %27 %28 0 %29 1 %30\n"
+ "%29 = OpLabel\n" // (in % 2) == 0
+ "OpStore %26 %15\n" // write 2
+ "OpBranch %28\n"
+ "%30 = OpLabel\n" // (in % 2) == 1
+ "OpStore %26 %14\n" // write 1
+ "OpBranch %28\n"
+ "%28 = OpLabel\n"
+ // End of branch logic
+ "OpReturn\n"
+ "OpFunctionEnd\n";
+
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 2 : 1; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseReturn)
+{
+ std::stringstream src;
+ src <<
+ "OpCapability Shader\n"
+ "OpMemoryModel Logical GLSL450\n"
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"
+ "OpExecutionMode %1 LocalSize " <<
+ GetParam().localSizeX << " " <<
+ GetParam().localSizeY << " " <<
+ GetParam().localSizeZ << "\n" <<
+ "OpDecorate %3 ArrayStride 4\n"
+ "OpMemberDecorate %4 0 Offset 0\n"
+ "OpDecorate %4 BufferBlock\n"
+ "OpDecorate %5 DescriptorSet 0\n"
+ "OpDecorate %5 Binding 1\n"
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+ "OpDecorate %6 DescriptorSet 0\n"
+ "OpDecorate %6 Binding 0\n"
+ "%7 = OpTypeVoid\n"
+ "%8 = OpTypeFunction %7\n" // void()
+ "%9 = OpTypeInt 32 1\n" // int32
+ "%10 = OpTypeInt 32 0\n" // uint32
+ "%11 = OpTypeBool\n"
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }
+ "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }*
+ "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in
+ "%13 = OpConstant %9 0\n" // int32(0)
+ "%14 = OpConstant %9 1\n" // int32(1)
+ "%15 = OpConstant %9 2\n" // int32(2)
+ "%16 = OpConstant %10 0\n" // uint32(0)
+ "%17 = OpTypeVector %10 3\n" // vec4<int32>
+ "%18 = OpTypePointer Input %17\n" // vec4<int32>*
+ "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId
+ "%19 = OpTypePointer Input %10\n" // uint32*
+ "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out
+ "%20 = OpTypePointer Uniform %9\n" // int32*
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --
+ "%21 = OpLabel\n"
+ "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x
+ "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x
+ "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x]
+ "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x]
+ "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x]
+ // Start of branch logic
+ // %25 = in value
+ "%27 = OpSMod %9 %25 %15\n" // in % 2
+ "OpSelectionMerge %28 None\n"
+ "OpSwitch %27 %28 0 %29 1 %30\n"
+ "%29 = OpLabel\n" // (in % 2) == 0
+ "OpBranch %28\n"
+ "%30 = OpLabel\n" // (in % 2) == 1
+ "OpReturn\n"
+ "%28 = OpLabel\n"
+ "OpStore %26 %14\n" // write 1
+ // End of branch logic
+ "OpReturn\n"
+ "OpFunctionEnd\n";
+
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 0 : 1; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultReturn)
+{
+ std::stringstream src;
+ src <<
+ "OpCapability Shader\n"
+ "OpMemoryModel Logical GLSL450\n"
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"
+ "OpExecutionMode %1 LocalSize " <<
+ GetParam().localSizeX << " " <<
+ GetParam().localSizeY << " " <<
+ GetParam().localSizeZ << "\n" <<
+ "OpDecorate %3 ArrayStride 4\n"
+ "OpMemberDecorate %4 0 Offset 0\n"
+ "OpDecorate %4 BufferBlock\n"
+ "OpDecorate %5 DescriptorSet 0\n"
+ "OpDecorate %5 Binding 1\n"
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+ "OpDecorate %6 DescriptorSet 0\n"
+ "OpDecorate %6 Binding 0\n"
+ "%7 = OpTypeVoid\n"
+ "%8 = OpTypeFunction %7\n" // void()
+ "%9 = OpTypeInt 32 1\n" // int32
+ "%10 = OpTypeInt 32 0\n" // uint32
+ "%11 = OpTypeBool\n"
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }
+ "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }*
+ "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in
+ "%13 = OpConstant %9 0\n" // int32(0)
+ "%14 = OpConstant %9 1\n" // int32(1)
+ "%15 = OpConstant %9 2\n" // int32(2)
+ "%16 = OpConstant %10 0\n" // uint32(0)
+ "%17 = OpTypeVector %10 3\n" // vec4<int32>
+ "%18 = OpTypePointer Input %17\n" // vec4<int32>*
+ "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId
+ "%19 = OpTypePointer Input %10\n" // uint32*
+ "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out
+ "%20 = OpTypePointer Uniform %9\n" // int32*
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --
+ "%21 = OpLabel\n"
+ "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x
+ "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x
+ "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x]
+ "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x]
+ "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x]
+ // Start of branch logic
+ // %25 = in value
+ "%27 = OpSMod %9 %25 %15\n" // in % 2
+ "OpSelectionMerge %28 None\n"
+ "OpSwitch %27 %29 1 %30\n"
+ "%30 = OpLabel\n" // (in % 2) == 1
+ "OpBranch %28\n"
+ "%29 = OpLabel\n" // (in % 2) != 1
+ "OpReturn\n"
+ "%28 = OpLabel\n" // merge
+ "OpStore %26 %14\n" // write 1
+ // End of branch logic
+ "OpReturn\n"
+ "OpFunctionEnd\n";
+
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 0; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseFallthrough)
+{
+ std::stringstream src;
+ src <<
+ "OpCapability Shader\n"
+ "OpMemoryModel Logical GLSL450\n"
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"
+ "OpExecutionMode %1 LocalSize " <<
+ GetParam().localSizeX << " " <<
+ GetParam().localSizeY << " " <<
+ GetParam().localSizeZ << "\n" <<
+ "OpDecorate %3 ArrayStride 4\n"
+ "OpMemberDecorate %4 0 Offset 0\n"
+ "OpDecorate %4 BufferBlock\n"
+ "OpDecorate %5 DescriptorSet 0\n"
+ "OpDecorate %5 Binding 1\n"
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+ "OpDecorate %6 DescriptorSet 0\n"
+ "OpDecorate %6 Binding 0\n"
+ "%7 = OpTypeVoid\n"
+ "%8 = OpTypeFunction %7\n" // void()
+ "%9 = OpTypeInt 32 1\n" // int32
+ "%10 = OpTypeInt 32 0\n" // uint32
+ "%11 = OpTypeBool\n"
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }
+ "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }*
+ "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in
+ "%13 = OpConstant %9 0\n" // int32(0)
+ "%14 = OpConstant %9 1\n" // int32(1)
+ "%15 = OpConstant %9 2\n" // int32(2)
+ "%16 = OpConstant %10 0\n" // uint32(0)
+ "%17 = OpTypeVector %10 3\n" // vec4<int32>
+ "%18 = OpTypePointer Input %17\n" // vec4<int32>*
+ "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId
+ "%19 = OpTypePointer Input %10\n" // uint32*
+ "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out
+ "%20 = OpTypePointer Uniform %9\n" // int32*
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --
+ "%21 = OpLabel\n"
+ "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x
+ "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x
+ "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x]
+ "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x]
+ "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x]
+ // Start of branch logic
+ // %25 = in value
+ "%27 = OpSMod %9 %25 %15\n" // in % 2
+ "OpSelectionMerge %28 None\n"
+ "OpSwitch %27 %29 0 %30 1 %31\n"
+ "%30 = OpLabel\n" // (in % 2) == 0
+ "%32 = OpIAdd %9 %27 %14\n" // generate an intermediate
+ "OpStore %26 %32\n" // write a value (overwritten later)
+ "OpBranch %31\n" // fallthrough
+ "%31 = OpLabel\n" // (in % 2) == 1
+ "OpStore %26 %15\n" // write 2
+ "OpBranch %28\n"
+ "%29 = OpLabel\n" // unreachable
+ "OpUnreachable\n"
+ "%28 = OpLabel\n" // merge
+ // End of branch logic
+ "OpReturn\n"
+ "OpFunctionEnd\n";
+
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultFallthrough)
+{
+ std::stringstream src;
+ src <<
+ "OpCapability Shader\n"
+ "OpMemoryModel Logical GLSL450\n"
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"
+ "OpExecutionMode %1 LocalSize " <<
+ GetParam().localSizeX << " " <<
+ GetParam().localSizeY << " " <<
+ GetParam().localSizeZ << "\n" <<
+ "OpDecorate %3 ArrayStride 4\n"
+ "OpMemberDecorate %4 0 Offset 0\n"
+ "OpDecorate %4 BufferBlock\n"
+ "OpDecorate %5 DescriptorSet 0\n"
+ "OpDecorate %5 Binding 1\n"
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+ "OpDecorate %6 DescriptorSet 0\n"
+ "OpDecorate %6 Binding 0\n"
+ "%7 = OpTypeVoid\n"
+ "%8 = OpTypeFunction %7\n" // void()
+ "%9 = OpTypeInt 32 1\n" // int32
+ "%10 = OpTypeInt 32 0\n" // uint32
+ "%11 = OpTypeBool\n"
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }
+ "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }*
+ "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in
+ "%13 = OpConstant %9 0\n" // int32(0)
+ "%14 = OpConstant %9 1\n" // int32(1)
+ "%15 = OpConstant %9 2\n" // int32(2)
+ "%16 = OpConstant %10 0\n" // uint32(0)
+ "%17 = OpTypeVector %10 3\n" // vec4<int32>
+ "%18 = OpTypePointer Input %17\n" // vec4<int32>*
+ "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId
+ "%19 = OpTypePointer Input %10\n" // uint32*
+ "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out
+ "%20 = OpTypePointer Uniform %9\n" // int32*
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --
+ "%21 = OpLabel\n"
+ "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x
+ "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x
+ "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x]
+ "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x]
+ "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x]
+ // Start of branch logic
+ // %25 = in value
+ "%27 = OpSMod %9 %25 %15\n" // in % 2
+ "OpSelectionMerge %28 None\n"
+ "OpSwitch %27 %29 0 %30 1 %31\n"
+ "%30 = OpLabel\n" // (in % 2) == 0
+ "%32 = OpIAdd %9 %27 %14\n" // generate an intermediate
+ "OpStore %26 %32\n" // write a value (overwritten later)
+ "OpBranch %29\n" // fallthrough
+ "%29 = OpLabel\n" // default
+ "%33 = OpIAdd %9 %27 %14\n" // generate an intermediate
+ "OpStore %26 %33\n" // write a value (overwritten later)
+ "OpBranch %31\n" // fallthrough
+ "%31 = OpLabel\n" // (in % 2) == 1
+ "OpStore %26 %15\n" // write 2
+ "OpBranch %28\n"
+ "%28 = OpLabel\n" // merge
+ // End of branch logic
+ "OpReturn\n"
+ "OpFunctionEnd\n";
+
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });
+}
+
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchPhi)
+{
+ std::stringstream src;
+ src <<
+ "OpCapability Shader\n"
+ "OpMemoryModel Logical GLSL450\n"
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"
+ "OpExecutionMode %1 LocalSize " <<
+ GetParam().localSizeX << " " <<
+ GetParam().localSizeY << " " <<
+ GetParam().localSizeZ << "\n" <<
+ "OpDecorate %3 ArrayStride 4\n"
+ "OpMemberDecorate %4 0 Offset 0\n"
+ "OpDecorate %4 BufferBlock\n"
+ "OpDecorate %5 DescriptorSet 0\n"
+ "OpDecorate %5 Binding 1\n"
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"
+ "OpDecorate %6 DescriptorSet 0\n"
+ "OpDecorate %6 Binding 0\n"
+ "%7 = OpTypeVoid\n"
+ "%8 = OpTypeFunction %7\n" // void()
+ "%9 = OpTypeInt 32 1\n" // int32
+ "%10 = OpTypeInt 32 0\n" // uint32
+ "%11 = OpTypeBool\n"
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }
+ "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }*
+ "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in
+ "%13 = OpConstant %9 0\n" // int32(0)
+ "%14 = OpConstant %9 1\n" // int32(1)
+ "%15 = OpConstant %9 2\n" // int32(2)
+ "%16 = OpConstant %10 0\n" // uint32(0)
+ "%17 = OpTypeVector %10 3\n" // vec4<int32>
+ "%18 = OpTypePointer Input %17\n" // vec4<int32>*
+ "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId
+ "%19 = OpTypePointer Input %10\n" // uint32*
+ "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out
+ "%20 = OpTypePointer Uniform %9\n" // int32*
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --
+ "%21 = OpLabel\n"
+ "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x
+ "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x
+ "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x]
+ "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x]
+ "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x]
+ // Start of branch logic
+ // %25 = in value
+ "%27 = OpSMod %9 %25 %15\n" // in % 2
+ "OpSelectionMerge %28 None\n"
+ "OpSwitch %27 %29 1 %30\n"
+ "%30 = OpLabel\n" // (in % 2) == 1
+ "OpBranch %28\n"
+ "%29 = OpLabel\n" // (in % 2) != 1
+ "OpBranch %28\n"
+ "%28 = OpLabel\n" // merge
+ "%31 = OpPhi %9 %14 %30 %15 %29\n" // (in % 2) == 1 ? 1 : 2
+ "OpStore %26 %31\n"
+ // End of branch logic
+ "OpReturn\n"
+ "OpFunctionEnd\n";
+
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; });
+}