Initial open-source release

PiperOrigin-RevId: 271685289
diff --git a/src/operator-run.c b/src/operator-run.c
new file mode 100644
index 0000000..0c35481
--- /dev/null
+++ b/src/operator-run.c
@@ -0,0 +1,784 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <string.h>
+
+#include <xnnpack.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/params.h>
+#include <xnnpack/compute.h>
+
+
+void xnn_compute_ggemm(
+    const struct gemm_context context[restrict static 1],
+    size_t group_index,
+    size_t mr_block_start,
+    size_t nr_block_start,
+    size_t mr_block_size,
+    size_t nr_block_size)
+{
+  const size_t k_scaled  = context->k_scaled;
+  const size_t a_stride  = context->a_stride;
+  const size_t cm_stride = context->cm_stride;
+
+  context->ukernel(
+      mr_block_size,
+      nr_block_size,
+      k_scaled,
+      (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
+      a_stride,
+      (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
+      (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
+      cm_stride,
+      context->cn_stride,
+      &context->params);
+}
+
+void xnn_compute_gemm(
+    const struct gemm_context context[restrict static 1],
+    size_t mr_block_start,
+    size_t nr_block_start,
+    size_t mr_block_size,
+    size_t nr_block_size)
+{
+  const size_t a_stride  = context->a_stride;
+  const size_t cm_stride = context->cm_stride;
+
+  context->ukernel(
+      mr_block_size,
+      nr_block_size,
+      context->k_scaled,
+      (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
+      a_stride,
+      (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
+      (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
+      cm_stride,
+      context->cn_stride,
+      &context->params);
+}
+
+void xnn_compute_spmm(
+    const struct spmm_context context[restrict static 1],
+    size_t batch_index,
+    size_t mr_block_start,
+    size_t mr_block_size)
+{
+  context->ukernel(
+      mr_block_size,
+      context->n,
+      (const void*) ((uintptr_t) context->a + batch_index * context->batched_a_stride + mr_block_start * sizeof(float)),
+      context->packed_weights,
+      context->input_increments,
+      context->output_channel_nonzeros,
+      (void*) ((uintptr_t) context->c + batch_index * context->batched_c_stride + mr_block_start * sizeof(float)),
+      &context->params);
+}
+
+void xnn_compute_gigemm(
+    const struct igemm_context context[restrict static 1],
+    size_t batch_index,
+    size_t group_index,
+    size_t mr_block_start,
+    size_t nr_block_start,
+    size_t mr_block_size,
+    size_t nr_block_size)
+{
+  const size_t ks        = context->ks;
+  const size_t cm_stride = context->cm_stride;
+
+  context->ukernel(
+      mr_block_size,
+      nr_block_size,
+      context->kc,
+      context->ks_scaled,
+      (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
+      (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
+      (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
+      cm_stride,
+      context->cn_stride,
+      context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
+      context->zero,
+      &context->params);
+}
+
+void xnn_compute_igemm(
+    const struct igemm_context context[restrict static 1],
+    size_t batch_index,
+    size_t mr_block_start,
+    size_t nr_block_start,
+    size_t mr_block_size,
+    size_t nr_block_size)
+{
+  const size_t ks        = context->ks;
+  const size_t cm_stride = context->cm_stride;
+
+  context->ukernel(
+      mr_block_size,
+      nr_block_size,
+      context->kc,
+      context->ks_scaled,
+      (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
+      (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
+      (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
+      cm_stride,
+      context->cn_stride,
+      context->a_offset + batch_index * context->ba_stride,
+      context->zero,
+      &context->params);
+}
+
+void xnn_compute_gsubconv2d(
+      const struct subconv_context context[restrict static 1],
+      size_t batch_index,
+      size_t group_index,
+      size_t subkernel_index,
+      size_t slice_y,
+      size_t slice_x_start,
+      size_t nc_block_start,
+      size_t slice_x_max,
+      size_t nc_block_size)
+{
+  const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
+
+  if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
+    return;
+  }
+
+  const size_t slice_width = subconvolution_params->slice_width;
+  if XNN_UNLIKELY(slice_x_start >= slice_width) {
+    return;
+  }
+  const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
+
+  const size_t cx_stride = context->cx_stride;
+  context->ukernel(
+      slice_x_size,
+      nc_block_size,
+      context->kc,
+      subconvolution_params->scaled_kernel_size,
+      (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
+      (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
+      (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
+      cx_stride,
+      context->cn_stride,
+      context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
+      context->zero,
+      &context->params);
+}
+
+void xnn_compute_subconv2d(
+      const struct subconv_context context[restrict static 1],
+      size_t batch_index,
+      size_t subkernel_index,
+      size_t slice_y,
+      size_t slice_x_start,
+      size_t nc_block_start,
+      size_t slice_x_max,
+      size_t nc_block_size)
+{
+  const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
+
+  if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
+    return;
+  }
+
+  const size_t slice_width = subconvolution_params->slice_width;
+  if XNN_UNLIKELY(slice_x_start >= slice_width) {
+    return;
+  }
+  const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
+
+  const size_t cx_stride = context->cx_stride;
+  context->ukernel(
+      slice_x_size,
+      nc_block_size,
+      context->kc,
+      subconvolution_params->scaled_kernel_size,
+      (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
+      (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
+      (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
+      cx_stride,
+      context->cn_stride,
+      context->a_offset + batch_index * context->ba_stride,
+      context->zero,
+      &context->params);
+}
+
+void xnn_compute_dconv2d_hwc2spchw(
+      const struct dconv2d_context context[restrict static 1],
+      size_t batch_index,
+      size_t output_y_start,
+      size_t output_y_slice)
+{
+  context->hwc2spchw_ukernel(
+      context->input_height,
+      context->input_width,
+      output_y_start,
+      output_y_start + output_y_slice,
+      (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
+      context->zero,
+      context->packed_weights,
+      (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
+      context->input_padding_top,
+      context->output_channels,
+      context->output_height_stride,
+      context->output_channel_stride,
+      &context->params);
+}
+
+void xnn_compute_dwconv_unipass(
+    const struct dwconv_context context[restrict static 1],
+    size_t output_y)
+{
+  context->unipass_ukernel(
+    context->groups,
+    context->output_width,
+    context->indirection_buffer + output_y * context->indirection_buffer_row_stride,
+    context->packed_weights,
+    context->output + output_y * context->output_row_stride,
+    context->indirection_buffer_col_stride,
+    context->output_col_increment,
+    &context->params);
+}
+
+void xnn_compute_dwconv2d_spchw(
+    const struct dwconv2d_context context[restrict static 1],
+    size_t batch_index,
+    size_t channel)
+{
+  context->spchw_ukernel(
+    context->output_height,
+    context->input_width,
+    (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
+    (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
+    (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
+    context->input_tuple_stride,
+    context->output_tuple_stride,
+    context->input_pixel_stride,
+    context->output_pixel_stride,
+    &context->params);
+}
+
+void xnn_compute_argmax_pooling_unipass(
+    const struct argmax_pooling_context context[restrict static 1],
+    size_t batch_index,
+    size_t output_y)
+{
+  const void** indirect_input =
+    (const void**) ((uintptr_t) context->indirect_input +
+      batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+  uint32_t* index =
+    (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
+
+  context->unipass_ukernel(
+    context->output_width, context->pooling_size, context->channels,
+    indirect_input, output, index,
+    context->input_increment, context->output_increment,
+    &context->params);
+}
+
+void xnn_compute_argmax_pooling_multipass(
+    const struct argmax_pooling_context context[restrict static 1],
+    size_t batch_index,
+    size_t output_y)
+{
+  const void** indirect_input =
+    (const void**) ((uintptr_t) context->indirect_input +
+      batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+  uint32_t* index =
+    (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
+
+  XNN_ALIGN(16) float multipass_output_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(float)];
+  XNN_ALIGN(16) uint32_t multipass_index_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint32_t)];
+
+  context->multipass_ukernel(
+    context->output_width, context->pooling_size, context->channels,
+    indirect_input, multipass_output_buffer, multipass_index_buffer, output, index,
+    context->input_increment, context->output_increment,
+    &context->params);
+}
+
+void xnn_compute_max_pooling(
+    const struct max_pooling_context context[restrict static 1],
+    size_t batch_index,
+    size_t output_y)
+{
+  const void** indirect_input =
+    (const void**) ((uintptr_t) context->indirect_input +
+      batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+
+  context->ukernel(
+    context->output_width, context->pooling_size, context->channels,
+    indirect_input, output,
+    context->input_increment, context->output_increment,
+    &context->params);
+}
+
+void xnn_compute_unpooling(
+    const struct unpooling_context context[restrict static 1],
+    size_t input_y,
+    size_t input_x)
+{
+  const void* input = (const void*) ((uintptr_t) context->input +
+      input_y * context->input_height_stride + input_x * context->input_width_stride);
+  const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
+      input_y * context->index_height_stride + input_x * context->index_width_stride);
+  void** indirect_output =
+    (void**) ((uintptr_t) context->indirect_output +
+      input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
+
+  context->ukernel(
+    context->pooling_size,
+    context->channels,
+    context->fill_value,
+    input, index, indirect_output);
+}
+
+void xnn_compute_average_pooling_unipass(
+    const struct average_pooling_context context[restrict static 1],
+    size_t batch_index,
+    size_t output_y)
+{
+  const void** indirect_input =
+    (const void**) ((uintptr_t) context->indirect_input +
+      batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+
+  context->unipass_ukernel(
+    context->output_width, context->pooling_size, context->channels,
+    indirect_input, context->zero, output,
+    context->input_increment, context->output_increment,
+    &context->params);
+}
+
+void xnn_compute_average_pooling_multipass(
+    const struct average_pooling_context context[restrict static 1],
+    size_t batch_index,
+    size_t output_y)
+{
+  const void** indirect_input =
+    (const void**) ((uintptr_t) context->indirect_input +
+      batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+  XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
+
+  context->multipass_ukernel(
+    context->output_width, context->pooling_size, context->channels,
+    indirect_input, context->zero, multipass_buffer, output,
+    context->input_increment, context->output_increment,
+    &context->params);
+}
+
+void xnn_compute_pixelwise_average_pooling_unipass(
+    const struct pixelwise_average_pooling_context context[restrict static 1],
+    size_t batch_index,
+    size_t output_y)
+{
+  const void** indirect_input =
+    (const void**) ((uintptr_t) context->indirect_input +
+      batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+  const void* pixelwise_buffer =
+    (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+
+  context->unipass_ukernel(
+    context->output_width, context->pooling_size, context->channels,
+    indirect_input, context->zero, pixelwise_buffer, output,
+    context->input_increment, context->output_increment,
+    &context->params);
+}
+
+void xnn_compute_pixelwise_average_pooling_multipass(
+    const struct pixelwise_average_pooling_context context[restrict static 1],
+    size_t batch_index,
+    size_t output_y)
+{
+  const void** indirect_input =
+    (const void**) ((uintptr_t) context->indirect_input +
+      batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+  const void* pixelwise_buffer =
+    (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+  XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
+
+  context->multipass_ukernel(
+    context->output_width, context->pooling_size, context->channels,
+    indirect_input, context->zero, pixelwise_buffer, multipass_buffer, output,
+    context->input_increment, context->output_increment,
+    &context->params);
+}
+
+void xnn_compute_global_average_pooling_unipass(
+    const struct global_average_pooling_context context[restrict static 1],
+    size_t batch_index)
+{
+  const void* input =
+    (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
+
+  context->unipass_ukernel(
+    context->input_elements,
+    context->channels,
+    input,
+    context->input_pixel_stride,
+    context->zero,
+    output,
+    &context->params);
+}
+
+void xnn_compute_global_average_pooling_multipass(
+    const struct global_average_pooling_context context[restrict static 1],
+    size_t batch_index)
+{
+  const void* input =
+    (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
+  XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
+
+  context->multipass_ukernel(
+    context->input_elements,
+    context->channels,
+    input,
+    context->input_pixel_stride,
+    context->zero,
+    multipass_buffer,
+    output,
+    &context->params);
+}
+
+void xnn_compute_global_average_pooling_spnchw(
+    const struct global_average_pooling_spnchw_context context[restrict static 1],
+    size_t batch_index,
+    size_t channels_start,
+    size_t channels_slice)
+{
+  const void* input =
+    (const void*) ((uintptr_t) context->input + channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
+  void* output =
+    (void*) ((uintptr_t) context->output + channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
+
+  context->ukernel(
+    context->input_elements,
+    channels_slice,
+    input,
+    output,
+    &context->params);
+}
+
+void xnn_compute_prelu(
+    const struct prelu_context context[restrict static 1],
+    size_t batch_start,
+    size_t batch_range)
+{
+  const size_t x_stride = context->x_stride;
+  const size_t y_stride = context->y_stride;
+  const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
+  void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
+
+  context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride, &context->params);
+}
+
+void xnn_compute_channel_pad(
+    const struct channel_pad_context context[restrict static 1],
+    size_t batch_start,
+    size_t batch_range)
+{
+  const size_t x_stride = context->x_stride;
+  const size_t y_stride = context->y_stride;
+  const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
+  void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
+
+  context->ukernel(batch_range, context->n, context->l, context->r, context->c, x, x_stride, y, y_stride);
+}
+
+void xnn_compute_add_strided(
+    const struct add_strided_context context[restrict static 1],
+    size_t batch_index,
+    size_t batch_range /* always 1 */)
+{
+  assert(batch_range == 1);
+
+  const size_t n = context->n;
+  const size_t a_stride = context->a_stride;
+  const size_t b_stride = context->b_stride;
+  const size_t y_stride = context->y_stride;
+  const void* a = (const void*) ((uintptr_t) context->a + a_stride * batch_index);
+  const void* b = (const void*) ((uintptr_t) context->b + b_stride * batch_index);
+  void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
+
+  context->ukernel(n, a, b, y, &context->params);
+}
+
+void xnn_compute_add_contiguous(
+    const struct add_contiguous_context context[restrict static 1],
+    size_t offset,
+    size_t size)
+{
+  const void* a = (const void*) ((uintptr_t) context->a + offset);
+  const void* b = (const void*) ((uintptr_t) context->b + offset);
+  void* y = (void*) ((uintptr_t) context->y + offset);
+  context->ukernel(size, a, b, y, &context->params);
+}
+
+void xnn_compute_channel_shuffle_fixed(
+    const struct channel_shuffle_context context[restrict static 1],
+    size_t index)
+{
+  const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
+  void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
+
+  context->fixed_ukernel(context->n, x, y);
+}
+
+void xnn_compute_channel_shuffle_variable(
+    const struct channel_shuffle_context context[restrict static 1],
+    size_t index)
+{
+  const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
+  void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
+
+  context->variable_ukernel(context->n, context->m, x, y);
+}
+
+void xnn_compute_lut_strided(
+    const struct lut_strided_context context[restrict static 1],
+    size_t batch_index)
+{
+  const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
+  void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
+
+  context->ukernel(context->n, x, context->t, y);
+}
+
+void xnn_compute_lut_contiguous(
+    const struct lut_contiguous_context context[restrict static 1],
+    size_t offset,
+    size_t size)
+{
+  const void* x = (const void*) ((uintptr_t) context->x + offset);
+  void* y = (void*) ((uintptr_t) context->y + offset);
+
+  context->ukernel(size, x, context->t, y);
+}
+
+void xnn_compute_univector_strided(
+    const struct univector_strided_context context[restrict static 1],
+    size_t batch_index,
+    size_t batch_range /* always 1 */)
+{
+  assert(batch_range == 1);
+
+  const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
+  void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
+  context->ukernel(context->n, x, y, &context->params);
+}
+
+void xnn_compute_univector_contiguous(
+    const struct univector_contiguous_context context[restrict static 1],
+    size_t offset,
+    size_t size)
+{
+  const void* x = (const void*) ((uintptr_t) context->x + offset);
+  void* y = (void*) ((uintptr_t) context->y + offset);
+  context->ukernel(size, x, y, &context->params);
+}
+
+void xnn_compute_u8_softargmax(
+    const struct u8_softargmax_context context[restrict static 1],
+    size_t batch_index)
+{
+  const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
+  uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
+  const size_t n = context->n;
+
+  uint8_t x_max = 0;
+  context->rmax_ukernel(n, x, &x_max);
+  const size_t adjustment = x_max ^ 255;
+  const uint32_t* t = (const uint32_t*) context->t + adjustment;
+  context->lut_norm_ukernel(n, x, t, y);
+}
+
+void xnn_compute_vmulcaddc(
+    const struct vmulcaddc_context context[restrict static 1],
+    size_t batch_start,
+    size_t batch_size)
+{
+  const size_t x_stride = context->x_stride;
+  const size_t y_stride = context->y_stride;
+
+  const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
+  void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
+
+  context->ukernel(
+    batch_size,
+    context->n,
+    x, x_stride,
+    context->w,
+    y, y_stride,
+    &context->params);
+}
+
+enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
+{
+  if (!xnn_params.initialized) {
+    xnn_log_error("failed to run operator: XNNPACK is not initialized");
+    return xnn_status_uninitialized;
+  }
+  switch (op->state) {
+    case xnn_run_state_invalid:
+      xnn_log_error("failed to run operator: operator was not successfully setup");
+      return xnn_status_invalid_state;
+    case xnn_run_state_ready:
+      break;
+    case xnn_run_state_skip:
+      return xnn_status_success;
+  }
+
+  switch (op->compute.type) {
+    case xnn_parallelization_type_invalid:
+      break;
+    case xnn_parallelization_type_1d:
+      assert(op->compute.range[0] != 0);
+      pthreadpool_parallelize_1d(
+          threadpool,
+          op->compute.task_1d,
+          &op->context,
+          op->compute.range[0],
+          PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+      break;
+    case xnn_parallelization_type_1d_tile_1d:
+      assert(op->compute.range[0] != 0);
+      assert(op->compute.tile[0] != 0);
+      pthreadpool_parallelize_1d_tile_1d(
+          threadpool,
+          op->compute.task_1d_tile_1d,
+          &op->context,
+          op->compute.range[0],
+          op->compute.tile[0],
+          PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+      break;
+    case xnn_parallelization_type_2d:
+      assert(op->compute.range[0] != 0);
+      assert(op->compute.range[1] != 0);
+      pthreadpool_parallelize_2d(
+          threadpool,
+          op->compute.task_2d,
+          &op->context,
+          op->compute.range[0], op->compute.range[1],
+          PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+      break;
+    case xnn_parallelization_type_2d_tile_1d:
+      assert(op->compute.range[0] != 0);
+      assert(op->compute.range[1] != 0);
+      assert(op->compute.tile[0] != 0);
+      pthreadpool_parallelize_2d_tile_1d(
+          threadpool,
+          op->compute.task_2d_tile_1d,
+          &op->context,
+          op->compute.range[0], op->compute.range[1],
+          op->compute.tile[0],
+          PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+      break;
+    case xnn_parallelization_type_2d_tile_2d:
+      assert(op->compute.range[0] != 0);
+      assert(op->compute.range[1] != 0);
+      assert(op->compute.tile[0] != 0);
+      assert(op->compute.tile[1] != 0);
+      pthreadpool_parallelize_2d_tile_2d(
+          threadpool,
+          op->compute.task_2d_tile_2d,
+          &op->context,
+          op->compute.range[0], op->compute.range[1],
+          op->compute.tile[0], op->compute.tile[1],
+          PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+      break;
+    case xnn_parallelization_type_3d_tile_2d:
+      assert(op->compute.range[0] != 0);
+      assert(op->compute.range[1] != 0);
+      assert(op->compute.range[2] != 0);
+      assert(op->compute.tile[0] != 0);
+      assert(op->compute.tile[1] != 0);
+      pthreadpool_parallelize_3d_tile_2d(
+          threadpool,
+          op->compute.task_3d_tile_2d,
+          &op->context,
+          op->compute.range[0], op->compute.range[1], op->compute.range[2],
+          op->compute.tile[0], op->compute.tile[1],
+          PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+      break;
+    case xnn_parallelization_type_4d_tile_2d:
+      assert(op->compute.range[0] != 0);
+      assert(op->compute.range[1] != 0);
+      assert(op->compute.range[2] != 0);
+      assert(op->compute.range[3] != 0);
+      assert(op->compute.tile[0] != 0);
+      assert(op->compute.tile[1] != 0);
+      pthreadpool_parallelize_4d_tile_2d(
+          threadpool,
+          op->compute.task_4d_tile_2d,
+          &op->context,
+          op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
+          op->compute.tile[0], op->compute.tile[1],
+          PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+      break;
+    case xnn_parallelization_type_5d_tile_2d:
+      assert(op->compute.range[0] != 0);
+      assert(op->compute.range[1] != 0);
+      assert(op->compute.range[2] != 0);
+      assert(op->compute.range[3] != 0);
+      assert(op->compute.range[4] != 0);
+      assert(op->compute.tile[0] != 0);
+      assert(op->compute.tile[1] != 0);
+      pthreadpool_parallelize_5d_tile_2d(
+          threadpool,
+          op->compute.task_5d_tile_2d,
+          &op->context,
+          op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
+          op->compute.tile[0], op->compute.tile[1],
+          PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+      break;
+    case xnn_parallelization_type_6d_tile_2d:
+      assert(op->compute.range[0] != 0);
+      assert(op->compute.range[1] != 0);
+      assert(op->compute.range[2] != 0);
+      assert(op->compute.range[3] != 0);
+      assert(op->compute.range[4] != 0);
+      assert(op->compute.range[5] != 0);
+      assert(op->compute.tile[0] != 0);
+      assert(op->compute.tile[1] != 0);
+      pthreadpool_parallelize_6d_tile_2d(
+          threadpool,
+          op->compute.task_6d_tile_2d,
+          &op->context,
+          op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
+          op->compute.tile[0], op->compute.tile[1],
+          PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+      break;
+    default:
+      XNN_UNREACHABLE;
+  }
+  return xnn_status_success;
+}