layers: Split record/validation of Pipeline/Shaders

Disentangle pipeline state initialization from pipeline state
validation.  Pipeline state is now const throughout validation, however
additional state is retained within the pipeline, which could better be
cached in the shader module, or instead pruned from the pipeline state
after creation.

Change-Id: I8e71e530fe1edab36ce5c31e6bd7c3ac55cde966
diff --git a/layers/shader_validation.cpp b/layers/shader_validation.cpp
index 4b31e8e..edd5182 100644
--- a/layers/shader_validation.cpp
+++ b/layers/shader_validation.cpp
@@ -46,16 +46,6 @@
 
 typedef std::pair<unsigned, unsigned> location_t;
 
-struct interface_var {
-    uint32_t id;
-    uint32_t type_id;
-    uint32_t offset;
-    bool is_patch;
-    bool is_block_member;
-    bool is_relaxed_precision;
-    // TODO: collect the name, too? Isn't required to be present.
-};
-
 struct shader_stage_attributes {
     char const *const name;
     bool arrayed_input;
@@ -71,6 +61,8 @@
     {"fragment shader", false, false, VK_SHADER_STAGE_FRAGMENT_BIT},
 };
 
+unsigned ExecutionModelToShaderStageFlagBits(unsigned mode);
+
 // SPIRV utility functions
 void SHADER_MODULE_STATE::BuildDefIndex() {
     for (auto insn : *this) {
@@ -130,6 +122,15 @@
                 def_index[insn.word(2)] = insn.offset();
                 break;
 
+                // Entry points ... add to the entrypoint table
+            case spv::OpEntryPoint: {
+                // Entry points do not have an id (the id is the function id) and thus need their own table
+                auto entrypoint_name = (char const *)&insn.word(3);
+                auto execution_model = insn.word(1);
+                auto entrypoint_stage = ExecutionModelToShaderStageFlagBits(execution_model);
+                entry_points.emplace(entrypoint_name, EntryPoint{insn.offset(), entrypoint_stage});
+                break;
+            }
             default:
                 // We don't care about any other defs for now.
                 break;
@@ -173,18 +174,12 @@
 }
 
 static spirv_inst_iter FindEntrypoint(SHADER_MODULE_STATE const *src, char const *name, VkShaderStageFlagBits stageBits) {
-    for (auto insn : *src) {
-        if (insn.opcode() == spv::OpEntryPoint) {
-            auto entrypointName = (char const *)&insn.word(3);
-            auto executionModel = insn.word(1);
-            auto entrypointStageBits = ExecutionModelToShaderStageFlagBits(executionModel);
-
-            if (!strcmp(entrypointName, name) && (entrypointStageBits & stageBits)) {
-                return insn;
-            }
+    auto range = src->entry_points.equal_range(name);
+    for (auto it = range.first; it != range.second; ++it) {
+        if (it->second.stage == stageBits) {
+            return src->at(it->second.offset);
         }
     }
-
     return src->end();
 }
 
@@ -1853,7 +1848,7 @@
 }
 
 bool CoreChecks::ValidateShaderStageInputOutputLimits(SHADER_MODULE_STATE const *src, VkPipelineShaderStageCreateInfo const *pStage,
-                                                      PIPELINE_STATE *pipeline, spirv_inst_iter entrypoint) {
+                                                      const PIPELINE_STATE *pipeline, spirv_inst_iter entrypoint) {
     if (pStage->stage == VK_SHADER_STAGE_COMPUTE_BIT || pStage->stage == VK_SHADER_STAGE_ALL_GRAPHICS ||
         pStage->stage == VK_SHADER_STAGE_ALL) {
         return false;
@@ -2106,7 +2101,7 @@
 // Validate SPV_NV_cooperative_matrix behavior that can't be statically validated
 // in SPIRV-Tools (e.g. due to specialization constant usage).
 bool CoreChecks::ValidateCooperativeMatrix(SHADER_MODULE_STATE const *src, VkPipelineShaderStageCreateInfo const *pStage,
-                                           PIPELINE_STATE *pipeline) {
+                                           const PIPELINE_STATE *pipeline) {
     bool skip = false;
 
     // Map SPIR-V result ID to specialization constant id (SpecId decoration value)
@@ -2570,7 +2565,7 @@
     return false;
 }
 
-static void ProcessExecutionModes(SHADER_MODULE_STATE const *src, spirv_inst_iter entrypoint, PIPELINE_STATE *pipeline) {
+static void ProcessExecutionModes(SHADER_MODULE_STATE const *src, const spirv_inst_iter &entrypoint, PIPELINE_STATE *pipeline) {
     auto entrypoint_id = entrypoint.word(2);
     bool is_point_mode = false;
 
@@ -2657,32 +2652,56 @@
     }
     return skip;
 }
+void ValidationStateTracker::RecordPipelineShaderStage(VkPipelineShaderStageCreateInfo const *pStage, PIPELINE_STATE *pipeline,
+                                                       PIPELINE_STATE::StageState *stage_state) {
+    // Validation shouldn't rely on anything in stage state being valid if the spirv isn't
+    auto module = GetShaderModuleState(pStage->module);
+    if (!module->has_valid_spirv) return;
 
-bool CoreChecks::ValidatePipelineShaderStage(VkPipelineShaderStageCreateInfo const *pStage, PIPELINE_STATE *pipeline,
-                                             SHADER_MODULE_STATE const **out_module, spirv_inst_iter *out_entrypoint,
-                                             bool check_point_size) {
-    bool skip = false;
-    auto module = *out_module = GetShaderModuleState(pStage->module);
-
-    if (!module->has_valid_spirv) return false;
-
-    // Find the entrypoint
-    auto entrypoint = *out_entrypoint = FindEntrypoint(module, pStage->pName, pStage->stage);
-    if (entrypoint == module->end()) {
-        if (log_msg(report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT, 0,
-                    "VUID-VkPipelineShaderStageCreateInfo-pName-00707", "No entrypoint found named `%s` for stage %s..",
-                    pStage->pName, string_VkShaderStageFlagBits(pStage->stage))) {
-            return true;  // no point continuing beyond here, any analysis is just going to be garbage.
-        }
-    }
+    // Validation shouldn't rely on anything in stage state being valid if the entrypoint isn't present
+    auto entrypoint = FindEntrypoint(module, pStage->pName, pStage->stage);
+    if (entrypoint == module->end()) return;
 
     // Mark accessible ids
-    auto accessible_ids = MarkAccessibleIds(module, entrypoint);
+    stage_state->accessible_ids = MarkAccessibleIds(module, entrypoint);
     ProcessExecutionModes(module, entrypoint, pipeline);
 
+    stage_state->descriptor_uses =
+        CollectInterfaceByDescriptorSlot(report_data, module, stage_state->accessible_ids, &stage_state->has_writable_descriptor);
+    // Capture descriptor uses for the pipeline
+    for (auto use : stage_state->descriptor_uses) {
+        // While validating shaders capture which slots are used by the pipeline
+        auto &reqs = pipeline->active_slots[use.first.first][use.first.second];
+        reqs = descriptor_req(reqs | DescriptorTypeToReqs(module, use.second.type_id));
+    }
+}
+
+bool CoreChecks::ValidatePipelineShaderStage(VkPipelineShaderStageCreateInfo const *pStage, const PIPELINE_STATE *pipeline,
+                                             const PIPELINE_STATE::StageState &stage_state, const SHADER_MODULE_STATE *module,
+                                             const spirv_inst_iter &entrypoint, bool check_point_size) {
+    bool skip = false;
+
+    // Check the module
+    if (!module->has_valid_spirv) {
+        skip |= log_msg(report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT, 0,
+                        "VUID-VkPipelineShaderStageCreateInfo-module-parameter", "%s does not contain valid spirv for stage %s.",
+                        report_data->FormatHandle(module->vk_shader_module).c_str(), string_VkShaderStageFlagBits(pStage->stage));
+    }
+
+    // Check the entrypoint
+    if (entrypoint == module->end()) {
+        skip |= log_msg(report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT, 0,
+                        "VUID-VkPipelineShaderStageCreateInfo-pName-00707", "No entrypoint found named `%s` for stage %s..",
+                        pStage->pName, string_VkShaderStageFlagBits(pStage->stage));
+    }
+    if (skip) return true;  // no point continuing beyond here, any analysis is just going to be garbage.
+
+    // Mark accessible ids
+    auto &accessible_ids = stage_state.accessible_ids;
+
     // Validate descriptor set layout against what the entrypoint actually uses
-    bool has_writable_descriptor = false;
-    auto descriptor_uses = CollectInterfaceByDescriptorSlot(report_data, module, accessible_ids, &has_writable_descriptor);
+    bool has_writable_descriptor = stage_state.has_writable_descriptor;
+    auto &descriptor_uses = stage_state.descriptor_uses;
 
     // Validate shader capabilities against enabled device features
     skip |= ValidateShaderCapabilities(module, pStage->stage);
@@ -2700,10 +2719,6 @@
 
     // Validate descriptor use
     for (auto use : descriptor_uses) {
-        // While validating shaders capture which slots are used by the pipeline
-        auto &reqs = pipeline->active_slots[use.first.first][use.first.second];
-        reqs = descriptor_req(reqs | DescriptorTypeToReqs(module, use.second.type_id));
-
         // Verify given pipelineLayout has requested setLayout with requested binding
         const auto &binding = GetDescriptorBinding(&pipeline->pipeline_layout, use.first);
         unsigned required_descriptor_count;
@@ -2861,7 +2876,7 @@
     return skip;
 }
 
-static inline uint32_t DetermineFinalGeomStage(PIPELINE_STATE *pipeline, VkGraphicsPipelineCreateInfo *pCreateInfo) {
+static inline uint32_t DetermineFinalGeomStage(const PIPELINE_STATE *pipeline, const VkGraphicsPipelineCreateInfo *pCreateInfo) {
     uint32_t stage_mask = 0;
     if (pipeline->topology_at_rasterizer == VK_PRIMITIVE_TOPOLOGY_POINT_LIST) {
         for (uint32_t i = 0; i < pCreateInfo->stageCount; i++) {
@@ -2883,12 +2898,12 @@
 
 // Validate that the shaders used by the given pipeline and store the active_slots
 //  that are actually used by the pipeline into pPipeline->active_slots
-bool CoreChecks::ValidateAndCapturePipelineShaderState(PIPELINE_STATE *pipeline) {
+bool CoreChecks::ValidateGraphicsPipelineShaderState(const PIPELINE_STATE *pipeline) {
     auto pCreateInfo = pipeline->graphicsPipelineCI.ptr();
     int vertex_stage = GetShaderStageId(VK_SHADER_STAGE_VERTEX_BIT);
     int fragment_stage = GetShaderStageId(VK_SHADER_STAGE_FRAGMENT_BIT);
 
-    SHADER_MODULE_STATE const *shaders[32];
+    const SHADER_MODULE_STATE *shaders[32];
     memset(shaders, 0, sizeof(shaders));
     spirv_inst_iter entrypoints[32];
     memset(entrypoints, 0, sizeof(entrypoints));
@@ -2899,7 +2914,10 @@
     for (uint32_t i = 0; i < pCreateInfo->stageCount; i++) {
         auto pStage = &pCreateInfo->pStages[i];
         auto stage_id = GetShaderStageId(pStage->stage);
-        skip |= ValidatePipelineShaderStage(pStage, pipeline, &shaders[stage_id], &entrypoints[stage_id],
+        shaders[stage_id] = GetShaderModuleState(pStage->module);
+        entrypoints[stage_id] = FindEntrypoint(shaders[stage_id], pStage->pName, pStage->stage);
+        skip |= ValidatePipelineShaderStage(pStage, pipeline, pipeline->stage_state[i], shaders[stage_id], entrypoints[stage_id],
+
                                             (pointlist_stage_mask == pStage->stage));
     }
 
@@ -2946,21 +2964,21 @@
 }
 
 bool CoreChecks::ValidateComputePipeline(PIPELINE_STATE *pipeline) {
-    auto pCreateInfo = pipeline->computePipelineCI.ptr();
+    const auto &stage = *pipeline->computePipelineCI.stage.ptr();
 
-    SHADER_MODULE_STATE const *module;
-    spirv_inst_iter entrypoint;
+    const SHADER_MODULE_STATE *module = GetShaderModuleState(stage.module);
+    const spirv_inst_iter entrypoint = FindEntrypoint(module, stage.pName, stage.stage);
 
-    return ValidatePipelineShaderStage(&pCreateInfo->stage, pipeline, &module, &entrypoint, false);
+    return ValidatePipelineShaderStage(&stage, pipeline, pipeline->stage_state[0], module, entrypoint, false);
 }
 
 bool CoreChecks::ValidateRayTracingPipelineNV(PIPELINE_STATE *pipeline) {
-    auto pCreateInfo = pipeline->raytracingPipelineCI.ptr();
+    const auto &stage = pipeline->raytracingPipelineCI.ptr()->pStages[0];
 
-    SHADER_MODULE_STATE const *module;
-    spirv_inst_iter entrypoint;
+    const SHADER_MODULE_STATE *module = GetShaderModuleState(stage.module);
+    const spirv_inst_iter entrypoint = FindEntrypoint(module, stage.pName, stage.stage);
 
-    return ValidatePipelineShaderStage(pCreateInfo->pStages, pipeline, &module, &entrypoint, false);
+    return ValidatePipelineShaderStage(&stage, pipeline, pipeline->stage_state[0], module, entrypoint, false);
 }
 
 uint32_t ValidationCache::MakeShaderHash(VkShaderModuleCreateInfo const *smci) { return XXH32(smci->pCode, smci->codeSize, 0); }