Refactor MaxPool and ArgMaxPool micro-kernels

- Support input_offset argument in MaxPool and ArgMaxPool micro-kernels
- Use input_offset to make indirection buffer independent on batch size
- Simplify and auto-generate unit tests
- Use more descriptive names for micro-kernel parameters

PiperOrigin-RevId: 281447682
diff --git a/src/max-pooling-nhwc.c b/src/max-pooling-nhwc.c
index 8f13af6..46ceae4 100644
--- a/src/max-pooling-nhwc.c
+++ b/src/max-pooling-nhwc.c
@@ -305,140 +305,19 @@
   return status;
 }
 
-enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
-    xnn_operator_t max_pooling_op,
-    size_t batch_size,
-    size_t input_height,
-    size_t input_width,
-    const uint8_t* input,
-    uint8_t* output,
-    pthreadpool_t threadpool)
+static enum xnn_status setup_max_pooling2d(
+  xnn_operator_t max_pooling_op,
+  size_t batch_size,
+  size_t input_height,
+  size_t input_width,
+  const void* input,
+  void* output,
+  uint32_t log2_input_element_size,
+  uint32_t log2_output_element_size,
+  struct maxpool_parameters maxpool[restrict static 1],
+  const void* params,
+  size_t num_threads)
 {
-  if (max_pooling_op->type != xnn_operator_type_max_pooling_nhwc_u8) {
-    xnn_log_error("failed to setup Max Pooling (NHWC, U8) operator: operator type mismatch");
-    return xnn_status_invalid_parameter;
-  }
-  max_pooling_op->state = xnn_run_state_invalid;
-
-  if (!xnn_params.initialized) {
-    xnn_log_error("failed to setup Max Pooling operator: XNNPACK is not initialized");
-    return xnn_status_uninitialized;
-  }
-
-  if (input_width == 0 || input_height == 0) {
-    xnn_log_error(
-      "failed to setup Max Pooling operator with %zux%zu input: input dimensions must be non-zero",
-      input_width, input_height);
-    return xnn_status_invalid_parameter;
-  }
-
-  if (batch_size == 0) {
-    max_pooling_op->state = xnn_run_state_skip;
-    return xnn_status_success;
-  }
-
-  max_pooling_op->batch_size = batch_size;
-  max_pooling_op->input_height = input_height;
-  max_pooling_op->input_width = input_width;
-  max_pooling_op->input = input;
-
-  max_pooling_op->output_height = compute_output_dimension(
-      max_pooling_op->padding_top + input_height + max_pooling_op->padding_bottom,
-      max_pooling_op->kernel_height,
-      max_pooling_op->dilation_height,
-      max_pooling_op->stride_height);
-  max_pooling_op->output_width = compute_output_dimension(
-      max_pooling_op->padding_left + input_width + max_pooling_op->padding_right,
-      max_pooling_op->kernel_width,
-      max_pooling_op->dilation_width,
-      max_pooling_op->stride_width);
-  max_pooling_op->output = output;
-
-  size_t valid_batch_size = 0;
-  if (input == max_pooling_op->last_input &&
-      input_height == max_pooling_op->last_input_height &&
-      input_width == max_pooling_op->last_input_width)
-  {
-    valid_batch_size = max_pooling_op->valid_batch_size;
-    if (batch_size <= valid_batch_size) {
-      max_pooling_op->compute.range[0] = batch_size;
-      max_pooling_op->state = xnn_run_state_ready;
-      return xnn_status_success;
-    }
-  }
-
-  const size_t pooling_height = max_pooling_op->kernel_height;
-  const size_t pooling_width = max_pooling_op->kernel_width;
-  const size_t pooling_size = pooling_height * pooling_width;
-  const size_t output_height = max_pooling_op->output_height;
-  const size_t output_width = max_pooling_op->output_width;
-  // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
-  const uint32_t mr = xnn_params.u8.maxpool.mr;
-
-  const size_t step_width =
-    max_pooling_op->dilation_width > 1 ? pooling_width : min(max_pooling_op->stride_width, pooling_width);
-  const size_t step_height = pooling_size + (output_width * step_width - 1) * pooling_height;
-  const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
-
-  const void** indirection_buffer = (const void**) xnn_reallocate_memory(max_pooling_op->indirection_buffer, indirection_buffer_size);
-  if (indirection_buffer == NULL) {
-    xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
-    return xnn_status_out_of_memory;
-  }
-  max_pooling_op->indirection_buffer = indirection_buffer;
-
-  xnn_indirection_init_maxpool2d(max_pooling_op, valid_batch_size, step_height, step_width, 0 /* log2(sizeof(uint8_t)) */);
-
-  const uint32_t qr = xnn_params.u8.maxpool.qr;
-  const size_t channels = max_pooling_op->channels;
-
-  const size_t indirect_input_height_stride = step_height * sizeof(void*);
-  const size_t output_width_stride = max_pooling_op->output_pixel_stride * sizeof(uint8_t);
-  const size_t output_height_stride = output_width * output_width_stride;
-  const size_t multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
-
-  max_pooling_op->context.max_pooling = (struct max_pooling_context) {
-      .indirect_input = indirection_buffer,
-      .indirect_input_batch_stride = output_height * indirect_input_height_stride,
-      .indirect_input_height_stride = indirect_input_height_stride,
-      .output = output,
-      .output_batch_stride = output_height * output_height_stride,
-      .output_height_stride = output_height_stride,
-      .output_width = output_width,
-      .pooling_size = pooling_size,
-      .channels = channels,
-      .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
-      .output_increment = output_width_stride - channels * sizeof(uint8_t),
-      .params.u8 = max_pooling_op->u8_output_params,
-      .ukernel = xnn_params.u8.maxpool.ukernel,
-  };
-  max_pooling_op->compute.type = xnn_parallelization_type_2d;
-  max_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_max_pooling;
-  max_pooling_op->compute.range[0] = batch_size;
-  max_pooling_op->compute.range[1] = output_height;
-  max_pooling_op->state = xnn_run_state_ready;
-
-  max_pooling_op->last_input = input;
-  max_pooling_op->last_input_height = input_height;
-  max_pooling_op->last_input_width = input_width;
-  max_pooling_op->valid_batch_size = max(valid_batch_size, batch_size);
-
-  return xnn_status_success;
-}
-
-enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
-    xnn_operator_t max_pooling_op,
-    size_t batch_size,
-    size_t input_height,
-    size_t input_width,
-    const float* input,
-    float* output,
-    pthreadpool_t threadpool)
-{
-  if (max_pooling_op->type != xnn_operator_type_max_pooling_nhwc_f32) {
-    xnn_log_error("failed to setup Max Pooling (NHWC, F32) operator: operator type mismatch");
-    return xnn_status_invalid_parameter;
-  }
   max_pooling_op->state = xnn_run_state_invalid;
 
   if (!xnn_params.initialized) {
@@ -459,7 +338,6 @@
     return xnn_status_success;
   }
 
-  max_pooling_op->batch_size = batch_size;
   max_pooling_op->input_height = input_height;
   max_pooling_op->input_width = input_width;
   max_pooling_op->input = input;
@@ -474,76 +352,118 @@
       max_pooling_op->kernel_width,
       max_pooling_op->dilation_width,
       max_pooling_op->stride_width);
