layers: Improve DrawState write descriptor update

Validate that stageFlags are the same for a single write update.
Validate that all sampler updates from a single write update are either immutable or non-immutable.
Refactor update contents check to have a loop per switch case instead of switching for each loop iteration.
Added two new validation errors:

DRAWSTATE_DESCRIPTOR_STAGEFLAGS_MISMATCH
DRAWSTATE_INCONSISTENT_IMMUTABLE_SAMPLER_UPDATE
diff --git a/layers/draw_state.cpp b/layers/draw_state.cpp
index c9db5bb..2bc0a09 100755
--- a/layers/draw_state.cpp
+++ b/layers/draw_state.cpp
@@ -782,12 +782,12 @@
     return skipCall;
 }
 // Verify that the descriptor type in the update struct matches what's expected by the layout
-static VkBool32 validateUpdateType(layer_data* my_data, const VkDevice device, const LAYOUT_NODE* pLayout, const GENERIC_HEADER* pUpdateStruct)
+static VkBool32 validateUpdateConsistency(layer_data* my_data, const VkDevice device, const LAYOUT_NODE* pLayout, const GENERIC_HEADER* pUpdateStruct, uint32_t startIndex, uint32_t endIndex)
 {
     // First get actual type of update
     VkBool32 skipCall = VK_FALSE;
     VkDescriptorType actualType;
-    uint32_t i = 0, startIndex = 0, endIndex = 0;
+    uint32_t i = 0;
     switch (pUpdateStruct->sType)
     {
         case VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET:
@@ -801,14 +801,19 @@
             skipCall |= log_msg(my_data->report_data, VK_DBG_REPORT_ERROR_BIT, (VkDbgObjectType) 0, 0, 0, DRAWSTATE_INVALID_UPDATE_STRUCT, "DS",
                     "Unexpected UPDATE struct of type %s (value %u) in vkUpdateDescriptors() struct tree", string_VkStructureType(pUpdateStruct->sType), pUpdateStruct->sType);
     }
-    skipCall |= getUpdateStartIndex(my_data, device, pLayout, pUpdateStruct, &startIndex);
-    skipCall |= getUpdateEndIndex(my_data, device, pLayout, pUpdateStruct, &endIndex);
     if (VK_FALSE == skipCall) {
+        // Set first stageFlags as reference and verify that all other updates match it
+        VkShaderStageFlags refStageFlags = pLayout->stageFlags[startIndex];
         for (i = startIndex; i <= endIndex; i++) {
             if (pLayout->descriptorTypes[i] != actualType) {
                 skipCall |= log_msg(my_data->report_data, VK_DBG_REPORT_ERROR_BIT, (VkDbgObjectType) 0, 0, 0, DRAWSTATE_DESCRIPTOR_TYPE_MISMATCH, "DS",
-                        "Descriptor update type of %s has descriptor type %s that does not match overlapping binding descriptor type of %s!",
-                        string_VkStructureType(pUpdateStruct->sType), string_VkDescriptorType(actualType), string_VkDescriptorType(pLayout->descriptorTypes[i]));
+                    "Write descriptor update has descriptor type %s that does not match overlapping binding descriptor type of %s!",
+                    string_VkDescriptorType(actualType), string_VkDescriptorType(pLayout->descriptorTypes[i]));
+            }
+            if (pLayout->stageFlags[i] != refStageFlags) {
+                skipCall |= log_msg(my_data->report_data, VK_DBG_REPORT_ERROR_BIT, (VkDbgObjectType) 0, 0, 0, DRAWSTATE_DESCRIPTOR_STAGEFLAGS_MISMATCH, "DS",
+                    "Write descriptor update has stageFlags %x that do not match overlapping binding descriptor stageFlags of %x!",
+                    refStageFlags, pLayout->stageFlags[i]);
             }
         }
     }
@@ -1001,75 +1006,82 @@
     VkImageLayout* pImageLayout = NULL;
     VkDescriptorBufferInfo* pBufferInfo = NULL;
     VkBool32 immutable = VK_FALSE;
