WIP bug-14815: VkShaderStage
diff --git a/include/vulkan.h b/include/vulkan.h
index 93b21bc..5f9a8b9 100644
--- a/include/vulkan.h
+++ b/include/vulkan.h
@@ -500,19 +500,6 @@
 } VkChannelSwizzle;
 
 typedef enum {
-    VK_SHADER_STAGE_VERTEX = 0,
-    VK_SHADER_STAGE_TESSELLATION_CONTROL = 1,
-    VK_SHADER_STAGE_TESSELLATION_EVALUATION = 2,
-    VK_SHADER_STAGE_GEOMETRY = 3,
-    VK_SHADER_STAGE_FRAGMENT = 4,
-    VK_SHADER_STAGE_COMPUTE = 5,
-    VK_SHADER_STAGE_BEGIN_RANGE = VK_SHADER_STAGE_VERTEX,
-    VK_SHADER_STAGE_END_RANGE = VK_SHADER_STAGE_COMPUTE,
-    VK_SHADER_STAGE_NUM = (VK_SHADER_STAGE_COMPUTE - VK_SHADER_STAGE_VERTEX + 1),
-    VK_SHADER_STAGE_MAX_ENUM = 0x7FFFFFFF
-} VkShaderStage;
-
-typedef enum {
     VK_VERTEX_INPUT_STEP_RATE_VERTEX = 0,
     VK_VERTEX_INPUT_STEP_RATE_INSTANCE = 1,
     VK_VERTEX_INPUT_STEP_RATE_BEGIN_RANGE = VK_VERTEX_INPUT_STEP_RATE_VERTEX,
@@ -973,6 +960,16 @@
 typedef VkFlags VkShaderCreateFlags;
 
 typedef enum {
+    VK_SHADER_STAGE_VERTEX_BIT = 0x00000001,
+    VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT = 0x00000002,
+    VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT = 0x00000004,
+    VK_SHADER_STAGE_GEOMETRY_BIT = 0x00000008,
+    VK_SHADER_STAGE_FRAGMENT_BIT = 0x00000010,
+    VK_SHADER_STAGE_COMPUTE_BIT = 0x00000020,
+    VK_SHADER_STAGE_ALL = 0x7FFFFFFF,
+} VkShaderStageFlagBits;
+
+typedef enum {
     VK_CHANNEL_R_BIT = 0x00000001,
     VK_CHANNEL_G_BIT = 0x00000002,
     VK_CHANNEL_B_BIT = 0x00000004,
@@ -986,16 +983,6 @@
     VK_PIPELINE_CREATE_DERIVATIVE_BIT = 0x00000004,
 } VkPipelineCreateFlagBits;
 typedef VkFlags VkPipelineCreateFlags;
-
-typedef enum {
-    VK_SHADER_STAGE_VERTEX_BIT = 0x00000001,
-    VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT = 0x00000002,
-    VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT = 0x00000004,
-    VK_SHADER_STAGE_GEOMETRY_BIT = 0x00000008,
-    VK_SHADER_STAGE_FRAGMENT_BIT = 0x00000010,
-    VK_SHADER_STAGE_COMPUTE_BIT = 0x00000020,
-    VK_SHADER_STAGE_ALL = 0x7FFFFFFF,
-} VkShaderStageFlagBits;
 typedef VkFlags VkShaderStageFlags;
 
 typedef enum {
@@ -1556,7 +1543,7 @@
     VkShaderModule                              module;
     const char*                                 pName;
     VkShaderCreateFlags                         flags;
-    VkShaderStage                               stage;
+    VkShaderStageFlagBits                       stage;
 } VkShaderCreateInfo;
 
 typedef struct {
@@ -1583,7 +1570,7 @@
 typedef struct {
     VkStructureType                             sType;
     const void*                                 pNext;
-    VkShaderStage                               stage;
+    VkShaderStageFlagBits                       stage;
     VkShader                                    shader;
     const VkSpecializationInfo*                 pSpecializationInfo;
 } VkPipelineShaderStageCreateInfo;
diff --git a/layers/draw_state.cpp b/layers/draw_state.cpp
index d36b1af..3317a43 100755
--- a/layers/draw_state.cpp
+++ b/layers/draw_state.cpp
@@ -217,7 +217,6 @@
 }
 // Block of code at start here for managing/tracking Pipeline state that this layer cares about
 // Just track 2 shaders for now
-#define VK_NUM_GRAPHICS_SHADERS VK_SHADER_STAGE_COMPUTE
 #define MAX_SLOTS 2048
 #define NUM_COMMAND_BUFFERS_TO_DISPLAY 10
 
@@ -448,27 +447,27 @@
         const VkPipelineShaderStageCreateInfo *pPSSCI = &pCreateInfo->pStages[i];
 
         switch (pPSSCI->stage) {
-            case VK_SHADER_STAGE_VERTEX:
+            case VK_SHADER_STAGE_VERTEX_BIT:
                 memcpy(&pPipeline->vsCI, pPSSCI, sizeof(VkPipelineShaderStageCreateInfo));
                 pPipeline->active_shaders |= VK_SHADER_STAGE_VERTEX_BIT;
                 break;
-            case VK_SHADER_STAGE_TESSELLATION_CONTROL:
+            case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
                 memcpy(&pPipeline->tcsCI, pPSSCI, sizeof(VkPipelineShaderStageCreateInfo));
                 pPipeline->active_shaders |= VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
                 break;
-            case VK_SHADER_STAGE_TESSELLATION_EVALUATION:
+            case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
                 memcpy(&pPipeline->tesCI, pPSSCI, sizeof(VkPipelineShaderStageCreateInfo));
                 pPipeline->active_shaders |= VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
                 break;
-            case VK_SHADER_STAGE_GEOMETRY:
+            case VK_SHADER_STAGE_GEOMETRY_BIT:
                 memcpy(&pPipeline->gsCI, pPSSCI, sizeof(VkPipelineShaderStageCreateInfo));
                 pPipeline->active_shaders |= VK_SHADER_STAGE_GEOMETRY_BIT;
                 break;
-            case VK_SHADER_STAGE_FRAGMENT:
+            case VK_SHADER_STAGE_FRAGMENT_BIT:
                 memcpy(&pPipeline->fsCI, pPSSCI, sizeof(VkPipelineShaderStageCreateInfo));
                 pPipeline->active_shaders |= VK_SHADER_STAGE_FRAGMENT_BIT;
                 break;
-            case VK_SHADER_STAGE_COMPUTE:
+            case VK_SHADER_STAGE_COMPUTE_BIT:
                 // TODO : Flag error, CS is specified through VkComputePipelineCreateInfo
                 pPipeline->active_shaders |= VK_SHADER_STAGE_COMPUTE_BIT;
                 break;
diff --git a/layers/param_checker.cpp b/layers/param_checker.cpp
index da5cf1a..8ea75ff 100644
--- a/layers/param_checker.cpp
+++ b/layers/param_checker.cpp
@@ -3978,8 +3978,10 @@
         "vkCreateGraphicsPipelines parameter, VkStructureType pCreateInfos->pStages->sType, is an invalid enumerator");
         return false;
     }
