Refactor MaxPool and ArgMaxPool micro-kernels
- Support input_offset argument in MaxPool and ArgMaxPool micro-kernels
- Use input_offset to make indirection buffer independent on batch size
- Simplify and auto-generate unit tests
- Use more descriptive names for micro-kernel parameters
PiperOrigin-RevId: 281447682
diff --git a/src/max-pooling-nhwc.c b/src/max-pooling-nhwc.c
index 8f13af6..46ceae4 100644
--- a/src/max-pooling-nhwc.c
+++ b/src/max-pooling-nhwc.c
@@ -305,140 +305,19 @@
return status;
}
-enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
- xnn_operator_t max_pooling_op,
- size_t batch_size,
- size_t input_height,
- size_t input_width,
- const uint8_t* input,
- uint8_t* output,
- pthreadpool_t threadpool)
+static enum xnn_status setup_max_pooling2d(
+ xnn_operator_t max_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const void* input,
+ void* output,
+ uint32_t log2_input_element_size,
+ uint32_t log2_output_element_size,
+ struct maxpool_parameters maxpool[restrict static 1],
+ const void* params,
+ size_t num_threads)
{
- if (max_pooling_op->type != xnn_operator_type_max_pooling_nhwc_u8) {
- xnn_log_error("failed to setup Max Pooling (NHWC, U8) operator: operator type mismatch");
- return xnn_status_invalid_parameter;
- }
- max_pooling_op->state = xnn_run_state_invalid;
-
- if (!xnn_params.initialized) {
- xnn_log_error("failed to setup Max Pooling operator: XNNPACK is not initialized");
- return xnn_status_uninitialized;
- }
-
- if (input_width == 0 || input_height == 0) {
- xnn_log_error(
- "failed to setup Max Pooling operator with %zux%zu input: input dimensions must be non-zero",
- input_width, input_height);
- return xnn_status_invalid_parameter;
- }
-
- if (batch_size == 0) {
- max_pooling_op->state = xnn_run_state_skip;
- return xnn_status_success;
- }
-
- max_pooling_op->batch_size = batch_size;
- max_pooling_op->input_height = input_height;
- max_pooling_op->input_width = input_width;
- max_pooling_op->input = input;
-
- max_pooling_op->output_height = compute_output_dimension(
- max_pooling_op->padding_top + input_height + max_pooling_op->padding_bottom,
- max_pooling_op->kernel_height,
- max_pooling_op->dilation_height,
- max_pooling_op->stride_height);
- max_pooling_op->output_width = compute_output_dimension(
- max_pooling_op->padding_left + input_width + max_pooling_op->padding_right,
- max_pooling_op->kernel_width,
- max_pooling_op->dilation_width,
- max_pooling_op->stride_width);
- max_pooling_op->output = output;
-
- size_t valid_batch_size = 0;
- if (input == max_pooling_op->last_input &&
- input_height == max_pooling_op->last_input_height &&
- input_width == max_pooling_op->last_input_width)
- {
- valid_batch_size = max_pooling_op->valid_batch_size;
- if (batch_size <= valid_batch_size) {
- max_pooling_op->compute.range[0] = batch_size;
- max_pooling_op->state = xnn_run_state_ready;
- return xnn_status_success;
- }
- }
-
- const size_t pooling_height = max_pooling_op->kernel_height;
- const size_t pooling_width = max_pooling_op->kernel_width;
- const size_t pooling_size = pooling_height * pooling_width;
- const size_t output_height = max_pooling_op->output_height;
- const size_t output_width = max_pooling_op->output_width;
- // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
- const uint32_t mr = xnn_params.u8.maxpool.mr;
-
- const size_t step_width =
- max_pooling_op->dilation_width > 1 ? pooling_width : min(max_pooling_op->stride_width, pooling_width);
- const size_t step_height = pooling_size + (output_width * step_width - 1) * pooling_height;
- const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
-
- const void** indirection_buffer = (const void**) xnn_reallocate_memory(max_pooling_op->indirection_buffer, indirection_buffer_size);
- if (indirection_buffer == NULL) {
- xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
- return xnn_status_out_of_memory;
- }
- max_pooling_op->indirection_buffer = indirection_buffer;
-
- xnn_indirection_init_maxpool2d(max_pooling_op, valid_batch_size, step_height, step_width, 0 /* log2(sizeof(uint8_t)) */);
-
- const uint32_t qr = xnn_params.u8.maxpool.qr;
- const size_t channels = max_pooling_op->channels;
-
- const size_t indirect_input_height_stride = step_height * sizeof(void*);
- const size_t output_width_stride = max_pooling_op->output_pixel_stride * sizeof(uint8_t);
- const size_t output_height_stride = output_width * output_width_stride;
- const size_t multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
-
- max_pooling_op->context.max_pooling = (struct max_pooling_context) {
- .indirect_input = indirection_buffer,
- .indirect_input_batch_stride = output_height * indirect_input_height_stride,
- .indirect_input_height_stride = indirect_input_height_stride,
- .output = output,
- .output_batch_stride = output_height * output_height_stride,
- .output_height_stride = output_height_stride,
- .output_width = output_width,
- .pooling_size = pooling_size,
- .channels = channels,
- .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
- .output_increment = output_width_stride - channels * sizeof(uint8_t),
- .params.u8 = max_pooling_op->u8_output_params,
- .ukernel = xnn_params.u8.maxpool.ukernel,
- };
- max_pooling_op->compute.type = xnn_parallelization_type_2d;
- max_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_max_pooling;
- max_pooling_op->compute.range[0] = batch_size;
- max_pooling_op->compute.range[1] = output_height;
- max_pooling_op->state = xnn_run_state_ready;
-
- max_pooling_op->last_input = input;
- max_pooling_op->last_input_height = input_height;
- max_pooling_op->last_input_width = input_width;
- max_pooling_op->valid_batch_size = max(valid_batch_size, batch_size);
-
- return xnn_status_success;
-}
-
-enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
- xnn_operator_t max_pooling_op,
- size_t batch_size,
- size_t input_height,
- size_t input_width,
- const float* input,
- float* output,
- pthreadpool_t threadpool)
-{
- if (max_pooling_op->type != xnn_operator_type_max_pooling_nhwc_f32) {
- xnn_log_error("failed to setup Max Pooling (NHWC, F32) operator: operator type mismatch");
- return xnn_status_invalid_parameter;
- }
max_pooling_op->state = xnn_run_state_invalid;
if (!xnn_params.initialized) {
@@ -459,7 +338,6 @@
return xnn_status_success;
}
- max_pooling_op->batch_size = batch_size;
max_pooling_op->input_height = input_height;
max_pooling_op->input_width = input_width;
max_pooling_op->input = input;
@@ -474,76 +352,118 @@
max_pooling_op->kernel_width,
max_pooling_op->dilation_width,
max_pooling_op->stride_width);
- max_pooling_op->output = output;
-
- size_t valid_batch_size = 0;
- if (input == max_pooling_op->last_input &&
- input_height == max_pooling_op->last_input_height &&
- input_width == max_pooling_op->last_input_width)
- {
- valid_batch_size = max_pooling_op->valid_batch_size;
- if (batch_size <= valid_batch_size) {
- max_pooling_op->compute.range[0] = batch_size;
- max_pooling_op->state = xnn_run_state_ready;
- return xnn_status_success;
- }
- }
const size_t pooling_height = max_pooling_op->kernel_height;
const size_t pooling_width = max_pooling_op->kernel_width;
const size_t pooling_size = pooling_height * pooling_width;
const size_t output_height = max_pooling_op->output_height;
const size_t output_width = max_pooling_op->output_width;
- // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
- const uint32_t mr = xnn_params.f32.maxpool.mr;
+ const uint32_t mr = maxpool->mr;
const size_t step_width =
max_pooling_op->dilation_width > 1 ? pooling_width : min(max_pooling_op->stride_width, pooling_width);
const size_t step_height = pooling_size + (output_width * step_width - 1) * pooling_height;
- const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
- const void** indirection_buffer = (const void**) xnn_reallocate_memory(max_pooling_op->indirection_buffer, indirection_buffer_size);
- if (indirection_buffer == NULL) {
- xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
- return xnn_status_out_of_memory;
+ if (input_height != max_pooling_op->last_input_height ||
+ input_width != max_pooling_op->last_input_width)
+ {
+ // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
+ const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + output_height * step_height);
+ const void** indirection_buffer = (const void**) xnn_reallocate_memory(max_pooling_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ max_pooling_op->indirection_buffer = indirection_buffer;
+
+ xnn_indirection_init_maxpool2d(max_pooling_op, step_height, step_width, log2_input_element_size);
+
+ max_pooling_op->last_input = input;
+ max_pooling_op->last_input_height = input_height;
+ max_pooling_op->last_input_width = input_width;
}
- max_pooling_op->indirection_buffer = indirection_buffer;
- xnn_indirection_init_maxpool2d(max_pooling_op, valid_batch_size, step_height, step_width, 2 /* log2(sizeof(float)) */);
-
- const uint32_t qr = xnn_params.f32.maxpool.qr;
+ const uint32_t qr = maxpool->qr;
const size_t channels = max_pooling_op->channels;
const size_t indirect_input_height_stride = step_height * sizeof(void*);
- const size_t output_width_stride = max_pooling_op->output_pixel_stride * sizeof(float);
+ const size_t output_width_stride = max_pooling_op->output_pixel_stride << log2_output_element_size;
const size_t output_height_stride = output_width * output_width_stride;
const size_t multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
max_pooling_op->context.max_pooling = (struct max_pooling_context) {
- .indirect_input = indirection_buffer,
- .indirect_input_batch_stride = output_height * indirect_input_height_stride,
- .indirect_input_height_stride = indirect_input_height_stride,
- .output = output,
- .output_batch_stride = output_height * output_height_stride,
- .output_height_stride = output_height_stride,
- .output_width = output_width,
- .pooling_size = pooling_size,
- .channels = channels,
- .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
- .output_increment = output_width_stride - channels * sizeof(float),
- .params.f32 = max_pooling_op->f32_output_params,
- .ukernel = xnn_params.f32.maxpool.ukernel,
+ .indirect_input = max_pooling_op->indirection_buffer,
+ .indirect_input_height_stride = indirect_input_height_stride,
+ .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) max_pooling_op->last_input),
+ .input_batch_stride = (input_height * input_width * max_pooling_op->input_pixel_stride) << log2_input_element_size,
+ .output = output,
+ .output_batch_stride = output_height * output_height_stride,
+ .output_height_stride = output_height_stride,
+ .output_width = output_width,
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+ .output_increment = output_width_stride - (channels << log2_output_element_size),
+ .ukernel = maxpool->ukernel,
};
+ memcpy(&max_pooling_op->context.max_pooling.params, params, sizeof(max_pooling_op->context.max_pooling.params));
+
max_pooling_op->compute.type = xnn_parallelization_type_2d;
max_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_max_pooling;
max_pooling_op->compute.range[0] = batch_size;
max_pooling_op->compute.range[1] = output_height;
max_pooling_op->state = xnn_run_state_ready;
- max_pooling_op->last_input = input;
- max_pooling_op->last_input_height = input_height;
- max_pooling_op->last_input_width = input_width;
- max_pooling_op->valid_batch_size = max(valid_batch_size, batch_size);
-
return xnn_status_success;
}
+
+enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
+ xnn_operator_t max_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (max_pooling_op->type != xnn_operator_type_max_pooling_nhwc_u8) {
+ xnn_log_error("failed to setup Max Pooling (NHWC, U8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_max_pooling2d(
+ max_pooling_op,
+ batch_size, input_height, input_width,
+ input, output,
+ 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
+ 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
+ &xnn_params.u8.maxpool,
+ &max_pooling_op->u8_output_params,
+ pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
+ xnn_operator_t max_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (max_pooling_op->type != xnn_operator_type_max_pooling_nhwc_f32) {
+ xnn_log_error("failed to setup Max Pooling (NHWC, F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_max_pooling2d(
+ max_pooling_op,
+ batch_size, input_height, input_width,
+ input, output,
+ 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
+ 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
+ &xnn_params.f32.maxpool,
+ &max_pooling_op->f32_output_params,
+ pthreadpool_get_threads_count(threadpool));
+}
+