-  max_pooling_op->output = output;
-
-  size_t valid_batch_size = 0;
-  if (input == max_pooling_op->last_input &&
-      input_height == max_pooling_op->last_input_height &&
-      input_width == max_pooling_op->last_input_width)
-  {
-    valid_batch_size = max_pooling_op->valid_batch_size;
-    if (batch_size <= valid_batch_size) {
-      max_pooling_op->compute.range[0] = batch_size;
-      max_pooling_op->state = xnn_run_state_ready;
-      return xnn_status_success;
-    }
-  }
 
   const size_t pooling_height = max_pooling_op->kernel_height;
   const size_t pooling_width = max_pooling_op->kernel_width;
   const size_t pooling_size = pooling_height * pooling_width;
   const size_t output_height = max_pooling_op->output_height;
   const size_t output_width = max_pooling_op->output_width;
-  // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
-  const uint32_t mr = xnn_params.f32.maxpool.mr;
+  const uint32_t mr = maxpool->mr;
 
   const size_t step_width =
     max_pooling_op->dilation_width > 1 ? pooling_width : min(max_pooling_op->stride_width, pooling_width);
   const size_t step_height = pooling_size + (output_width * step_width - 1) * pooling_height;
-  const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
 
-  const void** indirection_buffer = (const void**) xnn_reallocate_memory(max_pooling_op->indirection_buffer, indirection_buffer_size);
-  if (indirection_buffer == NULL) {
-    xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
-    return xnn_status_out_of_memory;
+  if (input_height != max_pooling_op->last_input_height ||
+      input_width != max_pooling_op->last_input_width)
+  {
+    // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
+    const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + output_height * step_height);
+    const void** indirection_buffer = (const void**) xnn_reallocate_memory(max_pooling_op->indirection_buffer, indirection_buffer_size);
+    if (indirection_buffer == NULL) {
+      xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+      return xnn_status_out_of_memory;
+    }
+    max_pooling_op->indirection_buffer = indirection_buffer;
+
+    xnn_indirection_init_maxpool2d(max_pooling_op, step_height, step_width, log2_input_element_size);
+
+    max_pooling_op->last_input = input;
+    max_pooling_op->last_input_height = input_height;
+    max_pooling_op->last_input_width = input_width;
   }
