SpirvShader: Implement OpSwitch

Tests: dEQP-VK.spirv_assembly.instruction.compute.*
Tests: dEQP-VK.spirv_assembly.instruction.graphics.*
Bug: b/128527271
Change-Id: I7ba31ca504a582a4d36d25ef2747fb1c1607bade
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/27775
Presubmit-Ready: Ben Clayton <bclayton@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index c226f13..cc81c9f 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -1179,6 +1179,8 @@
 			case Block::Simple:
 			case Block::StructuredBranchConditional:
 			case Block::UnstructuredBranchConditional:
+			case Block::StructuredSwitch:
+			case Block::UnstructuredSwitch:
 				if (id != mainBlockId)
 				{
 					// Emit all preceeding blocks and set the activeLaneMask.
@@ -1404,6 +1406,9 @@
 		case spv::OpBranchConditional:
 			return EmitBranchConditional(insn, state);
 
+		case spv::OpSwitch:
+			return EmitSwitch(insn, state);
+
 		case spv::OpUnreachable:
 			return EmitUnreachable(insn, state);
 
@@ -2638,6 +2643,39 @@
 		return EmitResult::Terminator;
 	}
 
+	SpirvShader::EmitResult SpirvShader::EmitSwitch(InsnIterator insn, EmitState *state) const
+	{
+		auto block = getBlock(state->currentBlock);
+		ASSERT(block.branchInstruction == insn);
+
+		auto selId = Object::ID(block.branchInstruction.word(1));
+
+		auto sel = GenericValue(this, state->routine, selId);
+		ASSERT_MSG(getType(getObject(selId).type).sizeInComponents == 1, "Selector must be a scalar");
+
+		auto numCases = (block.branchInstruction.wordCount() - 3) / 2;
+
+		// TODO: Optimize for case where all lanes take same path.
+
+		SIMD::Int defaultLaneMask = state->activeLaneMask();
+
+		// Gather up the case label matches and calculate defaultLaneMask.
+		std::vector<RValue<SIMD::Int>> caseLabelMatches;
+		caseLabelMatches.reserve(numCases);
+		for (uint32_t i = 0; i < numCases; i++)
+		{
+			auto label = block.branchInstruction.word(i * 2 + 3);
+			auto caseBlockId = Block::ID(block.branchInstruction.word(i * 2 + 4));
+			auto caseLabelMatch = CmpEQ(sel.Int(0), SIMD::Int(label));
+			state->addOutputActiveLaneMaskEdge(caseBlockId, caseLabelMatch);
+			defaultLaneMask &= ~caseLabelMatch;
+		}
+
+		auto defaultBlockId = Block::ID(block.branchInstruction.word(2));
+		state->addOutputActiveLaneMaskEdge(defaultBlockId, defaultLaneMask);
+
+		return EmitResult::Terminator;
+	}
 
 	SpirvShader::EmitResult SpirvShader::EmitUnreachable(InsnIterator insn, EmitState *state) const
 	{
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 8268c71..6d72745 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -602,6 +602,7 @@
 		EmitResult EmitAll(InsnIterator insn, EmitState *state) const;
 		EmitResult EmitBranch(InsnIterator insn, EmitState *state) const;
 		EmitResult EmitBranchConditional(InsnIterator insn, EmitState *state) const;
+		EmitResult EmitSwitch(InsnIterator insn, EmitState *state) const;
 		EmitResult EmitUnreachable(InsnIterator insn, EmitState *state) const;
 		EmitResult EmitReturn(InsnIterator insn, EmitState *state) const;
 		EmitResult EmitPhi(InsnIterator insn, EmitState *state) const;
diff --git a/tests/VulkanUnitTests/unittests.cpp b/tests/VulkanUnitTests/unittests.cpp
index 2dca716..d2bfcc5 100644
--- a/tests/VulkanUnitTests/unittests.cpp
+++ b/tests/VulkanUnitTests/unittests.cpp
@@ -917,3 +917,455 @@
     test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });

 }

 

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchEmptyCases)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+        "%11 = OpTypeBool\n"

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

+        "%13 = OpConstant %9 0\n"               // int32(0)

+        "%14 = OpConstant %9 2\n"               // int32(2)

+        "%15 = OpConstant %10 0\n"              // uint32(0)

+        "%16 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%17 = OpTypePointer Input %16\n"       // vec4<int32>*

+         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId

+        "%18 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

+        "%19 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%20 = OpLabel\n"

+        "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x

+        "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x

+        "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]

+        "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]

+    // Start of branch logic

+    // %24 = in value

+        "%26 = OpSMod %9 %24 %14\n"             // in % 2