-    if(pCreateInfos->pStages->stage < VK_SHADER_STAGE_BEGIN_RANGE ||
-        pCreateInfos->pStages->stage > VK_SHADER_STAGE_END_RANGE)
+    if((pCreateInfos->pStages->stage & (VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT |
+                                       VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT |
+                                       VK_SHADER_STAGE_GEOMETRY_BIT |
+                                       VK_SHADER_STAGE_COMPUTE_BIT)) == 0)
     {
         log_msg(mdd(device), VK_DBG_REPORT_ERROR_BIT, (VkDbgObjectType)0, 0, 0, 1, "PARAMCHECK",
         "vkCreateGraphicsPipelines parameter, VkShaderStage pCreateInfos->pStages->stage, is an unrecognized enumerator");
@@ -4299,8 +4301,10 @@
         "vkCreateComputePipelines parameter, VkStructureType pCreateInfos->cs.sType, is an invalid enumerator");
         return false;
     }
-    if(pCreateInfos->stage.stage < VK_SHADER_STAGE_BEGIN_RANGE ||
-        pCreateInfos->stage.stage > VK_SHADER_STAGE_END_RANGE)
+    if((pCreateInfos->stage.stage & (VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT |
+                                       VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT |
+                                       VK_SHADER_STAGE_GEOMETRY_BIT |
+                                       VK_SHADER_STAGE_COMPUTE_BIT)) == 0)
     {
         log_msg(mdd(device), VK_DBG_REPORT_ERROR_BIT, (VkDbgObjectType)0, 0, 0, 1, "PARAMCHECK",
         "vkCreateComputePipelines parameter, VkShaderStage pCreateInfos->cs.stage, is an unrecognized enumerator");
diff --git a/layers/shader_checker.cpp b/layers/shader_checker.cpp
index 2b39c52..aae2813 100644
--- a/layers/shader_checker.cpp
+++ b/layers/shader_checker.cpp
@@ -33,6 +33,7 @@
 #include "vk_loader_platform.h"
 #include "vk_dispatch_table_helper.h"
 #include "vk_layer.h"
+#include "vk_layer_utils.h"
 #include "vk_layer_config.h"
 #include "vk_layer_table.h"
 #include "vk_layer_logging.h"
@@ -1013,7 +1014,7 @@
 
 
 static shader_stage_attributes
-shader_stage_attribs[VK_SHADER_STAGE_FRAGMENT + 1] = {
+shader_stage_attribs[] = {
     { "vertex shader", false },
     { "tessellation control shader", true },
     { "tessellation evaluation shader", false },
@@ -1040,14 +1041,22 @@
     return &(*set)[slot.second];
 }
 
+static uint32_t get_shader_stage_id(VkShaderStageFlagBits stage)
+{
+    uint32_t bit_pos = u_ffs(stage);
+    return bit_pos-1;
+}
 
 static bool
 validate_graphics_pipeline(VkDevice dev, VkGraphicsPipelineCreateInfo const *pCreateInfo)
 {
     /* We seem to allow pipeline stages to be specified out of order, so collect and identify them
      * before trying to do anything more: */
+    int vertex_stage = get_shader_stage_id(VK_SHADER_STAGE_VERTEX_BIT);
+    int geometry_stage = get_shader_stage_id(VK_SHADER_STAGE_GEOMETRY_BIT);
+    int fragment_stage = get_shader_stage_id(VK_SHADER_STAGE_FRAGMENT_BIT);
 
-    shader_module const *shaders[VK_SHADER_STAGE_FRAGMENT + 1];  /* exclude CS */
+    shader_module const *shaders[fragment_stage + 1];  /* exclude CS */
     memset(shaders, 0, sizeof(shaders));
     render_pass const *rp = 0;
     VkPipelineVertexInputStateCreateInfo const *vi = 0;
@@ -1059,7 +1068,8 @@
         VkPipelineShaderStageCreateInfo const *pStage = &pCreateInfo->pStages[i];
         if (pStage->sType == VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO) {
 
-            if (pStage->stage < VK_SHADER_STAGE_VERTEX || pStage->stage > VK_SHADER_STAGE_FRAGMENT) {
+            if ((pStage->stage & (VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_GEOMETRY_BIT | VK_SHADER_STAGE_FRAGMENT_BIT
+                                  | VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT)) == 0) {
                 if (log_msg(mdd(dev), VK_DBG_REPORT_WARN_BIT, VK_OBJECT_TYPE_DEVICE, /*dev*/0, 0, SHADER_CHECKER_UNKNOWN_STAGE, "SC",
                         "Unknown shader stage %d", pStage->stage)) {
                     pass = false;
@@ -1067,7 +1077,7 @@
             }
             else {
                 struct shader_object *shader = shader_object_map[pStage->shader.handle];
-                shaders[pStage->stage] = shader->module;
+                shaders[get_shader_stage_id(pStage->stage)] = shader->module;
 
                 /* validate descriptor set layout against what the spirv module actually uses */
                 std::map<std::pair<unsigned, unsigned>, interface_var> descriptor_uses;
@@ -1106,20 +1116,20 @@
         pass = validate_vi_consistency(dev, vi) && pass;
     }
 
-    if (shaders[VK_SHADER_STAGE_VERTEX]) {
-        pass = validate_vi_against_vs_inputs(dev, vi, shaders[VK_SHADER_STAGE_VERTEX]) && pass;
+    if (shaders[vertex_stage]) {
+        pass = validate_vi_against_vs_inputs(dev, vi, shaders[vertex_stage]) && pass;
     }
 
     /* TODO: enforce rules about present combinations of shaders */
-    int producer = VK_SHADER_STAGE_VERTEX;
-    int consumer = VK_SHADER_STAGE_GEOMETRY;
+    int producer = get_shader_stage_id(VK_SHADER_STAGE_VERTEX_BIT);
+    int consumer = get_shader_stage_id(VK_SHADER_STAGE_GEOMETRY_BIT);
 
-    while (!shaders[producer] && producer != VK_SHADER_STAGE_FRAGMENT) {
+    while (!shaders[producer] && producer != fragment_stage) {
         producer++;
         consumer++;
     }
 
-    for (; producer != VK_SHADER_STAGE_FRAGMENT && consumer <= VK_SHADER_STAGE_FRAGMENT; consumer++) {
+    for (; producer != fragment_stage && consumer <= fragment_stage; consumer++) {
         assert(shaders[producer]);
         if (shaders[consumer]) {
             pass = validate_interface_between_stages(dev,
@@ -1131,8 +1141,8 @@
         }
     }
 
-    if (shaders[VK_SHADER_STAGE_FRAGMENT] && rp) {
-        pass = validate_fs_outputs_against_render_pass(dev, shaders[VK_SHADER_STAGE_FRAGMENT], rp, pCreateInfo->subpass) && pass;
+    if (shaders[fragment_stage] && rp) {
+        pass = validate_fs_outputs_against_render_pass(dev, shaders[fragment_stage], rp, pCreateInfo->subpass) && pass;
     }
 
     loader_platform_thread_unlock_mutex(&globalLock);
diff --git a/layers/vk_layer_utils.h b/layers/vk_layer_utils.h
index 643e1d2..bd5eac8 100644
--- a/layers/vk_layer_utils.h
+++ b/layers/vk_layer_utils.h
@@ -24,6 +24,9 @@
  **************************************************************************/
 #pragma once
 #include <stdbool.h>
+#ifndef WIN32
+#include <strings.h> /* for ffs() */
+#endif
 
 #ifdef __cplusplus
 extern "C" {
@@ -48,6 +51,14 @@
 bool   vk_format_is_compressed(VkFormat format);
 size_t vk_format_get_size(VkFormat format);
 
+static inline int u_ffs(int val)
+{
+#ifdef WIN32
+        return __lzcnt(val) + 1;
+#else
+        return ffs(val);
+#endif
+}
 
 #ifdef __cplusplus
 }
diff --git a/tests/init.cpp b/tests/init.cpp
index a8adbe8..031f759 100644
--- a/tests/init.cpp
+++ b/tests/init.cpp
@@ -72,7 +72,7 @@
     void CreateCommandBufferTest();
     void CreatePipelineTest();
     void CreateShaderTest();
-    void CreateShader(VkShader *pshader, VkShaderStage stage);
+    void CreateShader(VkShader *pshader, VkShaderStageFlagBits stage);
 
     VkDevice device() {return m_device->handle();}
 
@@ -518,7 +518,7 @@
     CreateCommandBufferTest();
 }
 
-void VkTest::CreateShader(VkShader *pshader, VkShaderStage stage)
+void VkTest::CreateShader(VkShader *pshader, VkShaderStageFlagBits stage)
 {
     void *code;
     uint32_t codeSize;
diff --git a/tests/layer_validation_tests.cpp b/tests/layer_validation_tests.cpp
index 6325f87..bdf3871 100644
--- a/tests/layer_validation_tests.cpp
+++ b/tests/layer_validation_tests.cpp
@@ -288,8 +288,8 @@
 
     VkConstantBufferObj constantBuffer(m_device, bufSize*2, sizeof(float), (const void*) &data);
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1308,8 +1308,8 @@
     err = vkCreatePipelineLayout(m_device->device(), &pipeline_layout_ci, &pipeline_layout);
     ASSERT_VK_SUCCESS(err);
 
-    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT, this); //  TODO - We shouldn't need a fragment shader
+    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); //  TODO - We shouldn't need a fragment shader
                                                                                        // but add it to be able to run on more devices
     VkPipelineObj pipe(m_device);
     pipe.AddShader(&vs);
@@ -1562,8 +1562,8 @@
     err = vkCreatePipelineLayout(m_device->device(), &pipeline_layout_ci, &pipeline_layout);
     ASSERT_VK_SUCCESS(err);
 
-    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT, this); //  TODO - We shouldn't need a fragment shader
+    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); //  TODO - We shouldn't need a fragment shader
                                                                                        // but add it to be able to run on more devices
 
     VkPipelineObj pipe(m_device);
@@ -1829,19 +1829,19 @@
     VkPipelineShaderStageCreateInfo shaderStages[3];
     memset(&shaderStages, 0, 3 * sizeof(VkPipelineShaderStageCreateInfo));
 
-    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX, this);
+    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
     // Just using VS txt for Tess shaders as we don't care about functionality
-    VkShaderObj tc(m_device,bindStateVertShaderText,VK_SHADER_STAGE_TESSELLATION_CONTROL, this);
-    VkShaderObj te(m_device,bindStateVertShaderText,VK_SHADER_STAGE_TESSELLATION_EVALUATION, this);
+    VkShaderObj tc(m_device,bindStateVertShaderText,VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, this);
+    VkShaderObj te(m_device,bindStateVertShaderText,VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, this);
 
     shaderStages[0].sType  = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[0].stage  = VK_SHADER_STAGE_VERTEX;
+    shaderStages[0].stage  = VK_SHADER_STAGE_VERTEX_BIT;
     shaderStages[0].shader = vs.handle();
     shaderStages[1].sType  = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[1].stage  = VK_SHADER_STAGE_TESSELLATION_CONTROL;
+    shaderStages[1].stage  = VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
     shaderStages[1].shader = tc.handle();
     shaderStages[2].sType  = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[2].stage  = VK_SHADER_STAGE_TESSELLATION_EVALUATION;
+    shaderStages[2].stage  = VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
     shaderStages[2].shader = te.handle();
 
     VkPipelineInputAssemblyStateCreateInfo iaCI = {};
@@ -1960,15 +1960,15 @@
     VkPipelineShaderStageCreateInfo shaderStages[2];
 	memset(&shaderStages, 0, 2 * sizeof(VkPipelineShaderStageCreateInfo));
 
-    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX, this);
-	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT, this); // TODO - We shouldn't need a fragment shader
+    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); // TODO - We shouldn't need a fragment shader
 	                                                                                   // but add it to be able to run on more devices
     shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX;
+    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
     shaderStages[0].shader = vs.handle();
 
 	shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT;
+	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
 	shaderStages[1].shader = fs.handle();
 
     VkGraphicsPipelineCreateInfo gp_ci = {};
@@ -2066,15 +2066,15 @@
     VkPipelineShaderStageCreateInfo shaderStages[2];
 	memset(&shaderStages, 0, 2 * sizeof(VkPipelineShaderStageCreateInfo));
 
-    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX, this);
-	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT, this); // TODO - We shouldn't need a fragment shader
+    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); // TODO - We shouldn't need a fragment shader
 	                                                                                   // but add it to be able to run on more devices
     shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX;
