layers: Add group size param checks to CmdDispatch

Added parameter validation checks to group size parameters against
device limits, for vkCmdDispatch() and vkCmdDispatchBaseKHX().

Change-Id: Ic41031d694c6d311431fc49f48b1427a2b042337
diff --git a/layers/parameter_validation_utils.cpp b/layers/parameter_validation_utils.cpp
index 598d976..79d50a4 100644
--- a/layers/parameter_validation_utils.cpp
+++ b/layers/parameter_validation_utils.cpp
@@ -2554,6 +2554,94 @@
     return skip;
 }
 
+bool pv_vkCmdDispatch(VkCommandBuffer commandBuffer, uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ) {
+    bool skip = false;
+    layer_data *device_data = GetLayerDataPtr(get_dispatch_key(commandBuffer), layer_data_map);
+
+    if (groupCountX > device_data->device_limits.maxComputeWorkGroupCount[0]) {
+        skip |= log_msg(
+            device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+            HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19c00304, LayerName,
+            "vkCmdDispatch(): groupCountX (%" PRIu32 ") exceeds device limit maxComputeWorkGroupCount[0] (%" PRIu32 "). %s",
+            groupCountX, device_data->device_limits.maxComputeWorkGroupCount[0], validation_error_map[VALIDATION_ERROR_19c00304]);
+    }
+
+    if (groupCountY > device_data->device_limits.maxComputeWorkGroupCount[1]) {
+        skip |= log_msg(
+            device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+            HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19c00306, LayerName,
+            "vkCmdDispatch(): groupCountY (%" PRIu32 ") exceeds device limit maxComputeWorkGroupCount[1] (%" PRIu32 "). %s",
+            groupCountY, device_data->device_limits.maxComputeWorkGroupCount[1], validation_error_map[VALIDATION_ERROR_19c00306]);
+    }
+
+    if (groupCountZ > device_data->device_limits.maxComputeWorkGroupCount[2]) {
+        skip |= log_msg(
+            device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+            HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19c00308, LayerName,
+            "vkCmdDispatch(): groupCountZ (%" PRIu32 ") exceeds device limit maxComputeWorkGroupCount[2] (%" PRIu32 "). %s",
+            groupCountZ, device_data->device_limits.maxComputeWorkGroupCount[2], validation_error_map[VALIDATION_ERROR_19c00308]);
+    }
+
+    return skip;
+}
+
+bool pv_vkCmdDispatchBaseKHX(VkCommandBuffer commandBuffer, uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ,
+                             uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ) {
+    bool skip = false;
+    layer_data *device_data = GetLayerDataPtr(get_dispatch_key(commandBuffer), layer_data_map);
+
+    // Paired if {} else if {} tests used to avoid any possible uint underflow
+    uint32_t limit = device_data->device_limits.maxComputeWorkGroupCount[0];
+    if (baseGroupX >= limit) {
+        skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+                        HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e0034a, LayerName,
+                        "vkCmdDispatch(): baseGroupX (%" PRIu32
+                        ") equals or exceeds device limit maxComputeWorkGroupCount[0] (%" PRIu32 "). %s",
+                        baseGroupX, limit, validation_error_map[VALIDATION_ERROR_19e0034a]);
+    } else if (groupCountX > (limit - baseGroupX)) {
+        skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+                        HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e00350, LayerName,
+                        "vkCmdDispatchBaseKHX(): baseGroupX (%" PRIu32 ") + groupCountX (%" PRIu32
+                        ") exceeds device limit "
+                        "maxComputeWorkGroupCount[0] (%" PRIu32 "). %s",
+                        baseGroupX, groupCountX, limit, validation_error_map[VALIDATION_ERROR_19e00350]);
+    }
+
+    limit = device_data->device_limits.maxComputeWorkGroupCount[1];
+    if (baseGroupY >= limit) {
+        skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+                        HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e0034c, LayerName,
+                        "vkCmdDispatch(): baseGroupY (%" PRIu32
+                        ") equals or exceeds device limit maxComputeWorkGroupCount[1] (%" PRIu32 "). %s",
+                        baseGroupY, limit, validation_error_map[VALIDATION_ERROR_19e0034c]);
+    } else if (groupCountY > (limit - baseGroupY)) {
+        skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+                        HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e00352, LayerName,
+                        "vkCmdDispatchBaseKHX(): baseGroupY (%" PRIu32 ") + groupCountY (%" PRIu32
+                        ") exceeds device limit "
+                        "maxComputeWorkGroupCount[1] (%" PRIu32 "). %s",
+                        baseGroupY, groupCountY, limit, validation_error_map[VALIDATION_ERROR_19e00352]);
+    }
+
+    limit = device_data->device_limits.maxComputeWorkGroupCount[2];
+    if (baseGroupZ >= limit) {
+        skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+                        HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e0034e, LayerName,
+                        "vkCmdDispatch(): baseGroupZ (%" PRIu32
+                        ") equals or exceeds device limit maxComputeWorkGroupCount[2] (%" PRIu32 "). %s",
+                        baseGroupZ, limit, validation_error_map[VALIDATION_ERROR_19e0034e]);
+    } else if (groupCountZ > (limit - baseGroupZ)) {
+        skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+                        HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e00354, LayerName,
+                        "vkCmdDispatchBaseKHX(): baseGroupZ (%" PRIu32 ") + groupCountZ (%" PRIu32
+                        ") exceeds device limit "
+                        "maxComputeWorkGroupCount[2] (%" PRIu32 "). %s",
+                        baseGroupZ, groupCountZ, limit, validation_error_map[VALIDATION_ERROR_19e00354]);
+    }
+
+    return skip;
+}
+
 VKAPI_ATTR PFN_vkVoidFunction VKAPI_CALL vkGetDeviceProcAddr(VkDevice device, const char *funcName) {
     const auto item = name_to_funcptr_map.find(funcName);
     if (item != name_to_funcptr_map.end()) {
@@ -2616,6 +2704,8 @@
     custom_functions["vkCreateSwapchainKHR"] = (void*)pv_vkCreateSwapchainKHR;
     custom_functions["vkQueuePresentKHR"] = (void*)pv_vkQueuePresentKHR;
     custom_functions["vkCreateDescriptorPool"] = (void*)pv_vkCreateDescriptorPool;
+    custom_functions["vkCmdDispatch"] = (void*)pv_vkCmdDispatch;
+    custom_functions["vkCmdDispatchBaseKHX"] = (void*)pv_vkCmdDispatchBaseKHX;
 }
 
 }  // namespace parameter_validation