Refactor MaxPool and ArgMaxPool micro-kernels
- Support input_offset argument in MaxPool and ArgMaxPool micro-kernels
- Use input_offset to make indirection buffer independent on batch size
- Simplify and auto-generate unit tests
- Use more descriptive names for micro-kernel parameters
PiperOrigin-RevId: 281447682
diff --git a/src/operator-run.c b/src/operator-run.c
index 28c0b5a..b4f8dbe 100644
--- a/src/operator-run.c
+++ b/src/operator-run.c
@@ -275,17 +275,17 @@
size_t batch_index,
size_t output_y)
{
- const void** indirect_input =
- (const void**) ((uintptr_t) context->indirect_input +
- batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
- void* output =
- (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
- uint32_t* index =
- (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
+ const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
+ output_y * context->indirect_input_height_stride);
+ const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
+ void* output = (void*) ((uintptr_t) context->output +
+ batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+ uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
+ batch_index * context->index_batch_stride + output_y * context->index_height_stride);
context->unipass_ukernel(
context->output_width, context->pooling_size, context->channels,
- indirect_input, output, index,
+ indirect_input, input_offset, output, index,
context->input_increment, context->output_increment,
&context->params);
}
@@ -295,20 +295,20 @@
size_t batch_index,
size_t output_y)
{
- const void** indirect_input =
- (const void**) ((uintptr_t) context->indirect_input +
- batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
- void* output =
- (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
- uint32_t* index =
- (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
+ const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
+ output_y * context->indirect_input_height_stride);
+ const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
+ void* output = (void*) ((uintptr_t) context->output +
+ batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+ uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
+ batch_index * context->index_batch_stride + output_y * context->index_height_stride);
- XNN_ALIGN(16) float multipass_output_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(float)];
+ XNN_ALIGN(16) float multipass_accumulation_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(float)];
XNN_ALIGN(16) uint32_t multipass_index_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint32_t)];
context->multipass_ukernel(
context->output_width, context->pooling_size, context->channels,
- indirect_input, multipass_output_buffer, multipass_index_buffer, output, index,
+ indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
context->input_increment, context->output_increment,
&context->params);
}
@@ -318,15 +318,15 @@
size_t batch_index,
size_t output_y)
{
- const void** indirect_input =
- (const void**) ((uintptr_t) context->indirect_input +
- batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
- void* output =
- (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+ const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
+ output_y * context->indirect_input_height_stride);
+ const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
+ void* output = (void*) ((uintptr_t) context->output +
+ batch_index * context->output_batch_stride + output_y * context->output_height_stride);
context->ukernel(
context->output_width, context->pooling_size, context->channels,
- indirect_input, output,
+ indirect_input, input_offset, output,
context->input_increment, context->output_increment,
&context->params);
}