+    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
     shaderStages[0].shader = vs.handle();
 
 	shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT;
+	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
 	shaderStages[1].shader = fs.handle();
 
     VkGraphicsPipelineCreateInfo gp_ci = {};
@@ -2179,15 +2179,15 @@
     VkPipelineShaderStageCreateInfo shaderStages[2];
 	memset(&shaderStages, 0, 2 * sizeof(VkPipelineShaderStageCreateInfo));
 
-    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX, this);
-	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT, this); // TODO - We shouldn't need a fragment shader
+    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); // TODO - We shouldn't need a fragment shader
 	                                                                                   // but add it to be able to run on more devices
     shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX;
+    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
     shaderStages[0].shader = vs.handle();
 
 	shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT;
+	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
 	shaderStages[1].shader = fs.handle();
 
 	VkPipelineVertexInputStateCreateInfo vi_ci = {};
@@ -2335,15 +2335,15 @@
     VkPipelineShaderStageCreateInfo shaderStages[2];
 	memset(&shaderStages, 0, 2 * sizeof(VkPipelineShaderStageCreateInfo));
 
-    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX, this);
-	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT, this); // TODO - We shouldn't need a fragment shader
+    VkShaderObj vs(m_device,bindStateVertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+	VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); // TODO - We shouldn't need a fragment shader
 	                                                                                   // but add it to be able to run on more devices
     shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX;
