Refactor Average Pooling setup

Extract datatype-independent setup function

PiperOrigin-RevId: 300310995
diff --git a/src/average-pooling-nhwc.c b/src/average-pooling-nhwc.c
index 69756af..e586b19 100644
--- a/src/average-pooling-nhwc.c
+++ b/src/average-pooling-nhwc.c
@@ -389,19 +389,24 @@
   return status;
 }
 
-enum xnn_status xnn_setup_average_pooling2d_nhwc_q8(
-    xnn_operator_t average_pooling_op,
-    size_t batch_size,
-    size_t input_height,
-    size_t input_width,
-    const uint8_t* input,
-    uint8_t* output,
-    pthreadpool_t threadpool)
+static enum xnn_status setup_average_pooling2d(
+  xnn_operator_t average_pooling_op,
+  size_t batch_size,
+  size_t input_height,
+  size_t input_width,
+  const void* input,
+  void* output,
+  uint32_t log2_input_element_size,
+  uint32_t log2_output_element_size,
+  struct avgpool_parameters avgpool[restrict static 1],
+  struct pavgpool_parameters pavgpool[restrict 1],
+  const void* params,
+  size_t params_size,
+  size_t num_threads,
+  bool is_pixelwise)
 {
-  if (average_pooling_op->type != xnn_operator_type_average_pooling_nhwc_q8) {
-    xnn_log_error("failed to setup Average Pooling (Q8) operator: operator type mismatch");
-    return xnn_status_invalid_parameter;
-  }
+  assert(!is_pixelwise || pavgpool != NULL);
+
   average_pooling_op->state = xnn_run_state_invalid;
 
   if (!xnn_params.initialized) {
@@ -458,8 +463,8 @@
   const size_t pooling_size = pooling_height * pooling_width;
   const size_t output_height = average_pooling_op->output_height;
   const size_t output_width = average_pooling_op->output_width;
-  // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
-  const uint32_t mr = xnn_params.q8.avgpool.mr;
+
+  const uint32_t mr = is_pixelwise ? pavgpool->mr : avgpool->mr;
 
   const size_t step_width = min(average_pooling_op->stride_width, pooling_width);
   const size_t step_height = pooling_size + (output_width - 1) * step_width * pooling_height;
@@ -480,43 +485,101 @@
     // Indirection buffer always setup for batch size 1, larger batch size supported through input_offset argument
     average_pooling_op->batch_size = 1;
     xnn_indirection_init_dwconv2d(
-      average_pooling_op, 0, step_height, step_width, 0 /* log2(sizeof(uint8_t)) */);
+      average_pooling_op, 0, step_height, step_width, log2_input_element_size);
 
     average_pooling_op->last_input = input;
     average_pooling_op->last_input_height = input_height;
     average_pooling_op->last_input_width = input_width;
   }
 
-  const uint32_t qr = xnn_params.q8.avgpool.qr;
   const size_t channels = average_pooling_op->channels;
 
-  const size_t output_width_stride = average_pooling_op->output_pixel_stride * sizeof(uint8_t);
+  const size_t indirect_input_height_stride = step_height * sizeof(void*);
+  const size_t output_width_stride = average_pooling_op->output_pixel_stride << log2_output_element_size;
   const size_t output_height_stride = output_width * output_width_stride;
 