+              "OpSelectionMerge %27 None\n"

+              "OpSwitch %26 %27 0 %28 1 %29\n"

+        "%28 = OpLabel\n"                       // (in % 2) == 0

+              "OpBranch %27\n"

+        "%29 = OpLabel\n"                       // (in % 2) == 1

+              "OpBranch %27\n"

+        "%27 = OpLabel\n"

+    // %26 = out value

+    // End of branch logic

+              "OpStore %25 %26\n"               // use SSA value from previous block

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i%2; });

+}

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchStore)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+        "%11 = OpTypeBool\n"

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

+        "%13 = OpConstant %9 0\n"               // int32(0)

+        "%14 = OpConstant %9 1\n"               // int32(1)

+        "%15 = OpConstant %9 2\n"               // int32(2)

+        "%16 = OpConstant %10 0\n"              // uint32(0)

+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

+         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

+        "%19 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

+        "%20 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%21 = OpLabel\n"

+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

+    // Start of branch logic

+    // %25 = in value

+        "%27 = OpSMod %9 %25 %15\n"             // in % 2

+              "OpSelectionMerge %28 None\n"

+              "OpSwitch %27 %28 0 %29 1 %30\n"

+        "%29 = OpLabel\n"                       // (in % 2) == 0

+              "OpStore %26 %15\n"               // write 2

+              "OpBranch %28\n"

+        "%30 = OpLabel\n"                       // (in % 2) == 1

+              "OpStore %26 %14\n"               // write 1

+              "OpBranch %28\n"

+        "%28 = OpLabel\n"

+    // End of branch logic

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 2 : 1; });

+}

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseReturn)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+        "%11 = OpTypeBool\n"

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

+        "%13 = OpConstant %9 0\n"               // int32(0)

+        "%14 = OpConstant %9 1\n"               // int32(1)

+        "%15 = OpConstant %9 2\n"               // int32(2)

+        "%16 = OpConstant %10 0\n"              // uint32(0)

+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

+         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

+        "%19 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

+        "%20 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%21 = OpLabel\n"

+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

+    // Start of branch logic

+    // %25 = in value

+        "%27 = OpSMod %9 %25 %15\n"             // in % 2

+              "OpSelectionMerge %28 None\n"

+              "OpSwitch %27 %28 0 %29 1 %30\n"

+        "%29 = OpLabel\n"                       // (in % 2) == 0

+              "OpBranch %28\n"

+        "%30 = OpLabel\n"                       // (in % 2) == 1

+              "OpReturn\n"

+        "%28 = OpLabel\n"

+              "OpStore %26 %14\n"               // write 1

+    // End of branch logic

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 0 : 1; });

+}

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultReturn)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+        "%11 = OpTypeBool\n"

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

+        "%13 = OpConstant %9 0\n"               // int32(0)

+        "%14 = OpConstant %9 1\n"               // int32(1)

+        "%15 = OpConstant %9 2\n"               // int32(2)

+        "%16 = OpConstant %10 0\n"              // uint32(0)

+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

+         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

+        "%19 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

+        "%20 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%21 = OpLabel\n"

+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

+    // Start of branch logic

+    // %25 = in value

+        "%27 = OpSMod %9 %25 %15\n"             // in % 2

+              "OpSelectionMerge %28 None\n"

+              "OpSwitch %27 %29 1 %30\n"

+        "%30 = OpLabel\n"                       // (in % 2) == 1

+              "OpBranch %28\n"

+        "%29 = OpLabel\n"                       // (in % 2) != 1

+              "OpReturn\n"

+        "%28 = OpLabel\n"                       // merge

+              "OpStore %26 %14\n"               // write 1

+    // End of branch logic

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 0; });

+}

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseFallthrough)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+        "%11 = OpTypeBool\n"

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

+        "%13 = OpConstant %9 0\n"               // int32(0)

+        "%14 = OpConstant %9 1\n"               // int32(1)

+        "%15 = OpConstant %9 2\n"               // int32(2)

+        "%16 = OpConstant %10 0\n"              // uint32(0)

+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

+         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

+        "%19 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

+        "%20 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%21 = OpLabel\n"

+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

+    // Start of branch logic

+    // %25 = in value

+        "%27 = OpSMod %9 %25 %15\n"             // in % 2

+              "OpSelectionMerge %28 None\n"

+              "OpSwitch %27 %29 0 %30 1 %31\n"

+        "%30 = OpLabel\n"                       // (in % 2) == 0

+        "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate

+              "OpStore %26 %32\n"               // write a value (overwritten later)

+              "OpBranch %31\n"                  // fallthrough

+        "%31 = OpLabel\n"                       // (in % 2) == 1