+    shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
     shaderStages[0].shader = vs.handle();
 
 	shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT;
+	shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
 	shaderStages[1].shader = fs.handle();
 
     VkPipelineVertexInputStateCreateInfo vi_ci = {};
@@ -3166,8 +3166,8 @@
     err = vkCreatePipelineLayout(m_device->device(), &pipeline_layout_ci, &pipeline_layout);
     ASSERT_VK_SUCCESS(err);
 
-    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT, this); //  TODO - We shouldn't need a fragment shader
+    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); //  TODO - We shouldn't need a fragment shader
                                                                                        // but add it to be able to run on more devices
     VkPipelineObj pipe(m_device);
     pipe.AddShader(&vs);
@@ -3254,8 +3254,8 @@
     err = vkCreatePipelineLayout(m_device->device(), &pipeline_layout_ci, &pipeline_layout);
     ASSERT_VK_SUCCESS(err);
 
-    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT, this); //  TODO - We shouldn't need a fragment shader
+    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); //  TODO - We shouldn't need a fragment shader
                                                                                        // but add it to be able to run on more devices
     VkPipelineObj pipe(m_device);
     pipe.AddShader(&vs);
@@ -3355,8 +3355,8 @@
     err = vkCreatePipelineLayout(m_device->device(), &pipeline_layout_ci, &pipeline_layout);
     ASSERT_VK_SUCCESS(err);
 
