Refactor Average Pooling setup
Extract datatype-independent setup function
PiperOrigin-RevId: 300310995
diff --git a/src/average-pooling-nhwc.c b/src/average-pooling-nhwc.c
index 69756af..e586b19 100644
--- a/src/average-pooling-nhwc.c
+++ b/src/average-pooling-nhwc.c
@@ -389,19 +389,24 @@
return status;
}
-enum xnn_status xnn_setup_average_pooling2d_nhwc_q8(
- xnn_operator_t average_pooling_op,
- size_t batch_size,
- size_t input_height,
- size_t input_width,
- const uint8_t* input,
- uint8_t* output,
- pthreadpool_t threadpool)
+static enum xnn_status setup_average_pooling2d(
+ xnn_operator_t average_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const void* input,
+ void* output,
+ uint32_t log2_input_element_size,
+ uint32_t log2_output_element_size,
+ struct avgpool_parameters avgpool[restrict static 1],
+ struct pavgpool_parameters pavgpool[restrict 1],
+ const void* params,
+ size_t params_size,
+ size_t num_threads,
+ bool is_pixelwise)
{
- if (average_pooling_op->type != xnn_operator_type_average_pooling_nhwc_q8) {
- xnn_log_error("failed to setup Average Pooling (Q8) operator: operator type mismatch");
- return xnn_status_invalid_parameter;
- }
+ assert(!is_pixelwise || pavgpool != NULL);
+
average_pooling_op->state = xnn_run_state_invalid;
if (!xnn_params.initialized) {
@@ -458,8 +463,8 @@
const size_t pooling_size = pooling_height * pooling_width;
const size_t output_height = average_pooling_op->output_height;
const size_t output_width = average_pooling_op->output_width;
- // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
- const uint32_t mr = xnn_params.q8.avgpool.mr;
+
+ const uint32_t mr = is_pixelwise ? pavgpool->mr : avgpool->mr;
const size_t step_width = min(average_pooling_op->stride_width, pooling_width);
const size_t step_height = pooling_size + (output_width - 1) * step_width * pooling_height;
@@ -480,43 +485,101 @@
// Indirection buffer always setup for batch size 1, larger batch size supported through input_offset argument
average_pooling_op->batch_size = 1;
xnn_indirection_init_dwconv2d(
- average_pooling_op, 0, step_height, step_width, 0 /* log2(sizeof(uint8_t)) */);
+ average_pooling_op, 0, step_height, step_width, log2_input_element_size);
average_pooling_op->last_input = input;
average_pooling_op->last_input_height = input_height;
average_pooling_op->last_input_width = input_width;
}
- const uint32_t qr = xnn_params.q8.avgpool.qr;
const size_t channels = average_pooling_op->channels;
- const size_t output_width_stride = average_pooling_op->output_pixel_stride * sizeof(uint8_t);
+ const size_t indirect_input_height_stride = step_height * sizeof(void*);
+ const size_t output_width_stride = average_pooling_op->output_pixel_stride << log2_output_element_size;
const size_t output_height_stride = output_width * output_width_stride;
- const size_t multipass_adjustment =
- pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
- average_pooling_op->context.average_pooling = (struct average_pooling_context) {
- .indirect_input = average_pooling_op->indirection_buffer,
- .indirect_input_height_stride = step_height * sizeof(void*),
- .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
- .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride * sizeof(uint8_t),
- .output = output,
- .output_batch_stride = output_height * output_height_stride,
- .output_height_stride = output_height_stride,
- .output_width = output_width,
- .pooling_size = pooling_size,
- .channels = channels,
- .zero = average_pooling_op->zero_buffer,
- .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
- .output_increment = output_width_stride - channels * sizeof(uint8_t),
- .params.q8 = average_pooling_op->q8_avgpool_params,
- };
- if (pooling_size <= mr) {
- average_pooling_op->context.average_pooling.unipass_ukernel = xnn_params.q8.avgpool.up;
- average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_unipass;
+ if (is_pixelwise) {
+ if (input_height != last_input_height || input_width != last_input_width) {
+ const size_t pixelwise_buffer_size = output_height * output_width * sizeof(float);
+ float* pixelwise_buffer = (float*) xnn_reallocate_memory(average_pooling_op->pixelwise_buffer, pixelwise_buffer_size);
+ if (pixelwise_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for pixelwise buffer", pixelwise_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ average_pooling_op->pixelwise_buffer = pixelwise_buffer;
+
+ float* pixelwise_pointer = pixelwise_buffer;
+ for (size_t output_y = 0; output_y < output_height; output_y++) {
+ const size_t input_y_start = doz(output_y * average_pooling_op->stride_height, average_pooling_op->padding_top);
+ const size_t input_y_end =
+ min(doz(output_y * average_pooling_op->stride_height + average_pooling_op->kernel_height, average_pooling_op->padding_top), input_height);
+ const uint32_t input_y_range = (uint32_t) (input_y_end - input_y_start);
+ for (size_t output_x = 0; output_x < output_width; output_x++) {
+ const size_t input_x_start = doz(output_x * average_pooling_op->stride_width, average_pooling_op->padding_left);
+ const size_t input_x_end =
+ min(doz(output_x * average_pooling_op->stride_width + average_pooling_op->kernel_width, average_pooling_op->padding_left), input_width);
+ const uint32_t input_x_range = (uint32_t) (input_x_end - input_x_start);
+ *pixelwise_pointer++ = 1.0f / ((float) (int32_t) (input_y_range * input_x_range));
+ }
+ }
+ }
+
+ const uint32_t qr = pavgpool->qr;
+ const size_t multipass_adjustment =
+ pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
+ average_pooling_op->context.pixelwise_average_pooling = (struct pixelwise_average_pooling_context) {
+ .indirect_input = average_pooling_op->indirection_buffer,
+ .indirect_input_height_stride = indirect_input_height_stride,
+ .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride << log2_input_element_size,
+ .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
+ .pixelwise_buffer = average_pooling_op->pixelwise_buffer,
+ .pixelwise_buffer_height_stride = output_width * sizeof(float),
+ .output = output,
+ .output_batch_stride = output_height * output_height_stride,
+ .output_height_stride = output_height_stride,
+ .output_width = output_width,
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .zero = average_pooling_op->zero_buffer,
+ .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+ .output_increment = output_width_stride - (channels << log2_output_element_size),
+ };
+ memcpy(&average_pooling_op->context.pixelwise_average_pooling.params, params, params_size);
+ if (pooling_size <= mr) {
+ average_pooling_op->context.pixelwise_average_pooling.unipass_ukernel = pavgpool->up;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_unipass;
+ } else {
+ average_pooling_op->context.pixelwise_average_pooling.multipass_ukernel = pavgpool->mp;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_multipass;
+ }
} else {
- average_pooling_op->context.average_pooling.multipass_ukernel = xnn_params.q8.avgpool.mp;
- average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_multipass;
+ const uint32_t qr = avgpool->qr;
+ const size_t multipass_adjustment =
+ pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
+ average_pooling_op->context.average_pooling = (struct average_pooling_context) {
+ .indirect_input = average_pooling_op->indirection_buffer,
+ .indirect_input_height_stride = indirect_input_height_stride,
+ .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
+ .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride << log2_input_element_size,
+ .output = output,
+ .output_batch_stride = output_height * output_height_stride,
+ .output_height_stride = output_height_stride,
+ .output_width = output_width,
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .zero = average_pooling_op->zero_buffer,
+ .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+ .output_increment = output_width_stride - (channels << log2_output_element_size),
+ .params.f32 = average_pooling_op->f32_avgpool_params,
+ };
+ memcpy(&average_pooling_op->context.average_pooling.params, params, params_size);
+ if (pooling_size <= mr) {
+ average_pooling_op->context.average_pooling.unipass_ukernel = avgpool->up;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_unipass;
+ } else {
+ average_pooling_op->context.average_pooling.multipass_ukernel = avgpool->mp;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_multipass;
+ }
}
average_pooling_op->compute.type = xnn_parallelization_type_2d;
average_pooling_op->compute.range[0] = batch_size;
@@ -526,6 +589,36 @@
return xnn_status_success;
}
+enum xnn_status xnn_setup_average_pooling2d_nhwc_q8(
+ xnn_operator_t average_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (average_pooling_op->type != xnn_operator_type_average_pooling_nhwc_q8) {
+ xnn_log_error("failed to setup Average Pooling (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ assert(average_pooling_op->ukernel.type == xnn_ukernel_type_average_pooling);
+
+ return setup_average_pooling2d(
+ average_pooling_op,
+ batch_size, input_height, input_width,
+ input, output,
+ 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
+ 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
+ &xnn_params.q8.avgpool,
+ NULL /* pavgpool */,
+ &average_pooling_op->q8_avgpool_params,
+ sizeof(average_pooling_op->q8_avgpool_params),
+ pthreadpool_get_threads_count(threadpool),
+ false /* pixelwise not supported */);
+}
+
enum xnn_status xnn_setup_average_pooling2d_nhwc_f32(
xnn_operator_t average_pooling_op,
size_t batch_size,
@@ -539,193 +632,20 @@
xnn_log_error("failed to setup Average Pooling (F32) operator: operator type mismatch");
return xnn_status_invalid_parameter;
}
- average_pooling_op->state = xnn_run_state_invalid;
- if (!xnn_params.initialized) {
- xnn_log_error("failed to setup Average Pooling operator: XNNPACK is not initialized");
- return xnn_status_uninitialized;
- }
+ assert(average_pooling_op->ukernel.type == xnn_ukernel_type_average_pooling ||
+ average_pooling_op->ukernel.type == xnn_ukernel_type_pixelwise_average_pooling);
- if (input_width == 0 || input_height == 0) {
- xnn_log_error(
- "failed to setup Average Pooling operator with %zux%zu input: input dimensions must be non-zero",
- input_width, input_height);
- return xnn_status_invalid_parameter;
- }
-
- if (batch_size == 0) {
- average_pooling_op->state = xnn_run_state_skip;
- return xnn_status_success;
- }
-
- average_pooling_op->input_height = input_height;
- average_pooling_op->input_width = input_width;
- average_pooling_op->input = input;
-
- if (average_pooling_op->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
- average_pooling_op->output_height = compute_output_dimension_with_tf_same_padding(
- input_height, average_pooling_op->stride_height);
- average_pooling_op->output_width = compute_output_dimension_with_tf_same_padding(
- input_width, average_pooling_op->stride_width);
-
- const uint32_t effective_kernel_height = (average_pooling_op->kernel_height - 1) * average_pooling_op->dilation_height + 1;
- const uint32_t effective_kernel_width = (average_pooling_op->kernel_width - 1) * average_pooling_op->dilation_width + 1;
- const uint32_t total_padding_height =
- (average_pooling_op->output_height - 1) * average_pooling_op->stride_height + effective_kernel_height - input_height;
- const uint32_t total_padding_width =
- (average_pooling_op->output_width - 1) * average_pooling_op->stride_width + effective_kernel_width - input_width;
- average_pooling_op->padding_top = total_padding_height / 2;
- average_pooling_op->padding_left = total_padding_width / 2;
- average_pooling_op->padding_bottom = total_padding_height - average_pooling_op->padding_top;
- average_pooling_op->padding_right = total_padding_width - average_pooling_op->padding_left;
- } else {
- average_pooling_op->output_height = compute_output_dimension(
- average_pooling_op->padding_top + input_height + average_pooling_op->padding_bottom,
- average_pooling_op->kernel_height,
- average_pooling_op->stride_height);
- average_pooling_op->output_width = compute_output_dimension(
- average_pooling_op->padding_left + input_width + average_pooling_op->padding_right,
- average_pooling_op->kernel_width,
- average_pooling_op->stride_width);
- }
- average_pooling_op->output = output;
-
- const size_t pooling_height = average_pooling_op->kernel_height;
- const size_t pooling_width = average_pooling_op->kernel_width;
- const size_t pooling_size = pooling_height * pooling_width;
- const size_t output_height = average_pooling_op->output_height;
- const size_t output_width = average_pooling_op->output_width;
- // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
- const uint32_t mr = xnn_params.f32.avgpool.mr;
- assert(mr == xnn_params.f32.pavgpool.mr);
-
- const size_t step_width = min(average_pooling_op->stride_width, pooling_width);
- const size_t step_height = pooling_size + (output_width - 1) * step_width * pooling_height;
-
- const size_t last_input_height = average_pooling_op->last_input_height;
- const size_t last_input_width = average_pooling_op->last_input_width;
- if (input_height != last_input_height || input_width != last_input_width) {
- // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
- const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
-
- const void** indirection_buffer = (const void**) xnn_reallocate_memory(average_pooling_op->indirection_buffer, indirection_buffer_size);
- if (indirection_buffer == NULL) {
- xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
- return xnn_status_out_of_memory;
- }
- average_pooling_op->indirection_buffer = indirection_buffer;
-
- // Indirection buffer always setup for batch size 1, larger batch size supported through input_offset argument
- average_pooling_op->batch_size = 1;
- xnn_indirection_init_dwconv2d(
- average_pooling_op, 0, step_height, step_width, 2 /* log2(sizeof(float)) */);
-
- average_pooling_op->last_input = input;
- average_pooling_op->last_input_height = input_height;
- average_pooling_op->last_input_width = input_width;
- }
-
- const size_t channels = average_pooling_op->channels;
-
- const size_t indirect_input_height_stride = step_height * sizeof(void*);
- const size_t output_width_stride = average_pooling_op->output_pixel_stride * sizeof(float);
- const size_t output_height_stride = output_width * output_width_stride;
-
- switch (average_pooling_op->ukernel.type) {
- case xnn_ukernel_type_average_pooling:
- {
- const uint32_t qr = xnn_params.f32.avgpool.qr;
- const size_t multipass_adjustment =
- pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
- average_pooling_op->context.average_pooling = (struct average_pooling_context) {
- .indirect_input = average_pooling_op->indirection_buffer,
- .indirect_input_height_stride = indirect_input_height_stride,
- .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
- .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride * sizeof(float),
- .output = output,
- .output_batch_stride = output_height * output_height_stride,
- .output_height_stride = output_height_stride,
- .output_width = output_width,
- .pooling_size = pooling_size,
- .channels = channels,
- .zero = average_pooling_op->zero_buffer,
- .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
- .output_increment = output_width_stride - channels * sizeof(float),
- .params.f32 = average_pooling_op->f32_avgpool_params,
- };
- if (pooling_size <= mr) {
- average_pooling_op->context.average_pooling.unipass_ukernel = xnn_params.f32.avgpool.up;
- average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_unipass;
- } else {
- average_pooling_op->context.average_pooling.multipass_ukernel = xnn_params.f32.avgpool.mp;
- average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_multipass;
- }
- break;
- }
- case xnn_ukernel_type_pixelwise_average_pooling:
- {
- if (input_height != last_input_height || input_width != last_input_width) {
- const size_t pixelwise_buffer_size = output_height * output_width * sizeof(float);
- float* pixelwise_buffer = (float*) xnn_reallocate_memory(average_pooling_op->pixelwise_buffer, pixelwise_buffer_size);
- if (pixelwise_buffer == NULL) {
- xnn_log_error("failed to allocate %zu bytes for pixelwise buffer", pixelwise_buffer_size);
- return xnn_status_out_of_memory;
- }
- average_pooling_op->pixelwise_buffer = pixelwise_buffer;
-
- float* pixelwise_pointer = pixelwise_buffer;
- for (size_t output_y = 0; output_y < output_height; output_y++) {
- const size_t input_y_start = doz(output_y * average_pooling_op->stride_height, average_pooling_op->padding_top);
- const size_t input_y_end =
- min(doz(output_y * average_pooling_op->stride_height + average_pooling_op->kernel_height, average_pooling_op->padding_top), input_height);
- const uint32_t input_y_range = (uint32_t) (input_y_end - input_y_start);
- for (size_t output_x = 0; output_x < output_width; output_x++) {
- const size_t input_x_start = doz(output_x * average_pooling_op->stride_width, average_pooling_op->padding_left);
- const size_t input_x_end =
- min(doz(output_x * average_pooling_op->stride_width + average_pooling_op->kernel_width, average_pooling_op->padding_left), input_width);
- const uint32_t input_x_range = (uint32_t) (input_x_end - input_x_start);
- *pixelwise_pointer++ = 1.0f / ((float) (int32_t) (input_y_range * input_x_range));
- }
- }
- }
-
- const uint32_t qr = xnn_params.f32.pavgpool.qr;
- const size_t multipass_adjustment =
- pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
- average_pooling_op->context.pixelwise_average_pooling = (struct pixelwise_average_pooling_context) {
- .indirect_input = average_pooling_op->indirection_buffer,
- .indirect_input_height_stride = indirect_input_height_stride,
- .input_batch_stride = input_height * input_width * average_pooling_op->input_pixel_stride * sizeof(float),
- .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) average_pooling_op->last_input),
- .pixelwise_buffer = average_pooling_op->pixelwise_buffer,
- .pixelwise_buffer_height_stride = output_width * sizeof(float),
- .output = output,
- .output_batch_stride = output_height * output_height_stride,
- .output_height_stride = output_height_stride,
- .output_width = output_width,
- .pooling_size = pooling_size,
- .channels = channels,
- .zero = average_pooling_op->zero_buffer,
- .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
- .output_increment = output_width_stride - channels * sizeof(float),
- .params.f32 = average_pooling_op->f32_output_params,
- };
- if (pooling_size <= mr) {
- average_pooling_op->context.pixelwise_average_pooling.unipass_ukernel = xnn_params.f32.pavgpool.up;
- average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_unipass;
- } else {
- average_pooling_op->context.pixelwise_average_pooling.multipass_ukernel = xnn_params.f32.pavgpool.mp;
- average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_multipass;
- }
- break;
- }
- default:
- XNN_UNREACHABLE;
- }
- average_pooling_op->compute.type = xnn_parallelization_type_2d;
- average_pooling_op->compute.range[0] = batch_size;
- average_pooling_op->compute.range[1] = output_height;
- average_pooling_op->state = xnn_run_state_ready;
-
- return xnn_status_success;
+ return setup_average_pooling2d(
+ average_pooling_op,
+ batch_size, input_height, input_width,
+ input, output,
+ 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
+ 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
+ &xnn_params.f32.avgpool,
+ &xnn_params.f32.pavgpool,
+ &average_pooling_op->f32_avgpool_params,
+ sizeof(average_pooling_op->f32_avgpool_params),
+ pthreadpool_get_threads_count(threadpool),
+ average_pooling_op->ukernel.type == xnn_ukernel_type_pixelwise_average_pooling);
}