diff --git a/include/vulkan.h b/include/vulkan.h
index 560910a..8399053 100644
--- a/include/vulkan.h
+++ b/include/vulkan.h
@@ -1552,7 +1552,6 @@
 typedef struct {
     VkStructureType                             sType;
     const void*                                 pNext;
-    VkShaderStageFlagBits                       stage;
     VkShader                                    shader;
     const VkSpecializationInfo*                 pSpecializationInfo;
 } VkPipelineShaderStageCreateInfo;
diff --git a/layers/draw_state.cpp b/layers/draw_state.cpp
index 492895f..a928547 100755
--- a/layers/draw_state.cpp
+++ b/layers/draw_state.cpp
@@ -73,6 +73,7 @@
     unordered_map<uint64_t, SET_NODE*> setMap;
     unordered_map<uint64_t, LAYOUT_NODE*> layoutMap;
     unordered_map<uint64_t, PIPELINE_LAYOUT_NODE> pipelineLayoutMap;
+    unordered_map<uint64_t, VkShaderStageFlagBits> shaderStageMap;
     // Map for layout chains
     unordered_map<void*, GLOBAL_CB_NODE*> cmdBufferMap;
     unordered_map<uint64_t, VkRenderPassCreateInfo*> renderPassMap;
@@ -430,7 +431,7 @@
 }
 // Init the pipeline mapping info based on pipeline create info LL tree
 //  Threading note : Calls to this function should wrapped in mutex