-    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT, this); //  TODO - We shouldn't need a fragment shader
+    VkShaderObj vs(m_device, bindStateVertShaderText, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, bindStateFragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this); //  TODO - We shouldn't need a fragment shader
                                                                                        // but add it to be able to run on more devices
     VkPipelineObj pipe(m_device);
     pipe.AddShader(&vs);
@@ -3588,8 +3588,8 @@
         "   color = vec4(1);\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddColorAttachment();
@@ -3637,8 +3637,8 @@
         "   color = vec4(x);\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddColorAttachment();
@@ -3688,8 +3688,8 @@
         "   color = vec4(x);\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddColorAttachment();
@@ -3743,8 +3743,8 @@
         "   color = vec4(1);\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddColorAttachment();
@@ -3795,8 +3795,8 @@
         "   color = vec4(1);\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddColorAttachment();
@@ -3851,8 +3851,8 @@
         "   color = vec4(1);\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddColorAttachment();
@@ -3911,8 +3911,8 @@
         "   color = vec4(1);\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddColorAttachment();
@@ -3962,8 +3962,8 @@
         "void main(){\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddShader(&vs);
@@ -4014,8 +4014,8 @@
         "   y = vec4(1);\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddShader(&vs);
@@ -4065,8 +4065,8 @@
         "   x = ivec4(1);\n"
         "}\n";
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipe(m_device);
     pipe.AddShader(&vs);
@@ -4118,8 +4118,8 @@
 
     m_errorMonitor->ClearState();
 