-  const size_t multipass_adjustment =
-    pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
-  average_pooling_op->context.average_pooling = (struct average_pooling_context) {
-    .indirect_input = average_pooling_op->indirection_buffer,
-    .indirect_input_height_stride = step_height * sizeof(void*),
-    .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
-    .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride * sizeof(uint8_t),
-    .output = output,
-    .output_batch_stride = output_height * output_height_stride,
-    .output_height_stride = output_height_stride,
-    .output_width = output_width,
-    .pooling_size = pooling_size,
-    .channels = channels,
-    .zero = average_pooling_op->zero_buffer,
-    .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
-    .output_increment = output_width_stride - channels * sizeof(uint8_t),
-    .params.q8 = average_pooling_op->q8_avgpool_params,
-  };
-  if (pooling_size <= mr) {
-    average_pooling_op->context.average_pooling.unipass_ukernel = xnn_params.q8.avgpool.up;
-    average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_unipass;
+  if (is_pixelwise) {
+    if (input_height != last_input_height || input_width != last_input_width) {
+      const size_t pixelwise_buffer_size = output_height * output_width * sizeof(float);
+      float* pixelwise_buffer = (float*) xnn_reallocate_memory(average_pooling_op->pixelwise_buffer, pixelwise_buffer_size);
+      if (pixelwise_buffer == NULL) {
+        xnn_log_error("failed to allocate %zu bytes for pixelwise buffer", pixelwise_buffer_size);
+        return xnn_status_out_of_memory;
+      }
+      average_pooling_op->pixelwise_buffer = pixelwise_buffer;
+
+      float* pixelwise_pointer = pixelwise_buffer;
+      for (size_t output_y = 0; output_y < output_height; output_y++) {
+        const size_t input_y_start = doz(output_y * average_pooling_op->stride_height, average_pooling_op->padding_top);
+        const size_t input_y_end =
+          min(doz(output_y * average_pooling_op->stride_height + average_pooling_op->kernel_height, average_pooling_op->padding_top), input_height);
+        const uint32_t input_y_range = (uint32_t) (input_y_end - input_y_start);
+        for (size_t output_x = 0; output_x < output_width; output_x++) {
+          const size_t input_x_start = doz(output_x * average_pooling_op->stride_width, average_pooling_op->padding_left);
+          const size_t input_x_end =
+            min(doz(output_x * average_pooling_op->stride_width + average_pooling_op->kernel_width, average_pooling_op->padding_left), input_width);
+          const uint32_t input_x_range = (uint32_t) (input_x_end - input_x_start);
+          *pixelwise_pointer++ = 1.0f / ((float) (int32_t) (input_y_range * input_x_range));
+        }
+      }
+    }
+
+    const uint32_t qr = pavgpool->qr;
+    const size_t multipass_adjustment =
+      pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
+    average_pooling_op->context.pixelwise_average_pooling = (struct pixelwise_average_pooling_context) {
+      .indirect_input = average_pooling_op->indirection_buffer,
+      .indirect_input_height_stride = indirect_input_height_stride,
+      .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride << log2_input_element_size,
+      .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
+      .pixelwise_buffer = average_pooling_op->pixelwise_buffer,
+      .pixelwise_buffer_height_stride = output_width * sizeof(float),
+      .output = output,
+      .output_batch_stride = output_height * output_height_stride,
+      .output_height_stride = output_height_stride,
+      .output_width = output_width,
+      .pooling_size = pooling_size,
+      .channels = channels,
+      .zero = average_pooling_op->zero_buffer,
+      .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+      .output_increment = output_width_stride - (channels << log2_output_element_size),
+    };
+    memcpy(&average_pooling_op->context.pixelwise_average_pooling.params, params, params_size);
+    if (pooling_size <= mr) {
+      average_pooling_op->context.pixelwise_average_pooling.unipass_ukernel = pavgpool->up;
+      average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_unipass;
+    } else {
+      average_pooling_op->context.pixelwise_average_pooling.multipass_ukernel = pavgpool->mp;
+      average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_multipass;
+    }
   } else {
-    average_pooling_op->context.average_pooling.multipass_ukernel = xnn_params.q8.avgpool.mp;
-    average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_multipass;
+    const uint32_t qr = avgpool->qr;
+    const size_t multipass_adjustment =
+      pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
+    average_pooling_op->context.average_pooling = (struct average_pooling_context) {
+      .indirect_input = average_pooling_op->indirection_buffer,
+      .indirect_input_height_stride = indirect_input_height_stride,
+      .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
+      .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride << log2_input_element_size,
+      .output = output,
+      .output_batch_stride = output_height * output_height_stride,
+      .output_height_stride = output_height_stride,
+      .output_width = output_width,
+      .pooling_size = pooling_size,
+      .channels = channels,
+      .zero = average_pooling_op->zero_buffer,
+      .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+      .output_increment = output_width_stride - (channels << log2_output_element_size),
+      .params.f32 = average_pooling_op->f32_avgpool_params,
+    };
+    memcpy(&average_pooling_op->context.average_pooling.params, params, params_size);
+    if (pooling_size <= mr) {
+      average_pooling_op->context.average_pooling.unipass_ukernel = avgpool->up;
+      average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_unipass;
+    } else {
+      average_pooling_op->context.average_pooling.multipass_ukernel = avgpool->mp;
+      average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_multipass;
+    }
   }
   average_pooling_op->compute.type = xnn_parallelization_type_2d;
   average_pooling_op->compute.range[0] = batch_size;
@@ -526,6 +589,36 @@
   return xnn_status_success;
 }
 