-static PIPELINE_NODE* initPipeline(const VkGraphicsPipelineCreateInfo* pCreateInfo, PIPELINE_NODE* pBasePipeline)
+static PIPELINE_NODE* initPipeline(layer_data* dev_data, const VkGraphicsPipelineCreateInfo* pCreateInfo, PIPELINE_NODE* pBasePipeline)
 {
     PIPELINE_NODE* pPipeline = new PIPELINE_NODE;
     if (pBasePipeline) {
@@ -448,7 +449,10 @@
     for (uint32_t i = 0; i < pCreateInfo->stageCount; i++) {
         const VkPipelineShaderStageCreateInfo *pPSSCI = &pCreateInfo->pStages[i];
 
-        switch (pPSSCI->stage) {
+        if (dev_data->shaderStageMap.find(pPSSCI->shader.handle) == dev_data->shaderStageMap.end())
+            continue;
+
+        switch (dev_data->shaderStageMap[pPSSCI->shader.handle]) {
             case VK_SHADER_STAGE_VERTEX_BIT:
                 memcpy(&pPipeline->vsCI, pPSSCI, sizeof(VkPipelineShaderStageCreateInfo));
                 pPipeline->active_shaders |= VK_SHADER_STAGE_VERTEX_BIT;
@@ -1798,8 +1802,10 @@
 
 VK_LAYER_EXPORT void VKAPI vkDestroyShader(VkDevice device, VkShader shader)
 {
-    get_my_data_ptr(get_dispatch_key(device), layer_data_map)->device_dispatch_table->DestroyShader(device, shader);
-    // TODO : Clean up any internal data structures using this obj.
+    layer_data* dev_data = get_my_data_ptr(get_dispatch_key(device), layer_data_map);
+    uint64_t handle = shader.handle;
+    dev_data->device_dispatch_table->DestroyShader(device, shader);
+    dev_data->shaderStageMap.erase(handle);
 }
 
 VK_LAYER_EXPORT void VKAPI vkDestroyPipeline(VkDevice device, VkPipeline pipeline)
@@ -1899,6 +1905,23 @@
     return result;
 }
 
+VK_LAYER_EXPORT VkResult VKAPI vkCreateShader(
+        VkDevice device,
+        const VkShaderCreateInfo *pCreateInfo,
+        VkShader *pShader)
+{
+    layer_data* dev_data = get_my_data_ptr(get_dispatch_key(device), layer_data_map);
+    VkResult result = dev_data->device_dispatch_table->CreateShader(device, pCreateInfo, pShader);
+
+    if (VK_SUCCESS == result) {
+        loader_platform_thread_lock_mutex(&globalLock);
+        dev_data->shaderStageMap[pShader->handle] = pCreateInfo->stage;
+        loader_platform_thread_unlock_mutex(&globalLock);
+    }
+
+    return result;
+}
+
 //TODO handle pipeline caches
 VkResult VKAPI vkCreatePipelineCache(
     VkDevice                                    device,
@@ -1955,7 +1978,7 @@
     uint32_t i=0;
     loader_platform_thread_lock_mutex(&globalLock);
     for (i=0; i<count; i++) {
-        pPipeNode[i] = initPipeline(&pCreateInfos[i], NULL);
+        pPipeNode[i] = initPipeline(dev_data, &pCreateInfos[i], NULL);
         skipCall |= verifyPipelineCreateState(dev_data, device, pPipeNode[i]);
     }
     loader_platform_thread_unlock_mutex(&globalLock);
@@ -3701,6 +3724,8 @@
         return (PFN_vkVoidFunction) vkCreateImage;
     if (!strcmp(funcName, "vkCreateImageView"))
         return (PFN_vkVoidFunction) vkCreateImageView;
+    if (!strcmp(funcName, "vkCreateShader"))
+        return (PFN_vkVoidFunction) vkCreateShader;
     if (!strcmp(funcName, "CreatePipelineCache"))
         return (PFN_vkVoidFunction) vkCreatePipelineCache;
     if (!strcmp(funcName, "DestroyPipelineCache"))
diff --git a/layers/param_checker.cpp b/layers/param_checker.cpp
index f93ad1c..d7c13d4 100644
--- a/layers/param_checker.cpp
+++ b/layers/param_checker.cpp
@@ -3889,15 +3889,6 @@
         "vkCreateGraphicsPipelines parameter, VkStructureType pCreateInfos->pStages->sType, is an invalid enumerator");
         return false;
     }
-    if((pCreateInfos->pStages->stage & (VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT |
-                                       VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT |
-                                       VK_SHADER_STAGE_GEOMETRY_BIT |
-                                       VK_SHADER_STAGE_COMPUTE_BIT)) == 0)
-    {
-        log_msg(mdd(device), VK_DBG_REPORT_ERROR_BIT, (VkDbgObjectType)0, 0, 0, 1, "PARAMCHECK",
-        "vkCreateGraphicsPipelines parameter, VkShaderStage pCreateInfos->pStages->stage, is an unrecognized enumerator");
-        return false;
-    }
     if(pCreateInfos->pStages->pSpecializationInfo != nullptr)
     {
     if(pCreateInfos->pStages->pSpecializationInfo->pMap != nullptr)
@@ -4212,15 +4203,6 @@
         "vkCreateComputePipelines parameter, VkStructureType pCreateInfos->cs.sType, is an invalid enumerator");
         return false;
     }
-    if((pCreateInfos->stage.stage & (VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT |
-                                       VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT |
-                                       VK_SHADER_STAGE_GEOMETRY_BIT |
-                                       VK_SHADER_STAGE_COMPUTE_BIT)) == 0)
-    {
-        log_msg(mdd(device), VK_DBG_REPORT_ERROR_BIT, (VkDbgObjectType)0, 0, 0, 1, "PARAMCHECK",
-        "vkCreateComputePipelines parameter, VkShaderStage pCreateInfos->cs.stage, is an unrecognized enumerator");
-        return false;
-    }
     if(pCreateInfos->stage.pSpecializationInfo != nullptr)
     {
     if(pCreateInfos->stage.pSpecializationInfo->pMap != nullptr)
diff --git a/layers/shader_checker.cpp b/layers/shader_checker.cpp
index 64feaf0..22bb357 100644
--- a/layers/shader_checker.cpp
+++ b/layers/shader_checker.cpp
@@ -214,10 +214,12 @@
 struct shader_object {
     std::string name;
     struct shader_module *module;
+    VkShaderStageFlagBits stage;
 
     shader_object(VkShaderCreateInfo const *pCreateInfo)
     {
         module = shader_module_map[pCreateInfo->module.handle];
+        stage = pCreateInfo->stage;
         name = pCreateInfo->pName;
     }
 };
@@ -1068,16 +1070,18 @@
         VkPipelineShaderStageCreateInfo const *pStage = &pCreateInfo->pStages[i];
         if (pStage->sType == VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO) {
 
-            if ((pStage->stage & (VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_GEOMETRY_BIT | VK_SHADER_STAGE_FRAGMENT_BIT
+            // always true; pStage->stage may be revived in a later revision and
+            // this will make sense again
+            if ((VK_SHADER_STAGE_VERTEX_BIT & (VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_GEOMETRY_BIT | VK_SHADER_STAGE_FRAGMENT_BIT
                                   | VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT)) == 0) {
                 if (log_msg(mdd(dev), VK_DBG_REPORT_WARN_BIT, VK_OBJECT_TYPE_DEVICE, /*dev*/0, 0, SHADER_CHECKER_UNKNOWN_STAGE, "SC",
-                        "Unknown shader stage %d", pStage->stage)) {
+                        "Unknown shader stage %d", VK_SHADER_STAGE_VERTEX_BIT)) {
                     pass = false;
                 }
             }
             else {
                 struct shader_object *shader = shader_object_map[pStage->shader.handle];
-                shaders[get_shader_stage_id(pStage->stage)] = shader->module;
+                shaders[get_shader_stage_id(shader->stage)] = shader->module;
 
                 /* validate descriptor set layout against what the spirv module actually uses */
                 std::map<std::pair<unsigned, unsigned>, interface_var> descriptor_uses;
diff --git a/tests/layer_validation_tests.cpp b/tests/layer_validation_tests.cpp
index c9fb02c..8e120ba 100644
--- a/tests/layer_validation_tests.cpp
+++ b/tests/layer_validation_tests.cpp
@@ -2070,11 +2070,9 @@
 	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); // TODO - We shouldn't need a fragment shader
 	                                                                                   // but add it to be able to run on more devices
     shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
     shaderStages[0].shader = vs.handle();
 
 	shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
 	shaderStages[1].shader = fs.handle();
 
     VkGraphicsPipelineCreateInfo gp_ci = {};
@@ -2180,11 +2178,9 @@
 	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); // TODO - We shouldn't need a fragment shader
 	                                                                                   // but add it to be able to run on more devices
     shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
     shaderStages[0].shader = vs.handle();
 
 	shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
 	shaderStages[1].shader = fs.handle();
 
     VkGraphicsPipelineCreateInfo gp_ci = {};
@@ -2297,11 +2293,9 @@
 	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); // TODO - We shouldn't need a fragment shader
 	                                                                                   // but add it to be able to run on more devices
     shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
     shaderStages[0].shader = vs.handle();
 
 	shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
 	shaderStages[1].shader = fs.handle();
 
 	VkPipelineVertexInputStateCreateInfo vi_ci = {};
@@ -2457,11 +2451,9 @@
 	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); // TODO - We shouldn't need a fragment shader
 	                                                                                   // but add it to be able to run on more devices
     shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
     shaderStages[0].shader = vs.handle();
 
 	shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
 	shaderStages[1].shader = fs.handle();
 
     VkPipelineVertexInputStateCreateInfo vi_ci = {};
diff --git a/tests/vkrenderframework.cpp b/tests/vkrenderframework.cpp
index 0c759d9..a26b6d3 100644
--- a/tests/vkrenderframework.cpp
+++ b/tests/vkrenderframework.cpp
@@ -1052,7 +1052,6 @@
 {
     VkPipelineShaderStageCreateInfo *stageInfo = (VkPipelineShaderStageCreateInfo*) calloc( 1,sizeof(VkPipelineShaderStageCreateInfo) );
     stageInfo->sType                = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    stageInfo->stage                = m_stage;
     stageInfo->shader               = handle();
 
     return stageInfo;