-    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vsSource, VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj fs(m_device, fsSource, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
 
     VkPipelineObj pipe(m_device);
diff --git a/tests/render_tests.cpp b/tests/render_tests.cpp
index 022d1b1..59548cc 100644
--- a/tests/render_tests.cpp
+++ b/tests/render_tests.cpp
@@ -493,8 +493,8 @@
 
     VkConstantBufferObj constantBuffer(m_device, bufSize*2, sizeof(float), (const void*) &data);
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -752,8 +752,8 @@
     VkConstantBufferObj meshBuffer(m_device,sizeof(g_vbData)/sizeof(g_vbData[0]),sizeof(g_vbData[0]), g_vbData);
     meshBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -833,8 +833,8 @@
     VkConstantBufferObj meshBuffer(m_device, sizeof(vb_data) / sizeof(vb_data[0]), sizeof(vb_data[0]), vb_data);
     meshBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -939,8 +939,8 @@
     indexBuffer.CreateAndInitBuffer(sizeof(g_idxData)/sizeof(g_idxData[0]), VK_INDEX_TYPE_UINT16, g_idxData);
     indexBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1033,8 +1033,8 @@
     VkConstantBufferObj meshBuffer(m_device,sizeof(g_vbData)/sizeof(g_vbData[0]),sizeof(g_vbData[0]), g_vbData);
     meshBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1121,8 +1121,8 @@
     VkConstantBufferObj meshBuffer(m_device,sizeof(g_vbData)/sizeof(g_vbData[0]),sizeof(g_vbData[0]), g_vbData);
     meshBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1219,8 +1219,8 @@
     VkConstantBufferObj meshBuffer(m_device,sizeof(g_vbData)/sizeof(g_vbData[0]),sizeof(g_vbData[0]), g_vbData);
     meshBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1308,8 +1308,8 @@
     VkConstantBufferObj meshBuffer(m_device,sizeof(g_vbData)/sizeof(g_vbData[0]),sizeof(g_vbData[0]), g_vbData);
     meshBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1395,8 +1395,8 @@
     const int matrixSize = sizeof(MVP) / sizeof(MVP[0]);
 
     VkConstantBufferObj MVPBuffer(m_device, matrixSize, sizeof(MVP[0]), (const void*) &MVP[0][0]);
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1476,8 +1476,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1545,8 +1545,8 @@
     VkConstantBufferObj meshBuffer(m_device,sizeof(g_vbData)/sizeof(g_vbData[0]),sizeof(g_vbData[0]), g_vbData);
     meshBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1648,8 +1648,8 @@
     VkConstantBufferObj meshBuffer(m_device,sizeof(vData)/sizeof(vData[0]),sizeof(vData[0]), vData);
     meshBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1741,8 +1741,8 @@
     VkConstantBufferObj meshBuffer(m_device,sizeof(g_vbData)/sizeof(g_vbData[0]),sizeof(g_vbData[0]), g_vbData);
     meshBuffer.BufferMemoryBarrier();
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1844,8 +1844,8 @@
     const int buf_size = sizeof(MVP) / sizeof(float);
 
     VkConstantBufferObj MVPBuffer(m_device, buf_size, sizeof(MVP[0]), (const void*) &MVP[0][0]);
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -1947,8 +1947,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
     VkSamplerObj sampler(m_device);
     VkTextureObj texture(m_device);
 
@@ -2018,8 +2018,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
     VkSamplerObj sampler(m_device);
     VkTextureObj texture(m_device);
 
@@ -2096,8 +2096,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
     VkSamplerObj sampler(m_device);
     VkTextureObj texture(m_device);
 
@@ -2164,8 +2164,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
     VkSamplerObj sampler(m_device);
     VkTextureObj texture(m_device);
 
@@ -2238,8 +2238,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkSamplerObj sampler1(m_device);
     VkSamplerObj sampler2(m_device);
@@ -2320,8 +2320,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     // Let's populate our buffer with the following:
     //     vec4 red;
@@ -2402,8 +2402,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     // We're going to create a number of uniform buffers, and then allow
     // the shader to select which it wants to read from with a binding
@@ -2496,8 +2496,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     // We're going to create a number of uniform buffers, and then allow
     // the shader to select which it wants to read from with a binding
@@ -2614,8 +2614,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     // We're going to create a number of uniform buffers, and then allow
     // the shader to select which it wants to read from with a binding
@@ -2729,8 +2729,8 @@
     const int buf_size = sizeof(MVP) / sizeof(float);
 
     VkConstantBufferObj mvpBuffer(m_device, buf_size, sizeof(MVP[0]), (const void*) &MVP[0][0]);
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
     VkSamplerObj sampler(m_device);
     VkTextureObj texture(m_device);
 
@@ -2845,8 +2845,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     const float redVals[4]   = { 1.0, 0.0, 0.0, 1.0 };
     const float greenVals[4] = { 0.0, 1.0, 0.0, 1.0 };
@@ -2955,8 +2955,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     const float redVals[4]   = { 1.0, 0.0, 0.0, 1.0 };
     const float greenVals[4] = { 0.0, 1.0, 0.0, 1.0 };
@@ -3237,8 +3237,8 @@
 
     const int constCount   = sizeof(mixedVals)   / sizeof(float);
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkConstantBufferObj mixedBuffer(m_device, constCount, sizeof(mixedVals[0]), (const void*) mixedVals);
 
@@ -3334,8 +3334,8 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX, this);
-    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device,vertShaderText,VK_SHADER_STAGE_VERTEX_BIT, this);
+    VkShaderObj ps(m_device,fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     uint32_t tex_colors[2] = { 0xffff0000, 0xffff0000 };
     VkSamplerObj sampler0(m_device);
@@ -3453,9 +3453,9 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device, vertShaderText, VK_SHADER_STAGE_VERTEX,   this);
-    VkShaderObj gs(m_device, geomShaderText, VK_SHADER_STAGE_GEOMETRY, this);
-    VkShaderObj ps(m_device, fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vertShaderText, VK_SHADER_STAGE_VERTEX_BIT,   this);
+    VkShaderObj gs(m_device, geomShaderText, VK_SHADER_STAGE_GEOMETRY_BIT, this);
+    VkShaderObj ps(m_device, fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -3801,9 +3801,9 @@
 
     const int constCount   = sizeof(mixedVals)   / sizeof(float);
 
-    VkShaderObj vs(m_device, vertShaderText, VK_SHADER_STAGE_VERTEX,   this);
-    VkShaderObj gs(m_device, geomShaderText, VK_SHADER_STAGE_GEOMETRY, this);
-    VkShaderObj ps(m_device, fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vertShaderText, VK_SHADER_STAGE_VERTEX_BIT,   this);
+    VkShaderObj gs(m_device, geomShaderText, VK_SHADER_STAGE_GEOMETRY_BIT, this);
+    VkShaderObj ps(m_device, fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkConstantBufferObj mixedBuffer(m_device, constCount, sizeof(mixedVals[0]), (const void*) mixedVals);
 
@@ -3929,9 +3929,9 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device, vertShaderText, VK_SHADER_STAGE_VERTEX,   this);
-    VkShaderObj gs(m_device, geomShaderText, VK_SHADER_STAGE_GEOMETRY, this);
-    VkShaderObj ps(m_device, fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vertShaderText, VK_SHADER_STAGE_VERTEX_BIT,   this);
+    VkShaderObj gs(m_device, geomShaderText, VK_SHADER_STAGE_GEOMETRY_BIT, this);
+    VkShaderObj ps(m_device, fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
@@ -4059,9 +4059,9 @@
     ASSERT_NO_FATAL_FAILURE(InitState());
     ASSERT_NO_FATAL_FAILURE(InitViewport());
 
-    VkShaderObj vs(m_device, vertShaderText, VK_SHADER_STAGE_VERTEX,   this);
-    VkShaderObj gs(m_device, geomShaderText, VK_SHADER_STAGE_GEOMETRY, this);
-    VkShaderObj ps(m_device, fragShaderText, VK_SHADER_STAGE_FRAGMENT, this);
+    VkShaderObj vs(m_device, vertShaderText, VK_SHADER_STAGE_VERTEX_BIT,   this);
+    VkShaderObj gs(m_device, geomShaderText, VK_SHADER_STAGE_GEOMETRY_BIT, this);
+    VkShaderObj ps(m_device, fragShaderText, VK_SHADER_STAGE_FRAGMENT_BIT, this);
 
     VkPipelineObj pipelineobj(m_device);
     pipelineobj.AddColorAttachment();
diff --git a/tests/vkrenderframework.cpp b/tests/vkrenderframework.cpp
index ce7f112..b435d25 100644
--- a/tests/vkrenderframework.cpp
+++ b/tests/vkrenderframework.cpp
@@ -1071,7 +1071,7 @@
     return stageInfo;
 }
 
-VkShaderObj::VkShaderObj(VkDeviceObj *device, const char * shader_code, VkShaderStage stage, VkRenderFramework *framework)
+VkShaderObj::VkShaderObj(VkDeviceObj *device, const char * shader_code, VkShaderStageFlagBits stage, VkRenderFramework *framework)
 {
     VkResult U_ASSERT_ONLY err = VK_SUCCESS;
     std::vector<unsigned int> spv;
diff --git a/tests/vkrenderframework.h b/tests/vkrenderframework.h
index 37d0e4a..c6914a2 100644
--- a/tests/vkrenderframework.h
+++ b/tests/vkrenderframework.h
@@ -405,12 +405,12 @@
 class VkShaderObj : public vk_testing::Shader
 {
 public:
-    VkShaderObj(VkDeviceObj *device, const char * shaderText, VkShaderStage stage, VkRenderFramework *framework);
+    VkShaderObj(VkDeviceObj *device, const char * shaderText, VkShaderStageFlagBits stage, VkRenderFramework *framework);
     VkPipelineShaderStageCreateInfo* GetStageCreateInfo();
 
 protected:
     VkPipelineShaderStageCreateInfo     stage_info;
-    VkShaderStage                       m_stage;
+    VkShaderStageFlagBits               m_stage;
     VkDeviceObj                        *m_device;
 
 };
diff --git a/tests/vktestbinding.h b/tests/vktestbinding.h
index 7397e59..6ca8e26 100644
--- a/tests/vktestbinding.h
+++ b/tests/vktestbinding.h
@@ -468,7 +468,7 @@
     void init(const Device &dev, const VkShaderCreateInfo &info);
     VkResult init_try(const Device &dev, const VkShaderCreateInfo &info);
 
-    static VkShaderCreateInfo create_info(VkShaderModule module, const char *pName, VkFlags flags, VkShaderStage stage);
+    static VkShaderCreateInfo create_info(VkShaderModule module, const char *pName, VkFlags flags, VkShaderStageFlagBits stage);
 };
 
 class Pipeline : public internal::NonDispHandle<VkPipeline> {
@@ -793,7 +793,7 @@
     return info;
 }
 
-inline VkShaderCreateInfo Shader::create_info(VkShaderModule module, const char *pName, VkFlags flags, VkShaderStage stage)
+inline VkShaderCreateInfo Shader::create_info(VkShaderModule module, const char *pName, VkFlags flags, VkShaderStageFlagBits stage)
 {
     VkShaderCreateInfo info = {};
     info.sType = VK_STRUCTURE_TYPE_SHADER_CREATE_INFO;
diff --git a/tests/vktestframework.cpp b/tests/vktestframework.cpp
index 0d13a5c..e61efb8 100644
--- a/tests/vktestframework.cpp
+++ b/tests/vktestframework.cpp
@@ -1674,25 +1674,25 @@
 //
 // Convert VK shader type to compiler's
 //
-EShLanguage VkTestFramework::FindLanguage(const VkShaderStage shader_type)
+EShLanguage VkTestFramework::FindLanguage(const VkShaderStageFlagBits shader_type)
 {
     switch (shader_type) {
-    case VK_SHADER_STAGE_VERTEX:
+    case VK_SHADER_STAGE_VERTEX_BIT:
         return EShLangVertex;
 
-    case VK_SHADER_STAGE_TESSELLATION_CONTROL:
+    case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
         return EShLangTessControl;
 
-    case VK_SHADER_STAGE_TESSELLATION_EVALUATION:
+    case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
         return EShLangTessEvaluation;
 
-    case VK_SHADER_STAGE_GEOMETRY:
+    case VK_SHADER_STAGE_GEOMETRY_BIT:
         return EShLangGeometry;
 
-    case VK_SHADER_STAGE_FRAGMENT:
+    case VK_SHADER_STAGE_FRAGMENT_BIT:
         return EShLangFragment;
 
-    case VK_SHADER_STAGE_COMPUTE:
+    case VK_SHADER_STAGE_COMPUTE_BIT:
         return EShLangCompute;
 
     default:
@@ -1705,7 +1705,7 @@
 // Compile a given string containing GLSL into SPV for use by VK
 // Return value of false means an error was encountered.
 //
-bool VkTestFramework::GLSLtoSPV(const VkShaderStage shader_type,
+bool VkTestFramework::GLSLtoSPV(const VkShaderStageFlagBits shader_type,
                                  const char *pshader,
                                  std::vector<unsigned int> &spirv)
 {
diff --git a/tests/vktestframework.h b/tests/vktestframework.h
index 99ca276..3b2264a 100644
--- a/tests/vktestframework.h
+++ b/tests/vktestframework.h
@@ -91,7 +91,7 @@
     void Compare(const char *comment, VkImageObj *image);
     void RecordImage(VkImageObj * image);
     void RecordImages(vector<VkImageObj *> image);
-    bool GLSLtoSPV(const VkShaderStage shader_type,
+    bool GLSLtoSPV(const VkShaderStageFlagBits shader_type,
                    const char *pshader,
                    std::vector<unsigned int> &spv);
     static bool         m_use_glsl;
@@ -109,7 +109,7 @@
     void SetMessageOptions(EShMessages& messages);
     void ProcessConfigFile();
     EShLanguage FindLanguage(const std::string& name);
-    EShLanguage FindLanguage(const VkShaderStage shader_type);
+    EShLanguage FindLanguage(const VkShaderStageFlagBits shader_type);
     std::string ConfigFile;
     bool SetConfigFile(const std::string& name);