-  max_pooling_op->indirection_buffer = indirection_buffer;
 
-  xnn_indirection_init_maxpool2d(max_pooling_op, valid_batch_size, step_height, step_width, 2 /* log2(sizeof(float)) */);
-
-  const uint32_t qr = xnn_params.f32.maxpool.qr;
+  const uint32_t qr = maxpool->qr;
   const size_t channels = max_pooling_op->channels;
 
   const size_t indirect_input_height_stride = step_height * sizeof(void*);
-  const size_t output_width_stride = max_pooling_op->output_pixel_stride * sizeof(float);
+  const size_t output_width_stride = max_pooling_op->output_pixel_stride << log2_output_element_size;
   const size_t output_height_stride = output_width * output_width_stride;
   const size_t multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
 
   max_pooling_op->context.max_pooling = (struct max_pooling_context) {
-      .indirect_input = indirection_buffer,
-      .indirect_input_batch_stride = output_height * indirect_input_height_stride,
-      .indirect_input_height_stride = indirect_input_height_stride,
-      .output = output,
-      .output_batch_stride = output_height * output_height_stride,
-      .output_height_stride = output_height_stride,
-      .output_width = output_width,
-      .pooling_size = pooling_size,
-      .channels = channels,
-      .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
-      .output_increment = output_width_stride - channels * sizeof(float),
-      .params.f32 = max_pooling_op->f32_output_params,
-      .ukernel = xnn_params.f32.maxpool.ukernel,
+    .indirect_input = max_pooling_op->indirection_buffer,
+    .indirect_input_height_stride = indirect_input_height_stride,
+    .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) max_pooling_op->last_input),
+    .input_batch_stride = (input_height * input_width * max_pooling_op->input_pixel_stride) << log2_input_element_size,
+    .output = output,
+    .output_batch_stride = output_height * output_height_stride,
+    .output_height_stride = output_height_stride,
+    .output_width = output_width,
+    .pooling_size = pooling_size,
+    .channels = channels,
+    .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+    .output_increment = output_width_stride - (channels << log2_output_element_size),
+    .ukernel = maxpool->ukernel,
   };
+  memcpy(&max_pooling_op->context.max_pooling.params, params, sizeof(max_pooling_op->context.max_pooling.params));
+
   max_pooling_op->compute.type = xnn_parallelization_type_2d;
   max_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_max_pooling;
   max_pooling_op->compute.range[0] = batch_size;
   max_pooling_op->compute.range[1] = output_height;
   max_pooling_op->state = xnn_run_state_ready;
 
-  max_pooling_op->last_input = input;
-  max_pooling_op->last_input_height = input_height;
-  max_pooling_op->last_input_width = input_width;
-  max_pooling_op->valid_batch_size = max(valid_batch_size, batch_size);
-
   return xnn_status_success;
 }
+
+enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
+    xnn_operator_t max_pooling_op,
+    size_t batch_size,
+    size_t input_height,
+    size_t input_width,
+    const uint8_t* input,
+    uint8_t* output,
+    pthreadpool_t threadpool)
+{
+  if (max_pooling_op->type != xnn_operator_type_max_pooling_nhwc_u8) {
+    xnn_log_error("failed to setup Max Pooling (NHWC, U8) operator: operator type mismatch");
+    return xnn_status_invalid_parameter;
+  }
+
+  return setup_max_pooling2d(
+    max_pooling_op,
+    batch_size, input_height, input_width, 
+    input, output,
+    0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
+    0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
+    &xnn_params.u8.maxpool,
+    &max_pooling_op->u8_output_params,
+    pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
+    xnn_operator_t max_pooling_op,
+    size_t batch_size,
+    size_t input_height,
+    size_t input_width,
+    const float* input,
+    float* output,
+    pthreadpool_t threadpool)
+{
+  if (max_pooling_op->type != xnn_operator_type_max_pooling_nhwc_f32) {
+    xnn_log_error("failed to setup Max Pooling (NHWC, F32) operator: operator type mismatch");
+    return xnn_status_invalid_parameter;
+  }
+
+  return setup_max_pooling2d(
+    max_pooling_op,
+    batch_size, input_height, input_width, 
+    input, output,
+    2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
+    2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
+    &xnn_params.f32.maxpool,
+    &max_pooling_op->f32_output_params,
+    pthreadpool_get_threads_count(threadpool));
+}
+