-    // TODO : Can refactor this to only switch once and then have loop inside of switch case
-    for (uint32_t i=0; i<pWDS->count; ++i) {
-        switch (pWDS->descriptorType) {
-            case VK_DESCRIPTOR_TYPE_SAMPLER:
+    uint32_t i = 0;
+    // For given update type, verify that update contents are correct
+    switch (pWDS->descriptorType) {
+        case VK_DESCRIPTOR_TYPE_SAMPLER:
+            for (i=0; i<pWDS->count; ++i) {
                 skipCall |= validateSampler(my_data, &(pWDS->pImageInfo[i].sampler), immutable);
-                break;
-            case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
+            }
+            break;
+        case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
+            for (i=0; i<pWDS->count; ++i) {
                 if (NULL == pLayoutBinding->pImmutableSamplers) {
                     pSampler = &(pWDS->pImageInfo[i].sampler);
+                    if (immutable) {
+                        skipCall |= log_msg(my_data->report_data, VK_DBG_REPORT_ERROR_BIT, VK_OBJECT_TYPE_SAMPLER, pSampler->handle, 0, DRAWSTATE_INCONSISTENT_IMMUTABLE_SAMPLER_UPDATE, "DS",
+                            "vkUpdateDescriptorSets: Update #%u is not an immutable sampler %#" PRIxLEAST64 ", but previous update(s) from this "
+                            "VkWriteDescriptorSet struct used an immutable sampler. All updates from a single struct must either "
+                            "use immutable or non-immutable samplers.", i, pSampler->handle);
+                    }
                 } else {
+                    if (i>0 && !immutable) {
+                        skipCall |= log_msg(my_data->report_data, VK_DBG_REPORT_ERROR_BIT, VK_OBJECT_TYPE_SAMPLER, pSampler->handle, 0, DRAWSTATE_INCONSISTENT_IMMUTABLE_SAMPLER_UPDATE, "DS",
+                            "vkUpdateDescriptorSets: Update #%u is an immutable sampler, but previous update(s) from this "
+                            "VkWriteDescriptorSet struct used a non-immutable sampler. All updates from a single struct must either "
+                            "use immutable or non-immutable samplers.", i);
+                    }
                     immutable = VK_TRUE;
                     pSampler = &(pLayoutBinding->pImmutableSamplers[i]);
                 }
                 skipCall |= validateSampler(my_data, pSampler, immutable);
-                // Intentionally fall through here to also validate image stuff
-            case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
-            case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
-            case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
+            }
+            // Intentionally fall through here to also validate image stuff
+        case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
+        case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
+        case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
+            for (i=0; i<pWDS->count; ++i) {
                 skipCall |= validateImageView(my_data, &(pWDS->pImageInfo[i].imageView), pWDS->pImageInfo[i].imageLayout);
-                break;
-            case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
-            case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
+            }
+            break;
+        case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
+        case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
+            for (i=0; i<pWDS->count; ++i) {
                 skipCall |= validateBufferView(my_data, &(pWDS->pTexelBufferView[i]));
-                break;
-            case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
-            case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
-            case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
-            case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:
+            }
+            break;
+        case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
+        case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
+        case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
+        case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:
+            for (i=0; i<pWDS->count; ++i) {
                 skipCall |= validateBufferInfo(my_data, &(pWDS->pBufferInfo[i]));
-                break;
-        }
+            }
+            break;
     }
     return skipCall;
 }