+enum xnn_status xnn_setup_average_pooling2d_nhwc_q8(
+    xnn_operator_t average_pooling_op,
+    size_t batch_size,
+    size_t input_height,
+    size_t input_width,
+    const uint8_t* input,
+    uint8_t* output,
+    pthreadpool_t threadpool)
+{
+  if (average_pooling_op->type != xnn_operator_type_average_pooling_nhwc_q8) {
+    xnn_log_error("failed to setup Average Pooling (Q8) operator: operator type mismatch");
+    return xnn_status_invalid_parameter;
+  }
+
+  assert(average_pooling_op->ukernel.type == xnn_ukernel_type_average_pooling);
+
+  return setup_average_pooling2d(
+    average_pooling_op,
+    batch_size, input_height, input_width,
+    input, output,
+    0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
+    0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
+    &xnn_params.q8.avgpool,
+    NULL /* pavgpool */,
+    &average_pooling_op->q8_avgpool_params,
+    sizeof(average_pooling_op->q8_avgpool_params),
+    pthreadpool_get_threads_count(threadpool),
+    false /* pixelwise not supported */);
+}
+
 enum xnn_status xnn_setup_average_pooling2d_nhwc_f32(
     xnn_operator_t average_pooling_op,
     size_t batch_size,
@@ -539,193 +632,20 @@
     xnn_log_error("failed to setup Average Pooling (F32) operator: operator type mismatch");
     return xnn_status_invalid_parameter;
   }
-  average_pooling_op->state = xnn_run_state_invalid;
 
-  if (!xnn_params.initialized) {
-    xnn_log_error("failed to setup Average Pooling operator: XNNPACK is not initialized");
-    return xnn_status_uninitialized;
-  }
+  assert(average_pooling_op->ukernel.type == xnn_ukernel_type_average_pooling ||
+         average_pooling_op->ukernel.type == xnn_ukernel_type_pixelwise_average_pooling);
 
