layers: Add group size param checks to CmdDispatch
Added parameter validation checks to group size parameters against
device limits, for vkCmdDispatch() and vkCmdDispatchBaseKHX().
Change-Id: Ic41031d694c6d311431fc49f48b1427a2b042337
diff --git a/layers/parameter_validation_utils.cpp b/layers/parameter_validation_utils.cpp
index 598d976..79d50a4 100644
--- a/layers/parameter_validation_utils.cpp
+++ b/layers/parameter_validation_utils.cpp
@@ -2554,6 +2554,94 @@
return skip;
}
+bool pv_vkCmdDispatch(VkCommandBuffer commandBuffer, uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ) {
+ bool skip = false;
+ layer_data *device_data = GetLayerDataPtr(get_dispatch_key(commandBuffer), layer_data_map);
+
+ if (groupCountX > device_data->device_limits.maxComputeWorkGroupCount[0]) {
+ skip |= log_msg(
+ device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+ HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19c00304, LayerName,
+ "vkCmdDispatch(): groupCountX (%" PRIu32 ") exceeds device limit maxComputeWorkGroupCount[0] (%" PRIu32 "). %s",
+ groupCountX, device_data->device_limits.maxComputeWorkGroupCount[0], validation_error_map[VALIDATION_ERROR_19c00304]);
+ }
+
+ if (groupCountY > device_data->device_limits.maxComputeWorkGroupCount[1]) {
+ skip |= log_msg(
+ device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+ HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19c00306, LayerName,
+ "vkCmdDispatch(): groupCountY (%" PRIu32 ") exceeds device limit maxComputeWorkGroupCount[1] (%" PRIu32 "). %s",
+ groupCountY, device_data->device_limits.maxComputeWorkGroupCount[1], validation_error_map[VALIDATION_ERROR_19c00306]);
+ }
+
+ if (groupCountZ > device_data->device_limits.maxComputeWorkGroupCount[2]) {
+ skip |= log_msg(
+ device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+ HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19c00308, LayerName,
+ "vkCmdDispatch(): groupCountZ (%" PRIu32 ") exceeds device limit maxComputeWorkGroupCount[2] (%" PRIu32 "). %s",
+ groupCountZ, device_data->device_limits.maxComputeWorkGroupCount[2], validation_error_map[VALIDATION_ERROR_19c00308]);
+ }
+
+ return skip;
+}
+
+bool pv_vkCmdDispatchBaseKHX(VkCommandBuffer commandBuffer, uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ,
+ uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ) {
+ bool skip = false;
+ layer_data *device_data = GetLayerDataPtr(get_dispatch_key(commandBuffer), layer_data_map);
+
+ // Paired if {} else if {} tests used to avoid any possible uint underflow
+ uint32_t limit = device_data->device_limits.maxComputeWorkGroupCount[0];
+ if (baseGroupX >= limit) {
+ skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+ HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e0034a, LayerName,
+ "vkCmdDispatch(): baseGroupX (%" PRIu32
+ ") equals or exceeds device limit maxComputeWorkGroupCount[0] (%" PRIu32 "). %s",
+ baseGroupX, limit, validation_error_map[VALIDATION_ERROR_19e0034a]);
+ } else if (groupCountX > (limit - baseGroupX)) {
+ skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+ HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e00350, LayerName,
+ "vkCmdDispatchBaseKHX(): baseGroupX (%" PRIu32 ") + groupCountX (%" PRIu32
+ ") exceeds device limit "
+ "maxComputeWorkGroupCount[0] (%" PRIu32 "). %s",
+ baseGroupX, groupCountX, limit, validation_error_map[VALIDATION_ERROR_19e00350]);
+ }
+
+ limit = device_data->device_limits.maxComputeWorkGroupCount[1];
+ if (baseGroupY >= limit) {
+ skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+ HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e0034c, LayerName,
+ "vkCmdDispatch(): baseGroupY (%" PRIu32
+ ") equals or exceeds device limit maxComputeWorkGroupCount[1] (%" PRIu32 "). %s",
+ baseGroupY, limit, validation_error_map[VALIDATION_ERROR_19e0034c]);
+ } else if (groupCountY > (limit - baseGroupY)) {
+ skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+ HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e00352, LayerName,
+ "vkCmdDispatchBaseKHX(): baseGroupY (%" PRIu32 ") + groupCountY (%" PRIu32
+ ") exceeds device limit "
+ "maxComputeWorkGroupCount[1] (%" PRIu32 "). %s",
+ baseGroupY, groupCountY, limit, validation_error_map[VALIDATION_ERROR_19e00352]);
+ }
+
+ limit = device_data->device_limits.maxComputeWorkGroupCount[2];
+ if (baseGroupZ >= limit) {
+ skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+ HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e0034e, LayerName,
+ "vkCmdDispatch(): baseGroupZ (%" PRIu32
+ ") equals or exceeds device limit maxComputeWorkGroupCount[2] (%" PRIu32 "). %s",
+ baseGroupZ, limit, validation_error_map[VALIDATION_ERROR_19e0034e]);
+ } else if (groupCountZ > (limit - baseGroupZ)) {
+ skip |= log_msg(device_data->report_data, VK_DEBUG_REPORT_ERROR_BIT_EXT, VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT,
+ HandleToUint64(commandBuffer), __LINE__, VALIDATION_ERROR_19e00354, LayerName,
+ "vkCmdDispatchBaseKHX(): baseGroupZ (%" PRIu32 ") + groupCountZ (%" PRIu32
+ ") exceeds device limit "
+ "maxComputeWorkGroupCount[2] (%" PRIu32 "). %s",
+ baseGroupZ, groupCountZ, limit, validation_error_map[VALIDATION_ERROR_19e00354]);
+ }
+
+ return skip;
+}
+
VKAPI_ATTR PFN_vkVoidFunction VKAPI_CALL vkGetDeviceProcAddr(VkDevice device, const char *funcName) {
const auto item = name_to_funcptr_map.find(funcName);
if (item != name_to_funcptr_map.end()) {
@@ -2616,6 +2704,8 @@
custom_functions["vkCreateSwapchainKHR"] = (void*)pv_vkCreateSwapchainKHR;
custom_functions["vkQueuePresentKHR"] = (void*)pv_vkQueuePresentKHR;
custom_functions["vkCreateDescriptorPool"] = (void*)pv_vkCreateDescriptorPool;
+ custom_functions["vkCmdDispatch"] = (void*)pv_vkCmdDispatch;
+ custom_functions["vkCmdDispatchBaseKHX"] = (void*)pv_vkCmdDispatchBaseKHX;
}
} // namespace parameter_validation