+              "OpStore %26 %15\n"               // write 2

+              "OpBranch %28\n"

+        "%29 = OpLabel\n"                       // unreachable

+              "OpUnreachable\n"

+        "%28 = OpLabel\n"                       // merge

+    // End of branch logic

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });

+}

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultFallthrough)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+        "%11 = OpTypeBool\n"

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

+        "%13 = OpConstant %9 0\n"               // int32(0)

+        "%14 = OpConstant %9 1\n"               // int32(1)

+        "%15 = OpConstant %9 2\n"               // int32(2)

+        "%16 = OpConstant %10 0\n"              // uint32(0)

+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

+         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

+        "%19 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

+        "%20 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%21 = OpLabel\n"

+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

+    // Start of branch logic

+    // %25 = in value

+        "%27 = OpSMod %9 %25 %15\n"             // in % 2

+              "OpSelectionMerge %28 None\n"

+              "OpSwitch %27 %29 0 %30 1 %31\n"

+        "%30 = OpLabel\n"                       // (in % 2) == 0

+        "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate

+              "OpStore %26 %32\n"               // write a value (overwritten later)

+              "OpBranch %29\n"                  // fallthrough

+        "%29 = OpLabel\n"                       // default

+        "%33 = OpIAdd %9 %27 %14\n"             // generate an intermediate

+              "OpStore %26 %33\n"               // write a value (overwritten later)

+              "OpBranch %31\n"                  // fallthrough

+        "%31 = OpLabel\n"                       // (in % 2) == 1

+              "OpStore %26 %15\n"               // write 2

+              "OpBranch %28\n"

+        "%28 = OpLabel\n"                       // merge

+    // End of branch logic

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });

+}

+

+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchPhi)

+{

+    std::stringstream src;

+    src <<

+              "OpCapability Shader\n"

+              "OpMemoryModel Logical GLSL450\n"

+              "OpEntryPoint GLCompute %1 \"main\" %2\n"

+              "OpExecutionMode %1 LocalSize " <<

+                GetParam().localSizeX << " " <<

+                GetParam().localSizeY << " " <<

+                GetParam().localSizeZ << "\n" <<

+              "OpDecorate %3 ArrayStride 4\n"

+              "OpMemberDecorate %4 0 Offset 0\n"

+              "OpDecorate %4 BufferBlock\n"

+              "OpDecorate %5 DescriptorSet 0\n"

+              "OpDecorate %5 Binding 1\n"

+              "OpDecorate %2 BuiltIn GlobalInvocationId\n"

+              "OpDecorate %6 DescriptorSet 0\n"

+              "OpDecorate %6 Binding 0\n"

+         "%7 = OpTypeVoid\n"

+         "%8 = OpTypeFunction %7\n"             // void()

+         "%9 = OpTypeInt 32 1\n"                // int32

+        "%10 = OpTypeInt 32 0\n"                // uint32

+        "%11 = OpTypeBool\n"

+         "%3 = OpTypeRuntimeArray %9\n"         // int32[]

+         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }

+        "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*

+         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in

+        "%13 = OpConstant %9 0\n"               // int32(0)

+        "%14 = OpConstant %9 1\n"               // int32(1)

+        "%15 = OpConstant %9 2\n"               // int32(2)

+        "%16 = OpConstant %10 0\n"              // uint32(0)

+        "%17 = OpTypeVector %10 3\n"            // vec4<int32>

+        "%18 = OpTypePointer Input %17\n"       // vec4<int32>*

+         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId

+        "%19 = OpTypePointer Input %10\n"       // uint32*

+         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out

+        "%20 = OpTypePointer Uniform %9\n"      // int32*

+         "%1 = OpFunction %7 None %8\n"         // -- Function begin --

+        "%21 = OpLabel\n"

+        "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x

+        "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x

+        "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]

+        "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]

+        "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]

+    // Start of branch logic

+    // %25 = in value

+        "%27 = OpSMod %9 %25 %15\n"             // in % 2

+              "OpSelectionMerge %28 None\n"

+              "OpSwitch %27 %29 1 %30\n"

+        "%30 = OpLabel\n"                       // (in % 2) == 1

+              "OpBranch %28\n"

+        "%29 = OpLabel\n"                       // (in % 2) != 1

+              "OpBranch %28\n"

+        "%28 = OpLabel\n"                       // merge

+        "%31 = OpPhi %9 %14 %30 %15 %29\n"      // (in % 2) == 1 ? 1 : 2

+              "OpStore %26 %31\n"

+    // End of branch logic

+              "OpReturn\n"

+              "OpFunctionEnd\n";

+

+    test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; });

+}