-  if (input_width == 0 || input_height == 0) {
-    xnn_log_error(
-      "failed to setup Average Pooling operator with %zux%zu input: input dimensions must be non-zero",
-      input_width, input_height);
-    return xnn_status_invalid_parameter;
-  }
-
-  if (batch_size == 0) {
-    average_pooling_op->state = xnn_run_state_skip;
-    return xnn_status_success;
-  }
-
-  average_pooling_op->input_height = input_height;
-  average_pooling_op->input_width = input_width;
-  average_pooling_op->input = input;
-
-  if (average_pooling_op->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
-    average_pooling_op->output_height = compute_output_dimension_with_tf_same_padding(
-        input_height, average_pooling_op->stride_height);
-    average_pooling_op->output_width = compute_output_dimension_with_tf_same_padding(
-        input_width, average_pooling_op->stride_width);
-
-    const uint32_t effective_kernel_height = (average_pooling_op->kernel_height - 1) * average_pooling_op->dilation_height + 1;
-    const uint32_t effective_kernel_width = (average_pooling_op->kernel_width - 1) * average_pooling_op->dilation_width + 1;
-    const uint32_t total_padding_height =
-      (average_pooling_op->output_height - 1) * average_pooling_op->stride_height + effective_kernel_height - input_height;
-    const uint32_t total_padding_width =
-      (average_pooling_op->output_width - 1) * average_pooling_op->stride_width + effective_kernel_width - input_width;
-    average_pooling_op->padding_top = total_padding_height / 2;
-    average_pooling_op->padding_left = total_padding_width / 2;
-    average_pooling_op->padding_bottom = total_padding_height - average_pooling_op->padding_top;
-    average_pooling_op->padding_right = total_padding_width - average_pooling_op->padding_left;
-  } else {
-    average_pooling_op->output_height = compute_output_dimension(
-        average_pooling_op->padding_top + input_height + average_pooling_op->padding_bottom,
-        average_pooling_op->kernel_height,
-        average_pooling_op->stride_height);
-    average_pooling_op->output_width = compute_output_dimension(
-        average_pooling_op->padding_left + input_width + average_pooling_op->padding_right,
-        average_pooling_op->kernel_width,
-        average_pooling_op->stride_width);
-  }
-  average_pooling_op->output = output;
-
-  const size_t pooling_height = average_pooling_op->kernel_height;
-  const size_t pooling_width = average_pooling_op->kernel_width;
-  const size_t pooling_size = pooling_height * pooling_width;
-  const size_t output_height = average_pooling_op->output_height;
-  const size_t output_width = average_pooling_op->output_width;
-  // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
-  const uint32_t mr = xnn_params.f32.avgpool.mr;
-  assert(mr == xnn_params.f32.pavgpool.mr);
-
-  const size_t step_width = min(average_pooling_op->stride_width, pooling_width);
-  const size_t step_height = pooling_size + (output_width - 1) * step_width * pooling_height;
-
-  const size_t last_input_height = average_pooling_op->last_input_height;
-  const size_t last_input_width = average_pooling_op->last_input_width;
-  if (input_height != last_input_height || input_width != last_input_width) {
-    // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
-    const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
-
-    const void** indirection_buffer = (const void**) xnn_reallocate_memory(average_pooling_op->indirection_buffer, indirection_buffer_size);
-    if (indirection_buffer == NULL) {
-      xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
-      return xnn_status_out_of_memory;
-    }
-    average_pooling_op->indirection_buffer = indirection_buffer;
-
-    // Indirection buffer always setup for batch size 1, larger batch size supported through input_offset argument
-    average_pooling_op->batch_size = 1;
-    xnn_indirection_init_dwconv2d(
-      average_pooling_op, 0, step_height, step_width, 2 /* log2(sizeof(float)) */);
-
-    average_pooling_op->last_input = input;
-    average_pooling_op->last_input_height = input_height;
-    average_pooling_op->last_input_width = input_width;
-  }
-
-  const size_t channels = average_pooling_op->channels;
-
-  const size_t indirect_input_height_stride = step_height * sizeof(void*);
-  const size_t output_width_stride = average_pooling_op->output_pixel_stride * sizeof(float);
-  const size_t output_height_stride = output_width * output_width_stride;
-
-  switch (average_pooling_op->ukernel.type) {
-    case xnn_ukernel_type_average_pooling:
-    {
-      const uint32_t qr = xnn_params.f32.avgpool.qr;
-      const size_t multipass_adjustment =
-        pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
-      average_pooling_op->context.average_pooling = (struct average_pooling_context) {
-        .indirect_input = average_pooling_op->indirection_buffer,
-        .indirect_input_height_stride = indirect_input_height_stride,
-        .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
-        .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride * sizeof(float),
-        .output = output,
-        .output_batch_stride = output_height * output_height_stride,
-        .output_height_stride = output_height_stride,
-        .output_width = output_width,
-        .pooling_size = pooling_size,
-        .channels = channels,
-        .zero = average_pooling_op->zero_buffer,
-        .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
-        .output_increment = output_width_stride - channels * sizeof(float),
-        .params.f32 = average_pooling_op->f32_avgpool_params,
-      };
-      if (pooling_size <= mr) {
-        average_pooling_op->context.average_pooling.unipass_ukernel = xnn_params.f32.avgpool.up;
-        average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_unipass;
-      } else {
-        average_pooling_op->context.average_pooling.multipass_ukernel = xnn_params.f32.avgpool.mp;
-        average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_multipass;
-      }
-      break;
-    }
-    case xnn_ukernel_type_pixelwise_average_pooling:
-    {
-      if (input_height != last_input_height || input_width != last_input_width) {
-        const size_t pixelwise_buffer_size = output_height * output_width * sizeof(float);
-        float* pixelwise_buffer = (float*) xnn_reallocate_memory(average_pooling_op->pixelwise_buffer, pixelwise_buffer_size);
-        if (pixelwise_buffer == NULL) {
-          xnn_log_error("failed to allocate %zu bytes for pixelwise buffer", pixelwise_buffer_size);
-          return xnn_status_out_of_memory;
-        }
-        average_pooling_op->pixelwise_buffer = pixelwise_buffer;
-
-        float* pixelwise_pointer = pixelwise_buffer;
-        for (size_t output_y = 0; output_y < output_height; output_y++) {
-          const size_t input_y_start = doz(output_y * average_pooling_op->stride_height, average_pooling_op->padding_top);
-          const size_t input_y_end =
-            min(doz(output_y * average_pooling_op->stride_height + average_pooling_op->kernel_height, average_pooling_op->padding_top), input_height);
-          const uint32_t input_y_range = (uint32_t) (input_y_end - input_y_start);
-          for (size_t output_x = 0; output_x < output_width; output_x++) {
-            const size_t input_x_start = doz(output_x * average_pooling_op->stride_width, average_pooling_op->padding_left);
-            const size_t input_x_end =
-              min(doz(output_x * average_pooling_op->stride_width + average_pooling_op->kernel_width, average_pooling_op->padding_left), input_width);
-            const uint32_t input_x_range = (uint32_t) (input_x_end - input_x_start);
-            *pixelwise_pointer++ = 1.0f / ((float) (int32_t) (input_y_range * input_x_range));
-          }
-        }
-      }
-
-      const uint32_t qr = xnn_params.f32.pavgpool.qr;
-      const size_t multipass_adjustment =
-        pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
-      average_pooling_op->context.pixelwise_average_pooling = (struct pixelwise_average_pooling_context) {
-        .indirect_input = average_pooling_op->indirection_buffer,
-        .indirect_input_height_stride = indirect_input_height_stride,
-        .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride * sizeof(float),
-        .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
-        .pixelwise_buffer = average_pooling_op->pixelwise_buffer,
-        .pixelwise_buffer_height_stride = output_width * sizeof(float),
-        .output = output,
-        .output_batch_stride = output_height * output_height_stride,
-        .output_height_stride = output_height_stride,
-        .output_width = output_width,
-        .pooling_size = pooling_size,
-        .channels = channels,
-        .zero = average_pooling_op->zero_buffer,
-        .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
-        .output_increment = output_width_stride - channels * sizeof(float),
-        .params.f32 = average_pooling_op->f32_output_params,
-      };
-      if (pooling_size <= mr) {
-        average_pooling_op->context.pixelwise_average_pooling.unipass_ukernel = xnn_params.f32.pavgpool.up;
-        average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_unipass;
-      } else {
-        average_pooling_op->context.pixelwise_average_pooling.multipass_ukernel = xnn_params.f32.pavgpool.mp;
-        average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_multipass;
-      }
-      break;
-    }
-    default:
-      XNN_UNREACHABLE;
-  }
-  average_pooling_op->compute.type = xnn_parallelization_type_2d;
-  average_pooling_op->compute.range[0] = batch_size;
-  average_pooling_op->compute.range[1] = output_height;
-  average_pooling_op->state = xnn_run_state_ready;
-
-  return xnn_status_success;
+  return setup_average_pooling2d(
+    average_pooling_op,
+    batch_size, input_height, input_width,
+    input, output,
+    2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
+    2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
+    &xnn_params.f32.avgpool,
+    &xnn_params.f32.pavgpool,
+    &average_pooling_op->f32_avgpool_params,
+    sizeof(average_pooling_op->f32_avgpool_params),
+    pthreadpool_get_threads_count(threadpool),
+    average_pooling_op->ukernel.type == xnn_ukernel_type_pixelwise_average_pooling);
 }