-
-// update DS mappings based on pUpdateArray
-// TODO : Verify we validate this spec req: All consecutive bindings updated via a single
-//  VkWriteDescriptorSet structure must have identical descriptorType and stageFlags,
-//  and must all either use immutable samplers or none must use immutable samplers.
-static VkBool32 dsUpdate(layer_data* my_data, VkDevice device, VkStructureType type, uint32_t updateCount, const void* pUpdateArray)
+// update DS mappings based on write and copy update arrays
+static VkBool32 dsUpdate(layer_data* my_data, VkDevice device, uint32_t writeCount, const VkWriteDescriptorSet* pWDS, uint32_t copyCount, const VkCopyDescriptorSet* pCDS)
 {
-    const VkWriteDescriptorSet *pWDS = NULL;
     VkBool32 skipCall = VK_FALSE;
 
-    if (type == VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET) {
-        pWDS = (const VkWriteDescriptorSet *)pUpdateArray;
-    } else {
-        // TODO : copy update validation is completely broken so just skip it for now
-        //  Need to add separate loop to verify copy array
-        return VK_FALSE;
-    }
-
     loader_platform_thread_lock_mutex(&globalLock);
     LAYOUT_NODE* pLayout = NULL;
     VkDescriptorSetLayoutCreateInfo* pLayoutCI = NULL;
-    // TODO : If pCIList is NULL, flag error
-    // Perform all updates
-    for (uint32_t i = 0; i < updateCount; i++) {
+    // Validate Write updates
+    for (uint32_t i = 0; i < writeCount; i++) {
         VkDescriptorSet ds = pWDS[i].destSet;
-        SET_NODE* pSet = my_data->setMap[ds.handle]; // getSetNode() without locking
+        SET_NODE* pSet = my_data->setMap[ds.handle];
         GENERIC_HEADER* pUpdate = (GENERIC_HEADER*) &pWDS[i];
         pLayout = pSet->pLayout;
         // First verify valid update struct
         if ((skipCall = validUpdateStruct(my_data, device, pUpdate)) == VK_TRUE) {
             break;
         }
-        // Make sure that binding is within bounds
         uint32_t binding = 0, endIndex = 0;
-        skipCall |= getUpdateBinding(my_data, device, pUpdate, &binding);
+        binding = pWDS[i].destBinding;
+        // Make sure that layout being updated has the binding being updated
         if (pLayout->createInfo.count < binding) {
             skipCall |= log_msg(my_data->report_data, VK_DBG_REPORT_ERROR_BIT, VK_OBJECT_TYPE_DESCRIPTOR_SET, ds.handle, 0, DRAWSTATE_INVALID_UPDATE_INDEX, "DS",
                     "Descriptor Set %p does not have binding to match update binding %u for update type %s!", ds, binding, string_VkStructureType(pUpdate->sType));
@@ -1077,15 +1089,16 @@
             // Next verify that update falls within size of given binding
             skipCall |= getUpdateEndIndex(my_data, device, pLayout, pUpdate, &endIndex);
             if (getBindingEndIndex(pLayout, binding) < endIndex) {
-                // TODO : Keep count of layout CI structs and size this string dynamically based on that count
                 pLayoutCI = &pLayout->createInfo;
                 string DSstr = vk_print_vkdescriptorsetlayoutcreateinfo(pLayoutCI, "{DS}    ");
                 skipCall |= log_msg(my_data->report_data, VK_DBG_REPORT_ERROR_BIT, VK_OBJECT_TYPE_DESCRIPTOR_SET, ds.handle, 0, DRAWSTATE_DESCRIPTOR_UPDATE_OUT_OF_BOUNDS, "DS",
                         "Descriptor update type of %s is out of bounds for matching binding %u in Layout w/ CI:\n%s!", string_VkStructureType(pUpdate->sType), binding, DSstr.c_str());
             } else { // TODO : should we skip update on a type mismatch or force it?
-                // Layout bindings match w/ update ok, now verify that update is of the right type
-                if ((skipCall = validateUpdateType(my_data, device, pLayout, pUpdate)) == VK_FALSE) {
-                    // The update matches the layout, but need to make sure it's self-consistent as well
+                uint32_t startIndex;
+                skipCall |= getUpdateStartIndex(my_data, device, pLayout, pUpdate, &startIndex);
+                // Layout bindings match w/ update, now verify that update type & stageFlags are the same for entire update
+                if ((skipCall = validateUpdateConsistency(my_data, device, pLayout, pUpdate, startIndex, endIndex)) == VK_FALSE) {
+                    // The update is within bounds and consistent, but need to make sure contents make sense as well
                     if ((skipCall = validateUpdateContents(my_data, &pWDS[i], &pLayout->createInfo.pBinding[binding])) == VK_FALSE) {
                         // Update is good. Save the update info
                         // Create new update struct for this set's shadow copy
@@ -1099,9 +1112,6 @@
                             pNewNode->pNext = pSet->pUpdateStructs;
                             pSet->pUpdateStructs = pNewNode;
                             // Now update appropriate descriptor(s) to point to new Update node
-                            skipCall |= getUpdateEndIndex(my_data, device, pLayout, pUpdate, &endIndex);
-                            uint32_t startIndex;
-                            skipCall |= getUpdateStartIndex(my_data, device, pLayout, pUpdate, &startIndex);
                             for (uint32_t j = startIndex; j <= endIndex; j++) {
                                 assert(j<pSet->descriptorCount);
                                 pSet->ppDescriptors[j] = pNewNode;
@@ -2026,11 +2036,13 @@
         }
         if (totalCount > 0) {
             pNewNode->descriptorTypes.resize(totalCount);
+            pNewNode->stageFlags.resize(totalCount);
             uint32_t offset = 0;
             uint32_t j = 0;
             for (uint32_t i=0; i<pCreateInfo->count; i++) {
                 for (j = 0; j < pCreateInfo->pBinding[i].arraySize; j++) {
                     pNewNode->descriptorTypes[offset + j] = pCreateInfo->pBinding[i].descriptorType;
+                    pNewNode->stageFlags[offset + j] = pCreateInfo->pBinding[i].stageFlags;
                 }
                 offset += j;
             }
@@ -2195,10 +2207,9 @@
 
 VK_LAYER_EXPORT void VKAPI vkUpdateDescriptorSets(VkDevice device, uint32_t writeCount, const VkWriteDescriptorSet* pDescriptorWrites, uint32_t copyCount, const VkCopyDescriptorSet* pDescriptorCopies)
 {
-    // dsUpdate will return VK_TRUE only if a bailout error occurs, so we want to call down tree when both updates return VK_FALSE
+    // dsUpdate will return VK_TRUE only if a bailout error occurs, so we want to call down tree when update returns VK_FALSE
     layer_data* dev_data = get_my_data_ptr(get_dispatch_key(device), layer_data_map);
-    if (!dsUpdate(dev_data, device, VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET, writeCount, pDescriptorWrites) &&
-        !dsUpdate(dev_data, device, VK_STRUCTURE_TYPE_COPY_DESCRIPTOR_SET, copyCount, pDescriptorCopies)) {
+    if (!dsUpdate(dev_data, device, writeCount, pDescriptorWrites, copyCount, pDescriptorCopies)) {
         dev_data->device_dispatch_table->UpdateDescriptorSets(device, writeCount, pDescriptorWrites, copyCount, pDescriptorCopies);
     }
 }