Initial open-source release
PiperOrigin-RevId: 271685289
diff --git a/src/add.c b/src/add.c
new file mode 100644
index 0000000..8f41eba
--- /dev/null
+++ b/src/add.c
@@ -0,0 +1,374 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/requantization.h>
+#include <xnnpack/log.h>
+#include <xnnpack/params.h>
+
+
+enum xnn_status xnn_create_add_nc_q8(
+ size_t channels,
+ size_t a_stride,
+ size_t b_stride,
+ size_t sum_stride,
+ uint8_t a_zero_point,
+ float a_scale,
+ uint8_t b_zero_point,
+ float b_scale,
+ uint8_t sum_zero_point,
+ float sum_scale,
+ uint8_t sum_min,
+ uint8_t sum_max,
+ uint32_t flags,
+ xnn_operator_t* add_op_out)
+{
+ xnn_operator_t add_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Add operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Add operator with %zu channels: number of channels must be non-zero", channels);
+ goto error;
+ }
+
+ if (a_stride < channels) {
+ xnn_log_error(
+ "failed to create Add operator with A element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ a_stride, channels);
+ goto error;
+ }
+
+ if (b_stride < channels) {
+ xnn_log_error(
+ "failed to create Add operator with B element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ b_stride, channels);
+ goto error;
+ }
+
+ if (sum_stride < channels) {
+ xnn_log_error(
+ "failed to create Add operator with Sum element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ sum_stride, channels);
+ goto error;
+ }
+
+ if (a_scale <= 0.0f || !isnormal(a_scale)) {
+ xnn_log_error(
+ "failed to create Add operator with %.7g A scale: scale must be finite, normalized, and positive", a_scale);
+ goto error;
+ }
+
+ if (b_scale <= 0.0f || !isnormal(b_scale)) {
+ xnn_log_error(
+ "failed to create Add operator with %.7g B scale: scale must be finite, normalized, and positive", b_scale);
+ goto error;
+ }
+
+ if (sum_scale <= 0.0f || !isnormal(sum_scale)) {
+ xnn_log_error(
+ "failed to create Add operator with %.7g output scale: scale must be finite, normalized, and positive",
+ sum_scale);
+ goto error;
+ }
+
+ if (sum_min >= sum_max) {
+ xnn_log_error(
+ "failed to create Add operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
+ sum_min, sum_max);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ const float a_output_scale = a_scale / sum_scale;
+ if (a_output_scale < 0x1.0p-14f || a_output_scale >= 0x1.0p+8f) {
+ xnn_log_error(
+ "failed to create Add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
+ a_output_scale);
+ goto error;
+ }
+
+ const float b_output_scale = b_scale / sum_scale;
+ if (b_output_scale < 0x1.0p-14f || b_output_scale >= 0x1.0p+8f) {
+ xnn_log_error(
+ "failed to create Add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
+ b_output_scale);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ add_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (add_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Add operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ add_op->channels = channels;
+ add_op->input_pixel_stride = a_stride;
+ add_op->input2_pixel_stride = b_stride;
+ add_op->output_pixel_stride = sum_stride;
+ add_op->q8_add_params =
+ xnn_compute_q8_add_params(
+ a_zero_point, b_zero_point, sum_zero_point,
+ a_scale / sum_scale, b_scale / sum_scale,
+ sum_min, sum_max);
+
+ add_op->type = xnn_operator_type_add_q8;
+ add_op->ukernel.type = xnn_ukernel_type_add;
+
+ add_op->state = xnn_run_state_invalid;
+
+ *add_op_out = add_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(add_op);
+ return status;
+}
+
+enum xnn_status xnn_create_add_nc_f32(
+ size_t channels,
+ size_t a_stride,
+ size_t b_stride,
+ size_t sum_stride,
+ float sum_min,
+ float sum_max,
+ uint32_t flags,
+ xnn_operator_t* add_op_out)
+{
+ xnn_operator_t add_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Add operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create add operator with %zu channels: number of channels must be non-zero", channels);
+ goto error;
+ }
+
+ if (a_stride < channels) {
+ xnn_log_error(
+ "failed to create Add operator with A element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ a_stride, channels);
+ goto error;
+ }
+
+ if (b_stride < channels) {
+ xnn_log_error(
+ "failed to create Add operator with B element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ b_stride, channels);
+ goto error;
+ }
+
+ if (sum_stride < channels) {
+ xnn_log_error(
+ "failed to create Add operator with Sum element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ sum_stride, channels);
+ goto error;
+ }
+
+ if (isnan(sum_min)) {
+ xnn_log_error(
+ "failed to create Add operator with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(sum_max)) {
+ xnn_log_error(
+ "failed to create Add operator with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (sum_min >= sum_max) {
+ xnn_log_error(
+ "failed to create Add operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
+ sum_min, sum_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ add_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (add_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Add operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ add_op->channels = channels;
+ add_op->input_pixel_stride = a_stride;
+ add_op->input2_pixel_stride = b_stride;
+ add_op->output_pixel_stride = sum_stride;
+ add_op->f32_output_params = xnn_compute_f32_output_params(sum_min, sum_max);
+
+ add_op->type = xnn_operator_type_add_f32;
+ add_op->ukernel.type = xnn_ukernel_type_add;
+
+ add_op->state = xnn_run_state_invalid;
+
+ *add_op_out = add_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(add_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_add_nc_q8(
+ xnn_operator_t add_op,
+ size_t batch_size,
+ const uint8_t* a,
+ const uint8_t* b,
+ uint8_t* sum,
+ pthreadpool_t threadpool)
+{
+ if (add_op->type != xnn_operator_type_add_q8) {
+ xnn_log_error("failed to setup Add (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ add_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Add operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ add_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ const size_t channels = add_op->channels;
+ const size_t a_stride = add_op->input_pixel_stride;
+ const size_t b_stride = add_op->input2_pixel_stride;
+ const size_t sum_stride = add_op->output_pixel_stride;
+ if ((((a_stride ^ channels) | (b_stride ^ channels) | (sum_stride ^ channels)) == 0) || batch_size == 1) {
+ const size_t block_size = 4096;
+ add_op->context.add_contiguous = (struct add_contiguous_context) {
+ .a = a,
+ .b = b,
+ .y = sum,
+ .params.q8 = add_op->q8_add_params,
+ .ukernel = xnn_params.q8.vadd,
+ };
+ add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_contiguous;
+ add_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
+ add_op->compute.tile[0] = block_size;
+ } else {
+ add_op->context.add_strided = (struct add_strided_context) {
+ .a = a,
+ .a_stride = a_stride * sizeof(uint8_t),
+ .b = b,
+ .b_stride = b_stride * sizeof(uint8_t),
+ .y = sum,
+ .y_stride = sum_stride * sizeof(uint8_t),
+ .n = channels,
+ .params.q8 = add_op->q8_add_params,
+ .ukernel = xnn_params.q8.vadd,
+ };
+ add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_strided;
+ add_op->compute.range[0] = batch_size;
+ add_op->compute.tile[0] = 1;
+ }
+ add_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_setup_add_nc_f32(
+ xnn_operator_t add_op,
+ size_t batch_size,
+ const float* a,
+ const float* b,
+ float* sum,
+ pthreadpool_t threadpool)
+{
+ if (add_op->type != xnn_operator_type_add_f32) {
+ xnn_log_error("failed to setup Add (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ add_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Add operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ add_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ const size_t channels = add_op->channels;
+ const size_t a_stride = add_op->input_pixel_stride;
+ const size_t b_stride = add_op->input2_pixel_stride;
+ const size_t sum_stride = add_op->output_pixel_stride;
+ if ((((a_stride ^ channels) | (b_stride ^ channels) | (sum_stride ^ channels)) == 0) || batch_size == 1) {
+ const size_t block_size = 4096;
+ add_op->context.add_contiguous = (struct add_contiguous_context) {
+ .a = a,
+ .b = b,
+ .y = sum,
+ .params.f32 = add_op->f32_output_params,
+ .ukernel = xnn_params.f32.vadd,
+ };
+ add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_contiguous;
+ add_op->compute.range[0] = batch_size * channels * sizeof(float);
+ add_op->compute.tile[0] = block_size;
+ } else {
+ add_op->context.add_strided = (struct add_strided_context) {
+ .a = a,
+ .a_stride = a_stride * sizeof(float),
+ .b = b,
+ .b_stride = b_stride * sizeof(float),
+ .y = sum,
+ .y_stride = sum_stride * sizeof(float),
+ .n = channels * sizeof(float),
+ .params.f32 = add_op->f32_output_params,
+ .ukernel = xnn_params.f32.vadd,
+ };
+ add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_strided;
+ add_op->compute.range[0] = batch_size;
+ add_op->compute.tile[0] = 1;
+ }
+ add_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/argmax-pooling.c b/src/argmax-pooling.c
new file mode 100644
index 0000000..9d2191a
--- /dev/null
+++ b/src/argmax-pooling.c
@@ -0,0 +1,292 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/params.h>
+#include <xnnpack/indirection.h>
+
+
+static inline size_t compute_output_dimension(
+ size_t padded_input_dimension,
+ size_t kernel_dimension)
+{
+ return padded_input_dimension / kernel_dimension;
+}
+
+static const struct argmaxpool_parameters* select_ukernel(
+ size_t pooling_size,
+ const struct argmaxpool_parameters* ukernel)
+{
+ while (ukernel->qr == 0 && ukernel->mr < pooling_size) {
+ ukernel++;
+ }
+ return ukernel;
+}
+
+enum xnn_status xnn_create_argmax_pooling2d_nhwc_f32(
+ uint32_t input_padding_top,
+ uint32_t input_padding_right,
+ uint32_t input_padding_bottom,
+ uint32_t input_padding_left,
+ uint32_t pooling_height,
+ uint32_t pooling_width,
+ size_t channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* argmax_pooling_op_out)
+{
+ xnn_operator_t argmax_pooling_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create ArgMax Pooling operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ const uint32_t pooling_size = pooling_height * pooling_width;
+ if (pooling_size == 0) {
+ xnn_log_error(
+ "failed to create Argmax Pooling operator with %" PRIu32 "x%" PRIu32 " pooling size: "
+ "pooling size dimensions must be non-zero",
+ pooling_width, pooling_height);
+ goto error;
+ }
+
+ if (pooling_size == 1) {
+ xnn_log_error(
+ "failed to create Argmax Pooling operator with 1 pooling element: "
+ "1x1 pooling is meaningless");
+ goto error;
+ }
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Argmax Pooling operator with %zu channels: "
+ "number of channels must be non-zero",
+ channels);
+ goto error;
+ }
+
+ if (input_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Argmax Pooling operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_pixel_stride, channels);
+ goto error;
+ }
+
+ if (output_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Argmax Pooling operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_pixel_stride, channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Argmax Pooling operator with NaN output lower bound: "
+ "lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Argmax Pooling operator with NaN output upper bound: "
+ "upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Argmax Pooling operator with [%.7g, %.7g] output range: "
+ "lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ argmax_pooling_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (argmax_pooling_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Argmax Pooling operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ argmax_pooling_op->padding_top = input_padding_top;
+ argmax_pooling_op->padding_right = input_padding_right;
+ argmax_pooling_op->padding_bottom = input_padding_bottom;
+ argmax_pooling_op->padding_left = input_padding_left;
+
+ argmax_pooling_op->kernel_height = pooling_height;
+ argmax_pooling_op->kernel_width = pooling_width;
+ argmax_pooling_op->stride_height = pooling_height;
+ argmax_pooling_op->stride_width = pooling_width;
+ argmax_pooling_op->dilation_height = 1;
+ argmax_pooling_op->dilation_width = 1;
+ argmax_pooling_op->channels = channels;
+ argmax_pooling_op->input_pixel_stride = input_pixel_stride;
+ argmax_pooling_op->output_pixel_stride = output_pixel_stride;
+
+ argmax_pooling_op->f32_output_params = xnn_compute_f32_output_params(output_min, output_max);
+
+ argmax_pooling_op->type = xnn_operator_type_argmax_pooling_f32;
+ argmax_pooling_op->ukernel.type = xnn_ukernel_type_argmax_pooling;
+
+ argmax_pooling_op->state = xnn_run_state_invalid;
+
+ *argmax_pooling_op_out = argmax_pooling_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(argmax_pooling_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_argmax_pooling2d_nhwc_f32(
+ xnn_operator_t argmax_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const float* input,
+ float* output,
+ uint32_t* index,
+ pthreadpool_t threadpool)
+{
+ if (argmax_pooling_op->type != xnn_operator_type_argmax_pooling_f32) {
+ xnn_log_error("failed to setup ArgMax Pooling (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ argmax_pooling_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup ArgMax Pooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_width == 0 || input_height == 0) {
+ xnn_log_error(
+ "failed to setup ArgMax Pooling operator with %zux%zu input: input dimensions must be non-zero",
+ input_width, input_height);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ argmax_pooling_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ argmax_pooling_op->batch_size = batch_size;
+ argmax_pooling_op->input_height = input_height;
+ argmax_pooling_op->input_width = input_width;
+ argmax_pooling_op->input = input;
+
+ argmax_pooling_op->output_height = compute_output_dimension(
+ argmax_pooling_op->padding_top + input_height + argmax_pooling_op->padding_bottom,
+ argmax_pooling_op->kernel_height);
+ argmax_pooling_op->output_width = compute_output_dimension(
+ argmax_pooling_op->padding_left + input_width + argmax_pooling_op->padding_right,
+ argmax_pooling_op->kernel_width);
+ argmax_pooling_op->output = output;
+
+ size_t valid_batch_size = 0;
+ if (input == argmax_pooling_op->last_input &&
+ input_height == argmax_pooling_op->last_input_height &&
+ input_width == argmax_pooling_op->last_input_width)
+ {
+ valid_batch_size = argmax_pooling_op->valid_batch_size;
+ if (batch_size <= valid_batch_size) {
+ argmax_pooling_op->compute.range[0] = batch_size;
+ argmax_pooling_op->state = xnn_run_state_ready;
+ return xnn_status_success;
+ }
+ }
+
+ const size_t pooling_height = argmax_pooling_op->kernel_height;
+ const size_t pooling_width = argmax_pooling_op->kernel_width;
+ const size_t pooling_size = pooling_height * pooling_width;
+ const size_t output_height = argmax_pooling_op->output_height;
+ const size_t output_width = argmax_pooling_op->output_width;
+ const struct argmaxpool_parameters* ukernel = select_ukernel(pooling_size, xnn_params.f32.argmaxpool);
+ const uint32_t mr = ukernel->mr;
+
+ const size_t step_width = pooling_width;
+ const size_t step_height = pooling_size + (output_width * step_width - 1) * pooling_height;
+ // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
+ const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
+
+ const void** indirection_buffer = (const void**) realloc(argmax_pooling_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ argmax_pooling_op->indirection_buffer = indirection_buffer;
+
+ xnn_indirection_init_maxpool2d(argmax_pooling_op, valid_batch_size, step_height, step_width, 2 /* log2(sizeof(float)) */);
+
+ const size_t channels = argmax_pooling_op->channels;
+
+ const size_t indirect_input_height_stride = step_height * sizeof(void*);
+ const size_t output_width_stride = argmax_pooling_op->output_pixel_stride * sizeof(float);
+ const size_t output_height_stride = output_width * output_width_stride;
+ const size_t index_height_stride = output_width * channels * sizeof(uint32_t);
+
+ const uint32_t qr = ukernel->qr;
+ const size_t multipass_adjustment = qr == 0 ? 0 : round_up(pooling_size - mr, qr) + mr - qr;
+ argmax_pooling_op->context.argmax_pooling = (struct argmax_pooling_context) {
+ .indirect_input = indirection_buffer,
+ .indirect_input_batch_stride = output_height * indirect_input_height_stride,
+ .indirect_input_height_stride = indirect_input_height_stride,
+ .output = output,
+ .output_batch_stride = output_height * output_height_stride,
+ .output_height_stride = output_height_stride,
+ .output_width = output_width,
+ .index = index,
+ .index_batch_stride = output_height * index_height_stride,
+ .index_height_stride = index_height_stride,
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+ .output_increment = output_width_stride - channels * sizeof(float),
+ .params.f32 = argmax_pooling_op->f32_output_params,
+ };
+ argmax_pooling_op->compute.type = xnn_parallelization_type_2d;
+ argmax_pooling_op->compute.range[0] = batch_size;
+ argmax_pooling_op->compute.range[1] = output_height;
+
+ if (pooling_size <= mr) {
+ argmax_pooling_op->context.argmax_pooling.unipass_ukernel = ukernel->up;
+ argmax_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_argmax_pooling_unipass;
+ } else {
+ argmax_pooling_op->context.argmax_pooling.multipass_ukernel = ukernel->mp;
+ argmax_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_argmax_pooling_multipass;
+ }
+ argmax_pooling_op->state = xnn_run_state_ready;
+
+ argmax_pooling_op->last_input = input;
+ argmax_pooling_op->last_input_height = input_height;
+ argmax_pooling_op->last_input_width = input_width;
+ argmax_pooling_op->valid_batch_size = max(valid_batch_size, batch_size);
+
+ return xnn_status_success;
+}
diff --git a/src/average-pooling.c b/src/average-pooling.c
new file mode 100644
index 0000000..1869b53
--- /dev/null
+++ b/src/average-pooling.c
@@ -0,0 +1,679 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/params.h>
+#include <xnnpack/indirection.h>
+
+
+static inline size_t compute_output_dimension(
+ size_t padded_input_dimension,
+ size_t pooling_dimension,
+ size_t stride_dimension)
+{
+ return (padded_input_dimension - pooling_dimension) / stride_dimension + 1;
+}
+
+enum xnn_status xnn_create_average_pooling2d_nhwc_q8(
+ uint32_t input_padding_top,
+ uint32_t input_padding_right,
+ uint32_t input_padding_bottom,
+ uint32_t input_padding_left,
+ uint32_t pooling_height,
+ uint32_t pooling_width,
+ uint32_t stride_height,
+ uint32_t stride_width,
+ size_t channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ uint8_t input_zero_point,
+ float input_scale,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* average_pooling_op_out)
+{
+ xnn_operator_t average_pooling_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Average Pooling operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ const uint32_t pooling_size = pooling_height * pooling_width;
+ if (pooling_size == 0) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %" PRIu32 "x%" PRIu32 " pooling size: "
+ "pooling size dimensions must be non-zero",
+ pooling_width, pooling_height);
+ goto error;
+ }
+
+ if (pooling_size == 1) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with 1 pooling element: 1x1 pooling is meaningless");
+ goto error;
+ }
+
+ if (stride_height == 0 || stride_width == 0) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %" PRIu32 "x%" PRIu32 " stride: "
+ "stride dimensions must be non-zero",
+ stride_width, stride_height);
+ goto error;
+ }
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %zu channels: number of channels must be non-zero",
+ channels);
+ goto error;
+ }
+
+ if (input_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_pixel_stride, channels);
+ goto error;
+ }
+
+ if (output_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_pixel_stride, channels);
+ goto error;
+ }
+
+ if (input_scale <= 0.0f || !isnormal(input_scale)) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %.7g input scale: "
+ "scale must be finite, normalized, and positive",
+ input_scale);
+ goto error;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %.7g output scale: "
+ "scale must be finite, normalized, and positive",
+ output_scale);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with [%" PRIu8 ", %" PRIu8 "] output range: "
+ "range min must be below range max",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ const float input_output_scale = input_scale / output_scale;
+ if (input_output_scale < 0x1.0p-8f || input_output_scale >= 0x1.0p+8f) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %.7g input scale and %.7g output scale: "
+ "input-to-output scale ratio (%.7f) must be in [2**-8, 2**8) range",
+ input_scale, output_scale, input_output_scale);
+ goto error;
+ }
+
+ if (pooling_size >= 16777216) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %"PRIu32" (%" PRIu32 "x%" PRIu32 ") pooling elements: "
+ "the number of elements in the pooling area must be below 2**24",
+ pooling_size, pooling_width, pooling_height);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ average_pooling_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (average_pooling_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Average Pooling operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
+ const uint32_t mr = xnn_params.q8.avgpool.mr;
+ const uint32_t qr = xnn_params.q8.avgpool.qr;
+ if (any_padding || pooling_size < mr || (pooling_size - mr) % qr != 0) {
+ void* zero_buffer = xnn_allocate_memory(channels * sizeof(uint8_t) + XNN_EXTRA_BYTES);
+ if (zero_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Average Pooling zero padding",
+ channels * sizeof(uint8_t) + XNN_EXTRA_BYTES);
+ goto error;
+ }
+ memset(zero_buffer, input_zero_point, channels * sizeof(uint8_t));
+ average_pooling_op->zero_buffer = zero_buffer;
+ }
+
+ average_pooling_op->padding_top = input_padding_top;
+ average_pooling_op->padding_right = input_padding_right;
+ average_pooling_op->padding_bottom = input_padding_bottom;
+ average_pooling_op->padding_left = input_padding_left;
+
+ average_pooling_op->kernel_height = pooling_height;
+ average_pooling_op->kernel_width = pooling_width;
+ average_pooling_op->stride_height = stride_height;
+ average_pooling_op->stride_width = stride_width;
+ average_pooling_op->dilation_height = 1;
+ average_pooling_op->dilation_width = 1;
+ average_pooling_op->channels = channels;
+ average_pooling_op->input_pixel_stride = input_pixel_stride;
+ average_pooling_op->output_pixel_stride = output_pixel_stride;
+
+ // Number of rows read in the micro-kernel.
+ const size_t nrows = round_up(doz(pooling_size, mr), qr) + mr;
+ average_pooling_op->q8_avgpool_params =
+ xnn_compute_q8_avgpool_params(
+ (int32_t) -((uint32_t) input_zero_point * (uint32_t) nrows),
+ input_scale / (output_scale * (float) pooling_size),
+ output_zero_point, output_min, output_max);
+
+ average_pooling_op->type = xnn_operator_type_average_pooling_q8;
+ average_pooling_op->ukernel.type = xnn_ukernel_type_average_pooling;
+
+ *average_pooling_op_out = average_pooling_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(average_pooling_op);
+ return status;
+}
+
+enum xnn_status xnn_create_average_pooling2d_nhwc_f32(
+ uint32_t input_padding_top,
+ uint32_t input_padding_right,
+ uint32_t input_padding_bottom,
+ uint32_t input_padding_left,
+ uint32_t pooling_height,
+ uint32_t pooling_width,
+ uint32_t stride_height,
+ uint32_t stride_width,
+ size_t channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* average_pooling_op_out)
+{
+ xnn_operator_t average_pooling_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Average Pooling operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ const uint32_t pooling_size = pooling_height * pooling_width;
+ if (pooling_size == 0) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %" PRIu32 "x%" PRIu32 " pooling size: "
+ "pooling size dimensions must be non-zero",
+ pooling_width, pooling_height);
+ goto error;
+ }
+
+ if (pooling_size == 1) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with 1 pooling element: 1x1 pooling is meaningless");
+ goto error;
+ }
+
+ if (stride_height == 0 || stride_width == 0) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %" PRIu32 "x%" PRIu32 " stride: "
+ "stride dimensions must be non-zero",
+ stride_width, stride_height);
+ goto error;
+ }
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with %zu channels: number of channels must be non-zero",
+ channels);
+ goto error;
+ }
+
+ if (input_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_pixel_stride, channels);
+ goto error;
+ }
+
+ if (output_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_pixel_stride, channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Average Pooling operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ average_pooling_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (average_pooling_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Average Pooling operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
+ const uint32_t mr = xnn_params.f32.avgpool.mr;
+ const uint32_t qr = xnn_params.f32.avgpool.qr;
+ if (any_padding || pooling_size < mr || (pooling_size - mr) % qr != 0) {
+ void* zero_buffer = xnn_allocate_memory(channels * sizeof(float) + XNN_EXTRA_BYTES);
+ if (zero_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Average Pooling zero padding",
+ channels * sizeof(float) + XNN_EXTRA_BYTES);
+ goto error;
+ }
+ memset(zero_buffer, 0, channels * sizeof(float));
+ average_pooling_op->zero_buffer = zero_buffer;
+ }
+
+ average_pooling_op->padding_top = input_padding_top;
+ average_pooling_op->padding_right = input_padding_right;
+ average_pooling_op->padding_bottom = input_padding_bottom;
+ average_pooling_op->padding_left = input_padding_left;
+
+ average_pooling_op->kernel_height = pooling_height;
+ average_pooling_op->kernel_width = pooling_width;
+ average_pooling_op->stride_height = stride_height;
+ average_pooling_op->stride_width = stride_width;
+ average_pooling_op->dilation_height = 1;
+ average_pooling_op->dilation_width = 1;
+ average_pooling_op->channels = channels;
+ average_pooling_op->input_pixel_stride = input_pixel_stride;
+ average_pooling_op->output_pixel_stride = output_pixel_stride;
+
+ average_pooling_op->type = xnn_operator_type_average_pooling_f32;
+ if (any_padding) {
+ average_pooling_op->f32_output_params =
+ xnn_compute_f32_output_params(output_min, output_max);
+
+ average_pooling_op->ukernel.type = xnn_ukernel_type_pixelwise_average_pooling;
+ } else {
+ average_pooling_op->f32_avgpool_params =
+ xnn_compute_f32_avgpool_params(1.0f / (float) pooling_size, output_min, output_max);
+
+ average_pooling_op->ukernel.type = xnn_ukernel_type_average_pooling;
+ }
+
+ *average_pooling_op_out = average_pooling_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(average_pooling_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_average_pooling2d_nhwc_q8(
+ xnn_operator_t average_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (average_pooling_op->type != xnn_operator_type_average_pooling_q8) {
+ xnn_log_error("failed to setup Average Pooling (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ average_pooling_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Average Pooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_width == 0 || input_height == 0) {
+ xnn_log_error(
+ "failed to setup Average Pooling operator with %zux%zu input: input dimensions must be non-zero",
+ input_width, input_height);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ average_pooling_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ average_pooling_op->batch_size = batch_size;
+ average_pooling_op->input_height = input_height;
+ average_pooling_op->input_width = input_width;
+ average_pooling_op->input = input;
+
+ average_pooling_op->output_height = compute_output_dimension(
+ average_pooling_op->padding_top + input_height + average_pooling_op->padding_bottom,
+ average_pooling_op->kernel_height,
+ average_pooling_op->stride_height);
+ average_pooling_op->output_width = compute_output_dimension(
+ average_pooling_op->padding_left + input_width + average_pooling_op->padding_right,
+ average_pooling_op->kernel_width,
+ average_pooling_op->stride_width);
+ average_pooling_op->output = output;
+
+ size_t valid_batch_size = 0;
+ if (input == average_pooling_op->last_input &&
+ input_height == average_pooling_op->last_input_height &&
+ input_width == average_pooling_op->last_input_width)
+ {
+ valid_batch_size = average_pooling_op->valid_batch_size;
+ if (batch_size <= valid_batch_size) {
+ average_pooling_op->compute.range[0] = batch_size;
+ average_pooling_op->state = xnn_run_state_ready;
+ return xnn_status_success;
+ }
+ }
+
+ const size_t pooling_height = average_pooling_op->kernel_height;
+ const size_t pooling_width = average_pooling_op->kernel_width;
+ const size_t pooling_size = pooling_height * pooling_width;
+ const size_t output_height = average_pooling_op->output_height;
+ const size_t output_width = average_pooling_op->output_width;
+ // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
+ const uint32_t mr = xnn_params.q8.avgpool.mr;
+
+ const size_t step_width = min(average_pooling_op->stride_width, pooling_width);
+ const size_t step_height = pooling_size + (output_width * step_width - 1) * pooling_height;
+ const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
+
+ const void** indirection_buffer = (const void**) realloc(average_pooling_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ average_pooling_op->indirection_buffer = indirection_buffer;
+
+ xnn_indirection_init_dwconv2d(
+ average_pooling_op, valid_batch_size, step_height, step_width, 0 /* log2(sizeof(uint8_t)) */);
+
+ const uint32_t qr = xnn_params.q8.avgpool.qr;
+ const size_t channels = average_pooling_op->channels;
+
+ const size_t indirect_input_height_stride = step_height * sizeof(void*);
+ const size_t output_width_stride = average_pooling_op->output_pixel_stride * sizeof(uint8_t);
+ const size_t output_height_stride = output_width * output_width_stride;
+
+ const size_t multipass_adjustment =
+ pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
+ average_pooling_op->context.average_pooling = (struct average_pooling_context) {
+ .indirect_input = indirection_buffer,
+ .indirect_input_batch_stride = output_height * indirect_input_height_stride,
+ .indirect_input_height_stride = indirect_input_height_stride,
+ .output = output,
+ .output_batch_stride = output_height * output_height_stride,
+ .output_height_stride = output_height_stride,
+ .output_width = output_width,
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .zero = average_pooling_op->zero_buffer,
+ .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+ .output_increment = output_width_stride - channels * sizeof(uint8_t),
+ .params.q8 = average_pooling_op->q8_avgpool_params,
+ };
+ average_pooling_op->compute.type = xnn_parallelization_type_2d;
+ average_pooling_op->compute.range[0] = batch_size;
+ average_pooling_op->compute.range[1] = output_height;
+
+ if (pooling_size <= mr) {
+ average_pooling_op->context.average_pooling.unipass_ukernel = xnn_params.q8.avgpool.up;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_unipass;
+ } else {
+ average_pooling_op->context.average_pooling.multipass_ukernel = xnn_params.q8.avgpool.mp;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_multipass;
+ }
+ average_pooling_op->state = xnn_run_state_ready;
+
+ average_pooling_op->last_input = input;
+ average_pooling_op->last_input_height = input_height;
+ average_pooling_op->last_input_width = input_width;
+ average_pooling_op->valid_batch_size = max(valid_batch_size, batch_size);
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_setup_average_pooling2d_nhwc_f32(
+ xnn_operator_t average_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (average_pooling_op->type != xnn_operator_type_average_pooling_f32) {
+ xnn_log_error("failed to setup Average Pooling (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ average_pooling_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Average Pooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_width == 0 || input_height == 0) {
+ xnn_log_error(
+ "failed to setup Average Pooling operator with %zux%zu input: input dimensions must be non-zero",
+ input_width, input_height);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ average_pooling_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ average_pooling_op->batch_size = batch_size;
+ average_pooling_op->input_height = input_height;
+ average_pooling_op->input_width = input_width;
+ average_pooling_op->input = input;
+
+ average_pooling_op->output_height = compute_output_dimension(
+ average_pooling_op->padding_top + input_height + average_pooling_op->padding_bottom,
+ average_pooling_op->kernel_height,
+ average_pooling_op->stride_height);
+ average_pooling_op->output_width = compute_output_dimension(
+ average_pooling_op->padding_left + input_width + average_pooling_op->padding_right,
+ average_pooling_op->kernel_width,
+ average_pooling_op->stride_width);
+ average_pooling_op->output = output;
+
+ size_t valid_batch_size = 0;
+ if (input == average_pooling_op->last_input &&
+ input_height == average_pooling_op->last_input_height &&
+ input_width == average_pooling_op->last_input_width)
+ {
+ valid_batch_size = average_pooling_op->valid_batch_size;
+ if (batch_size <= valid_batch_size) {
+ average_pooling_op->compute.range[0] = batch_size;
+ average_pooling_op->state = xnn_run_state_ready;
+ return xnn_status_success;
+ }
+ }
+
+ const size_t pooling_height = average_pooling_op->kernel_height;
+ const size_t pooling_width = average_pooling_op->kernel_width;
+ const size_t pooling_size = pooling_height * pooling_width;
+ const size_t output_height = average_pooling_op->output_height;
+ const size_t output_width = average_pooling_op->output_width;
+ // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
+ const uint32_t mr = xnn_params.f32.avgpool.mr;
+ assert(mr == xnn_params.f32.pavgpool.mr);
+
+ const size_t step_width = min(average_pooling_op->stride_width, pooling_width);
+ const size_t step_height = pooling_size + (output_width * step_width - 1) * pooling_height;
+ const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
+
+ const void** indirection_buffer = (const void**) realloc(average_pooling_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ average_pooling_op->indirection_buffer = indirection_buffer;
+
+ xnn_indirection_init_dwconv2d(
+ average_pooling_op, valid_batch_size, step_height, step_width, 2 /* log2(sizeof(float)) */);
+
+ const size_t channels = average_pooling_op->channels;
+
+ const size_t indirect_input_height_stride = step_height * sizeof(void*);
+ const size_t output_width_stride = average_pooling_op->output_pixel_stride * sizeof(float);
+ const size_t output_height_stride = output_width * output_width_stride;
+
+ switch (average_pooling_op->ukernel.type) {
+ case xnn_ukernel_type_average_pooling:
+ {
+ const uint32_t qr = xnn_params.f32.avgpool.qr;
+ const size_t multipass_adjustment =
+ pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
+ average_pooling_op->context.average_pooling = (struct average_pooling_context) {
+ .indirect_input = indirection_buffer,
+ .indirect_input_batch_stride = output_height * indirect_input_height_stride,
+ .indirect_input_height_stride = indirect_input_height_stride,
+ .output = output,
+ .output_batch_stride = output_height * output_height_stride,
+ .output_height_stride = output_height_stride,
+ .output_width = output_width,
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .zero = average_pooling_op->zero_buffer,
+ .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+ .output_increment = output_width_stride - channels * sizeof(float),
+ .params.f32 = average_pooling_op->f32_avgpool_params,
+ };
+ if (pooling_size <= mr) {
+ average_pooling_op->context.average_pooling.unipass_ukernel = xnn_params.f32.avgpool.up;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_unipass;
+ } else {
+ average_pooling_op->context.average_pooling.multipass_ukernel = xnn_params.f32.avgpool.mp;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_average_pooling_multipass;
+ }
+ break;
+ }
+ case xnn_ukernel_type_pixelwise_average_pooling:
+ {
+ if (input_height != average_pooling_op->last_input_height ||
+ input_width != average_pooling_op->last_input_width)
+ {
+ const size_t pixelwise_buffer_size = output_height * output_width * sizeof(float);
+ float* pixelwise_buffer = (float*) realloc(average_pooling_op->pixelwise_buffer, pixelwise_buffer_size);
+ if (pixelwise_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for pixelwise buffer", pixelwise_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ average_pooling_op->pixelwise_buffer = pixelwise_buffer;
+
+ float* pixelwise_pointer = pixelwise_buffer;
+ for (size_t output_y = 0; output_y < output_height; output_y++) {
+ const size_t input_y_start = doz(output_y * average_pooling_op->stride_height, average_pooling_op->padding_top);
+ const size_t input_y_end =
+ min(doz(output_y * average_pooling_op->stride_height + average_pooling_op->kernel_height, average_pooling_op->padding_top), input_height);
+ const uint32_t input_y_range = (uint32_t) (input_y_end - input_y_start);
+ for (size_t output_x = 0; output_x < output_width; output_x++) {
+ const size_t input_x_start = doz(output_x * average_pooling_op->stride_width, average_pooling_op->padding_left);
+ const size_t input_x_end =
+ min(doz(output_x * average_pooling_op->stride_width + average_pooling_op->kernel_width, average_pooling_op->padding_left), input_width);
+ const uint32_t input_x_range = (uint32_t) (input_x_end - input_x_start);
+ *pixelwise_pointer++ = 1.0f / ((float) (int32_t) (input_y_range * input_x_range));
+ }
+ }
+ }
+
+ const uint32_t qr = xnn_params.f32.pavgpool.qr;
+ const size_t multipass_adjustment =
+ pooling_size > mr ? round_up(pooling_size - mr, qr) + mr - qr : 0;
+ average_pooling_op->context.pixelwise_average_pooling = (struct pixelwise_average_pooling_context) {
+ .indirect_input = indirection_buffer,
+ .indirect_input_batch_stride = output_height * indirect_input_height_stride,
+ .indirect_input_height_stride = indirect_input_height_stride,
+ .pixelwise_buffer = average_pooling_op->pixelwise_buffer,
+ .pixelwise_buffer_height_stride = output_width * sizeof(float),
+ .output = output,
+ .output_batch_stride = output_height * output_height_stride,
+ .output_height_stride = output_height_stride,
+ .output_width = output_width,
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .zero = average_pooling_op->zero_buffer,
+ .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+ .output_increment = output_width_stride - channels * sizeof(float),
+ .params.f32 = average_pooling_op->f32_output_params,
+ };
+ if (pooling_size <= mr) {
+ average_pooling_op->context.pixelwise_average_pooling.unipass_ukernel = xnn_params.f32.pavgpool.up;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_unipass;
+ } else {
+ average_pooling_op->context.pixelwise_average_pooling.multipass_ukernel = xnn_params.f32.pavgpool.mp;
+ average_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_pixelwise_average_pooling_multipass;
+ }
+ break;
+ }
+ default:
+ XNN_UNREACHABLE;
+ }
+ average_pooling_op->compute.type = xnn_parallelization_type_2d;
+ average_pooling_op->compute.range[0] = batch_size;
+ average_pooling_op->compute.range[1] = output_height;
+ average_pooling_op->state = xnn_run_state_ready;
+
+ average_pooling_op->last_input = input;
+ average_pooling_op->last_input_height = input_height;
+ average_pooling_op->last_input_width = input_width;
+ average_pooling_op->valid_batch_size = max(valid_batch_size, batch_size);
+
+ return xnn_status_success;
+}
diff --git a/src/channel-pad.c b/src/channel-pad.c
new file mode 100644
index 0000000..f0380cd
--- /dev/null
+++ b/src/channel-pad.c
@@ -0,0 +1,137 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/params.h>
+
+
+enum xnn_status xnn_create_channel_pad_nc_x32(
+ size_t input_channels,
+ size_t pad_before_channels,
+ size_t pad_after_channels,
+ size_t input_stride,
+ size_t output_stride,
+ const void* pad_value,
+ uint32_t flags,
+ xnn_operator_t* channel_pad_op_out)
+{
+ xnn_operator_t channel_pad_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Channel Pad operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (input_channels == 0) {
+ xnn_log_error(
+ "failed to create Channel Pad operator with %zu input channels: number of channels must be non-zero",
+ input_channels);
+ goto error;
+ }
+
+ if (input_stride < input_channels) {
+ xnn_log_error(
+ "failed to create Channel Pad operator with input element stride of %zu: "
+ "stride must be at least as large as the number of input channels (%zu)",
+ input_stride, input_channels);
+ goto error;
+ }
+
+ const size_t output_channels = pad_before_channels + input_channels + pad_after_channels;
+ if (output_stride < output_channels) {
+ xnn_log_error(
+ "failed to create Channel Pad operator with output element stride of %zu: "
+ "stride must be at least as large as the number of output channels (%zu+%zu+%zu)",
+ output_stride, pad_before_channels, input_channels, pad_after_channels);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ channel_pad_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (channel_pad_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Channel Pad operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ channel_pad_op->channels = input_channels;
+ channel_pad_op->pad_before_channels = pad_before_channels;
+ channel_pad_op->pad_after_channels = pad_after_channels;
+ channel_pad_op->input_pixel_stride = input_stride;
+ channel_pad_op->output_pixel_stride = output_stride;
+ channel_pad_op->pad_value = *((const uint32_t*) pad_value);
+
+ channel_pad_op->type = xnn_operator_type_channel_pad_x32;
+ channel_pad_op->ukernel.type = xnn_ukernel_type_pad;
+
+ channel_pad_op->state = xnn_run_state_invalid;
+
+ *channel_pad_op_out = channel_pad_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(channel_pad_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_channel_pad_nc_x32(
+ xnn_operator_t channel_pad_op,
+ size_t batch_size,
+ const void* input,
+ void* output,
+ pthreadpool_t threadpool)
+{
+ if (channel_pad_op->type != xnn_operator_type_channel_pad_x32) {
+ xnn_log_error("failed to setup Channel Pad (X32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ channel_pad_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Channel Pad operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ channel_pad_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ channel_pad_op->batch_size = batch_size;
+ channel_pad_op->input = input;
+ channel_pad_op->output = output;
+
+ channel_pad_op->context.channel_pad = (struct channel_pad_context) {
+ .x = input,
+ .x_stride = channel_pad_op->input_pixel_stride * sizeof(uint32_t),
+ .y = output,
+ .y_stride = channel_pad_op->output_pixel_stride * sizeof(uint32_t),
+ .n = channel_pad_op->channels * sizeof(uint32_t),
+ .l = channel_pad_op->pad_before_channels * sizeof(uint32_t),
+ .r = channel_pad_op->pad_after_channels * sizeof(uint32_t),
+ .c = channel_pad_op->pad_value,
+ .ukernel = xnn_params.x32.pad.ukernel,
+ };
+ channel_pad_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ channel_pad_op->compute.task_1d_tile_1d =
+ (pthreadpool_task_1d_tile_1d_t) xnn_compute_channel_pad;
+ channel_pad_op->compute.range[0] = batch_size;
+ channel_pad_op->compute.tile[0] = xnn_params.x32.pad.mr;
+ channel_pad_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/channel-shuffle.c b/src/channel-shuffle.c
new file mode 100644
index 0000000..c27d6d3
--- /dev/null
+++ b/src/channel-shuffle.c
@@ -0,0 +1,235 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/params.h>
+
+
+static enum xnn_status create_channel_shuffle_nc(
+ size_t groups,
+ size_t group_channels,
+ size_t input_stride,
+ size_t output_stride,
+ uint32_t flags,
+ enum xnn_operator_type operator_type,
+ xnn_operator_t* channel_shuffle_op_out)
+{
+ xnn_operator_t channel_shuffle_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Channel Shuffle operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (groups <= 1) {
+ xnn_log_error(
+ "failed to create Channel Shuffle operator with %zu groups: at least two groups required", groups);
+ goto error;
+ }
+
+ if (group_channels == 0) {
+ xnn_log_error(
+ "failed to create Channel Shuffle operator with %zu group channels: number of group channels must be non-zero",
+ group_channels);
+ goto error;
+ }
+
+ const size_t channels = groups * group_channels;
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create Channel Shuffle operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zux%zu)",
+ input_stride, groups, group_channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create Channel Shuffle operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zux%zu)",
+ output_stride, groups, group_channels);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ channel_shuffle_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (channel_shuffle_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Channel Shuffle operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ channel_shuffle_op->groups = groups;
+ channel_shuffle_op->group_channels = group_channels;
+ channel_shuffle_op->input_pixel_stride = input_stride;
+ channel_shuffle_op->output_pixel_stride = output_stride;
+
+ channel_shuffle_op->type = operator_type;
+ channel_shuffle_op->ukernel.type = xnn_ukernel_type_channel_shuffle;
+
+ channel_shuffle_op->state = xnn_run_state_invalid;
+
+ *channel_shuffle_op_out = channel_shuffle_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(channel_shuffle_op);
+ return status;
+}
+
+
+enum xnn_status xnn_create_channel_shuffle_nc_x8(
+ size_t groups,
+ size_t group_channels,
+ size_t input_stride,
+ size_t output_stride,
+ uint32_t flags,
+ xnn_operator_t* channel_shuffle_op_out)
+{
+ return create_channel_shuffle_nc(
+ groups,
+ group_channels,
+ input_stride,
+ output_stride,
+ flags,
+ xnn_operator_type_channel_shuffle_x8,
+ channel_shuffle_op_out);
+}
+
+enum xnn_status xnn_create_channel_shuffle_nc_x32(
+ size_t groups,
+ size_t group_channels,
+ size_t input_stride,
+ size_t output_stride,
+ uint32_t flags,
+ xnn_operator_t* channel_shuffle_op_out)
+{
+ return create_channel_shuffle_nc(
+ groups,
+ group_channels,
+ input_stride,
+ output_stride,
+ flags,
+ xnn_operator_type_channel_shuffle_x32,
+ channel_shuffle_op_out);
+}
+
+static enum xnn_status setup_channel_shuffle_nc(
+ xnn_operator_t channel_shuffle_op,
+ size_t batch_size,
+ const void* input,
+ void* output,
+ uint32_t log2_element_size,
+ const struct zip_parameters zip[restrict static 1])
+{
+ channel_shuffle_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Channel Shuffle operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ channel_shuffle_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ channel_shuffle_op->batch_size = batch_size;
+ channel_shuffle_op->input = input;
+ channel_shuffle_op->output = output;
+
+ const size_t groups = channel_shuffle_op->groups;
+ channel_shuffle_op->context.channel_shuffle = (struct channel_shuffle_context) {
+ .x = input,
+ .x_stride = channel_shuffle_op->input_pixel_stride << log2_element_size,
+ .y = output,
+ .y_stride = channel_shuffle_op->output_pixel_stride << log2_element_size,
+ .n = channel_shuffle_op->group_channels << log2_element_size,
+ .m = groups,
+ };
+ channel_shuffle_op->compute.type = xnn_parallelization_type_1d;
+ channel_shuffle_op->compute.range[0] = batch_size;
+ switch (groups) {
+ case 2:
+ channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
+ channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x2;
+ break;
+ case 3:
+ channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
+ channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x3;
+ break;
+ case 4:
+ channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
+ channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x4;
+ break;
+ default:
+ channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_variable;
+ channel_shuffle_op->context.channel_shuffle.variable_ukernel = zip->xm;
+ break;
+ case 0:
+ case 1:
+ XNN_UNREACHABLE;
+ }
+ channel_shuffle_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_setup_channel_shuffle_nc_x8(
+ xnn_operator_t channel_shuffle_op,
+ size_t batch_size,
+ const void* input,
+ void* output,
+ pthreadpool_t threadpool)
+{
+ if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_x8) {
+ xnn_log_error("failed to setup Channel Shuffle (X8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_channel_shuffle_nc(
+ channel_shuffle_op,
+ batch_size,
+ input,
+ output,
+ 0 /* log2(sizeof(element)) = log2(sizeof(uint8_t)) */,
+ &xnn_params.x8.zip);
+}
+
+enum xnn_status xnn_setup_channel_shuffle_nc_x32(
+ xnn_operator_t channel_shuffle_op,
+ size_t batch_size,
+ const void* input,
+ void* output,
+ pthreadpool_t threadpool)
+{
+ if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_x32) {
+ xnn_log_error("failed to setup Channel Shuffle (X32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_channel_shuffle_nc(
+ channel_shuffle_op,
+ batch_size,
+ input,
+ output,
+ 2 /* log2(sizeof(element)) = log2(sizeof(uint32_t)) */,
+ &xnn_params.x32.zip);
+}
diff --git a/src/clamp.c b/src/clamp.c
new file mode 100644
index 0000000..02cac7f
--- /dev/null
+++ b/src/clamp.c
@@ -0,0 +1,298 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+
+
+enum xnn_status xnn_create_clamp_nc_u8(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* clamp_op_out)
+{
+ xnn_operator_t clamp_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Clamp operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Clamp operator with %zu channels: number of channels must be non-zero", channels);
+ goto error;
+ }
+
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create Clamp operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_stride, channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create Clamp operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_stride, channels);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Clamp operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ clamp_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (clamp_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Clamp operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ clamp_op->channels = channels;
+ clamp_op->input_pixel_stride = input_stride;
+ clamp_op->output_pixel_stride = output_stride;
+ clamp_op->u8_output_params = xnn_compute_u8_output_params(output_min, output_max);
+
+ clamp_op->type = xnn_operator_type_clamp_u8;
+ clamp_op->ukernel.type = xnn_ukernel_type_clamp;
+
+ clamp_op->state = xnn_run_state_invalid;
+
+ *clamp_op_out = clamp_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(clamp_op);
+ return status;
+}
+
+enum xnn_status xnn_create_clamp_nc_f32(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* clamp_op_out)
+{
+ xnn_operator_t clamp_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Clamp operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Clamp operator with %zu channels: number of channels must be non-zero", channels);
+ goto error;
+ }
+
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create Clamp operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_stride, channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create Clamp operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_stride, channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Clamp operator with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Clamp operator with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Clamp operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ clamp_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (clamp_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Clamp operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ clamp_op->channels = channels;
+ clamp_op->input_pixel_stride = input_stride;
+ clamp_op->output_pixel_stride = output_stride;
+ clamp_op->f32_output_params = xnn_compute_f32_output_params(output_min, output_max);
+
+ clamp_op->type = xnn_operator_type_clamp_f32;
+ clamp_op->ukernel.type = xnn_ukernel_type_clamp;
+
+ clamp_op->state = xnn_run_state_invalid;
+
+ *clamp_op_out = clamp_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(clamp_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_clamp_nc_u8(
+ xnn_operator_t clamp_op,
+ size_t batch_size,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (clamp_op->type != xnn_operator_type_clamp_u8) {
+ xnn_log_error("failed to setup Clamp (U8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ clamp_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Clamp operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ clamp_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ const size_t channels = clamp_op->channels;
+ const size_t input_stride = clamp_op->input_pixel_stride;
+ const size_t output_stride = clamp_op->output_pixel_stride;
+ if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
+ const size_t block_size = 4096;
+ clamp_op->context.univector_contiguous = (struct univector_contiguous_context) {
+ .x = input,
+ .x_stride = input_stride * sizeof(uint8_t),
+ .y = output,
+ .y_stride = output_stride * sizeof(uint8_t),
+ .ukernel = xnn_params.u8.clamp,
+ .params.u8_output = clamp_op->u8_output_params,
+ };
+ clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
+ clamp_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
+ clamp_op->compute.tile[0] = block_size;
+ } else {
+ clamp_op->context.univector_strided = (struct univector_strided_context) {
+ .n = channels * sizeof(uint8_t),
+ .x = input,
+ .x_stride = input_stride * sizeof(uint8_t),
+ .y = output,
+ .y_stride = output_stride * sizeof(uint8_t),
+ .ukernel = xnn_params.u8.clamp,
+ .params.u8_output = clamp_op->u8_output_params,
+ };
+ clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
+ clamp_op->compute.range[0] = batch_size;
+ clamp_op->compute.tile[0] = 1;
+ }
+ clamp_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_setup_clamp_nc_f32(
+ xnn_operator_t clamp_op,
+ size_t batch_size,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (clamp_op->type != xnn_operator_type_clamp_f32) {
+ xnn_log_error("failed to setup Clamp (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ clamp_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Clamp operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ clamp_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ const size_t channels = clamp_op->channels;
+ const size_t input_stride = clamp_op->input_pixel_stride;
+ const size_t output_stride = clamp_op->output_pixel_stride;
+ if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
+ const size_t block_size = 4096;
+ clamp_op->context.univector_contiguous = (struct univector_contiguous_context) {
+ .x = input,
+ .x_stride = input_stride * sizeof(float),
+ .y = output,
+ .y_stride = output_stride * sizeof(float),
+ .ukernel = xnn_params.f32.clamp,
+ .params.f32_output = clamp_op->f32_output_params,
+ };
+ clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
+ clamp_op->compute.range[0] = batch_size * channels * sizeof(float);
+ clamp_op->compute.tile[0] = block_size;
+ } else {
+ clamp_op->context.univector_strided = (struct univector_strided_context) {
+ .n = channels * sizeof(float),
+ .x = input,
+ .x_stride = input_stride * sizeof(float),
+ .y = output,
+ .y_stride = output_stride * sizeof(float),
+ .ukernel = xnn_params.f32.clamp,
+ .params.f32_output = clamp_op->f32_output_params,
+ };
+ clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
+ clamp_op->compute.range[0] = batch_size;
+ clamp_op->compute.tile[0] = 1;
+ }
+ clamp_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/convolution-spnchw.c b/src/convolution-spnchw.c
new file mode 100644
index 0000000..572870a
--- /dev/null
+++ b/src/convolution-spnchw.c
@@ -0,0 +1,694 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/common.h>
+#include <xnnpack/compute.h>
+#include <xnnpack/math.h>
+#include <xnnpack/pack.h>
+#include <xnnpack/params.h>
+#include <xnnpack/indirection.h>
+
+
+static inline size_t compute_output_dimension(
+ size_t padded_input_dimension,
+ size_t kernel_dimension,
+ size_t dilation_dimension,
+ size_t subsampling_dimension)
+{
+ const size_t effective_kernel_dimension = (kernel_dimension - 1) * dilation_dimension + 1;
+ return doz(padded_input_dimension, effective_kernel_dimension) / subsampling_dimension + 1;
+}
+
+enum xnn_status xnn_create_convolution2d_spnchw_f32(
+ uint32_t input_padding_top,
+ uint32_t input_padding_right,
+ uint32_t input_padding_bottom,
+ uint32_t input_padding_left,
+ uint32_t kernel_height,
+ uint32_t kernel_width,
+ uint32_t subsampling_height,
+ uint32_t subsampling_width,
+ uint32_t dilation_height,
+ uint32_t dilation_width,
+ uint32_t groups,
+ size_t group_input_channels,
+ size_t group_output_channels,
+ const float* kernel,
+ const float* bias,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* convolution_op_out)
+{
+ xnn_operator_t convolution_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Convolution operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (kernel_width == 0 || kernel_height == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
+ kernel_width, kernel_height);
+ goto error;
+ }
+
+ if (subsampling_width == 0 || subsampling_height == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 "x%" PRIu32 " subsampling: "
+ "subsampling dimensions must be non-zero",
+ subsampling_width, subsampling_height);
+ goto error;
+ }
+
+ if (dilation_width == 0 || dilation_height == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 "x%" PRIu32 " dilation: "
+ "dilation dimensions must be non-zero",
+ dilation_width, dilation_height);
+ goto error;
+ }
+
+ if (groups == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 " groups: number of groups must be non-zero", groups);
+ goto error;
+ }
+
+ if (group_input_channels == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %zu input channels per group: "
+ "number of channels must be non-zero",
+ group_input_channels);
+ goto error;
+ }
+
+ if (group_output_channels == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %zu output channels per group: "
+ "number of channels must be non-zero",
+ group_output_channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Convolution operator with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Convolution operator with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Convolution operator with [%.7g, %.7g] output range: "
+ "lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ if ((flags & XNN_CONVOLUTION_FLAG_DEPTHWISE) != 0 && group_input_channels != 1) {
+ xnn_log_error(
+ "failed to create Depthwise Convolution operator with %zu input channels per group: "
+ "Depthwise Convolution must have exactly 1 input channel per group",
+ group_input_channels);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ enum xnn_ukernel_type ukernel_type;
+ struct spchw_dwconv_parameters* dwconv_parameters = NULL;
+ // Supported cases:
+ // + 1x1 convolution (no groups)
+ // + 3x3 stride-2 with 3 input channels and NHWC input layout
+ // + 3x3 stride-2 depthwise convolution with horizontal padding 1 & no vertical padding
+ // - 3x3 stride-1 depthwise convolution with horizontal padding 1 & no vertical padding
+ // - 5x5 stride-2 depthwise convolution with horizontal padding 2 & no vertical padding
+ // - 5x5 stride-1 depthwise convolution with horizontal padding 2 & no vertical padding
+ const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
+ const bool is_1x1 = kernel_width == 1 && kernel_height == 1 && subsampling_height == 1 && subsampling_width == 1;
+ const bool is_3x3 = kernel_width == 3 && kernel_height == 3 && dilation_height == 1 && dilation_width == 1;
+ const bool nhwc_input = (flags & XNN_FLAG_INPUT_NHWC) != 0;
+ if (is_1x1 && !any_padding && !nhwc_input && groups == 1 && xnn_params.f32.spmm.ukernel != NULL) {
+ ukernel_type = xnn_ukernel_type_spmm;
+ } else if (is_3x3 && subsampling_height == 2 && subsampling_width == 2 &&
+ input_padding_top == 1 && input_padding_left == 1 && input_padding_bottom == 1 && input_padding_right == 1 &&
+ nhwc_input && groups == 1 && xnn_params.f32.hwc2spchw_dconv3x3c3s2.ukernel_with_symm_padding != NULL)
+ {
+ ukernel_type = xnn_ukernel_type_dconv2d_hwc2spchw;
+ } else if (is_3x3 && subsampling_height == 1 && subsampling_width == 1 &&
+ input_padding_top == 0 && input_padding_left == 1 && input_padding_bottom == 0 && input_padding_right == 1 &&
+ !nhwc_input && group_input_channels == 1 && group_output_channels == 1 && xnn_params.f32.spchw_dwconv3x3.ukernel != NULL)
+ {
+ ukernel_type = xnn_ukernel_type_dwconv;
+ dwconv_parameters = &xnn_params.f32.spchw_dwconv3x3;
+ } else if (is_3x3 && subsampling_height == 2 && subsampling_width == 2 &&
+ input_padding_top == 0 && input_padding_left == 1 && input_padding_bottom == 0 && input_padding_right == 1 &&
+ !nhwc_input && group_input_channels == 1 && group_output_channels == 1 && xnn_params.f32.spchw_dwconv3x3s2.ukernel != NULL)
+ {
+ ukernel_type = xnn_ukernel_type_dwconv;
+ dwconv_parameters = &xnn_params.f32.spchw_dwconv3x3s2;
+ } else {
+ xnn_log_error(
+ "failed to create Convolution operator: only selected Convolution parameters are supported");
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ convolution_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (convolution_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Convolution operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ switch (ukernel_type) {
+ case xnn_ukernel_type_spmm:
+ {
+ assert(kernel_height == 1);
+ assert(kernel_width == 1);
+ assert(groups == 1);
+
+ size_t num_nonzeroes = 0;
+ size_t num_nonzero_blocks2 = 0;
+ size_t num_nonzero_blocks4 = 0;
+ for (size_t oc = 0; oc < round_down_po2(group_output_channels, 4); oc += 4) {
+ for (size_t ic = 0; ic < group_input_channels; ic++) {
+ const size_t row0_nonzero = (size_t) (kernel[oc * group_input_channels + ic] != 0.0f);
+ const size_t row1_nonzero = (size_t) (kernel[(oc + 1) * group_input_channels + ic] != 0.0f);
+ const size_t row2_nonzero = (size_t) (kernel[(oc + 2) * group_input_channels + ic] != 0.0f);
+ const size_t row3_nonzero = (size_t) (kernel[(oc + 3) * group_input_channels + ic] != 0.0f);
+ num_nonzeroes += row0_nonzero + row1_nonzero + row2_nonzero + row3_nonzero;
+ num_nonzero_blocks2 += (row0_nonzero | row1_nonzero) + (row2_nonzero | row3_nonzero);
+ num_nonzero_blocks4 += (row0_nonzero | row1_nonzero | row2_nonzero | row3_nonzero);
+ }
+ }
+ const size_t num_block4_nonzeroes = num_nonzeroes;
+ for (size_t oc = round_down_po2(group_output_channels, 4); oc < round_down_po2(group_output_channels, 2); oc += 2) {
+ for (size_t ic = 0; ic < group_input_channels; ic++) {
+ const size_t row0_nonzero = (size_t) (kernel[oc * group_input_channels + ic] != 0.0f);
+ const size_t row1_nonzero = (size_t) (kernel[(oc + 1) * group_input_channels + ic] != 0.0f);
+ num_nonzeroes += row0_nonzero + row1_nonzero;
+ num_nonzero_blocks2 += (row0_nonzero | row1_nonzero);
+ }
+ }
+ const size_t num_block2_nonzeroes = num_nonzeroes;
+ for (size_t oc = round_down_po2(group_output_channels, 2); oc < group_output_channels; oc++) {
+ for (size_t ic = 0; ic < group_input_channels; ic++) {
+ num_nonzeroes += (size_t) (kernel[oc * group_input_channels + ic] != 0.0f);
+ }
+ }
+ size_t output_channels_block_size = 1;
+ size_t num_output_channel_blocks = group_output_channels;
+ size_t num_nonzero_values = num_nonzeroes;
+ size_t num_nonzero_blocks = num_nonzeroes;
+ const struct spmm_parameters* spmm_parameters = &xnn_params.f32.spmm;
+ if (num_block4_nonzeroes * 5 >= num_nonzero_blocks4 * 18 && xnn_params.f32.spmm4.ukernel != NULL) {
+ // 4-channel blocks have 90%+ non-zeroes
+
+ output_channels_block_size = 4;
+ num_output_channel_blocks = num_output_channel_blocks / 4 + num_output_channel_blocks % 4;
+ spmm_parameters = &xnn_params.f32.spmm4;
+ // Non-zeroes which don't fit into whole 4-channel blocks, processed one-by-one
+ const size_t num_remaining_nonzeroes = num_nonzeroes - num_block4_nonzeroes;
+ num_nonzero_values = num_nonzero_blocks4 * 4 + num_remaining_nonzeroes;
+ num_nonzero_blocks = num_nonzero_blocks4 + num_remaining_nonzeroes;
+ } else if (num_block2_nonzeroes * 5 >= num_nonzero_blocks2 * 9 && xnn_params.f32.spmm2.ukernel != NULL) {
+ // 2-channel blocks have 90%+ non-zeroes
+
+ output_channels_block_size = 2;
+ num_output_channel_blocks = num_output_channel_blocks / 2 + num_output_channel_blocks % 2;
+ spmm_parameters = &xnn_params.f32.spmm2;
+ // Non-zeroes which don't fit into whole 2-channel blocks, processed one-by-one
+ const size_t num_remaining_nonzeroes = num_nonzeroes - num_block2_nonzeroes;
+ num_nonzero_values = num_nonzero_blocks2 * 2 + num_remaining_nonzeroes;
+ num_nonzero_blocks = num_nonzero_blocks2 + num_remaining_nonzeroes;
+ }
+
+ // Sparse representation of weights consists of four components:
+ // 1. An array of float values storing non-zero kernel elements, and all (group_output_channels) bias elements.
+ // All elements within non-zero block are assumed to be non-zero.
+ // 2. An array of int32_t values storing increment for input pointer after each processed tile. This array is
+ // derived from scaled difference in array 2 using parameters to setup function.
+ // 3. An array of uint32_t values storing the number of non-zero kernel elements per each output channel.
+ // 4. An array of int32_t values storing scaled [by sizeof(input element)] difference between input channels
+ // corresponding to successive non-zero blocks.
+ const size_t packed_weights_size = num_output_channel_blocks * sizeof(uint32_t) +
+ (num_nonzero_blocks * 2) * sizeof(int32_t) + (num_nonzero_values + group_output_channels) * sizeof(float);
+
+ convolution_op->packed_weights = xnn_allocate_memory(packed_weights_size);
+ if (convolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_weights_size);
+ goto error;
+ }
+ convolution_op->num_nonzero_values = num_nonzero_values;
+ convolution_op->num_nonzero_blocks = num_nonzero_blocks;
+ convolution_op->num_output_channel_blocks = num_output_channel_blocks;
+
+ float* nonzero_values = convolution_op->packed_weights;
+ int32_t* input_increments = (int32_t*) (nonzero_values + num_nonzero_values + group_output_channels);
+ uint32_t* output_channel_nonzeros = (uint32_t*) (input_increments + num_nonzero_blocks);
+ int32_t* input_channel_diffs = (int32_t*) (output_channel_nonzeros + num_output_channel_blocks);
+ memset(output_channel_nonzeros, 0, num_output_channel_blocks * sizeof(uint32_t));
+
+ status = xnn_status_unsupported_parameter;
+
+ size_t first_ic = 0, last_ic = 0;
+ bool first_nonzero = true;
+ for (size_t ocb = 0; ocb < round_down_po2(group_output_channels, output_channels_block_size); ocb += output_channels_block_size) {
+ for (size_t oco = 0; oco < output_channels_block_size; oco++) {
+ *nonzero_values++ = bias[ocb + oco];
+ }
+ for (size_t ic = 0; ic < group_input_channels; ic++) {
+ bool is_nonzero_block = false;
+ for (size_t oco = 0; oco < output_channels_block_size; oco++) {
+ is_nonzero_block |= (kernel[(ocb + oco) * group_input_channels + ic] != 0.0f);
+ }
+ if (is_nonzero_block) {
+ for (size_t oco = 0; oco < output_channels_block_size; oco++) {
+ *nonzero_values++ = kernel[(ocb + oco) * group_input_channels + ic];
+ }
+ if (first_nonzero) {
+ first_ic = ic;
+ } else {
+ const int64_t diff = (int64_t) ((uint64_t) ic - (uint64_t) last_ic) * (int64_t) sizeof(float);
+ if (diff != (int64_t) (int32_t) diff) {
+ xnn_log_error("failed to convert kernel to sparse representation: "
+ "scaled difference in input channels exceeds int32_t range");
+ goto error;
+ }
+ *input_channel_diffs++ = (int32_t) diff;
+ }
+ first_nonzero = false;
+ last_ic = ic;
+ *output_channel_nonzeros += 1;
+ }
+ }
+ output_channel_nonzeros += 1;
+ }
+ for (size_t oc = round_down_po2(group_output_channels, output_channels_block_size); oc < group_output_channels; oc++) {
+ *nonzero_values++ = bias[oc];
+ for (size_t ic = 0; ic < group_input_channels; ic++) {
+ const float weight = kernel[oc * group_input_channels + ic];
+ if (weight != 0.0f) {
+ *nonzero_values++ = weight;
+ if (first_nonzero) {
+ first_ic = ic;
+ } else {
+ const int64_t diff = (int64_t) ((uint64_t) ic - (uint64_t) last_ic) * (int64_t) sizeof(float);
+ if (diff != (int64_t) (int32_t) diff) {
+ xnn_log_error("failed to convert kernel to sparse representation: "
+ "scaled difference in input channels exceeds int32_t range");
+ goto error;
+ }
+ *input_channel_diffs++ = (int32_t) diff;
+ }
+ first_nonzero = false;
+ last_ic = ic;
+ *output_channel_nonzeros += 1;
+ }
+ }
+ output_channel_nonzeros += 1;
+ }
+ // If there are any non-zero elements, we have to return to the initial input channel.
+ if (!first_nonzero) {
+ const int64_t diff = (int64_t) ((uint64_t) first_ic - (uint64_t) last_ic) * (int64_t) sizeof(float);
+ if (diff != (int64_t) (int32_t) diff) {
+ xnn_log_error("failed to convert kernel to sparse representation: "
+ "scaled difference in input channels exceeds int32_t range");
+ goto error;
+ }
+ *input_channel_diffs++ = (int32_t) diff;
+ }
+ convolution_op->first_input_channel = first_ic;
+
+ convolution_op->ukernel.spmm = (struct xnn_ukernel_spmm) {
+ .function = spmm_parameters->ukernel,
+ .mr = spmm_parameters->mr,
+ };
+
+ break;
+ }
+ case xnn_ukernel_type_dconv2d_hwc2spchw:
+ {
+ assert(groups == 1);
+
+ const size_t packed_group_output_channels =
+ round_up(group_output_channels, xnn_params.f32.hwc2spchw_dconv3x3c3s2.output_channel_tile);
+ const size_t packed_weights_size = groups * packed_group_output_channels *
+ (group_input_channels * kernel_height * kernel_width + 1 /* bias */) * sizeof(float);
+ convolution_op->packed_weights = xnn_allocate_memory(packed_weights_size);
+ if (convolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_weights_size);
+ goto error;
+ }
+
+ xnn_pack_f32_dconv_oki_w(
+ group_output_channels,
+ group_input_channels,
+ xnn_params.f32.hwc2spchw_dconv3x3c3s2.output_channel_tile,
+ kernel_height, kernel_width,
+ kernel, bias, convolution_op->packed_weights);
+
+ convolution_op->ukernel.dconv2d = (struct xnn_ukernel_dconv2d) {
+ .hwc2spchw_function = xnn_params.f32.hwc2spchw_dconv3x3c3s2.ukernel_with_symm_padding,
+ .output_height_tile = xnn_params.f32.hwc2spchw_dconv3x3c3s2.output_height_tile,
+ .output_channel_tile = xnn_params.f32.hwc2spchw_dconv3x3c3s2.output_channel_tile,
+ };
+
+ break;
+ }
+ case xnn_ukernel_type_dwconv:
+ {
+ assert(dwconv_parameters != NULL);
+ assert(group_input_channels == 1);
+ assert(group_output_channels == 1);
+
+ const size_t packed_weights_size = groups * (kernel_height * kernel_width + 1 /* bias */) * sizeof(float);
+ convolution_op->packed_weights = xnn_allocate_memory(packed_weights_size);
+ if (convolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_weights_size);
+ goto error;
+ }
+
+ xnn_pack_f32_spchw_dwconv_ghw_w(
+ kernel_height * kernel_width, groups,
+ kernel, bias, convolution_op->packed_weights);
+
+ convolution_op->ukernel.dwconv2d = (struct xnn_ukernel_dwconv2d) {
+ .spchw_function = dwconv_parameters->ukernel,
+ .input_width_tile = dwconv_parameters->input_width_tile,
+ .output_width_tile = dwconv_parameters->output_width_tile,
+ };
+
+ break;
+ }
+ default:
+ XNN_UNREACHABLE;
+ }
+
+ convolution_op->padding_top = input_padding_top;
+ convolution_op->padding_right = input_padding_right;
+ convolution_op->padding_bottom = input_padding_bottom;
+ convolution_op->padding_left = input_padding_left;
+
+ convolution_op->kernel_height = kernel_height;
+ convolution_op->kernel_width = kernel_width;
+ convolution_op->stride_height = subsampling_height;
+ convolution_op->stride_width = subsampling_width;
+ convolution_op->dilation_height = dilation_height;
+ convolution_op->dilation_width = dilation_width;
+ convolution_op->groups = groups;
+ convolution_op->group_input_channels = group_input_channels;
+ convolution_op->group_output_channels = group_output_channels;
+
+ if (ukernel_type == xnn_ukernel_type_dwconv) {
+ convolution_op->f32_spchw_params = xnn_compute_f32_spchw_params(0, output_min, output_max);
+ } else {
+ convolution_op->f32_output_params = xnn_compute_f32_output_params(output_min, output_max);
+ }
+
+ convolution_op->type = xnn_operator_type_convolution_spnchw_f32;
+ convolution_op->ukernel.type = ukernel_type;
+
+ convolution_op->state = xnn_run_state_invalid;
+
+ *convolution_op_out = convolution_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(convolution_op);
+ return status;
+}
+
+static enum xnn_status setup_convolution2d_spnchw(
+ xnn_operator_t convolution_op,
+ size_t batch_size,
+ size_t input_batch_stride,
+ size_t output_batch_stride,
+ size_t input_height,
+ size_t input_width,
+ const void* input,
+ void* output,
+ uint32_t log2_input_element_size,
+ uint32_t log2_filter_element_size,
+ uint32_t bias_element_size,
+ uint32_t log2_output_element_size,
+ const void* params,
+ size_t num_threads)
+{
+ convolution_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Convolution operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_width == 0 || input_height == 0) {
+ xnn_log_error(
+ "failed to setup Convolution operator with %zux%zu input: input dimensions must be non-zero",
+ input_width, input_height);
+ return xnn_status_invalid_parameter;
+ }
+
+ const uint32_t groups = convolution_op->groups;
+ const size_t group_input_channels = convolution_op->group_input_channels;
+ const size_t input_neurons = groups * group_input_channels * input_height * input_width;
+ if (input_batch_stride < input_neurons) {
+ xnn_log_error(
+ "failed to setup Convolution operator with input batch stride of %zu: "
+ "stride must be at least as large as the number of input neurons (%" PRIu32 "x%zux%zux%zu)",
+ input_batch_stride, groups, group_input_channels, input_height, input_width);
+ return xnn_status_invalid_parameter;
+ }
+
+ const size_t output_height = compute_output_dimension(
+ convolution_op->padding_top + input_height + convolution_op->padding_bottom,
+ convolution_op->kernel_height,
+ convolution_op->dilation_height,
+ convolution_op->stride_height);
+ const size_t output_width = compute_output_dimension(
+ convolution_op->padding_left + input_width + convolution_op->padding_right,
+ convolution_op->kernel_width,
+ convolution_op->dilation_width,
+ convolution_op->stride_width);
+
+ const size_t group_output_channels = convolution_op->group_output_channels;
+ const size_t output_neurons = groups * group_output_channels * output_height * output_width;
+ if (output_batch_stride < output_neurons) {
+ xnn_log_error(
+ "failed to setup Convolution operator with output batch stride of %zu: "
+ "stride must be at least as large as the number of output neurons (%" PRIu32 "x%zux%zux%zu)",
+ output_batch_stride, groups, group_output_channels, output_height, output_width);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ convolution_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ convolution_op->batch_size = batch_size;
+ convolution_op->input_height = input_height;
+ convolution_op->input_width = input_width;
+ convolution_op->input = input;
+ convolution_op->output = output;
+
+ switch (convolution_op->ukernel.type) {
+ case xnn_ukernel_type_spmm:
+ {
+ const size_t num_nonzero_values = convolution_op->num_nonzero_values;
+ const size_t num_nonzero_blocks = convolution_op->num_nonzero_blocks;
+ const size_t num_output_channel_blocks = convolution_op->num_output_channel_blocks;
+
+ convolution_op->num_nonzero_values = num_nonzero_values;
+ convolution_op->num_nonzero_blocks = num_nonzero_blocks;
+ convolution_op->num_output_channel_blocks = num_output_channel_blocks;
+
+ float* nonzero_values = convolution_op->packed_weights;
+ int32_t* input_increments = (int32_t*) (nonzero_values + num_nonzero_values + convolution_op->group_output_channels);
+ uint32_t* output_channel_nonzeros = (uint32_t*) (input_increments + num_nonzero_blocks);
+ int32_t* input_channel_diffs = (int32_t*) (output_channel_nonzeros + num_output_channel_blocks);
+
+ // const uint32_t* output_channel_nonzeros = convolution_op->packed_weights;
+ // const int32_t* input_channel_diffs = (const int32_t*) (output_channel_nonzeros + num_output_channel_blocks);
+ // int32_t* input_increments = (int32_t*) (input_channel_diffs + num_nonzero_blocks);
+ // const void* packed_weights = (const void*) (input_increments + num_nonzero_blocks);
+
+ const size_t input_size = input_height * input_width;
+ for (size_t i = 0; i < num_nonzero_blocks; i++) {
+ const int32_t diff = input_channel_diffs[i];
+ const int64_t increment = (int64_t) diff * input_size;
+ if ((int64_t) (int32_t) increment != increment) {
+ xnn_log_error("failed to setup Convolution operator with sparse kernel representation: "
+ "input increment exceeds int32_t range");
+ return xnn_status_unsupported_parameter;
+ }
+ input_increments[i] = (int32_t) increment;
+ }
+
+ convolution_op->context.spmm = (struct spmm_context) {
+ .n = group_output_channels,
+ .a = input + (convolution_op->first_input_channel * input_size * sizeof(float)),
+ .packed_weights = nonzero_values,
+ .input_increments = input_increments,
+ .output_channel_nonzeros = output_channel_nonzeros,
+ .c = output,
+ .batched_a_stride = input_batch_stride << log2_input_element_size,
+ .batched_c_stride = output_batch_stride << log2_output_element_size,
+ .ukernel = convolution_op->ukernel.spmm.function,
+ };
+ memcpy(&convolution_op->context.spmm.params, params, sizeof(convolution_op->context.spmm.params));
+
+ const size_t mr = convolution_op->ukernel.spmm.mr;
+ size_t mc = input_size;
+ if (num_threads > 1) {
+ const size_t target_tiles_per_thread = 5;
+ const size_t max_mc = divide_round_up(input_size, num_threads * target_tiles_per_thread);
+ if (max_mc < mc) {
+ mc = min(mc, divide_round_up(mc, max_mc * mr) * mr);
+ }
+ }
+ convolution_op->compute.type = xnn_parallelization_type_2d_tile_1d;
+ convolution_op->compute.task_2d_tile_1d = (pthreadpool_task_2d_tile_1d_t) xnn_compute_spmm;
+ convolution_op->compute.range[0] = batch_size;
+ convolution_op->compute.range[1] = input_size;
+ convolution_op->compute.tile[0] = mc;
+ convolution_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+ }
+ case xnn_ukernel_type_dconv2d_hwc2spchw:
+ {
+ const size_t zero_size = (input_width * convolution_op->group_input_channels << log2_input_element_size) + XNN_EXTRA_BYTES;
+ void* zero_buffer = realloc(convolution_op->zero_buffer, zero_size);
+ if (zero_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for zero padding", sizeof(struct xnn_operator));
+ return xnn_status_out_of_memory;
+ }
+ memset(zero_buffer, 0, zero_size);
+ convolution_op->zero_buffer = zero_buffer;
+
+ convolution_op->context.dconv2d = (struct dconv2d_context) {
+ .input_height = input_height,
+ .input_width = input_width,
+ .input = input,
+ .input_batch_stride = input_batch_stride << log2_input_element_size,
+ .zero = zero_buffer,
+ .packed_weights = convolution_op->packed_weights,
+ .output = output,
+ .output_batch_stride = output_batch_stride << log2_input_element_size,
+ .input_padding_top = convolution_op->padding_top,
+ .output_channels = convolution_op->group_output_channels,
+ .output_height_stride = output_width << log2_output_element_size,
+ .output_channel_stride = output_height * output_width << log2_output_element_size,
+ .hwc2spchw_ukernel = convolution_op->ukernel.dconv2d.hwc2spchw_function,
+ };
+ memcpy(&convolution_op->context.dconv2d.params, params, sizeof(convolution_op->context.dconv2d.params));
+
+ size_t output_height_slice = output_height;
+ const size_t output_height_tile = convolution_op->ukernel.dconv2d.output_height_tile;
+ if (num_threads > 1) {
+ const size_t target_tiles_per_thread = 5;
+ const size_t max_output_height_slice = divide_round_up(output_height, num_threads * target_tiles_per_thread);
+ if (max_output_height_slice < output_height_slice) {
+ output_height_slice = min(output_height_slice,
+ divide_round_up(output_height_slice, max_output_height_slice * output_height_tile) * output_height_tile);
+ }
+ }
+ convolution_op->compute.type = xnn_parallelization_type_2d_tile_1d;
+ convolution_op->compute.task_2d_tile_1d = (pthreadpool_task_2d_tile_1d_t) xnn_compute_dconv2d_hwc2spchw;
+ convolution_op->compute.range[0] = batch_size;
+ convolution_op->compute.range[1] = output_height;
+ convolution_op->compute.tile[0] = output_height_slice;
+ convolution_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+ }
+ case xnn_ukernel_type_dwconv:
+ {
+ xnn_update_f32_spchw_params((union xnn_f32_spchw_params*) params, input_width);
+ convolution_op->context.dwconv2d = (struct dwconv2d_context) {
+ .output_height = output_height,
+ .input_width = input_width,
+ .input = input,
+ .input_channel_stride = input_height * input_width << log2_input_element_size,
+ .input_batch_stride = input_batch_stride << log2_input_element_size,
+ .packed_weights = convolution_op->packed_weights,
+ .weights_channel_stride = bias_element_size +
+ (convolution_op->kernel_height * convolution_op->kernel_width << log2_filter_element_size),
+ .output = output,
+ .output_channel_stride = output_height * output_width << log2_output_element_size,
+ .output_batch_stride = output_batch_stride << log2_output_element_size,
+ .input_tuple_stride = convolution_op->ukernel.dwconv2d.input_width_tile << log2_input_element_size,
+ .output_tuple_stride = convolution_op->ukernel.dwconv2d.output_width_tile << log2_output_element_size,
+ .input_pixel_stride = input_width << log2_input_element_size,
+ .output_pixel_stride = output_width << log2_output_element_size,
+ .spchw_ukernel = convolution_op->ukernel.dwconv2d.spchw_function,
+ };
+ memcpy(&convolution_op->context.dwconv2d.params, params, sizeof(convolution_op->context.dwconv2d.params));
+
+ convolution_op->compute.type = xnn_parallelization_type_2d;
+ convolution_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_dwconv2d_spchw;
+ convolution_op->compute.range[0] = batch_size;
+ convolution_op->compute.range[1] = groups;
+ convolution_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+ }
+ default:
+ XNN_UNREACHABLE;
+ }
+}
+
+enum xnn_status xnn_setup_convolution2d_spnchw_f32(
+ xnn_operator_t convolution_op,
+ size_t batch_size,
+ size_t input_batch_stride,
+ size_t output_batch_stride,
+ size_t input_height,
+ size_t input_width,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (convolution_op->type != xnn_operator_type_convolution_spnchw_f32) {
+ xnn_log_error("failed to setup Convolution (F32, SpNCHW) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_convolution2d_spnchw(
+ convolution_op,
+ batch_size, input_batch_stride, output_batch_stride,
+ input_height, input_width,
+ input, output,
+ 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
+ 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
+ sizeof(float) /* sizeof(bias element) */,
+ 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
+ &convolution_op->f32_output_params,
+ pthreadpool_get_threads_count(threadpool));
+}
diff --git a/src/convolution.c b/src/convolution.c
new file mode 100644
index 0000000..a5023e9
--- /dev/null
+++ b/src/convolution.c
@@ -0,0 +1,1104 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/common.h>
+#include <xnnpack/compute.h>
+#include <xnnpack/math.h>
+#include <xnnpack/pack.h>
+#include <xnnpack/params.h>
+#include <xnnpack/indirection.h>
+
+
+static inline size_t compute_output_dimension(
+ size_t padded_input_dimension,
+ size_t kernel_dimension,
+ size_t dilation_dimension,
+ size_t subsampling_dimension)
+{
+ const size_t effective_kernel_dimension = (kernel_dimension - 1) * dilation_dimension + 1;
+ return doz(padded_input_dimension, effective_kernel_dimension) / subsampling_dimension + 1;
+}
+
+static const struct dwconv_parameters* find_dwigemm_ukernel(
+ size_t kernel_size,
+ const struct dwconv_parameters* ukernel,
+ size_t num_ukernels)
+{
+ while (num_ukernels-- != 0) {
+ if (ukernel->mr == kernel_size) {
+ return ukernel;
+ }
+ ukernel++;
+ }
+ return NULL;
+}
+
+enum xnn_status xnn_create_convolution2d_nhwc_q8(
+ uint32_t input_padding_top,
+ uint32_t input_padding_right,
+ uint32_t input_padding_bottom,
+ uint32_t input_padding_left,
+ uint32_t kernel_height,
+ uint32_t kernel_width,
+ uint32_t subsampling_height,
+ uint32_t subsampling_width,
+ uint32_t dilation_height,
+ uint32_t dilation_width,
+ uint32_t groups,
+ size_t group_input_channels,
+ size_t group_output_channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ uint8_t input_zero_point,
+ float input_scale,
+ uint8_t kernel_zero_point,
+ float kernel_scale,
+ const uint8_t* kernel,
+ const int32_t* bias,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* convolution_op_out)
+{
+ xnn_operator_t convolution_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Convolution operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (kernel_width == 0 || kernel_height == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
+ kernel_width, kernel_height);
+ goto error;
+ }
+
+ if (subsampling_width == 0 || subsampling_height == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 "x%" PRIu32 " subsampling: "
+ "subsampling dimensions must be non-zero",
+ subsampling_width, subsampling_height);
+ goto error;
+ }
+
+ if (dilation_width == 0 || dilation_height == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 "x%" PRIu32 " dilation: "
+ "dilation dimensions must be non-zero",
+ dilation_width, dilation_height);
+ goto error;
+ }
+
+ if (groups == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 " groups: number of groups must be non-zero", groups);
+ goto error;
+ }
+
+ if (group_input_channels == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %zu input channels per group: "
+ "number of channels must be non-zero",
+ group_input_channels);
+ goto error;
+ }
+
+ if (group_output_channels == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %zu output channels per group: "
+ "number of channels must be non-zero",
+ group_output_channels);
+ goto error;
+ }
+
+ const size_t input_channels = groups * group_input_channels;
+ if (input_pixel_stride < input_channels) {
+ xnn_log_error(
+ "failed to create Convolution operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of input channels (%" PRIu32 "x%zu)",
+ input_pixel_stride, groups, group_input_channels);
+ goto error;
+ }
+
+ const size_t output_channels = groups * group_output_channels;
+ if (output_pixel_stride < output_channels) {
+ xnn_log_error(
+ "failed to create Convolution operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
+ output_pixel_stride, groups, group_output_channels);
+ goto error;
+ }
+
+ if (input_scale <= 0.0f || !isnormal(input_scale)) {
+ xnn_log_error(
+ "failed to create Convolution operator with %.7g input scale: scale must be finite, normalized, and positive",
+ input_scale);
+ goto error;
+ }
+
+ if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
+ xnn_log_error(
+ "failed to create Convolution operator with %.7g kernel scale: scale must be finite, normalized, and positive",
+ kernel_scale);
+ goto error;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create Convolution operator with %.7g output scale: scale must be finite, normalized, and positive",
+ output_scale);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Convolution operator with [%" PRIu8 ", %" PRIu8 "] output range: "
+ "range min must be below range max",
+ output_min, output_max);
+ goto error;
+ }
+
+ if ((flags & XNN_CONVOLUTION_FLAG_DEPTHWISE) != 0 && group_input_channels != 1) {
+ xnn_log_error(
+ "failed to create Depthwise Convolution operator with %zu input channels per group: "
+ "Depthwise Convolution must have exactly 1 input channel per group",
+ group_input_channels);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ const uint32_t effective_kernel_height = (kernel_height - 1) * dilation_height + 1;
+ const uint32_t effective_kernel_width = (kernel_width - 1) * dilation_width + 1;
+
+ if (input_padding_top >= effective_kernel_height) {
+ xnn_log_info(
+ "inefficiency in Convolution operator with %" PRIu32 "x%" PRIu32 " effective kernel and %" PRIu32 "+%" PRIu32 " height padding: "
+ "input top padding is greater or equal to effective kernel height",
+ effective_kernel_width, effective_kernel_height, input_padding_top, input_padding_bottom);
+ }
+
+ if (input_padding_bottom >= effective_kernel_height) {
+ xnn_log_info(
+ "inefficiency in Convolution operator with %" PRIu32 "x%" PRIu32 " effective kernel and %" PRIu32 "+%" PRIu32 " height padding: "
+ "input bottom padding is greater or equal to effective kernel height",
+ effective_kernel_width, effective_kernel_height, input_padding_top, input_padding_bottom);
+ }
+
+ if (input_padding_right >= effective_kernel_width) {
+ xnn_log_info(
+ "inefficiency in Convolution operator with %" PRIu32 "x%" PRIu32 " effective kernel and %" PRIu32 "+%" PRIu32 " width padding: "
+ "input right padding is greater or equal to effective kernel width",
+ effective_kernel_width, effective_kernel_height, input_padding_left, input_padding_right);
+ }
+
+ if (input_padding_left >= effective_kernel_width) {
+ xnn_log_info(
+ "inefficiency in Convolution operator with %" PRIu32 "x%" PRIu32 " effective kernel and %" PRIu32 "+%" PRIu32 " width padding: "
+ "input left padding is greater or equal to effective kernel width",
+ effective_kernel_width, effective_kernel_height, input_padding_left, input_padding_right);
+ }
+
+ const float convolution_scale = input_scale * kernel_scale / output_scale;
+ if (convolution_scale >= 1.0f) {
+ xnn_log_error(
+ "failed to create Convolution operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
+ "convolution scale %.7g is greater or equal to 1.0",
+ input_scale, kernel_scale, output_scale, convolution_scale);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ convolution_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (convolution_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Convolution operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ const size_t kernel_size = kernel_height * kernel_width;
+
+ enum xnn_ukernel_type ukernel_type = xnn_ukernel_type_none;
+ const struct dwconv_parameters* dwconv_parameters = NULL;
+ const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
+ if (group_input_channels == 1 && group_output_channels == 1 && groups > 1 &&
+ (dwconv_parameters = find_dwigemm_ukernel(kernel_size, xnn_params.q8.dwconv, XNN_MAX_Q8_DWCONV_UKERNELS)) != NULL)
+ {
+ ukernel_type = xnn_ukernel_type_dwconv;
+ } else if (kernel_size == 1 && subsampling_height == 1 && subsampling_width == 1 && !any_padding) {
+ ukernel_type = xnn_ukernel_type_gemm;
+ } else {
+ ukernel_type = xnn_ukernel_type_igemm;
+ }
+
+ size_t zero_size = 0;
+ switch (ukernel_type) {
+ case xnn_ukernel_type_dwconv:
+ {
+ assert(dwconv_parameters != NULL);
+ assert(dwconv_parameters->mr == kernel_size);
+
+ const uint32_t c_stride = round_up_po2(groups, dwconv_parameters->cr);
+ const size_t packed_weights_size = (sizeof(uint8_t) * kernel_size + sizeof(int32_t)) * c_stride;
+ convolution_op->packed_weights = xnn_allocate_memory(packed_weights_size);
+ if (convolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_weights_size);
+ goto error;
+ }
+
+ if (flags & XNN_CONVOLUTION_FLAG_DEPTHWISE) {
+ xnn_pack_q8_dwconv_hwg_w(
+ kernel_height, kernel_width,
+ groups, dwconv_parameters->cr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias, convolution_op->packed_weights);
+ } else {
+ xnn_pack_q8_dwconv_ghw_w(
+ kernel_height, kernel_width,
+ groups, dwconv_parameters->cr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias, convolution_op->packed_weights);
+ }
+
+ convolution_op->ukernel.dwconv = (struct xnn_ukernel_dwconv) {
+ .unipass_function = dwconv_parameters->up,
+ .mr = dwconv_parameters->mr,
+ .qr = dwconv_parameters->qr,
+ };
+
+ zero_size = sizeof(uint8_t) * c_stride + XNN_EXTRA_BYTES;
+ break;
+ }
+ case xnn_ukernel_type_gemm:
+ case xnn_ukernel_type_igemm:
+ {
+ const uint32_t nr = xnn_params.q8.gemm.nr;
+ const uint32_t kr = UINT32_C(1) << xnn_params.q8.gemm.log2_kr;
+ const uint32_t n_stride = round_up(group_output_channels, nr);
+ const uint32_t k_stride = round_up_po2(group_input_channels, kr);
+
+ const size_t packed_group_weights_size =
+ (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) * n_stride;
+ convolution_op->packed_weights = xnn_allocate_memory(packed_group_weights_size * groups);
+ if (convolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_group_weights_size * groups);
+ goto error;
+ }
+ memset(convolution_op->packed_weights, kernel_zero_point, packed_group_weights_size * groups);
+
+ switch (ukernel_type) {
+ case xnn_ukernel_type_gemm:
+ xnn_pack_q8_gemm_goi_w(
+ groups, group_output_channels, group_input_channels,
+ nr, kr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias, convolution_op->packed_weights);
+ convolution_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
+ .mr = xnn_params.q8.gemm.mr,
+ .nr = nr,
+ .kr = kr,
+ .default_function = xnn_params.q8.gemm.gemm,
+ };
+ break;
+ case xnn_ukernel_type_igemm:
+ if (flags & XNN_CONVOLUTION_FLAG_DEPTHWISE) {
+ xnn_pack_q8_conv_kgo_w(
+ groups, group_output_channels, kernel_size,
+ nr, kr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias, convolution_op->packed_weights);
+ } else {
+ xnn_pack_q8_conv_goki_w(
+ groups, group_output_channels, kernel_size, group_input_channels,
+ nr, kr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias, convolution_op->packed_weights);
+ }
+ convolution_op->ukernel.igemm = (struct xnn_ukernel_igemm) {
+ .mr = xnn_params.q8.gemm.mr,
+ .nr = nr,
+ .kr = kr,
+ .default_function = xnn_params.q8.gemm.igemm,
+ };
+ break;
+ default:
+ XNN_UNREACHABLE;
+ }
+
+ zero_size = sizeof(uint8_t) * k_stride + XNN_EXTRA_BYTES;
+ break;
+ }
+ default:
+ XNN_UNREACHABLE;
+ }
+
+ if (any_padding) {
+ void* zero_buffer = xnn_allocate_memory(zero_size);
+ if (zero_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for zero padding", zero_size);
+ goto error;
+ }
+ memset(zero_buffer, input_zero_point, zero_size);
+ convolution_op->zero_buffer = zero_buffer;
+ }
+
+ convolution_op->padding_top = input_padding_top;
+ convolution_op->padding_right = input_padding_right;
+ convolution_op->padding_bottom = input_padding_bottom;
+ convolution_op->padding_left = input_padding_left;
+
+ convolution_op->kernel_height = kernel_height;
+ convolution_op->kernel_width = kernel_width;
+ convolution_op->stride_height = subsampling_height;
+ convolution_op->stride_width = subsampling_width;
+ convolution_op->dilation_height = dilation_height;
+ convolution_op->dilation_width = dilation_width;
+ convolution_op->groups = groups;
+ convolution_op->group_input_channels = group_input_channels;
+ convolution_op->group_output_channels = group_output_channels;
+ convolution_op->input_pixel_stride = input_pixel_stride;
+ convolution_op->output_pixel_stride = output_pixel_stride;
+
+ convolution_op->kernel_zero_point = kernel_zero_point;
+
+ convolution_op->q8_gemm_params =
+ xnn_compute_q8_gemm_params(
+ input_zero_point, kernel_zero_point,
+ convolution_scale, output_zero_point, output_min, output_max);
+
+ convolution_op->type = xnn_operator_type_convolution_q8;
+ convolution_op->ukernel.type = ukernel_type;
+
+ convolution_op->state = xnn_run_state_invalid;
+
+ *convolution_op_out = convolution_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(convolution_op);
+ return status;
+}
+
+enum xnn_status xnn_create_convolution2d_nhwc_f32(
+ uint32_t input_padding_top,
+ uint32_t input_padding_right,
+ uint32_t input_padding_bottom,
+ uint32_t input_padding_left,
+ uint32_t kernel_height,
+ uint32_t kernel_width,
+ uint32_t subsampling_height,
+ uint32_t subsampling_width,
+ uint32_t dilation_height,
+ uint32_t dilation_width,
+ uint32_t groups,
+ size_t group_input_channels,
+ size_t group_output_channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ const float* kernel,
+ const float* bias,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* convolution_op_out)
+{
+ xnn_operator_t convolution_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Convolution operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (kernel_width == 0 || kernel_height == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
+ kernel_width, kernel_height);
+ goto error;
+ }
+
+ if (subsampling_width == 0 || subsampling_height == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 "x%" PRIu32 " subsampling: "
+ "subsampling dimensions must be non-zero",
+ subsampling_width, subsampling_height);
+ goto error;
+ }
+
+ if (dilation_width == 0 || dilation_height == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 "x%" PRIu32 " dilation: "
+ "dilation dimensions must be non-zero",
+ dilation_width, dilation_height);
+ goto error;
+ }
+
+ if (groups == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %" PRIu32 " groups: number of groups must be non-zero", groups);
+ goto error;
+ }
+
+ if (group_input_channels == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %zu input channels per group: "
+ "number of channels must be non-zero",
+ group_input_channels);
+ goto error;
+ }
+
+ if (group_output_channels == 0) {
+ xnn_log_error(
+ "failed to create Convolution operator with %zu output channels per group: "
+ "number of channels must be non-zero",
+ group_output_channels);
+ goto error;
+ }
+
+ const size_t input_channels = groups * group_input_channels;
+ if (input_pixel_stride < input_channels) {
+ xnn_log_error(
+ "failed to create Convolution operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of input channels (%" PRIu32 "x%zu)",
+ input_pixel_stride, groups, group_input_channels);
+ goto error;
+ }
+
+ const size_t output_channels = groups * group_output_channels;
+ if (output_pixel_stride < output_channels) {
+ xnn_log_error(
+ "failed to create Convolution operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
+ output_pixel_stride, groups, group_output_channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Convolution operator with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Convolution operator with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Convolution operator with [%.7g, %.7g] output range: "
+ "lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ if ((flags & XNN_CONVOLUTION_FLAG_DEPTHWISE) != 0 && group_input_channels != 1) {
+ xnn_log_error(
+ "failed to create Depthwise Convolution operator with %zu input channels per group: "
+ "Depthwise Convolution must have exactly 1 input channel per group",
+ group_input_channels);
+ goto error;
+ }
+
+ const uint32_t effective_kernel_height = (kernel_height - 1) * dilation_height + 1;
+ const uint32_t effective_kernel_width = (kernel_width - 1) * dilation_width + 1;
+
+ if (input_padding_top >= effective_kernel_height) {
+ xnn_log_info(
+ "inefficiency in Convolution operator with %" PRIu32 "x%" PRIu32 " effective kernel and %" PRIu32 "+%" PRIu32 " height padding: "
+ "input top padding is greater or equal to effective kernel height",
+ effective_kernel_width, effective_kernel_height, input_padding_top, input_padding_bottom);
+ }
+
+ if (input_padding_bottom >= effective_kernel_height) {
+ xnn_log_info(
+ "inefficiency in Convolution operator with %" PRIu32 "x%" PRIu32 " effective kernel and %" PRIu32 "+%" PRIu32 " height padding: "
+ "input bottom padding is greater or equal to effective kernel height",
+ effective_kernel_width, effective_kernel_height, input_padding_top, input_padding_bottom);
+ }
+
+ if (input_padding_right >= effective_kernel_width) {
+ xnn_log_info(
+ "inefficiency in Convolution operator with %" PRIu32 "x%" PRIu32 " effective kernel and %" PRIu32 "+%" PRIu32 " width padding: "
+ "input right padding is greater or equal to effective kernel width",
+ effective_kernel_width, effective_kernel_height, input_padding_left, input_padding_right);
+ }
+
+ if (input_padding_left >= effective_kernel_width) {
+ xnn_log_info(
+ "inefficiency in Convolution operator with %" PRIu32 "x%" PRIu32 " effective kernel and %" PRIu32 "+%" PRIu32 " width padding: "
+ "input left padding is greater or equal to effective kernel width",
+ effective_kernel_width, effective_kernel_height, input_padding_left, input_padding_right);
+ }
+
+ status = xnn_status_out_of_memory;
+
+ convolution_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (convolution_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Convolution operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ const size_t kernel_size = kernel_height * kernel_width;
+
+ enum xnn_ukernel_type ukernel_type = xnn_ukernel_type_none;
+ const struct dwconv_parameters* dwconv_parameters = NULL;
+ const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
+ const bool unit_subsampling = (subsampling_width | subsampling_height) == 1;
+ if (group_input_channels == 1 && group_output_channels == 1 && kernel_size == 1 && unit_subsampling && !any_padding) {
+ ukernel_type = xnn_ukernel_type_vmulcaddc;
+ } else if (group_input_channels == 1 && group_output_channels == 1 && (dwconv_parameters =
+ find_dwigemm_ukernel(kernel_size, xnn_params.f32.dwconv, XNN_MAX_F32_DWCONV_UKERNELS)) != NULL)
+ {
+ ukernel_type = xnn_ukernel_type_dwconv;
+ } else if (kernel_size == 1 && unit_subsampling && !any_padding) {
+ ukernel_type = xnn_ukernel_type_gemm;
+ } else {
+ ukernel_type = xnn_ukernel_type_igemm;
+ }
+
+ size_t zero_size = 0;
+ switch (ukernel_type) {
+ case xnn_ukernel_type_vmulcaddc:
+ {
+ const uint32_t c_stride = round_up_po2(groups, xnn_params.f32.vmulcaddc.cr);
+ const size_t packed_weights_size = 2 * sizeof(float) * c_stride;
+ convolution_op->packed_weights = xnn_allocate_memory(packed_weights_size);
+ if (convolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_weights_size);
+ goto error;
+ }
+
+ xnn_pack_f32_vmulcaddc_w(
+ groups, xnn_params.f32.vmulcaddc.cr,
+ kernel, bias, convolution_op->packed_weights);
+
+ convolution_op->ukernel.vmulcaddc = (struct xnn_ukernel_vmulcaddc) {
+ .function = xnn_params.f32.vmulcaddc.ukernel,
+ .mr = xnn_params.f32.vmulcaddc.mr,
+ };
+ break;
+ }
+ case xnn_ukernel_type_dwconv:
+ {
+ assert(dwconv_parameters != NULL);
+ assert(dwconv_parameters->mr == kernel_size);
+
+ const uint32_t c_stride = round_up_po2(groups, dwconv_parameters->cr);
+ const size_t packed_weights_size = (kernel_size + 1) * sizeof(float) * c_stride;
+ convolution_op->packed_weights = xnn_allocate_memory(packed_weights_size);
+ if (convolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_weights_size);
+ goto error;
+ }
+
+ if (flags & XNN_CONVOLUTION_FLAG_DEPTHWISE) {
+ xnn_pack_f32_dwconv_hwg_w(
+ kernel_height, kernel_width,
+ groups, dwconv_parameters->cr,
+ kernel, bias, convolution_op->packed_weights);
+ } else {
+ xnn_pack_f32_dwconv_ghw_w(
+ kernel_height, kernel_width,
+ groups, dwconv_parameters->cr,
+ kernel, bias, convolution_op->packed_weights);
+ }
+
+ convolution_op->ukernel.dwconv = (struct xnn_ukernel_dwconv) {
+ .unipass_function = dwconv_parameters->up,
+ .mr = dwconv_parameters->mr,
+ .qr = dwconv_parameters->qr,
+ };
+
+ zero_size = sizeof(float) * c_stride;
+ break;
+ }
+ case xnn_ukernel_type_gemm:
+ case xnn_ukernel_type_igemm:
+ {
+ const uint32_t nr = xnn_params.f32.gemm.nr;
+ const uint32_t kr = UINT32_C(1) << xnn_params.f32.gemm.log2_kr;
+ const uint32_t sr = UINT32_C(1) << xnn_params.f32.gemm.log2_sr;
+ const uint32_t n_stride = round_up(group_output_channels, nr);
+ const uint32_t k_stride = round_up_po2(group_input_channels, kr);
+
+ const size_t packed_group_weights_size = (kernel_size * k_stride + 1) * sizeof(float) * n_stride;
+ convolution_op->packed_weights = xnn_allocate_memory(packed_group_weights_size * groups);
+ if (convolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_group_weights_size * groups);
+ goto error;
+ }
+ memset(convolution_op->packed_weights, 0, packed_group_weights_size * groups);
+
+ switch (ukernel_type) {
+ case xnn_ukernel_type_gemm:
+ xnn_pack_f32_gemm_goi_w(
+ groups, group_output_channels, group_input_channels,
+ nr, kr, sr,
+ kernel, bias, convolution_op->packed_weights);
+ convolution_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
+ .mr = xnn_params.f32.gemm.mr,
+ .nr = nr,
+ .kr = kr,
+ .default_function = xnn_params.f32.gemm.gemm,
+ .mr1_function = xnn_params.f32.gemm.gemm1,
+ };
+ break;
+ case xnn_ukernel_type_igemm:
+ if (flags & XNN_CONVOLUTION_FLAG_DEPTHWISE) {
+ xnn_pack_f32_conv_kgo_w(
+ groups, group_output_channels, kernel_size,
+ nr, kr,
+ kernel, bias, convolution_op->packed_weights);
+ } else {
+ xnn_pack_f32_conv_goki_w(
+ groups, group_output_channels, kernel_size, group_input_channels,
+ nr, kr, sr,
+ kernel, bias, convolution_op->packed_weights);
+ }
+ convolution_op->ukernel.igemm = (struct xnn_ukernel_igemm) {
+ .mr = xnn_params.f32.gemm.mr,
+ .nr = nr,
+ .kr = kr,
+ .default_function = xnn_params.f32.gemm.igemm,
+ .mr1_function = xnn_params.f32.gemm.igemm1,
+ };
+ break;
+ default:
+ XNN_UNREACHABLE;
+ }
+
+ zero_size = sizeof(float) * k_stride;
+ break;
+ }
+ default:
+ XNN_UNREACHABLE;
+ }
+
+ if (any_padding) {
+ void* zero_buffer = xnn_allocate_zero_memory(zero_size);
+ if (zero_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for zero padding", zero_size);
+ goto error;
+ }
+ convolution_op->zero_buffer = zero_buffer;
+ }
+
+ convolution_op->padding_top = input_padding_top;
+ convolution_op->padding_right = input_padding_right;
+ convolution_op->padding_bottom = input_padding_bottom;
+ convolution_op->padding_left = input_padding_left;
+
+ convolution_op->kernel_height = kernel_height;
+ convolution_op->kernel_width = kernel_width;
+ convolution_op->stride_height = subsampling_height;
+ convolution_op->stride_width = subsampling_width;
+ convolution_op->dilation_height = dilation_height;
+ convolution_op->dilation_width = dilation_width;
+ convolution_op->groups = groups;
+ convolution_op->group_input_channels = group_input_channels;
+ convolution_op->group_output_channels = group_output_channels;
+ convolution_op->input_pixel_stride = input_pixel_stride;
+ convolution_op->output_pixel_stride = output_pixel_stride;
+
+ convolution_op->f32_output_params = xnn_compute_f32_output_params(output_min, output_max);
+
+ convolution_op->type = xnn_operator_type_convolution_f32;
+ convolution_op->ukernel.type = ukernel_type;
+
+ convolution_op->state = xnn_run_state_invalid;
+
+ *convolution_op_out = convolution_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(convolution_op);
+ return status;
+}
+
+static enum xnn_status setup_convolution2d_nhwc(
+ xnn_operator_t convolution_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const void* input,
+ void* output,
+ uint32_t log2_input_element_size,
+ uint32_t log2_filter_element_size,
+ uint32_t bias_element_size,
+ uint32_t log2_output_element_size,
+ const void* params,
+ size_t num_threads)
+{
+ convolution_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Convolution operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_width == 0 || input_height == 0) {
+ xnn_log_error(
+ "failed to setup Convolution operator with %zux%zu input: input dimensions must be non-zero",
+ input_width, input_height);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ convolution_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ convolution_op->batch_size = batch_size;
+ convolution_op->input_height = input_height;
+ convolution_op->input_width = input_width;
+ convolution_op->input = input;
+
+ convolution_op->output_height = compute_output_dimension(
+ convolution_op->padding_top + input_height + convolution_op->padding_bottom,
+ convolution_op->kernel_height,
+ convolution_op->dilation_height,
+ convolution_op->stride_height);
+ convolution_op->output_width = compute_output_dimension(
+ convolution_op->padding_left + input_width + convolution_op->padding_right,
+ convolution_op->kernel_width,
+ convolution_op->dilation_width,
+ convolution_op->stride_width);
+ convolution_op->output = output;
+
+ switch (convolution_op->ukernel.type) {
+ case xnn_ukernel_type_gemm:
+ {
+ // Convolution maps directly to GEMM and doesn't use indirection buffer.
+
+ const size_t output_height = convolution_op->output_height;
+ const size_t output_width = convolution_op->output_width;
+ const size_t output_size = output_height * output_width;
+ const size_t batch_output_size = batch_size * output_size;
+
+ const size_t groups = convolution_op->groups;
+ const size_t group_input_channels = convolution_op->group_input_channels;
+ const size_t w_stride = (round_up_po2(group_input_channels, convolution_op->ukernel.gemm.kr) << log2_filter_element_size) + bias_element_size;
+ const size_t group_output_channels = convolution_op->group_output_channels;
+
+ uint32_t mr = convolution_op->ukernel.gemm.mr;
+ const uint32_t nr = convolution_op->ukernel.gemm.nr;
+ xnn_gemm_ukernel_function gemm_ukernel = convolution_op->ukernel.gemm.default_function;
+ if (batch_output_size == 1 && convolution_op->ukernel.gemm.mr1_function != NULL) {
+ mr = 1;
+ gemm_ukernel = convolution_op->ukernel.gemm.mr1_function;
+ }
+
+ convolution_op->context.gemm = (struct gemm_context) {
+ .k_scaled = group_input_channels << log2_input_element_size,
+ .a = input,
+ .a_stride = convolution_op->input_pixel_stride << log2_input_element_size,
+ .packed_w = convolution_op->packed_weights,
+ .w_stride = w_stride,
+ .wg_stride = w_stride * round_up(group_output_channels, nr),
+ .c = output,
+ .cm_stride = convolution_op->output_pixel_stride << log2_output_element_size,
+ .cn_stride = nr << log2_output_element_size,
+ .cg_stride = group_output_channels << log2_output_element_size,
+ .log2_csize = log2_output_element_size,
+ .ukernel = gemm_ukernel,
+ };
+ memcpy(&convolution_op->context.gemm.params, params, sizeof(convolution_op->context.gemm.params));
+
+ size_t nc = group_output_channels;
+ if (num_threads > 1) {
+ const size_t num_other_tiles = groups * divide_round_up(batch_output_size, mr);
+ const size_t target_tiles_per_thread = 5;
+ const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
+ if (max_nc < nc) {
+ nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
+ }
+ }
+ if (groups == 1) {
+ convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
+ convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
+ convolution_op->compute.range[0] = batch_output_size;
+ convolution_op->compute.range[1] = group_output_channels;
+ convolution_op->compute.tile[0] = mr;
+ convolution_op->compute.tile[1] = nc;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_ggemm;
+ convolution_op->compute.range[0] = groups;
+ convolution_op->compute.range[1] = batch_output_size;
+ convolution_op->compute.range[2] = group_output_channels;
+ convolution_op->compute.tile[0] = mr;
+ convolution_op->compute.tile[1] = nc;
+ }
+ convolution_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+ }
+ case xnn_ukernel_type_igemm:
+ {
+ const size_t groups = convolution_op->groups;
+ const size_t kernel_height = convolution_op->kernel_height;
+ const size_t kernel_width = convolution_op->kernel_width;
+ const size_t kernel_size = kernel_height * kernel_width;
+ const size_t output_height = convolution_op->output_height;
+ const size_t output_width = convolution_op->output_width;
+ const size_t output_size = output_height * output_width;
+
+ uint32_t mr = convolution_op->ukernel.igemm.mr;
+ const uint32_t nr = convolution_op->ukernel.igemm.nr;
+ xnn_igemm_ukernel_function igemm_ukernel = convolution_op->ukernel.igemm.default_function;
+ if (output_size == 1 && convolution_op->ukernel.igemm.mr1_function != NULL) {
+ mr = 1;
+ igemm_ukernel = convolution_op->ukernel.igemm.mr1_function;
+ }
+
+ const size_t tiled_output_size = round_up(output_size, mr);
+ const size_t indirection_buffer_size = sizeof(void*) * kernel_size * tiled_output_size;
+
+ if (input_height != convolution_op->last_input_height ||
+ input_width != convolution_op->last_input_width)
+ {
+ const void** indirection_buffer = (const void**) realloc(convolution_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ convolution_op->indirection_buffer = indirection_buffer;
+ convolution_op->last_input = input;
+ convolution_op->last_input_height = input_height;
+ convolution_op->last_input_width = input_width;
+
+ xnn_indirection_init_conv2d(convolution_op, mr, log2_input_element_size);
+ }
+
+ const size_t group_input_channels = convolution_op->group_input_channels;
+ const size_t w_stride = (round_up_po2(group_input_channels, convolution_op->ukernel.igemm.kr) * kernel_size << log2_filter_element_size) + bias_element_size;
+ const size_t group_output_channels = convolution_op->group_output_channels;
+ convolution_op->context.igemm = (struct igemm_context) {
+ .ks = kernel_size,
+ .ks_scaled = kernel_size * mr * sizeof(void*),
+ .kc = group_input_channels << log2_input_element_size,
+ .w_stride = w_stride,
+ .indirect_a = convolution_op->indirection_buffer,
+ .a_offset = (size_t) ((uintptr_t) input - (uintptr_t) convolution_op->last_input),
+ .zero = convolution_op->zero_buffer,
+ .packed_w = convolution_op->packed_weights,
+ .c = convolution_op->output,
+ .cm_stride = convolution_op->output_pixel_stride << log2_output_element_size,
+ .cn_stride = nr << log2_output_element_size,
+ .ga_stride = group_input_channels << log2_input_element_size,
+ .gw_stride = w_stride * round_up(group_output_channels, nr),
+ .gc_stride = group_output_channels << log2_output_element_size,
+ .ba_stride = input_height * input_width * convolution_op->input_pixel_stride << log2_input_element_size,
+ .bc_stride = output_size * convolution_op->output_pixel_stride << log2_output_element_size,
+ .log2_csize = log2_output_element_size,
+ .ukernel = igemm_ukernel,
+ };
+ memcpy(&convolution_op->context.igemm.params, params, sizeof(convolution_op->context.igemm.params));
+
+ size_t nc = group_output_channels;
+ if (num_threads > 1) {
+ const size_t num_other_tiles = groups * batch_size * divide_round_up(output_size, mr);
+ const size_t target_tiles_per_thread = 5;
+ const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
+ if (max_nc < nc) {
+ nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
+ }
+ }
+ if (groups == 1) {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_igemm;
+ convolution_op->compute.range[0] = batch_size;
+ convolution_op->compute.range[1] = output_size;
+ convolution_op->compute.range[2] = group_output_channels;
+ convolution_op->compute.tile[0] = mr;
+ convolution_op->compute.tile[1] = nc;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
+ convolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_gigemm;
+ convolution_op->compute.range[0] = batch_size;
+ convolution_op->compute.range[1] = groups;
+ convolution_op->compute.range[2] = output_size;
+ convolution_op->compute.range[3] = group_output_channels;
+ convolution_op->compute.tile[0] = mr;
+ convolution_op->compute.tile[1] = nc;
+ }
+ convolution_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+ }
+ case xnn_ukernel_type_dwconv:
+ {
+ size_t valid_batch_size = 0;
+ if (input == convolution_op->last_input &&
+ input_height == convolution_op->last_input_height &&
+ input_width == convolution_op->last_input_width)
+ {
+ valid_batch_size = convolution_op->valid_batch_size;
+ if (batch_size <= valid_batch_size) {
+ convolution_op->compute.range[0] = batch_size * convolution_op->output_height;
+ convolution_op->state = xnn_run_state_ready;
+ return xnn_status_success;
+ }
+ }
+
+ const size_t kernel_height = convolution_op->kernel_height;
+ const size_t kernel_width = convolution_op->kernel_width;
+ const size_t kernel_size = kernel_height * kernel_width;
+ const size_t output_height = convolution_op->output_height;
+ const size_t output_width = convolution_op->output_width;
+ const size_t step_width = convolution_op->dilation_width == 1 ? convolution_op->stride_width : kernel_width;
+ const size_t step_height = kernel_size + (output_width * step_width - 1) * kernel_height;
+ const size_t indirection_buffer_size = sizeof(void*) * batch_size * output_height * step_height;
+
+ const void** indirection_buffer =
+ (const void**) realloc(convolution_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ convolution_op->indirection_buffer = indirection_buffer;
+
+ xnn_indirection_init_dwconv2d(convolution_op, valid_batch_size, step_height, step_width, log2_input_element_size);
+
+ const size_t groups = convolution_op->groups;
+ convolution_op->context.dwconv = (struct dwconv_context) {
+ .groups = groups,
+ .indirection_buffer = convolution_op->indirection_buffer,
+ .indirection_buffer_row_stride = step_height,
+ .indirection_buffer_col_stride = kernel_height * step_width * sizeof(void*),
+ .packed_weights = convolution_op->packed_weights,
+ .output = convolution_op->output,
+ .output_width = output_width,
+ .output_row_stride = output_width * convolution_op->output_pixel_stride << log2_output_element_size,
+ .output_col_increment = (convolution_op->output_pixel_stride - groups) << log2_output_element_size,
+ .unipass_ukernel = convolution_op->ukernel.dwconv.unipass_function,
+ };
+ memcpy(&convolution_op->context.dwconv.params, params, sizeof(convolution_op->context.dwconv.params));
+
+ convolution_op->compute.type = xnn_parallelization_type_1d;
+ convolution_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_dwconv_unipass;
+ convolution_op->compute.range[0] = batch_size * output_height;
+ convolution_op->state = xnn_run_state_ready;
+
+ convolution_op->last_input = input;
+ convolution_op->last_input_height = input_height;
+ convolution_op->last_input_width = input_width;
+ convolution_op->valid_batch_size = max(valid_batch_size, batch_size);
+
+ return xnn_status_success;
+ }
+ case xnn_ukernel_type_vmulcaddc:
+ {
+ const size_t batch_output_size = batch_size * convolution_op->output_height * convolution_op->output_width;
+
+ convolution_op->context.vmulcaddc = (struct vmulcaddc_context) {
+ .n = convolution_op->groups << log2_input_element_size,
+ .x = input,
+ .x_stride = convolution_op->input_pixel_stride << log2_input_element_size,
+ .w = convolution_op->packed_weights,
+ .y = output,
+ .y_stride = convolution_op->output_pixel_stride << log2_output_element_size,
+ .ukernel = convolution_op->ukernel.vmulcaddc.function,
+ };
+ memcpy(&convolution_op->context.vmulcaddc.params, params, sizeof(convolution_op->context.vmulcaddc.params));
+
+ size_t mc = batch_output_size;
+ if (num_threads > 1) {
+ const size_t target_tiles_per_thread = 5;
+ const size_t max_mc = divide_round_up(batch_output_size, num_threads * target_tiles_per_thread);
+ if (max_mc < mc) {
+ const uint32_t mr = convolution_op->ukernel.vmulcaddc.mr;
+ mc = min(mc, divide_round_up(mc, max_mc * mr) * mr);
+ }
+ }
+ convolution_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ convolution_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_vmulcaddc;
+ convolution_op->compute.range[0] = batch_output_size;
+ convolution_op->compute.tile[0] = mc;
+ convolution_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+ }
+ default:
+ XNN_UNREACHABLE;
+ }
+}
+
+enum xnn_status xnn_setup_convolution2d_nhwc_q8(
+ xnn_operator_t convolution_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (convolution_op->type != xnn_operator_type_convolution_q8) {
+ xnn_log_error("failed to setup Convolution (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_convolution2d_nhwc(
+ convolution_op,
+ batch_size, input_height, input_width,
+ input, output,
+ 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
+ 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
+ sizeof(int32_t) /* sizeof(bias element) */,
+ 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
+ &convolution_op->q8_gemm_params,
+ pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_convolution2d_nhwc_f32(
+ xnn_operator_t convolution_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (convolution_op->type != xnn_operator_type_convolution_f32) {
+ xnn_log_error("failed to setup Convolution (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_convolution2d_nhwc(
+ convolution_op,
+ batch_size, input_height, input_width,
+ input, output,
+ 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
+ 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
+ sizeof(float) /* sizeof(bias element) */,
+ 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
+ &convolution_op->f32_output_params,
+ pthreadpool_get_threads_count(threadpool));
+}
diff --git a/src/deconvolution.c b/src/deconvolution.c
new file mode 100644
index 0000000..9bce03a
--- /dev/null
+++ b/src/deconvolution.c
@@ -0,0 +1,905 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <string.h>
+#include <math.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/requantization.h>
+#include <xnnpack/log.h>
+#include <xnnpack/math.h>
+#include <xnnpack/pack.h>
+#include <xnnpack/params.h>
+#include <xnnpack/indirection.h>
+
+
+static inline size_t compute_output_dimension(
+ size_t input_dimension,
+ size_t output_padding_dimension,
+ size_t adjustment_dimension,
+ size_t kernel_dimension,
+ size_t dilation_dimension,
+ size_t stride_dimension)
+{
+ const size_t effective_kernel_dimension = (kernel_dimension - 1) * dilation_dimension + 1;
+ return doz(
+ stride_dimension * (input_dimension - 1) + adjustment_dimension + effective_kernel_dimension,
+ output_padding_dimension);
+}
+
+enum xnn_status xnn_create_deconvolution2d_nhwc_q8(
+ uint32_t output_padding_top,
+ uint32_t output_padding_right,
+ uint32_t output_padding_bottom,
+ uint32_t output_padding_left,
+ uint32_t adjustment_height,
+ uint32_t adjustment_width,
+ uint32_t kernel_height,
+ uint32_t kernel_width,
+ uint32_t stride_height,
+ uint32_t stride_width,
+ uint32_t dilation_height,
+ uint32_t dilation_width,
+ uint32_t groups,
+ size_t group_input_channels,
+ size_t group_output_channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ uint8_t input_zero_point,
+ float input_scale,
+ uint8_t kernel_zero_point,
+ float kernel_scale,
+ const uint8_t* kernel,
+ const int32_t* bias,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* deconvolution_op_out)
+{
+ xnn_operator_t deconvolution_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Deconvolution operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (kernel_width == 0 || kernel_height == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
+ kernel_width, kernel_height);
+ goto error;
+ }
+
+ if (stride_width == 0 || stride_height == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %" PRIu32 "x%" PRIu32 " stride: stride dimensions must be non-zero",
+ stride_width, stride_height);
+ goto error;
+ }
+
+ if (dilation_width == 0 || dilation_height == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %" PRIu32 "x%" PRIu32 " dilation: "
+ "dilation dimensions must be non-zero",
+ dilation_width, dilation_height);
+ goto error;
+ }
+
+ if (groups == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %" PRIu32 " groups: number of groups must be non-zero", groups);
+ goto error;
+ }
+
+ if (group_input_channels == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %zu input channels per group: "
+ "number of channels must be non-zero",
+ group_input_channels);
+ goto error;
+ }
+
+ if (group_output_channels == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %zu output channels per group: "
+ "number of channels must be non-zero",
+ group_output_channels);
+ goto error;
+ }
+
+ const size_t input_channels = groups * group_input_channels;
+ if (input_pixel_stride < input_channels) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
+ input_pixel_stride, groups, group_input_channels);
+ goto error;
+ }
+
+ const size_t output_channels = groups * group_output_channels;
+ if (output_pixel_stride < output_channels) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
+ output_pixel_stride, groups, group_output_channels);
+ goto error;
+ }
+
+ if (input_scale <= 0.0f || !isnormal(input_scale)) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %.7g input scale: scale must be finite, normalized, and positive",
+ input_scale);
+ goto error;
+ }
+
+ if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %.7g kernel scale: scale must be finite, normalized, and positive",
+ kernel_scale);
+ goto error;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %.7g output scale: scale must be finite, normalized, and positive",
+ output_scale);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with [%" PRIu8 ", %" PRIu8 "] output range: "
+ "range min must be below range max",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ const float deconvolution_scale = input_scale * kernel_scale / output_scale;
+ if (deconvolution_scale >= 1.0f) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
+ "Deconvolution operator scale %.7g is greater or equal to 1.0",
+ input_scale, kernel_scale, output_scale, deconvolution_scale);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ deconvolution_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (deconvolution_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Deconvolution operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ const uint32_t mr = xnn_params.q8.gemm.mr;
+ const uint32_t nr = xnn_params.q8.gemm.nr;
+ const uint32_t kr = UINT32_C(1) << xnn_params.q8.gemm.log2_kr;
+ const xnn_igemm_ukernel_function ukernel_function = xnn_params.q8.gemm.igemm;
+
+ const uint32_t n_stride = round_up(group_output_channels, nr);
+ const uint32_t k_stride = round_up_po2(group_input_channels, kr);
+ const uint32_t kernel_size = kernel_height * kernel_width;
+ enum xnn_ukernel_type ukernel_type = xnn_ukernel_type_igemm;
+ size_t packed_group_weights_size = (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) * n_stride;
+ if (max(stride_height, stride_width) > 1 && max(dilation_height, dilation_width) == 1 && stride_width <= kernel_width && stride_height <= kernel_height) {
+ ukernel_type = xnn_ukernel_type_subconv2d;
+ const size_t subkernels = stride_height * stride_width;
+ packed_group_weights_size = n_stride *
+ (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t) * subkernels);
+
+ const size_t subconvolution_buffer_size = sizeof(struct subconvolution_params) * subkernels;
+ deconvolution_op->subconvolution_buffer = xnn_allocate_zero_memory(subconvolution_buffer_size);
+ if (deconvolution_op->subconvolution_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for subconvolution buffer", subconvolution_buffer_size);
+ goto error;
+ }
+
+ struct subconvolution_params* subconvolution_params = deconvolution_op->subconvolution_buffer;
+ for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
+ for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
+ const size_t subkernel_height = divide_round_up(kernel_height - offset_y, stride_height);
+ const size_t subkernel_width = divide_round_up(kernel_width - offset_x, stride_width);
+ const size_t subkernel_size = subkernel_height * subkernel_width;
+
+ subconvolution_params->indirection_x_stride = sizeof(void*) * subkernel_size;
+ subconvolution_params->w_stride = sizeof(int32_t) + k_stride * subkernel_size * sizeof(uint8_t);
+ subconvolution_params++;
+ }
+ }
+ }
+ deconvolution_op->packed_weights = xnn_allocate_memory(packed_group_weights_size * groups);
+ if (deconvolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_group_weights_size * groups);
+ goto error;
+ }
+ memset(deconvolution_op->packed_weights, kernel_zero_point, packed_group_weights_size * groups);
+
+ switch (ukernel_type) {
+ case xnn_ukernel_type_igemm:
+ xnn_pack_q8_conv_goki_w(
+ groups, group_output_channels, kernel_size, group_input_channels,
+ nr, kr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias, deconvolution_op->packed_weights);
+ break;
+ case xnn_ukernel_type_subconv2d:
+ xnn_pack_q8_deconv_goki_w(
+ groups, group_output_channels, kernel_height, kernel_width, group_input_channels,
+ stride_height, stride_width,
+ nr, kr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias, deconvolution_op->packed_weights, deconvolution_op->subconvolution_buffer);
+ break;
+ default:
+ XNN_UNREACHABLE;
+ }
+
+ size_t zero_size = sizeof(uint8_t) * k_stride + XNN_EXTRA_BYTES;
+ void* zero_buffer = xnn_allocate_memory(zero_size);
+ if (zero_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for zero padding", zero_size);
+ goto error;
+ }
+ memset(zero_buffer, input_zero_point, zero_size);
+ deconvolution_op->zero_buffer = zero_buffer;
+
+ deconvolution_op->padding_top = output_padding_top;
+ deconvolution_op->padding_right = output_padding_right;
+ deconvolution_op->padding_bottom = output_padding_bottom;
+ deconvolution_op->padding_left = output_padding_left;
+ deconvolution_op->adjustment_height = adjustment_height;
+ deconvolution_op->adjustment_width = adjustment_width;
+
+ deconvolution_op->kernel_height = kernel_height;
+ deconvolution_op->kernel_width = kernel_width;
+ deconvolution_op->stride_height = stride_height;
+ deconvolution_op->stride_width = stride_width;
+ deconvolution_op->dilation_height = dilation_height;
+ deconvolution_op->dilation_width = dilation_width;
+ deconvolution_op->groups = groups;
+ deconvolution_op->group_input_channels = group_input_channels;
+ deconvolution_op->group_output_channels = group_output_channels;
+ deconvolution_op->input_pixel_stride = input_pixel_stride;
+ deconvolution_op->output_pixel_stride = output_pixel_stride;
+
+ deconvolution_op->kernel_zero_point = kernel_zero_point;
+
+ deconvolution_op->q8_gemm_params =
+ xnn_compute_q8_gemm_params(
+ input_zero_point, kernel_zero_point,
+ deconvolution_scale, output_zero_point, output_min, output_max);
+
+ deconvolution_op->type = xnn_operator_type_deconvolution_q8;
+ deconvolution_op->ukernel.type = ukernel_type;
+ deconvolution_op->ukernel.igemm = (struct xnn_ukernel_igemm) {
+ .default_function = ukernel_function,
+ .mr = mr,
+ .nr = nr,
+ .kr = kr,
+ };
+
+ deconvolution_op->state = xnn_run_state_invalid;
+
+ *deconvolution_op_out = deconvolution_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(deconvolution_op);
+ return status;
+}
+
+enum xnn_status xnn_create_deconvolution2d_nhwc_f32(
+ uint32_t output_padding_top,
+ uint32_t output_padding_right,
+ uint32_t output_padding_bottom,
+ uint32_t output_padding_left,
+ uint32_t adjustment_height,
+ uint32_t adjustment_width,
+ uint32_t kernel_height,
+ uint32_t kernel_width,
+ uint32_t stride_height,
+ uint32_t stride_width,
+ uint32_t dilation_height,
+ uint32_t dilation_width,
+ uint32_t groups,
+ size_t group_input_channels,
+ size_t group_output_channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ const float* kernel,
+ const float* bias,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* deconvolution_op_out)
+{
+ xnn_operator_t deconvolution_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Deconvolution operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (kernel_width == 0 || kernel_height == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
+ kernel_width, kernel_height);
+ goto error;
+ }
+
+ if (stride_width == 0 || stride_height == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %" PRIu32 "x%" PRIu32 " stride: stride dimensions must be non-zero",
+ stride_width, stride_height);
+ goto error;
+ }
+
+ if (dilation_width == 0 || dilation_height == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %" PRIu32 "x%" PRIu32 " dilation: "
+ "dilation dimensions must be non-zero",
+ dilation_width, dilation_height);
+ goto error;
+ }
+
+ if (groups == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %" PRIu32 " groups: number of groups must be non-zero", groups);
+ goto error;
+ }
+
+ if (group_input_channels == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %zu input channels per group: "
+ "number of channels must be non-zero",
+ group_input_channels);
+ goto error;
+ }
+
+ if (group_output_channels == 0) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with %zu output channels per group: "
+ "number of channels must be non-zero",
+ group_output_channels);
+ goto error;
+ }
+
+ const size_t input_channels = groups * group_input_channels;
+ if (input_pixel_stride < input_channels) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
+ input_pixel_stride, groups, group_input_channels);
+ goto error;
+ }
+
+ const size_t output_channels = groups * group_output_channels;
+ if (output_pixel_stride < output_channels) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
+ output_pixel_stride, groups, group_output_channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Deconvolution operator with [%.7g, %.7g] output range: "
+ "lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ deconvolution_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (deconvolution_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Deconvolution operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ uint32_t mr = xnn_params.f32.gemm.mr;
+ uint32_t nr = xnn_params.f32.gemm.nr;
+ uint32_t kr = UINT32_C(1) << xnn_params.f32.gemm.log2_kr;
+ uint32_t sr = UINT32_C(1) << xnn_params.f32.gemm.log2_sr;
+ xnn_igemm_ukernel_function ukernel_function = xnn_params.f32.gemm.igemm;
+ if (nr > group_output_channels) {
+ // Default micro-kernel is suboptimal. Try to find a better micro-kernel.
+ if (xnn_params.f32.gemm2.igemm != NULL) {
+ mr = xnn_params.f32.gemm2.mr;
+ nr = xnn_params.f32.gemm2.nr;
+ kr = UINT32_C(1) << xnn_params.f32.gemm2.log2_kr;
+ sr = UINT32_C(1) << xnn_params.f32.gemm2.log2_sr;
+ ukernel_function = xnn_params.f32.gemm2.igemm;
+ }
+ }
+
+ const uint32_t n_stride = round_up(group_output_channels, nr);
+ const uint32_t k_stride = round_up_po2(group_input_channels, kr);
+ const uint32_t kernel_size = kernel_height * kernel_width;
+ enum xnn_ukernel_type ukernel_type = xnn_ukernel_type_igemm;
+ size_t packed_group_weights_size = (sizeof(float) * kernel_size * k_stride + sizeof(float)) * n_stride;
+ if (max(stride_height, stride_width) > 1 && max(dilation_height, dilation_width) == 1 && stride_width <= kernel_width && stride_height <= kernel_height) {
+ ukernel_type = xnn_ukernel_type_subconv2d;
+ const size_t subkernels = stride_height * stride_width;
+ packed_group_weights_size = n_stride *
+ (sizeof(float) * kernel_size * k_stride + sizeof(float) * subkernels);
+
+ const size_t subconvolution_buffer_size = sizeof(struct subconvolution_params) * subkernels;
+ deconvolution_op->subconvolution_buffer = xnn_allocate_zero_memory(subconvolution_buffer_size);
+ if (deconvolution_op->subconvolution_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for subconvolution buffer", subconvolution_buffer_size);
+ goto error;
+ }
+
+ struct subconvolution_params* subconvolution_params = deconvolution_op->subconvolution_buffer;
+ for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
+ for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
+ const size_t subkernel_height = divide_round_up(kernel_height - offset_y, stride_height);
+ const size_t subkernel_width = divide_round_up(kernel_width - offset_x, stride_width);
+ const size_t subkernel_size = subkernel_height * subkernel_width;
+
+ subconvolution_params->indirection_x_stride = sizeof(void*) * subkernel_size;
+ subconvolution_params->w_stride = sizeof(float) + k_stride * subkernel_size * sizeof(float);
+ subconvolution_params++;
+ }
+ }
+ }
+ deconvolution_op->packed_weights = xnn_allocate_memory(packed_group_weights_size * groups);
+ if (deconvolution_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights", packed_group_weights_size * groups);
+ goto error;
+ }
+ memset(deconvolution_op->packed_weights, 0, packed_group_weights_size * groups);
+
+ switch (ukernel_type) {
+ case xnn_ukernel_type_igemm:
+ xnn_pack_f32_conv_goki_w(
+ groups, group_output_channels, kernel_size, group_input_channels,
+ nr, kr, sr,
+ kernel, bias, deconvolution_op->packed_weights);
+ break;
+ case xnn_ukernel_type_subconv2d:
+ xnn_pack_f32_deconv_goki_w(
+ groups, group_output_channels, kernel_height, kernel_width, group_input_channels,
+ stride_height, stride_width,
+ nr, kr,
+ kernel, bias, deconvolution_op->packed_weights, deconvolution_op->subconvolution_buffer);
+ break;
+ default:
+ XNN_UNREACHABLE;
+ }
+
+ const size_t zero_size = k_stride * sizeof(float) + XNN_EXTRA_BYTES;
+ void* zero_buffer = xnn_allocate_zero_memory(zero_size);
+ if (zero_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for zero padding", zero_size);
+ goto error;
+ }
+ deconvolution_op->zero_buffer = zero_buffer;
+
+ deconvolution_op->padding_top = output_padding_top;
+ deconvolution_op->padding_right = output_padding_right;
+ deconvolution_op->padding_bottom = output_padding_bottom;
+ deconvolution_op->padding_left = output_padding_left;
+ deconvolution_op->adjustment_height = adjustment_height;
+ deconvolution_op->adjustment_width = adjustment_width;
+
+ deconvolution_op->kernel_height = kernel_height;
+ deconvolution_op->kernel_width = kernel_width;
+ deconvolution_op->stride_height = stride_height;
+ deconvolution_op->stride_width = stride_width;
+ deconvolution_op->dilation_height = dilation_height;
+ deconvolution_op->dilation_width = dilation_width;
+ deconvolution_op->groups = groups;
+ deconvolution_op->group_input_channels = group_input_channels;
+ deconvolution_op->group_output_channels = group_output_channels;
+ deconvolution_op->input_pixel_stride = input_pixel_stride;
+ deconvolution_op->output_pixel_stride = output_pixel_stride;
+
+ deconvolution_op->f32_output_params = xnn_compute_f32_output_params(output_min, output_max);
+
+ deconvolution_op->type = xnn_operator_type_deconvolution_f32;
+ deconvolution_op->ukernel.type = ukernel_type;
+ deconvolution_op->ukernel.igemm = (struct xnn_ukernel_igemm) {
+ .default_function = ukernel_function,
+ .mr = mr,
+ .nr = nr,
+ .kr = kr,
+ };
+
+ deconvolution_op->state = xnn_run_state_invalid;
+
+ *deconvolution_op_out = deconvolution_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(deconvolution_op);
+ return status;
+}
+
+static enum xnn_status setup_conv_path(
+ xnn_operator_t deconvolution_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const void* input,
+ size_t output_height,
+ size_t output_width,
+ void* output,
+ uint32_t log2_input_element_size,
+ uint32_t log2_filter_element_size,
+ uint32_t bias_element_size,
+ uint32_t log2_output_element_size,
+ const void* params,
+ size_t num_threads)
+{
+ assert(deconvolution_op->ukernel.type == xnn_ukernel_type_igemm);
+
+ const size_t kernel_height = deconvolution_op->kernel_height;
+ const size_t kernel_width = deconvolution_op->kernel_width;
+ const size_t kernel_size = kernel_height * kernel_width;
+
+ const size_t groups = deconvolution_op->groups;
+ const size_t output_size = output_height * output_width;
+ const size_t mr = deconvolution_op->ukernel.igemm.mr;
+ const size_t tiled_output_size = round_up(output_size, mr);
+ const size_t indirection_buffer_size = sizeof(void*) * kernel_size * tiled_output_size;
+
+ if (input_height != deconvolution_op->last_input_height ||
+ input_width != deconvolution_op->last_input_width)
+ {
+ const void** indirection_buffer = (const void**) realloc(deconvolution_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ deconvolution_op->indirection_buffer = indirection_buffer;
+ deconvolution_op->last_input = input;
+ deconvolution_op->last_input_height = input_height;
+ deconvolution_op->last_input_width = input_width;
+
+ xnn_indirection_init_deconv2d(deconvolution_op, mr, log2_input_element_size);
+ }
+
+ const size_t group_input_channels = deconvolution_op->group_input_channels;
+ const size_t group_output_channels = deconvolution_op->group_output_channels;
+ const uint32_t nr = deconvolution_op->ukernel.igemm.nr;
+ const size_t w_stride = bias_element_size +
+ (round_up_po2(group_input_channels, deconvolution_op->ukernel.igemm.kr) * kernel_size << log2_filter_element_size);
+ deconvolution_op->context.igemm = (struct igemm_context) {
+ .ks = kernel_size,
+ .ks_scaled = kernel_size * mr * sizeof(void*),
+ .kc = group_input_channels << log2_input_element_size,
+ .w_stride = w_stride,
+ .indirect_a = deconvolution_op->indirection_buffer,
+ .a_offset = (size_t) ((uintptr_t) input - (uintptr_t) deconvolution_op->last_input),
+ .zero = deconvolution_op->zero_buffer,
+ .packed_w = deconvolution_op->packed_weights,
+ .c = deconvolution_op->output,
+ .cm_stride = deconvolution_op->output_pixel_stride << log2_output_element_size,
+ .cn_stride = nr << log2_output_element_size,
+ .ga_stride = group_input_channels << log2_input_element_size,
+ .gw_stride = w_stride * round_up(group_output_channels, nr),
+ .gc_stride = group_output_channels << log2_output_element_size,
+ .ba_stride = input_height * input_width * deconvolution_op->input_pixel_stride << log2_input_element_size,
+ .bc_stride = output_size * deconvolution_op->output_pixel_stride << log2_output_element_size,
+ .log2_csize = log2_output_element_size,
+ .ukernel = deconvolution_op->ukernel.igemm.default_function,
+ };
+ if (output_size == 1 && deconvolution_op->ukernel.igemm.mr1_function != NULL) {
+ deconvolution_op->context.igemm.ukernel = deconvolution_op->ukernel.igemm.mr1_function;
+ }
+ memcpy(&deconvolution_op->context.igemm.params, params, sizeof(deconvolution_op->context.igemm.params));
+
+ size_t nc = group_output_channels;
+ if (num_threads > 1) {
+ const size_t num_other_tiles = groups * batch_size * divide_round_up(output_size, mr);
+ const size_t target_tiles_per_thread = 5;
+ const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
+ if (max_nc < nc) {
+ nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
+ }
+ }
+ if (groups == 1) {
+ deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ deconvolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_igemm;
+ deconvolution_op->compute.range[0] = batch_size;
+ deconvolution_op->compute.range[1] = output_size;
+ deconvolution_op->compute.range[2] = group_output_channels;
+ deconvolution_op->compute.tile[0] = mr;
+ deconvolution_op->compute.tile[1] = nc;
+ } else {
+ deconvolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
+ deconvolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_gigemm;
+ deconvolution_op->compute.range[0] = batch_size;
+ deconvolution_op->compute.range[1] = groups;
+ deconvolution_op->compute.range[2] = output_size;
+ deconvolution_op->compute.range[3] = group_output_channels;
+ deconvolution_op->compute.tile[0] = mr;
+ deconvolution_op->compute.tile[1] = nc;
+ }
+ deconvolution_op->state = xnn_run_state_ready;
+ return xnn_status_success;
+}
+
+static enum xnn_status setup_subconv2d_path(
+ xnn_operator_t deconvolution_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const void* input,
+ size_t output_height,
+ size_t output_width,
+ void* output,
+ uint32_t log2_input_element_size,
+ uint32_t log2_filter_element_size,
+ uint32_t bias_element_size,
+ uint32_t log2_output_element_size,
+ const void* params,
+ size_t num_threads)
+{
+ assert(deconvolution_op->ukernel.type == xnn_ukernel_type_subconv2d);
+
+ const size_t kernel_height = deconvolution_op->kernel_height;
+ const size_t kernel_width = deconvolution_op->kernel_width;
+ const size_t kernel_size = kernel_height * kernel_width;
+ const size_t stride_height = deconvolution_op->stride_height;
+ const size_t stride_width = deconvolution_op->stride_width;
+
+ const size_t groups = deconvolution_op->groups;
+ const size_t output_size = output_height * output_width;
+ const size_t mr = deconvolution_op->ukernel.igemm.mr;
+ const size_t indirection_buffer_size =
+ sizeof(void*) * kernel_size * output_height * stride_width * round_up(divide_round_up(output_width, stride_width), mr);
+
+ if (input_height != deconvolution_op->last_input_height ||
+ input_width != deconvolution_op->last_input_width)
+ {
+ const void** indirection_buffer = (const void**) realloc(deconvolution_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ deconvolution_op->indirection_buffer = indirection_buffer;
+ deconvolution_op->last_input = input;
+ deconvolution_op->last_input_height = input_height;
+ deconvolution_op->last_input_width = input_width;
+
+ xnn_indirection_init_subconv2d(deconvolution_op, mr, log2_input_element_size);
+
+ // Initialize subconvolution parameters which depend on output dimensions or MR.
+ struct subconvolution_params* subconvolution_params = deconvolution_op->subconvolution_buffer;
+ const size_t modulo_padding_top = deconvolution_op->padding_top % stride_height;
+ const size_t modulo_padding_left = deconvolution_op->padding_left % stride_width;
+ const size_t output_pixel_stride = deconvolution_op->output_pixel_stride << log2_output_element_size;
+ for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
+ for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
+ const size_t output_x_start = subtract_modulo(offset_x, modulo_padding_left, stride_width);
+ const size_t output_y_start = subtract_modulo(offset_y, modulo_padding_top, stride_height);
+ subconvolution_params->scaled_kernel_size = mr * subconvolution_params->indirection_x_stride;
+ subconvolution_params->slice_width = divide_round_up(output_width - output_x_start, stride_width);
+ subconvolution_params->slice_height = divide_round_up(output_height - output_y_start, stride_height);
+ subconvolution_params->output =
+ (void*) ((uintptr_t) output + ((output_y_start * output_width + output_x_start) * output_pixel_stride));
+ ++subconvolution_params;
+ }
+ }
+ }
+
+ const size_t group_input_channels = deconvolution_op->group_input_channels;
+ const size_t group_output_channels = deconvolution_op->group_output_channels;
+ const uint32_t nr = deconvolution_op->ukernel.igemm.nr;
+ const size_t w_stride = stride_height * stride_width * bias_element_size +
+ (round_up_po2(group_input_channels, deconvolution_op->ukernel.igemm.kr) * kernel_size << log2_filter_element_size);
+ deconvolution_op->context.subconv = (struct subconv_context) {
+ .subconvolution_params = deconvolution_op->subconvolution_buffer,
+ .kc = group_input_channels << log2_input_element_size,
+ .a_offset = (size_t) ((uintptr_t) input - (uintptr_t) deconvolution_op->last_input),
+ .zero = deconvolution_op->zero_buffer,
+ .cx_stride = stride_width * deconvolution_op->output_pixel_stride << log2_output_element_size,
+ .cy_stride = stride_height * output_width * deconvolution_op->output_pixel_stride << log2_output_element_size,
+ .cn_stride = nr << log2_output_element_size,
+ .ga_stride = group_input_channels << log2_input_element_size,
+ .gw_stride = w_stride * round_up(group_output_channels, nr),
+ .gc_stride = group_output_channels << log2_output_element_size,
+ .ba_stride = input_height * input_width * deconvolution_op->input_pixel_stride << log2_input_element_size,
+ .bc_stride = output_size * deconvolution_op->output_pixel_stride << log2_output_element_size,
+ .log2_csize = log2_output_element_size,
+ .ukernel = deconvolution_op->ukernel.igemm.default_function,
+ };
+ memcpy(&deconvolution_op->context.subconv.params, params, sizeof(deconvolution_op->context.subconv.params));
+
+ const size_t output_height_positions = divide_round_up(output_height, stride_height);
+ const size_t output_width_positions = divide_round_up(output_width, stride_width);
+
+ size_t nc = group_output_channels;
+ if (num_threads > 1) {
+ const size_t num_other_tiles = groups * stride_height * stride_width *
+ output_height_positions * divide_round_up(output_width_positions, mr);
+ const size_t target_tiles_per_thread = 5;
+ const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
+ if (max_nc < nc) {
+ nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
+ }
+ }
+
+ if (groups == 1) {
+ deconvolution_op->compute.type = xnn_parallelization_type_5d_tile_2d;
+ deconvolution_op->compute.task_5d_tile_2d = (pthreadpool_task_5d_tile_2d_t) xnn_compute_subconv2d;
+ deconvolution_op->compute.range[0] = batch_size;
+ deconvolution_op->compute.range[1] = stride_height * stride_width;
+ deconvolution_op->compute.range[2] = divide_round_up(output_height, stride_height);
+ deconvolution_op->compute.range[3] = divide_round_up(output_width, stride_width);
+ deconvolution_op->compute.range[4] = group_output_channels;
+ deconvolution_op->compute.tile[0] = mr;
+ deconvolution_op->compute.tile[1] = nc;
+ } else {
+ deconvolution_op->compute.type = xnn_parallelization_type_6d_tile_2d;
+ deconvolution_op->compute.task_6d_tile_2d = (pthreadpool_task_6d_tile_2d_t) xnn_compute_gsubconv2d;
+ deconvolution_op->compute.range[0] = batch_size;
+ deconvolution_op->compute.range[1] = groups;
+ deconvolution_op->compute.range[2] = stride_height * stride_width;
+ deconvolution_op->compute.range[3] = divide_round_up(output_height, stride_height);
+ deconvolution_op->compute.range[4] = divide_round_up(output_width, stride_width);
+ deconvolution_op->compute.range[5] = group_output_channels;
+ deconvolution_op->compute.tile[0] = mr;
+ deconvolution_op->compute.tile[1] = nc;
+ }
+
+ deconvolution_op->state = xnn_run_state_ready;
+ return xnn_status_success;
+}
+
+static enum xnn_status setup_deconvolution2d(
+ xnn_operator_t deconvolution_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const void* input,
+ void* output,
+ uint32_t log2_input_element_size,
+ uint32_t log2_filter_element_size,
+ uint32_t bias_element_size,
+ uint32_t log2_output_element_size,
+ const void* params,
+ size_t num_threads)
+{
+ deconvolution_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Deconvolution operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_width == 0 || input_height == 0) {
+ xnn_log_error(
+ "failed to setup Deconvolution with %zux%zu input: input dimensions must be non-zero",
+ input_width, input_height);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ deconvolution_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ deconvolution_op->batch_size = batch_size;
+ deconvolution_op->input_height = input_height;
+ deconvolution_op->input_width = input_width;
+ deconvolution_op->input = input;
+ deconvolution_op->output = output;
+
+ const size_t output_height = deconvolution_op->output_height = compute_output_dimension(
+ input_height, deconvolution_op->padding_top + deconvolution_op->padding_bottom,
+ deconvolution_op->adjustment_height, deconvolution_op->kernel_height, deconvolution_op->dilation_height, deconvolution_op->stride_height);
+ const size_t output_width = deconvolution_op->output_width = compute_output_dimension(
+ input_width, deconvolution_op->padding_left + deconvolution_op->padding_right,
+ deconvolution_op->adjustment_width, deconvolution_op->kernel_width, deconvolution_op->dilation_width, deconvolution_op->stride_width);
+
+ switch (deconvolution_op->ukernel.type) {
+ case xnn_ukernel_type_igemm:
+ return setup_conv_path(
+ deconvolution_op,
+ batch_size,
+ input_height, input_width, input,
+ output_height, output_width, output,
+ log2_input_element_size, log2_filter_element_size, bias_element_size, log2_output_element_size,
+ params, num_threads);
+ case xnn_ukernel_type_subconv2d:
+ return setup_subconv2d_path(
+ deconvolution_op,
+ batch_size,
+ input_height, input_width, input,
+ output_height, output_width, output,
+ log2_input_element_size, log2_filter_element_size, bias_element_size, log2_output_element_size,
+ params, num_threads);
+ default:
+ XNN_UNREACHABLE;
+ }
+}
+
+enum xnn_status xnn_setup_deconvolution2d_nhwc_q8(
+ xnn_operator_t deconvolution_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (deconvolution_op->type != xnn_operator_type_deconvolution_q8) {
+ xnn_log_error("failed to setup Deconvolution (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_deconvolution2d(
+ deconvolution_op,
+ batch_size, input_height, input_width,
+ input, output,
+ 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
+ 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
+ sizeof(int32_t) /* sizeof(bias element) */,
+ 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
+ &deconvolution_op->q8_gemm_params,
+ pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_deconvolution2d_nhwc_f32(
+ xnn_operator_t deconvolution_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (deconvolution_op->type != xnn_operator_type_deconvolution_f32) {
+ xnn_log_error("failed to setup Deconvolution (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_deconvolution2d(
+ deconvolution_op,
+ batch_size, input_height, input_width,
+ input, output,
+ 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
+ 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
+ sizeof(float) /* sizeof(bias element) */,
+ 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
+ &deconvolution_op->f32_output_params,
+ pthreadpool_get_threads_count(threadpool));
+}
diff --git a/src/f16-gemm/4x8-neonfp16arith-ld64.c b/src/f16-gemm/4x8-neonfp16arith-ld64.c
new file mode 100644
index 0000000..53ff89e
--- /dev/null
+++ b/src/f16-gemm/4x8-neonfp16arith-ld64.c
@@ -0,0 +1,236 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-gemm/neonfp16arith-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f16_gemm_ukernel_4x8__neonfp16arith_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const void* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ void* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(__fp16) == 0);
+
+ const __fp16* a0 = a;
+ __fp16* c0 = c;
+ const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride);
+ __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride);
+ __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride);
+ __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+ float16x8_t vacc1x01234567 = vacc0x01234567;
+ float16x8_t vacc2x01234567 = vacc0x01234567;
+ float16x8_t vacc3x01234567 = vacc0x01234567;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(__fp16)) {
+ const float16x4_t va0 = vld1_f16(a0); a0 += 4;
+ const float16x4_t va1 = vld1_f16(a1); a1 += 4;
+ const float16x4_t va2 = vld1_f16(a2); a2 += 4;
+ const float16x4_t va3 = vld1_f16(a3); a3 += 4;
+
+ const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0);
+ #else
+ const float16x8_t va0c0 = vdupq_lane_f16(va0, 0);
+ const float16x8_t va1c0 = vdupq_lane_f16(va1, 0);
+ const float16x8_t va2c0 = vdupq_lane_f16(va2, 0);
+ const float16x8_t va3c0 = vdupq_lane_f16(va3, 0);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0);
+ #endif
+ const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1);
+ #else
+ const float16x8_t va0c1 = vdupq_lane_f16(va0, 1);
+ const float16x8_t va1c1 = vdupq_lane_f16(va1, 1);
+ const float16x8_t va2c1 = vdupq_lane_f16(va2, 1);
+ const float16x8_t va3c1 = vdupq_lane_f16(va3, 1);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1);
+ #endif
+ const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2);
+ #else
+ const float16x8_t va0c2 = vdupq_lane_f16(va0, 2);
+ const float16x8_t va1c2 = vdupq_lane_f16(va1, 2);
+ const float16x8_t va2c2 = vdupq_lane_f16(va2, 2);
+ const float16x8_t va3c2 = vdupq_lane_f16(va3, 2);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2);
+ #endif
+ const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3);
+ #else
+ const float16x8_t va0c3 = vdupq_lane_f16(va0, 3);
+ const float16x8_t va1c3 = vdupq_lane_f16(va1, 3);
+ const float16x8_t va2c3 = vdupq_lane_f16(va2, 3);
+ const float16x8_t va3c3 = vdupq_lane_f16(va3, 3);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3);
+ #endif
+
+ k -= 4 * sizeof(__fp16);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1;
+ const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1;
+ const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1;
+ const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1;
+
+ const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567);
+
+ k -= sizeof(__fp16);
+ } while (k != 0);
+ }
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale);
+ vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale);
+ vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale);
+ vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale);
+
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ vacc0x01234567 = vminq_f16(vacc0x01234567, vmax);
+ vacc1x01234567 = vminq_f16(vacc1x01234567, vmax);
+ vacc2x01234567 = vminq_f16(vacc2x01234567, vmax);
+ vacc3x01234567 = vminq_f16(vacc3x01234567, vmax);
+
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+ vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin);
+ vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin);
+ vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin);
+ vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f16(c0, vacc0x01234567);
+ c0 = (__fp16*) ((uintptr_t) c0 + cn_stride);
+ vst1q_f16(c1, vacc1x01234567);
+ c1 = (__fp16*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f16(c2, vacc2x01234567);
+ c2 = (__fp16*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f16(c3, vacc3x01234567);
+ c3 = (__fp16*) ((uintptr_t) c3 + cn_stride);
+
+ a0 = (const __fp16*) ((uintptr_t) a0 - kc);
+ a1 = (const __fp16*) ((uintptr_t) a1 - kc);
+ a2 = (const __fp16*) ((uintptr_t) a2 - kc);
+ a3 = (const __fp16*) ((uintptr_t) a3 - kc);
+
+ nc -= 8;
+ } else {
+ float16x4_t vacc0x0123 = vget_low_f32(vacc0x01234567);
+ float16x4_t vacc1x0123 = vget_low_f32(vacc1x01234567);
+ float16x4_t vacc2x0123 = vget_low_f32(vacc2x01234567);
+ float16x4_t vacc3x0123 = vget_low_f32(vacc3x01234567);
+ if (nc & 4) {
+ vst1_f16(c0, vacc0x0123); c0 += 4;
+ vst1_f16(c1, vacc1x0123); c1 += 4;
+ vst1_f16(c2, vacc2x0123); c2 += 4;
+ vst1_f16(c3, vacc3x0123); c3 += 4;
+
+ vacc0x0123 = vget_high_f32(vacc0x01234567);
+ vacc1x0123 = vget_high_f32(vacc1x01234567);
+ vacc2x0123 = vget_high_f32(vacc2x01234567);
+ vacc3x0123 = vget_high_f32(vacc3x01234567);
+ }
+ if (nc & 2) {
+ vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2;
+
+ vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2);
+ vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2);
+ vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2);
+ vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2);
+ }
+ if (nc & 1) {
+ vst1_lane_f16(c0, vacc0x0123, 0);
+ vst1_lane_f16(c1, vacc1x0123, 0);
+ vst1_lane_f16(c2, vacc2x0123, 0);
+ vst1_lane_f16(c3, vacc3x0123, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f16-gemm/6x8-neonfp16arith-ld64.c b/src/f16-gemm/6x8-neonfp16arith-ld64.c
new file mode 100644
index 0000000..43aa99c
--- /dev/null
+++ b/src/f16-gemm/6x8-neonfp16arith-ld64.c
@@ -0,0 +1,304 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-gemm/neonfp16arith-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f16_gemm_ukernel_6x8__neonfp16arith_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const void* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ void* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(__fp16) == 0);
+
+ const __fp16* a0 = a;
+ __fp16* c0 = c;
+ const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride);
+ __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride);
+ __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride);
+ __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const __fp16* a4 = (const __fp16*) ((uintptr_t) a3 + a_stride);
+ __fp16* c4 = (__fp16*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const __fp16* a5 = (const __fp16*) ((uintptr_t) a4 + a_stride);
+ __fp16* c5 = (__fp16*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+ float16x8_t vacc1x01234567 = vacc0x01234567;
+ float16x8_t vacc2x01234567 = vacc0x01234567;
+ float16x8_t vacc3x01234567 = vacc0x01234567;
+ float16x8_t vacc4x01234567 = vacc0x01234567;
+ float16x8_t vacc5x01234567 = vacc0x01234567;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(__fp16)) {
+ const float16x4_t va0 = vld1_f16(a0); a0 += 4;
+ const float16x4_t va1 = vld1_f16(a1); a1 += 4;
+ const float16x4_t va2 = vld1_f16(a2); a2 += 4;
+ const float16x4_t va3 = vld1_f16(a3); a3 += 4;
+ const float16x4_t va4 = vld1_f16(a4); a4 += 4;
+ const float16x4_t va5 = vld1_f16(a5); a5 += 4;
+
+ const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0);
+ vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c0, va4, 0);
+ vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c0, va5, 0);
+ #else
+ const float16x8_t va0c0 = vdupq_lane_f16(va0, 0);
+ const float16x8_t va1c0 = vdupq_lane_f16(va1, 0);
+ const float16x8_t va2c0 = vdupq_lane_f16(va2, 0);
+ const float16x8_t va3c0 = vdupq_lane_f16(va3, 0);
+ const float16x8_t va4c0 = vdupq_lane_f16(va4, 0);
+ const float16x8_t va5c0 = vdupq_lane_f16(va5, 0);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c0, vb01234567c0);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c0, vb01234567c0);
+ #endif
+ const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1);
+ vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c1, va4, 1);
+ vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c1, va5, 1);
+ #else
+ const float16x8_t va0c1 = vdupq_lane_f16(va0, 1);
+ const float16x8_t va1c1 = vdupq_lane_f16(va1, 1);
+ const float16x8_t va2c1 = vdupq_lane_f16(va2, 1);
+ const float16x8_t va3c1 = vdupq_lane_f16(va3, 1);
+ const float16x8_t va4c1 = vdupq_lane_f16(va4, 1);
+ const float16x8_t va5c1 = vdupq_lane_f16(va5, 1);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c1, vb01234567c1);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c1, vb01234567c1);
+ #endif
+ const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2);
+ vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c2, va4, 2);
+ vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c2, va5, 2);
+ #else
+ const float16x8_t va0c2 = vdupq_lane_f16(va0, 2);
+ const float16x8_t va1c2 = vdupq_lane_f16(va1, 2);
+ const float16x8_t va2c2 = vdupq_lane_f16(va2, 2);
+ const float16x8_t va3c2 = vdupq_lane_f16(va3, 2);
+ const float16x8_t va4c2 = vdupq_lane_f16(va4, 2);
+ const float16x8_t va5c2 = vdupq_lane_f16(va5, 2);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c2, vb01234567c2);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c2, vb01234567c2);
+ #endif
+ const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3);
+ vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c3, va4, 3);
+ vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c3, va5, 3);
+ #else
+ const float16x8_t va0c3 = vdupq_lane_f16(va0, 3);
+ const float16x8_t va1c3 = vdupq_lane_f16(va1, 3);
+ const float16x8_t va2c3 = vdupq_lane_f16(va2, 3);
+ const float16x8_t va3c3 = vdupq_lane_f16(va3, 3);
+ const float16x8_t va4c3 = vdupq_lane_f16(va4, 3);
+ const float16x8_t va5c3 = vdupq_lane_f16(va5, 3);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c3, vb01234567c3);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c3, vb01234567c3);
+ #endif
+
+ k -= 4 * sizeof(__fp16);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1;
+ const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1;
+ const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1;
+ const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1;
+ const float16x8_t va4 = vld1q_dup_f16(a4); a4 += 1;
+ const float16x8_t va5 = vld1q_dup_f16(a5); a5 += 1;
+
+ const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4, vb01234567);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5, vb01234567);
+
+ k -= sizeof(__fp16);
+ } while (k != 0);
+ }
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale);
+ vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale);
+ vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale);
+ vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale);
+ vacc4x01234567 = vmulq_f16(vacc4x01234567, vscale);
+ vacc5x01234567 = vmulq_f16(vacc5x01234567, vscale);
+
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ vacc0x01234567 = vminq_f16(vacc0x01234567, vmax);
+ vacc1x01234567 = vminq_f16(vacc1x01234567, vmax);
+ vacc2x01234567 = vminq_f16(vacc2x01234567, vmax);
+ vacc3x01234567 = vminq_f16(vacc3x01234567, vmax);
+ vacc4x01234567 = vminq_f16(vacc4x01234567, vmax);
+ vacc5x01234567 = vminq_f16(vacc5x01234567, vmax);
+
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+ vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin);
+ vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin);
+ vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin);
+ vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin);
+ vacc4x01234567 = vmaxq_f16(vacc4x01234567, vmin);
+ vacc5x01234567 = vmaxq_f16(vacc5x01234567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f16(c0, vacc0x01234567);
+ c0 = (__fp16*) ((uintptr_t) c0 + cn_stride);
+ vst1q_f16(c1, vacc1x01234567);
+ c1 = (__fp16*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f16(c2, vacc2x01234567);
+ c2 = (__fp16*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f16(c3, vacc3x01234567);
+ c3 = (__fp16*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f16(c4, vacc4x01234567);
+ c4 = (__fp16*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f16(c5, vacc5x01234567);
+ c5 = (__fp16*) ((uintptr_t) c5 + cn_stride);
+
+ a0 = (const __fp16*) ((uintptr_t) a0 - kc);
+ a1 = (const __fp16*) ((uintptr_t) a1 - kc);
+ a2 = (const __fp16*) ((uintptr_t) a2 - kc);
+ a3 = (const __fp16*) ((uintptr_t) a3 - kc);
+ a4 = (const __fp16*) ((uintptr_t) a4 - kc);
+ a5 = (const __fp16*) ((uintptr_t) a5 - kc);
+
+ nc -= 8;
+ } else {
+ float16x4_t vacc0x0123 = vget_low_f32(vacc0x01234567);
+ float16x4_t vacc1x0123 = vget_low_f32(vacc1x01234567);
+ float16x4_t vacc2x0123 = vget_low_f32(vacc2x01234567);
+ float16x4_t vacc3x0123 = vget_low_f32(vacc3x01234567);
+ float16x4_t vacc4x0123 = vget_low_f32(vacc4x01234567);
+ float16x4_t vacc5x0123 = vget_low_f32(vacc5x01234567);
+ if (nc & 4) {
+ vst1_f16(c0, vacc0x0123); c0 += 4;
+ vst1_f16(c1, vacc1x0123); c1 += 4;
+ vst1_f16(c2, vacc2x0123); c2 += 4;
+ vst1_f16(c3, vacc3x0123); c3 += 4;
+ vst1_f16(c4, vacc4x0123); c4 += 4;
+ vst1_f16(c5, vacc5x0123); c5 += 4;
+
+ vacc0x0123 = vget_high_f32(vacc0x01234567);
+ vacc1x0123 = vget_high_f32(vacc1x01234567);
+ vacc2x0123 = vget_high_f32(vacc2x01234567);
+ vacc3x0123 = vget_high_f32(vacc3x01234567);
+ vacc4x0123 = vget_high_f32(vacc4x01234567);
+ vacc5x0123 = vget_high_f32(vacc5x01234567);
+ }
+ if (nc & 2) {
+ vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpret_u32_f16(vacc4x0123), 0); c4 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpret_u32_f16(vacc5x0123), 0); c5 += 2;
+
+ vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2);
+ vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2);
+ vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2);
+ vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2);
+ vacc4x0123 = vext_f16(vacc4x0123, vacc4x0123, 2);
+ vacc5x0123 = vext_f16(vacc5x0123, vacc5x0123, 2);
+ }
+ if (nc & 1) {
+ vst1_lane_f16(c0, vacc0x0123, 0);
+ vst1_lane_f16(c1, vacc1x0123, 0);
+ vst1_lane_f16(c2, vacc2x0123, 0);
+ vst1_lane_f16(c3, vacc3x0123, 0);
+ vst1_lane_f16(c4, vacc4x0123, 0);
+ vst1_lane_f16(c5, vacc5x0123, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f16-gemm/8x8-neonfp16arith-ld64.c b/src/f16-gemm/8x8-neonfp16arith-ld64.c
new file mode 100644
index 0000000..5caa924
--- /dev/null
+++ b/src/f16-gemm/8x8-neonfp16arith-ld64.c
@@ -0,0 +1,372 @@
+// Auto-generated file. Do not edit!
+// Template: src/f16-gemm/neonfp16arith-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f16_gemm_ukernel_8x8__neonfp16arith_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const void* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ void* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 8);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(__fp16) == 0);
+
+ const __fp16* a0 = a;
+ __fp16* c0 = c;
+ const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride);
+ __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride);
+ __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride);
+ __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const __fp16* a4 = (const __fp16*) ((uintptr_t) a3 + a_stride);
+ __fp16* c4 = (__fp16*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const __fp16* a5 = (const __fp16*) ((uintptr_t) a4 + a_stride);
+ __fp16* c5 = (__fp16*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+ const __fp16* a6 = (const __fp16*) ((uintptr_t) a5 + a_stride);
+ __fp16* c6 = (__fp16*) ((uintptr_t) c5 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 6) {
+ a6 = a5;
+ c6 = c5;
+ }
+ const __fp16* a7 = (const __fp16*) ((uintptr_t) a6 + a_stride);
+ __fp16* c7 = (__fp16*) ((uintptr_t) c6 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 8) {
+ a7 = a6;
+ c7 = c6;
+ }
+
+ do {
+ float16x8_t vacc0x01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+ float16x8_t vacc1x01234567 = vacc0x01234567;
+ float16x8_t vacc2x01234567 = vacc0x01234567;
+ float16x8_t vacc3x01234567 = vacc0x01234567;
+ float16x8_t vacc4x01234567 = vacc0x01234567;
+ float16x8_t vacc5x01234567 = vacc0x01234567;
+ float16x8_t vacc6x01234567 = vacc0x01234567;
+ float16x8_t vacc7x01234567 = vacc0x01234567;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(__fp16)) {
+ const float16x4_t va0 = vld1_f16(a0); a0 += 4;
+ const float16x4_t va1 = vld1_f16(a1); a1 += 4;
+ const float16x4_t va2 = vld1_f16(a2); a2 += 4;
+ const float16x4_t va3 = vld1_f16(a3); a3 += 4;
+ const float16x4_t va4 = vld1_f16(a4); a4 += 4;
+ const float16x4_t va5 = vld1_f16(a5); a5 += 4;
+ const float16x4_t va6 = vld1_f16(a6); a6 += 4;
+ const float16x4_t va7 = vld1_f16(a7); a7 += 4;
+
+ const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0);
+ vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c0, va4, 0);
+ vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c0, va5, 0);
+ vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c0, va6, 0);
+ vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c0, va7, 0);
+ #else
+ const float16x8_t va0c0 = vdupq_lane_f16(va0, 0);
+ const float16x8_t va1c0 = vdupq_lane_f16(va1, 0);
+ const float16x8_t va2c0 = vdupq_lane_f16(va2, 0);
+ const float16x8_t va3c0 = vdupq_lane_f16(va3, 0);
+ const float16x8_t va4c0 = vdupq_lane_f16(va4, 0);
+ const float16x8_t va5c0 = vdupq_lane_f16(va5, 0);
+ const float16x8_t va6c0 = vdupq_lane_f16(va6, 0);
+ const float16x8_t va7c0 = vdupq_lane_f16(va7, 0);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c0, vb01234567c0);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c0, vb01234567c0);
+ vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c0, vb01234567c0);
+ vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c0, vb01234567c0);
+ #endif
+ const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1);
+ vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c1, va4, 1);
+ vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c1, va5, 1);
+ vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c1, va6, 1);
+ vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c1, va7, 1);
+ #else
+ const float16x8_t va0c1 = vdupq_lane_f16(va0, 1);
+ const float16x8_t va1c1 = vdupq_lane_f16(va1, 1);
+ const float16x8_t va2c1 = vdupq_lane_f16(va2, 1);
+ const float16x8_t va3c1 = vdupq_lane_f16(va3, 1);
+ const float16x8_t va4c1 = vdupq_lane_f16(va4, 1);
+ const float16x8_t va5c1 = vdupq_lane_f16(va5, 1);
+ const float16x8_t va6c1 = vdupq_lane_f16(va6, 1);
+ const float16x8_t va7c1 = vdupq_lane_f16(va7, 1);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c1, vb01234567c1);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c1, vb01234567c1);
+ vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c1, vb01234567c1);
+ vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c1, vb01234567c1);
+ #endif
+ const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2);
+ vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c2, va4, 2);
+ vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c2, va5, 2);
+ vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c2, va6, 2);
+ vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c2, va7, 2);
+ #else
+ const float16x8_t va0c2 = vdupq_lane_f16(va0, 2);
+ const float16x8_t va1c2 = vdupq_lane_f16(va1, 2);
+ const float16x8_t va2c2 = vdupq_lane_f16(va2, 2);
+ const float16x8_t va3c2 = vdupq_lane_f16(va3, 2);
+ const float16x8_t va4c2 = vdupq_lane_f16(va4, 2);
+ const float16x8_t va5c2 = vdupq_lane_f16(va5, 2);
+ const float16x8_t va6c2 = vdupq_lane_f16(va6, 2);
+ const float16x8_t va7c2 = vdupq_lane_f16(va7, 2);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c2, vb01234567c2);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c2, vb01234567c2);
+ vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c2, vb01234567c2);
+ vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c2, vb01234567c2);
+ #endif
+ const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3);
+ vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3);
+ vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3);
+ vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3);
+ vacc4x01234567 = vfmaq_lane_f16(vacc4x01234567, vb01234567c3, va4, 3);
+ vacc5x01234567 = vfmaq_lane_f16(vacc5x01234567, vb01234567c3, va5, 3);
+ vacc6x01234567 = vfmaq_lane_f16(vacc6x01234567, vb01234567c3, va6, 3);
+ vacc7x01234567 = vfmaq_lane_f16(vacc7x01234567, vb01234567c3, va7, 3);
+ #else
+ const float16x8_t va0c3 = vdupq_lane_f16(va0, 3);
+ const float16x8_t va1c3 = vdupq_lane_f16(va1, 3);
+ const float16x8_t va2c3 = vdupq_lane_f16(va2, 3);
+ const float16x8_t va3c3 = vdupq_lane_f16(va3, 3);
+ const float16x8_t va4c3 = vdupq_lane_f16(va4, 3);
+ const float16x8_t va5c3 = vdupq_lane_f16(va5, 3);
+ const float16x8_t va6c3 = vdupq_lane_f16(va6, 3);
+ const float16x8_t va7c3 = vdupq_lane_f16(va7, 3);
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4c3, vb01234567c3);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5c3, vb01234567c3);
+ vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6c3, vb01234567c3);
+ vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7c3, vb01234567c3);
+ #endif
+
+ k -= 4 * sizeof(__fp16);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1;
+ const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1;
+ const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1;
+ const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1;
+ const float16x8_t va4 = vld1q_dup_f16(a4); a4 += 1;
+ const float16x8_t va5 = vld1q_dup_f16(a5); a5 += 1;
+ const float16x8_t va6 = vld1q_dup_f16(a6); a6 += 1;
+ const float16x8_t va7 = vld1q_dup_f16(a7); a7 += 1;
+
+ const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567);
+ vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567);
+ vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567);
+ vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567);
+ vacc4x01234567 = vfmaq_f16(vacc4x01234567, va4, vb01234567);
+ vacc5x01234567 = vfmaq_f16(vacc5x01234567, va5, vb01234567);
+ vacc6x01234567 = vfmaq_f16(vacc6x01234567, va6, vb01234567);
+ vacc7x01234567 = vfmaq_f16(vacc7x01234567, va7, vb01234567);
+
+ k -= sizeof(__fp16);
+ } while (k != 0);
+ }
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale);
+ vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale);
+ vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale);
+ vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale);
+ vacc4x01234567 = vmulq_f16(vacc4x01234567, vscale);
+ vacc5x01234567 = vmulq_f16(vacc5x01234567, vscale);
+ vacc6x01234567 = vmulq_f16(vacc6x01234567, vscale);
+ vacc7x01234567 = vmulq_f16(vacc7x01234567, vscale);
+
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ vacc0x01234567 = vminq_f16(vacc0x01234567, vmax);
+ vacc1x01234567 = vminq_f16(vacc1x01234567, vmax);
+ vacc2x01234567 = vminq_f16(vacc2x01234567, vmax);
+ vacc3x01234567 = vminq_f16(vacc3x01234567, vmax);
+ vacc4x01234567 = vminq_f16(vacc4x01234567, vmax);
+ vacc5x01234567 = vminq_f16(vacc5x01234567, vmax);
+ vacc6x01234567 = vminq_f16(vacc6x01234567, vmax);
+ vacc7x01234567 = vminq_f16(vacc7x01234567, vmax);
+
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+ vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin);
+ vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin);
+ vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin);
+ vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin);
+ vacc4x01234567 = vmaxq_f16(vacc4x01234567, vmin);
+ vacc5x01234567 = vmaxq_f16(vacc5x01234567, vmin);
+ vacc6x01234567 = vmaxq_f16(vacc6x01234567, vmin);
+ vacc7x01234567 = vmaxq_f16(vacc7x01234567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f16(c0, vacc0x01234567);
+ c0 = (__fp16*) ((uintptr_t) c0 + cn_stride);
+ vst1q_f16(c1, vacc1x01234567);
+ c1 = (__fp16*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f16(c2, vacc2x01234567);
+ c2 = (__fp16*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f16(c3, vacc3x01234567);
+ c3 = (__fp16*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f16(c4, vacc4x01234567);
+ c4 = (__fp16*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f16(c5, vacc5x01234567);
+ c5 = (__fp16*) ((uintptr_t) c5 + cn_stride);
+ vst1q_f16(c6, vacc6x01234567);
+ c6 = (__fp16*) ((uintptr_t) c6 + cn_stride);
+ vst1q_f16(c7, vacc7x01234567);
+ c7 = (__fp16*) ((uintptr_t) c7 + cn_stride);
+
+ a0 = (const __fp16*) ((uintptr_t) a0 - kc);
+ a1 = (const __fp16*) ((uintptr_t) a1 - kc);
+ a2 = (const __fp16*) ((uintptr_t) a2 - kc);
+ a3 = (const __fp16*) ((uintptr_t) a3 - kc);
+ a4 = (const __fp16*) ((uintptr_t) a4 - kc);
+ a5 = (const __fp16*) ((uintptr_t) a5 - kc);
+ a6 = (const __fp16*) ((uintptr_t) a6 - kc);
+ a7 = (const __fp16*) ((uintptr_t) a7 - kc);
+
+ nc -= 8;
+ } else {
+ float16x4_t vacc0x0123 = vget_low_f32(vacc0x01234567);
+ float16x4_t vacc1x0123 = vget_low_f32(vacc1x01234567);
+ float16x4_t vacc2x0123 = vget_low_f32(vacc2x01234567);
+ float16x4_t vacc3x0123 = vget_low_f32(vacc3x01234567);
+ float16x4_t vacc4x0123 = vget_low_f32(vacc4x01234567);
+ float16x4_t vacc5x0123 = vget_low_f32(vacc5x01234567);
+ float16x4_t vacc6x0123 = vget_low_f32(vacc6x01234567);
+ float16x4_t vacc7x0123 = vget_low_f32(vacc7x01234567);
+ if (nc & 4) {
+ vst1_f16(c0, vacc0x0123); c0 += 4;
+ vst1_f16(c1, vacc1x0123); c1 += 4;
+ vst1_f16(c2, vacc2x0123); c2 += 4;
+ vst1_f16(c3, vacc3x0123); c3 += 4;
+ vst1_f16(c4, vacc4x0123); c4 += 4;
+ vst1_f16(c5, vacc5x0123); c5 += 4;
+ vst1_f16(c6, vacc6x0123); c6 += 4;
+ vst1_f16(c7, vacc7x0123); c7 += 4;
+
+ vacc0x0123 = vget_high_f32(vacc0x01234567);
+ vacc1x0123 = vget_high_f32(vacc1x01234567);
+ vacc2x0123 = vget_high_f32(vacc2x01234567);
+ vacc3x0123 = vget_high_f32(vacc3x01234567);
+ vacc4x0123 = vget_high_f32(vacc4x01234567);
+ vacc5x0123 = vget_high_f32(vacc5x01234567);
+ vacc6x0123 = vget_high_f32(vacc6x01234567);
+ vacc7x0123 = vget_high_f32(vacc7x01234567);
+ }
+ if (nc & 2) {
+ vst1_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpret_u32_f16(vacc4x0123), 0); c4 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpret_u32_f16(vacc5x0123), 0); c5 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c6, 1), vreinterpret_u32_f16(vacc6x0123), 0); c6 += 2;
+ vst1_lane_u32(__builtin_assume_aligned(c7, 1), vreinterpret_u32_f16(vacc7x0123), 0); c7 += 2;
+
+ vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2);
+ vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2);
+ vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2);
+ vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2);
+ vacc4x0123 = vext_f16(vacc4x0123, vacc4x0123, 2);
+ vacc5x0123 = vext_f16(vacc5x0123, vacc5x0123, 2);
+ vacc6x0123 = vext_f16(vacc6x0123, vacc6x0123, 2);
+ vacc7x0123 = vext_f16(vacc7x0123, vacc7x0123, 2);
+ }
+ if (nc & 1) {
+ vst1_lane_f16(c0, vacc0x0123, 0);
+ vst1_lane_f16(c1, vacc1x0123, 0);
+ vst1_lane_f16(c2, vacc2x0123, 0);
+ vst1_lane_f16(c3, vacc3x0123, 0);
+ vst1_lane_f16(c4, vacc4x0123, 0);
+ vst1_lane_f16(c5, vacc5x0123, 0);
+ vst1_lane_f16(c6, vacc6x0123, 0);
+ vst1_lane_f16(c7, vacc7x0123, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f16-gemm/neonfp16arith-ld64.c.in b/src/f16-gemm/neonfp16arith-ld64.c.in
new file mode 100644
index 0000000..3e8c4ea
--- /dev/null
+++ b/src/f16-gemm/neonfp16arith-ld64.c.in
@@ -0,0 +1,164 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 8 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f16_gemm_ukernel_${MR}x${NR}__neonfp16arith_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const void* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ void* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const struct xnn_f16_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(__fp16) == 0);
+
+ const __fp16* a0 = a;
+ __fp16* c0 = c;
+ $for M in range(1, MR):
+ const __fp16* a${M} = (const __fp16*) ((uintptr_t) a${M-1} + a_stride);
+ __fp16* c${M} = (__fp16*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $for N in range(0, NR, 8):
+ float16x8_t vacc0x${ABC[N:N+8]} = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+ $for M in range(1, MR):
+ $for N in range(0, NR, 8):
+ float16x8_t vacc${M}x${ABC[N:N+8]} = vacc0x${ABC[N:N+8]};
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(__fp16)) {
+ $for M in range(MR):
+ const float16x4_t va${M} = vld1_f16(a${M}); a${M} += 4;
+
+ $for L in range(4):
+ $for N in range(0, NR, 8):
+ const float16x8_t vb${ABC[N:N+8]}c${L} = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ #if defined(__aarch64__)
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+8]} = vfmaq_lane_f16(vacc${M}x${ABC[N:N+8]}, vb${ABC[N:N+8]}c${L}, va${M}, ${L});
+ #else
+ $for M in range(MR):
+ const float16x8_t va${M}c${L} = vdupq_lane_f16(va${M}, ${L});
+
+ $for N in range(0, NR, 8):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+8]} = vfmaq_f16(vacc${M}x${ABC[N:N+8]}, va${M}c${L}, vb${ABC[N:N+8]}c${L});
+ #endif
+
+ k -= 4 * sizeof(__fp16);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ $for M in range(MR):
+ const float16x8_t va${M} = vld1q_dup_f16(a${M}); a${M} += 1;
+
+ $for N in range(0, NR, 8):
+ const float16x8_t vb${ABC[N:N+8]} = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
+
+ $for N in range(0, NR, 8):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+8]} = vfmaq_f16(vacc${M}x${ABC[N:N+8]}, va${M}, vb${ABC[N:N+8]});
+
+ k -= sizeof(__fp16);
+ } while (k != 0);
+ }
+
+ const float16x8_t vscale = vld1q_dup_f16((const __fp16*) ¶ms->scale);
+ $for N in range(0, NR, 8):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+8]} = vmulq_f16(vacc${M}x${ABC[N:N+8]}, vscale);
+
+ const float16x8_t vmax = vld1q_dup_f16((const __fp16*) ¶ms->max);
+ $for N in range(0, NR, 8):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+8]} = vminq_f16(vacc${M}x${ABC[N:N+8]}, vmax);
+
+ const float16x8_t vmin = vld1q_dup_f16((const __fp16*) ¶ms->min);
+ $for N in range(0, NR, 8):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+8]} = vmaxq_f16(vacc${M}x${ABC[N:N+8]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in range(MR):
+ vst1q_f16(c${M}, vacc${M}x${ABC[0:8]});
+ $for N in range(8, NR, 8):
+ vst1q_f16(c${M} + ${N}, vacc${M}x${ABC[N:N+8]});
+ c${M} = (__fp16*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in range(MR):
+ a${M} = (const __fp16*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 3:
+ $for N in range(0, 1 << LOG2N, 8):
+ $for M in range(MR):
+ vst1q_f16(c${M}, vacc${M}x${ABC[N:N+8]}); c${M} += 8;
+
+ $for M in range(MR):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+ $elif LOG2N == 2:
+ $for M in range(MR):
+ vst1_f16(c${M}, vacc${M}x${ABC[0:4]}); c${M} += 4;
+
+ $for M in range(MR):
+ vacc${M}x${ABC[0:4]} = vget_high_f32(vacc${M}x${ABC[0:8]});
+ $elif LOG2N == 1:
+ $for M in range(MR):
+ vst1_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpret_u32_f16(vacc${M}x${ABC[0:4]}), 0); c${M} += 2;
+
+ $for M in range(MR):
+ vacc${M}x${ABC[0:4]} = vext_f16(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]}, 2);
+ $elif LOG2N == 0:
+ $for M in range(MR):
+ vst1_lane_f16(c${M}, vacc${M}x${ABC[0:4]}, 0);
+ }
+ $if LOG2N == 3:
+ $for M in range(MR):
+ float16x4_t vacc${M}x${ABC[0:4]} = vget_low_f32(vacc${M}x${ABC[0:8]});
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-argmaxpool/mp9p8q-psimd.c b/src/f32-argmaxpool/mp9p8q-psimd.c
new file mode 100644
index 0000000..cb8a4f5
--- /dev/null
+++ b/src/f32-argmaxpool/mp9p8q-psimd.c
@@ -0,0 +1,377 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/argmaxpool.h>
+
+
+void xnn_f32_argmaxpool_ukernel_mp9p8q__psimd(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* acc_buffer,
+ uint32_t* index_buffer,
+ float* output,
+ uint32_t* index,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.max);
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.min);
+ do {
+ {
+ float* ab = acc_buffer;
+ uint32_t* ib = index_buffer;
+
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ for (size_t k = 0; k < kc; k += 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ i8 += 4;
+
+ psimd_f32 vmax = vi0;
+ psimd_u32 vidx = psimd_splat_u32(0);
+
+ const psimd_s32 vm1 = vi1 > vmax;
+ vmax = psimd_blend_f32(vm1, vi1, vmax);
+ vidx = psimd_blend_u32(vm1, psimd_splat_u32(1), vidx);
+
+ const psimd_s32 vm2 = vi2 > vmax;
+ vmax = psimd_blend_f32(vm2, vi2, vmax);
+ vidx = psimd_blend_u32(vm2, psimd_splat_u32(2), vidx);
+
+ const psimd_s32 vm3 = vi3 > vmax;
+ vmax = psimd_blend_f32(vm3, vi3, vmax);
+ vidx = psimd_blend_u32(vm3, psimd_splat_u32(3), vidx);
+
+ const psimd_s32 vm4 = vi4 > vmax;
+ vmax = psimd_blend_f32(vm4, vi4, vmax);
+ vidx = psimd_blend_u32(vm4, psimd_splat_u32(4), vidx);
+
+ const psimd_s32 vm5 = vi5 > vmax;
+ vmax = psimd_blend_f32(vm5, vi5, vmax);
+ vidx = psimd_blend_u32(vm5, psimd_splat_u32(5), vidx);
+
+ const psimd_s32 vm6 = vi6 > vmax;
+ vmax = psimd_blend_f32(vm6, vi6, vmax);
+ vidx = psimd_blend_u32(vm6, psimd_splat_u32(6), vidx);
+
+ const psimd_s32 vm7 = vi7 > vmax;
+ vmax = psimd_blend_f32(vm7, vi7, vmax);
+ vidx = psimd_blend_u32(vm7, psimd_splat_u32(7), vidx);
+
+ const psimd_s32 vm8 = vi8 > vmax;
+ vmax = psimd_blend_f32(vm8, vi8, vmax);
+ vidx = psimd_blend_u32(vm8, psimd_splat_u32(8), vidx);
+
+ psimd_store_f32(ab, vmax);
+ ab += 4;
+ psimd_store_u32(ib, vidx);
+ ib += 4;
+ }
+ }
+ const psimd_u32 v1 = psimd_splat_u32(1);
+ const psimd_u32 v8 = psimd_splat_u32(8);
+ psimd_u32 vidx0 = psimd_add_u32(v1, v8);
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* ab = acc_buffer;
+ uint32_t* ib = index_buffer;
+
+ for (size_t k = 0; k < kc; k += 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+
+ psimd_f32 vmax = psimd_load_f32(ab);
+ psimd_u32 vidx = psimd_load_u32(ib);
+
+ const psimd_s32 vm0 = vi0 > vmax;
+ vmax = psimd_blend_f32(vm0, vi0, vmax);
+ vidx = psimd_blend_u32(vm0, vidx0, vidx);
+
+ const psimd_s32 vm1 = vi1 > vmax;
+ const psimd_u32 vidx1 = psimd_add_u32(vidx0, v1);
+ vmax = psimd_blend_f32(vm1, vi1, vmax);
+ vidx = psimd_blend_u32(vm1, vidx1, vidx);
+
+ const psimd_s32 vm2 = vi2 > vmax;
+ const psimd_u32 vidx2 = psimd_add_u32(vidx1, v1);
+ vmax = psimd_blend_f32(vm2, vi2, vmax);
+ vidx = psimd_blend_u32(vm2, vidx2, vidx);
+
+ const psimd_s32 vm3 = vi3 > vmax;
+ const psimd_u32 vidx3 = psimd_add_u32(vidx2, v1);
+ vmax = psimd_blend_f32(vm3, vi3, vmax);
+ vidx = psimd_blend_u32(vm3, vidx3, vidx);
+
+ const psimd_s32 vm4 = vi4 > vmax;
+ const psimd_u32 vidx4 = psimd_add_u32(vidx3, v1);
+ vmax = psimd_blend_f32(vm4, vi4, vmax);
+ vidx = psimd_blend_u32(vm4, vidx4, vidx);
+
+ const psimd_s32 vm5 = vi5 > vmax;
+ const psimd_u32 vidx5 = psimd_add_u32(vidx4, v1);
+ vmax = psimd_blend_f32(vm5, vi5, vmax);
+ vidx = psimd_blend_u32(vm5, vidx5, vidx);
+
+ const psimd_s32 vm6 = vi6 > vmax;
+ const psimd_u32 vidx6 = psimd_add_u32(vidx5, v1);
+ vmax = psimd_blend_f32(vm6, vi6, vmax);
+ vidx = psimd_blend_u32(vm6, vidx6, vidx);
+
+ const psimd_s32 vm7 = vi7 > vmax;
+ const psimd_u32 vidx7 = psimd_add_u32(vidx6, v1);
+ vmax = psimd_blend_f32(vm7, vi7, vmax);
+ vidx = psimd_blend_u32(vm7, vidx7, vidx);
+
+ psimd_store_f32(ab, vmax);
+ ab += 4;
+ psimd_store_u32(ib, vidx);
+ ib += 4;
+ }
+ vidx0 = psimd_add_u32(vidx0, v8);
+ }
+
+ float* o = output;
+ uint32_t* i = index;
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = i0;
+ }
+ if (m <= 2) {
+ i2 = i0;
+ }
+ if (m < 4) {
+ i3 = i0;
+ }
+ if (m <= 4) {
+ i4 = i0;
+ }
+ if (m < 6) {
+ i5 = i0;
+ }
+ if (m <= 6) {
+ i6 = i0;
+ }
+ if (m != 8) {
+ i7 = i0;
+ }
+
+ size_t k = kc;
+ float* ab = acc_buffer;
+ uint32_t* ib = index_buffer;
+ for (; k >= 4; k -= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+
+ psimd_f32 vmax = psimd_load_f32(ab);
+ ab += 4;
+ psimd_u32 vidx = psimd_load_u32(ib);
+ ib += 4;
+
+ const psimd_s32 vm0 = vi0 > vmax;
+ vmax = psimd_blend_f32(vm0, vi0, vmax);
+ vidx = psimd_blend_u32(vm0, vidx0, vidx);
+
+ const psimd_s32 vm1 = vi1 > vmax;
+ const psimd_u32 vidx1 = psimd_add_u32(vidx0, v1);
+ vmax = psimd_blend_f32(vm1, vi1, vmax);
+ vidx = psimd_blend_u32(vm1, vidx1, vidx);
+
+ const psimd_s32 vm2 = vi2 > vmax;
+ const psimd_u32 vidx2 = psimd_add_u32(vidx1, v1);
+ vmax = psimd_blend_f32(vm2, vi2, vmax);
+ vidx = psimd_blend_u32(vm2, vidx2, vidx);
+
+ const psimd_s32 vm3 = vi3 > vmax;
+ const psimd_u32 vidx3 = psimd_add_u32(vidx2, v1);
+ vmax = psimd_blend_f32(vm3, vi3, vmax);
+ vidx = psimd_blend_u32(vm3, vidx3, vidx);
+
+ const psimd_s32 vm4 = vi4 > vmax;
+ const psimd_u32 vidx4 = psimd_add_u32(vidx3, v1);
+ vmax = psimd_blend_f32(vm4, vi4, vmax);
+ vidx = psimd_blend_u32(vm4, vidx4, vidx);
+
+ const psimd_s32 vm5 = vi5 > vmax;
+ const psimd_u32 vidx5 = psimd_add_u32(vidx4, v1);
+ vmax = psimd_blend_f32(vm5, vi5, vmax);
+ vidx = psimd_blend_u32(vm5, vidx5, vidx);
+
+ const psimd_s32 vm6 = vi6 > vmax;
+ const psimd_u32 vidx6 = psimd_add_u32(vidx5, v1);
+ vmax = psimd_blend_f32(vm6, vi6, vmax);
+ vidx = psimd_blend_u32(vm6, vidx6, vidx);
+
+ const psimd_s32 vm7 = vi7 > vmax;
+ const psimd_u32 vidx7 = psimd_add_u32(vidx6, v1);
+ vmax = psimd_blend_f32(vm7, vi7, vmax);
+ vidx = psimd_blend_u32(vm7, vidx7, vidx);
+
+ psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ psimd_store_f32(o, vout);
+ o += 4;
+ psimd_store_u32(i, vidx);
+ i += 4;
+ }
+ if (k != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+
+ psimd_f32 vmax = psimd_load_f32(ab);
+ psimd_u32 vidx = psimd_load_u32(ib);
+
+ const psimd_s32 vm0 = vi0 > vmax;
+ vmax = psimd_blend_f32(vm0, vi0, vmax);
+ vidx = psimd_blend_u32(vm0, vidx0, vidx);
+
+ const psimd_s32 vm1 = vi1 > vmax;
+ const psimd_u32 vidx1 = psimd_add_u32(vidx0, v1);
+ vmax = psimd_blend_f32(vm1, vi1, vmax);
+ vidx = psimd_blend_u32(vm1, vidx1, vidx);
+
+ const psimd_s32 vm2 = vi2 > vmax;
+ const psimd_u32 vidx2 = psimd_add_u32(vidx1, v1);
+ vmax = psimd_blend_f32(vm2, vi2, vmax);
+ vidx = psimd_blend_u32(vm2, vidx2, vidx);
+
+ const psimd_s32 vm3 = vi3 > vmax;
+ const psimd_u32 vidx3 = psimd_add_u32(vidx2, v1);
+ vmax = psimd_blend_f32(vm3, vi3, vmax);
+ vidx = psimd_blend_u32(vm3, vidx3, vidx);
+
+ const psimd_s32 vm4 = vi4 > vmax;
+ const psimd_u32 vidx4 = psimd_add_u32(vidx3, v1);
+ vmax = psimd_blend_f32(vm4, vi4, vmax);
+ vidx = psimd_blend_u32(vm4, vidx4, vidx);
+
+ const psimd_s32 vm5 = vi5 > vmax;
+ const psimd_u32 vidx5 = psimd_add_u32(vidx4, v1);
+ vmax = psimd_blend_f32(vm5, vi5, vmax);
+ vidx = psimd_blend_u32(vm5, vidx5, vidx);
+
+ const psimd_s32 vm6 = vi6 > vmax;
+ const psimd_u32 vidx6 = psimd_add_u32(vidx5, v1);
+ vmax = psimd_blend_f32(vm6, vi6, vmax);
+ vidx = psimd_blend_u32(vm6, vidx6, vidx);
+
+ const psimd_s32 vm7 = vi7 > vmax;
+ const psimd_u32 vidx7 = psimd_add_u32(vidx6, v1);
+ vmax = psimd_blend_f32(vm7, vi7, vmax);
+ vidx = psimd_blend_u32(vm7, vidx7, vidx);
+
+ psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ psimd_store2_f32(o, vout);
+ psimd_store2_u32(i, vidx);
+ vout = psimd_concat_hi_f32(vout, vout);
+ vidx = psimd_concat_hi_u32(vidx, vidx);
+ o += 2;
+ i += 2;
+ }
+ if (k & 1) {
+ psimd_store1_f32(o, vout);
+ psimd_store1_u32(i, vidx);
+ o += 1;
+ i += 1;
+ }
+ }
+ }
+
+ output = (float*) ((uintptr_t) o + output_increment);
+ index = (uint32_t*) i;
+ } while (--n != 0);
+}
diff --git a/src/f32-argmaxpool/mp9p8q-scalar.c b/src/f32-argmaxpool/mp9p8q-scalar.c
new file mode 100644
index 0000000..f8ae537
--- /dev/null
+++ b/src/f32-argmaxpool/mp9p8q-scalar.c
@@ -0,0 +1,283 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/argmaxpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_argmaxpool_ukernel_mp9p8q__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* acc_buffer,
+ uint32_t* index_buffer,
+ float* output,
+ uint32_t* index,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const float voutput_max = params->scalar.max;
+ const float voutput_min = params->scalar.min;
+ do {
+ {
+ float* ab = acc_buffer;
+ uint32_t* ib = index_buffer;
+
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vi8 = *i8++;
+
+ float vmax = vi0;
+ uint32_t vidx = 0;
+
+ if (vi1 > vmax) {
+ vmax = vi1;
+ vidx = 1;
+ }
+
+ if (vi2 > vmax) {
+ vmax = vi2;
+ vidx = 2;
+ }
+
+ if (vi3 > vmax) {
+ vmax = vi3;
+ vidx = 3;
+ }
+
+ if (vi4 > vmax) {
+ vmax = vi4;
+ vidx = 4;
+ }
+
+ if (vi5 > vmax) {
+ vmax = vi5;
+ vidx = 5;
+ }
+
+ if (vi6 > vmax) {
+ vmax = vi6;
+ vidx = 6;
+ }
+
+ if (vi7 > vmax) {
+ vmax = vi7;
+ vidx = 7;
+ }
+
+ if (vi8 > vmax) {
+ vmax = vi8;
+ vidx = 8;
+ }
+
+ *ab++ = vmax;
+ *ib++ = vidx;
+ } while (--k != 0);
+ }
+ uint32_t vidx0 = 9;
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* ab = acc_buffer;
+ uint32_t* ib = index_buffer;
+
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+
+ float vmax = *ab;
+ uint32_t vidx = *ib;
+
+ if (vi0 > vmax) {
+ vmax = vi0;
+ vidx = vidx0;
+ }
+
+ if (vi1 > vmax) {
+ vmax = vi1;
+ vidx = vidx0 + 1;
+ }
+
+ if (vi2 > vmax) {
+ vmax = vi2;
+ vidx = vidx0 + 2;
+ }
+
+ if (vi3 > vmax) {
+ vmax = vi3;
+ vidx = vidx0 + 3;
+ }
+
+ if (vi4 > vmax) {
+ vmax = vi4;
+ vidx = vidx0 + 4;
+ }
+
+ if (vi5 > vmax) {
+ vmax = vi5;
+ vidx = vidx0 + 5;
+ }
+
+ if (vi6 > vmax) {
+ vmax = vi6;
+ vidx = vidx0 + 6;
+ }
+
+ if (vi7 > vmax) {
+ vmax = vi7;
+ vidx = vidx0 + 7;
+ }
+
+ *ab++ = vmax;
+ *ib++ = vidx;
+ } while (--k != 0);
+ vidx0 += 8;
+ }
+
+ float* o = output;
+ uint32_t* i = index;
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = i0;
+ }
+ if (m <= 2) {
+ i2 = i0;
+ }
+ if (m < 4) {
+ i3 = i0;
+ }
+ if (m <= 4) {
+ i4 = i0;
+ }
+ if (m < 6) {
+ i5 = i0;
+ }
+ if (m <= 6) {
+ i6 = i0;
+ }
+ if (m != 8) {
+ i7 = i0;
+ }
+
+ size_t k = kc;
+ float* ab = acc_buffer;
+ uint32_t* ib = index_buffer;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+
+ float vmax = *ab++;
+ uint32_t vidx = *ib++;
+
+ if (vi0 > vmax) {
+ vmax = vi0;
+ vidx = vidx0;
+ }
+
+ if (vi1 > vmax) {
+ vmax = vi1;
+ vidx = vidx0 + 1;
+ }
+
+ if (vi2 > vmax) {
+ vmax = vi2;
+ vidx = vidx0 + 2;
+ }
+
+ if (vi3 > vmax) {
+ vmax = vi3;
+ vidx = vidx0 + 3;
+ }
+
+ if (vi4 > vmax) {
+ vmax = vi4;
+ vidx = vidx0 + 4;
+ }
+
+ if (vi5 > vmax) {
+ vmax = vi5;
+ vidx = vidx0 + 5;
+ }
+
+ if (vi6 > vmax) {
+ vmax = vi6;
+ vidx = vidx0 + 6;
+ }
+
+ if (vi7 > vmax) {
+ vmax = vi7;
+ vidx = vidx0 + 7;
+ }
+
+ const float vout = math_max_f32(math_min_f32(vmax, voutput_max), voutput_min);
+
+ *o++ = vout;
+ *i++ = vidx;
+ } while (--k != 0);
+ }
+
+ output = (float*) ((uintptr_t) o + output_increment);
+ index = (uint32_t*) i;
+ } while (--n != 0);
+}
diff --git a/src/f32-argmaxpool/mp9p8q-sse2.c b/src/f32-argmaxpool/mp9p8q-sse2.c
new file mode 100644
index 0000000..7eddcd7
--- /dev/null
+++ b/src/f32-argmaxpool/mp9p8q-sse2.c
@@ -0,0 +1,377 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/argmaxpool.h>
+
+
+void xnn_f32_argmaxpool_ukernel_mp9p8q__sse2(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* acc_buffer,
+ uint32_t* index_buffer,
+ float* output,
+ uint32_t* index,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const __m128 voutput_max = _mm_load_ps(params->sse.max);
+ const __m128 voutput_min = _mm_load_ps(params->sse.min);
+ do {
+ {
+ float* ab = acc_buffer;
+ uint32_t* ib = index_buffer;
+
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ for (size_t k = 0; k < kc; k += 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ i8 += 4;
+
+ __m128 vmax = vi0;
+ __m128i vidx = _mm_setzero_si128();
+
+ const __m128i vm1 = _mm_castps_si128(_mm_cmpgt_ps(vi1, vmax));
+ vmax = _mm_max_ps(vi1, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm1, vidx), _mm_and_si128(vm1, _mm_set1_epi32(1)));
+
+ const __m128i vm2 = _mm_castps_si128(_mm_cmpgt_ps(vi2, vmax));
+ vmax = _mm_max_ps(vi2, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm2, vidx), _mm_and_si128(vm2, _mm_set1_epi32(2)));
+
+ const __m128i vm3 = _mm_castps_si128(_mm_cmpgt_ps(vi3, vmax));
+ vmax = _mm_max_ps(vi3, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm3, vidx), _mm_and_si128(vm3, _mm_set1_epi32(3)));
+
+ const __m128i vm4 = _mm_castps_si128(_mm_cmpgt_ps(vi4, vmax));
+ vmax = _mm_max_ps(vi4, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm4, vidx), _mm_and_si128(vm4, _mm_set1_epi32(4)));
+
+ const __m128i vm5 = _mm_castps_si128(_mm_cmpgt_ps(vi5, vmax));
+ vmax = _mm_max_ps(vi5, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm5, vidx), _mm_and_si128(vm5, _mm_set1_epi32(5)));
+
+ const __m128i vm6 = _mm_castps_si128(_mm_cmpgt_ps(vi6, vmax));
+ vmax = _mm_max_ps(vi6, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm6, vidx), _mm_and_si128(vm6, _mm_set1_epi32(6)));
+
+ const __m128i vm7 = _mm_castps_si128(_mm_cmpgt_ps(vi7, vmax));
+ vmax = _mm_max_ps(vi7, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm7, vidx), _mm_and_si128(vm7, _mm_set1_epi32(7)));
+
+ const __m128i vm8 = _mm_castps_si128(_mm_cmpgt_ps(vi8, vmax));
+ vmax = _mm_max_ps(vi8, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm8, vidx), _mm_and_si128(vm8, _mm_set1_epi32(8)));
+
+ _mm_store_ps(ab, vmax);
+ ab += 4;
+ _mm_store_si128((__m128i*) ib, vidx);
+ ib += 4;
+ }
+ }
+ const __m128i v1 = _mm_set1_epi32(1);
+ const __m128i v8 = _mm_set1_epi32(8);
+ __m128i vidx0 = _mm_add_epi32(v1, v8);
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* ab = acc_buffer;
+ uint32_t* ib = index_buffer;
+
+ for (size_t k = 0; k < kc; k += 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+
+ __m128 vmax = _mm_load_ps(ab);
+ __m128i vidx = _mm_load_si128((const __m128i*) ib);
+
+ const __m128i vm0 = _mm_castps_si128(_mm_cmpgt_ps(vi0, vmax));
+ vmax = _mm_max_ps(vi0, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm0, vidx), _mm_and_si128(vm0, vidx0));
+
+ const __m128i vm1 = _mm_castps_si128(_mm_cmpgt_ps(vi1, vmax));
+ const __m128i vidx1 = _mm_add_epi32(vidx0, v1);
+ vmax = _mm_max_ps(vi1, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm1, vidx), _mm_and_si128(vm1, vidx1));
+
+ const __m128i vm2 = _mm_castps_si128(_mm_cmpgt_ps(vi2, vmax));
+ const __m128i vidx2 = _mm_add_epi32(vidx1, v1);
+ vmax = _mm_max_ps(vi2, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm2, vidx), _mm_and_si128(vm2, vidx2));
+
+ const __m128i vm3 = _mm_castps_si128(_mm_cmpgt_ps(vi3, vmax));
+ const __m128i vidx3 = _mm_add_epi32(vidx2, v1);
+ vmax = _mm_max_ps(vi3, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm3, vidx), _mm_and_si128(vm3, vidx3));
+
+ const __m128i vm4 = _mm_castps_si128(_mm_cmpgt_ps(vi4, vmax));
+ const __m128i vidx4 = _mm_add_epi32(vidx3, v1);
+ vmax = _mm_max_ps(vi4, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm4, vidx), _mm_and_si128(vm4, vidx4));
+
+ const __m128i vm5 = _mm_castps_si128(_mm_cmpgt_ps(vi5, vmax));
+ const __m128i vidx5 = _mm_add_epi32(vidx4, v1);
+ vmax = _mm_max_ps(vi5, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm5, vidx), _mm_and_si128(vm5, vidx5));
+
+ const __m128i vm6 = _mm_castps_si128(_mm_cmpgt_ps(vi6, vmax));
+ const __m128i vidx6 = _mm_add_epi32(vidx5, v1);
+ vmax = _mm_max_ps(vi6, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm6, vidx), _mm_and_si128(vm6, vidx6));
+
+ const __m128i vm7 = _mm_castps_si128(_mm_cmpgt_ps(vi7, vmax));
+ const __m128i vidx7 = _mm_add_epi32(vidx6, v1);
+ vmax = _mm_max_ps(vi7, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm7, vidx), _mm_and_si128(vm7, vidx7));
+
+ _mm_store_ps(ab, vmax);
+ ab += 4;
+ _mm_store_si128((__m128i*) ib, vidx);
+ ib += 4;
+ }
+ vidx0 = _mm_add_epi32(vidx0, v8);
+ }
+
+ float* o = output;
+ uint32_t* i = index;
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = i0;
+ }
+ if (m <= 2) {
+ i2 = i0;
+ }
+ if (m < 4) {
+ i3 = i0;
+ }
+ if (m <= 4) {
+ i4 = i0;
+ }
+ if (m < 6) {
+ i5 = i0;
+ }
+ if (m <= 6) {
+ i6 = i0;
+ }
+ if (m != 8) {
+ i7 = i0;
+ }
+
+ size_t k = kc;
+ float* ab = acc_buffer;
+ uint32_t* ib = index_buffer;
+ for (; k >= 4; k -= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+
+ __m128 vmax = _mm_load_ps(ab);
+ ab += 4;
+ __m128i vidx = _mm_load_si128((const __m128i*) ib);
+ ib += 4;
+
+ const __m128i vm0 = _mm_castps_si128(_mm_cmpgt_ps(vi0, vmax));
+ vmax = _mm_max_ps(vi0, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm0, vidx), _mm_and_si128(vm0, vidx0));
+
+ const __m128i vm1 = _mm_castps_si128(_mm_cmpgt_ps(vi1, vmax));
+ const __m128i vidx1 = _mm_add_epi32(vidx0, v1);
+ vmax = _mm_max_ps(vi1, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm1, vidx), _mm_and_si128(vm1, vidx1));
+
+ const __m128i vm2 = _mm_castps_si128(_mm_cmpgt_ps(vi2, vmax));
+ const __m128i vidx2 = _mm_add_epi32(vidx1, v1);
+ vmax = _mm_max_ps(vi2, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm2, vidx), _mm_and_si128(vm2, vidx2));
+
+ const __m128i vm3 = _mm_castps_si128(_mm_cmpgt_ps(vi3, vmax));
+ const __m128i vidx3 = _mm_add_epi32(vidx2, v1);
+ vmax = _mm_max_ps(vi3, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm3, vidx), _mm_and_si128(vm3, vidx3));
+
+ const __m128i vm4 = _mm_castps_si128(_mm_cmpgt_ps(vi4, vmax));
+ const __m128i vidx4 = _mm_add_epi32(vidx3, v1);
+ vmax = _mm_max_ps(vi4, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm4, vidx), _mm_and_si128(vm4, vidx4));
+
+ const __m128i vm5 = _mm_castps_si128(_mm_cmpgt_ps(vi5, vmax));
+ const __m128i vidx5 = _mm_add_epi32(vidx4, v1);
+ vmax = _mm_max_ps(vi5, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm5, vidx), _mm_and_si128(vm5, vidx5));
+
+ const __m128i vm6 = _mm_castps_si128(_mm_cmpgt_ps(vi6, vmax));
+ const __m128i vidx6 = _mm_add_epi32(vidx5, v1);
+ vmax = _mm_max_ps(vi6, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm6, vidx), _mm_and_si128(vm6, vidx6));
+
+ const __m128i vm7 = _mm_castps_si128(_mm_cmpgt_ps(vi7, vmax));
+ const __m128i vidx7 = _mm_add_epi32(vidx6, v1);
+ vmax = _mm_max_ps(vi7, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm7, vidx), _mm_and_si128(vm7, vidx7));
+
+ __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ _mm_storeu_ps(o, vout);
+ o += 4;
+ _mm_storeu_si128((__m128i*) i, vidx);
+ i += 4;
+ }
+ if (k != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vi7 = _mm_loadu_ps(i7);
+
+ __m128 vmax = _mm_load_ps(ab);
+ __m128i vidx = _mm_load_si128((const __m128i*) ib);
+
+ const __m128i vm0 = _mm_castps_si128(_mm_cmpgt_ps(vi0, vmax));
+ vmax = _mm_max_ps(vi0, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm0, vidx), _mm_and_si128(vm0, vidx0));
+
+ const __m128i vm1 = _mm_castps_si128(_mm_cmpgt_ps(vi1, vmax));
+ const __m128i vidx1 = _mm_add_epi32(vidx0, v1);
+ vmax = _mm_max_ps(vi1, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm1, vidx), _mm_and_si128(vm1, vidx1));
+
+ const __m128i vm2 = _mm_castps_si128(_mm_cmpgt_ps(vi2, vmax));
+ const __m128i vidx2 = _mm_add_epi32(vidx1, v1);
+ vmax = _mm_max_ps(vi2, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm2, vidx), _mm_and_si128(vm2, vidx2));
+
+ const __m128i vm3 = _mm_castps_si128(_mm_cmpgt_ps(vi3, vmax));
+ const __m128i vidx3 = _mm_add_epi32(vidx2, v1);
+ vmax = _mm_max_ps(vi3, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm3, vidx), _mm_and_si128(vm3, vidx3));
+
+ const __m128i vm4 = _mm_castps_si128(_mm_cmpgt_ps(vi4, vmax));
+ const __m128i vidx4 = _mm_add_epi32(vidx3, v1);
+ vmax = _mm_max_ps(vi4, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm4, vidx), _mm_and_si128(vm4, vidx4));
+
+ const __m128i vm5 = _mm_castps_si128(_mm_cmpgt_ps(vi5, vmax));
+ const __m128i vidx5 = _mm_add_epi32(vidx4, v1);
+ vmax = _mm_max_ps(vi5, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm5, vidx), _mm_and_si128(vm5, vidx5));
+
+ const __m128i vm6 = _mm_castps_si128(_mm_cmpgt_ps(vi6, vmax));
+ const __m128i vidx6 = _mm_add_epi32(vidx5, v1);
+ vmax = _mm_max_ps(vi6, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm6, vidx), _mm_and_si128(vm6, vidx6));
+
+ const __m128i vm7 = _mm_castps_si128(_mm_cmpgt_ps(vi7, vmax));
+ const __m128i vidx7 = _mm_add_epi32(vidx6, v1);
+ vmax = _mm_max_ps(vi7, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm7, vidx), _mm_and_si128(vm7, vidx7));
+
+ __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ _mm_store_sd((double*) o, _mm_castps_pd(vout));
+ _mm_storel_epi64((__m128i*) i, vidx);
+ vout = _mm_movehl_ps(vout, vout);
+ vidx = _mm_unpackhi_epi64(vidx, vidx);
+ o += 2;
+ i += 2;
+ }
+ if (k & 1) {
+ _mm_store_ss(o, vout);
+ *i = (uint32_t) _mm_cvtsi128_si32(vidx);
+ o += 1;
+ i += 1;
+ }
+ }
+ }
+
+ output = (float*) ((uintptr_t) o + output_increment);
+ index = (uint32_t*) i;
+ } while (--n != 0);
+}
diff --git a/src/f32-argmaxpool/up4-psimd.c b/src/f32-argmaxpool/up4-psimd.c
new file mode 100644
index 0000000..5b55bfa
--- /dev/null
+++ b/src/f32-argmaxpool/up4-psimd.c
@@ -0,0 +1,126 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/argmaxpool.h>
+
+
+void xnn_f32_argmaxpool_ukernel_up4__psimd(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* output,
+ uint32_t* index,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 4);
+ assert(kc != 0);
+
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.max);
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.min);
+ do {
+ float* o = output;
+ uint32_t* i = index;
+
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks != 4) {
+ i3 = i0;
+ }
+
+ size_t k = kc;
+ for (; k >= 4; k -= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+
+ psimd_f32 vmax = vi0;
+ psimd_u32 vidx = psimd_splat_u32(0);
+
+ const psimd_s32 vm1 = vi1 > vmax;
+ vmax = psimd_blend_f32(vm1, vi1, vmax);
+ vidx = psimd_blend_u32(vm1, psimd_splat_u32(1), vidx);
+
+ const psimd_s32 vm2 = vi2 > vmax;
+ vmax = psimd_blend_f32(vm2, vi2, vmax);
+ vidx = psimd_blend_u32(vm2, psimd_splat_u32(2), vidx);
+
+ const psimd_s32 vm3 = vi3 > vmax;
+ vmax = psimd_blend_f32(vm3, vi3, vmax);
+ vidx = psimd_blend_u32(vm3, psimd_splat_u32(3), vidx);
+
+ const psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ psimd_store_f32(o, vout);
+ o += 4;
+ psimd_store_u32(i, vidx);
+ i += 4;
+ }
+ if (k != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+
+ psimd_f32 vmax = vi0;
+ psimd_u32 vidx = psimd_splat_u32(0);
+
+ const psimd_s32 vm1 = vi1 > vmax;
+ vmax = psimd_blend_f32(vm1, vi1, vmax);
+ vidx = psimd_blend_u32(vm1, psimd_splat_u32(1), vidx);
+
+ const psimd_s32 vm2 = vi2 > vmax;
+ vmax = psimd_blend_f32(vm2, vi2, vmax);
+ vidx = psimd_blend_u32(vm2, psimd_splat_u32(2), vidx);
+
+ const psimd_s32 vm3 = vi3 > vmax;
+ vmax = psimd_blend_f32(vm3, vi3, vmax);
+ vidx = psimd_blend_u32(vm3, psimd_splat_u32(3), vidx);
+
+ psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ psimd_store2_f32(o, vout);
+ psimd_store2_u32(i, vidx);
+ vout = psimd_concat_hi_f32(vout, vout);
+ vidx = psimd_concat_hi_u32(vidx, vidx);
+ o += 2;
+ i += 2;
+ }
+ if (k & 1) {
+ psimd_store1_f32(o, vout);
+ psimd_store1_u32(i, vidx);
+ o += 1;
+ i += 1;
+ }
+ }
+ input = (const float**) ((uintptr_t) input + input_increment);
+ output = (float*) ((uintptr_t) o + output_increment);
+ index = (uint32_t*) i;
+ } while (--n != 0);
+}
diff --git a/src/f32-argmaxpool/up4-scalar.c b/src/f32-argmaxpool/up4-scalar.c
new file mode 100644
index 0000000..1d95c8f
--- /dev/null
+++ b/src/f32-argmaxpool/up4-scalar.c
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/argmaxpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_argmaxpool_ukernel_up4__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* output,
+ uint32_t* index,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 4);
+ assert(kc != 0);
+
+ const float voutput_max = params->scalar.max;
+ const float voutput_min = params->scalar.min;
+ do {
+ float* o = output;
+ uint32_t* i = index;
+
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks != 4) {
+ i3 = i0;
+ }
+
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+
+ float vmax = vi0;
+ uint32_t vidx = 0;
+
+ if (vi1 > vmax) {
+ vmax = vi1;
+ vidx = 1;
+ }
+
+ if (vi2 > vmax) {
+ vmax = vi2;
+ vidx = 2;
+ }
+
+ if (vi3 > vmax) {
+ vmax = vi3;
+ vidx = 3;
+ }
+
+ const float vout = math_max_f32(math_min_f32(vmax, voutput_max), voutput_min);
+
+ *o++ = vout;
+ *i++ = vidx;
+ } while (--k != 0);
+ input = (const float**) ((uintptr_t) input + input_increment);
+ output = (float*) ((uintptr_t) o + output_increment);
+ index = (uint32_t*) i;
+ } while (--n != 0);
+}
diff --git a/src/f32-argmaxpool/up4-sse2.c b/src/f32-argmaxpool/up4-sse2.c
new file mode 100644
index 0000000..64d1d12
--- /dev/null
+++ b/src/f32-argmaxpool/up4-sse2.c
@@ -0,0 +1,126 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/argmaxpool.h>
+
+
+void xnn_f32_argmaxpool_ukernel_up4__sse2(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* output,
+ uint32_t* index,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 4);
+ assert(kc != 0);
+
+ const __m128 voutput_max = _mm_load_ps(params->sse.max);
+ const __m128 voutput_min = _mm_load_ps(params->sse.min);
+ do {
+ float* o = output;
+ uint32_t* i = index;
+
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks != 4) {
+ i3 = i0;
+ }
+
+ size_t k = kc;
+ for (; k >= 4; k -= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+
+ __m128 vmax = vi0;
+ __m128i vidx = _mm_setzero_si128();
+
+ const __m128i vm1 = _mm_castps_si128(_mm_cmpgt_ps(vi1, vmax));
+ vmax = _mm_max_ps(vi1, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm1, vidx), _mm_and_si128(vm1, _mm_set1_epi32(1)));
+
+ const __m128i vm2 = _mm_castps_si128(_mm_cmpgt_ps(vi2, vmax));
+ vmax = _mm_max_ps(vi2, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm2, vidx), _mm_and_si128(vm2, _mm_set1_epi32(2)));
+
+ const __m128i vm3 = _mm_castps_si128(_mm_cmpgt_ps(vi3, vmax));
+ vmax = _mm_max_ps(vi3, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm3, vidx), _mm_and_si128(vm3, _mm_set1_epi32(3)));
+
+ const __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ _mm_storeu_ps(o, vout);
+ o += 4;
+ _mm_storeu_si128((__m128i*) i, vidx);
+ i += 4;
+ }
+ if (k != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+
+ __m128 vmax = vi0;
+ __m128i vidx = _mm_setzero_si128();
+
+ const __m128i vm1 = _mm_castps_si128(_mm_cmpgt_ps(vi1, vmax));
+ vmax = _mm_max_ps(vi1, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm1, vidx), _mm_and_si128(vm1, _mm_set1_epi32(1)));
+
+ const __m128i vm2 = _mm_castps_si128(_mm_cmpgt_ps(vi2, vmax));
+ vmax = _mm_max_ps(vi2, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm2, vidx), _mm_and_si128(vm2, _mm_set1_epi32(2)));
+
+ const __m128i vm3 = _mm_castps_si128(_mm_cmpgt_ps(vi3, vmax));
+ vmax = _mm_max_ps(vi3, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm3, vidx), _mm_and_si128(vm3, _mm_set1_epi32(3)));
+
+ __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ _mm_store_sd((double*) o, _mm_castps_pd(vout));
+ _mm_storel_epi64((__m128i*) i, vidx);
+ vout = _mm_movehl_ps(vout, vout);
+ vidx = _mm_unpackhi_epi64(vidx, vidx);
+ o += 2;
+ i += 2;
+ }
+ if (k & 1) {
+ _mm_store_ss(o, vout);
+ *i = (uint32_t) _mm_cvtsi128_si32(vidx);
+ o += 1;
+ i += 1;
+ }
+ }
+ input = (const float**) ((uintptr_t) input + input_increment);
+ output = (float*) ((uintptr_t) o + output_increment);
+ index = (uint32_t*) i;
+ } while (--n != 0);
+}
diff --git a/src/f32-argmaxpool/up9-psimd.c b/src/f32-argmaxpool/up9-psimd.c
new file mode 100644
index 0000000..69ec655
--- /dev/null
+++ b/src/f32-argmaxpool/up9-psimd.c
@@ -0,0 +1,201 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/argmaxpool.h>
+
+
+void xnn_f32_argmaxpool_ukernel_up9__psimd(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* output,
+ uint32_t* index,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.max);
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.min);
+ do {
+ float* o = output;
+ uint32_t* i = index;
+
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks < 4) {
+ i3 = i0;
+ }
+ if (ks <= 4) {
+ i4 = i0;
+ }
+ if (ks < 6) {
+ i5 = i0;
+ }
+ if (ks <= 6) {
+ i6 = i0;
+ }
+ if (ks < 8) {
+ i7 = i0;
+ }
+ if (ks <= 8) {
+ i8 = i0;
+ }
+
+ size_t k = kc;
+ for (; k >= 4; k -= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ i8 += 4;
+
+ psimd_f32 vmax = vi0;
+ psimd_u32 vidx = psimd_splat_u32(0);
+
+ const psimd_s32 vm1 = vi1 > vmax;
+ vmax = psimd_blend_f32(vm1, vi1, vmax);
+ vidx = psimd_blend_u32(vm1, psimd_splat_u32(1), vidx);
+
+ const psimd_s32 vm2 = vi2 > vmax;
+ vmax = psimd_blend_f32(vm2, vi2, vmax);
+ vidx = psimd_blend_u32(vm2, psimd_splat_u32(2), vidx);
+
+ const psimd_s32 vm3 = vi3 > vmax;
+ vmax = psimd_blend_f32(vm3, vi3, vmax);
+ vidx = psimd_blend_u32(vm3, psimd_splat_u32(3), vidx);
+
+ const psimd_s32 vm4 = vi4 > vmax;
+ vmax = psimd_blend_f32(vm4, vi4, vmax);
+ vidx = psimd_blend_u32(vm4, psimd_splat_u32(4), vidx);
+
+ const psimd_s32 vm5 = vi5 > vmax;
+ vmax = psimd_blend_f32(vm5, vi5, vmax);
+ vidx = psimd_blend_u32(vm5, psimd_splat_u32(5), vidx);
+
+ const psimd_s32 vm6 = vi6 > vmax;
+ vmax = psimd_blend_f32(vm6, vi6, vmax);
+ vidx = psimd_blend_u32(vm6, psimd_splat_u32(6), vidx);
+
+ const psimd_s32 vm7 = vi7 > vmax;
+ vmax = psimd_blend_f32(vm7, vi7, vmax);
+ vidx = psimd_blend_u32(vm7, psimd_splat_u32(7), vidx);
+
+ const psimd_s32 vm8 = vi8 > vmax;
+ vmax = psimd_blend_f32(vm8, vi8, vmax);
+ vidx = psimd_blend_u32(vm8, psimd_splat_u32(8), vidx);
+
+ const psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ psimd_store_f32(o, vout);
+ o += 4;
+ psimd_store_u32(i, vidx);
+ i += 4;
+ }
+ if (k != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+
+ psimd_f32 vmax = vi0;
+ psimd_u32 vidx = psimd_splat_u32(0);
+
+ const psimd_s32 vm1 = vi1 > vmax;
+ vmax = psimd_blend_f32(vm1, vi1, vmax);
+ vidx = psimd_blend_u32(vm1, psimd_splat_u32(1), vidx);
+
+ const psimd_s32 vm2 = vi2 > vmax;
+ vmax = psimd_blend_f32(vm2, vi2, vmax);
+ vidx = psimd_blend_u32(vm2, psimd_splat_u32(2), vidx);
+
+ const psimd_s32 vm3 = vi3 > vmax;
+ vmax = psimd_blend_f32(vm3, vi3, vmax);
+ vidx = psimd_blend_u32(vm3, psimd_splat_u32(3), vidx);
+
+ const psimd_s32 vm4 = vi4 > vmax;
+ vmax = psimd_blend_f32(vm4, vi4, vmax);
+ vidx = psimd_blend_u32(vm4, psimd_splat_u32(4), vidx);
+
+ const psimd_s32 vm5 = vi5 > vmax;
+ vmax = psimd_blend_f32(vm5, vi5, vmax);
+ vidx = psimd_blend_u32(vm5, psimd_splat_u32(5), vidx);
+
+ const psimd_s32 vm6 = vi6 > vmax;
+ vmax = psimd_blend_f32(vm6, vi6, vmax);
+ vidx = psimd_blend_u32(vm6, psimd_splat_u32(6), vidx);
+
+ const psimd_s32 vm7 = vi7 > vmax;
+ vmax = psimd_blend_f32(vm7, vi7, vmax);
+ vidx = psimd_blend_u32(vm7, psimd_splat_u32(7), vidx);
+
+ const psimd_s32 vm8 = vi8 > vmax;
+ vmax = psimd_blend_f32(vm8, vi8, vmax);
+ vidx = psimd_blend_u32(vm8, psimd_splat_u32(8), vidx);
+
+ psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ psimd_store2_f32(o, vout);
+ psimd_store2_u32(i, vidx);
+ vout = psimd_concat_hi_f32(vout, vout);
+ vidx = psimd_concat_hi_u32(vidx, vidx);
+ o += 2;
+ i += 2;
+ }
+ if (k & 1) {
+ psimd_store1_f32(o, vout);
+ psimd_store1_u32(i, vidx);
+ o += 1;
+ i += 1;
+ }
+ }
+ input = (const float**) ((uintptr_t) input + input_increment);
+ output = (float*) ((uintptr_t) o + output_increment);
+ index = (uint32_t*) i;
+ } while (--n != 0);
+}
diff --git a/src/f32-argmaxpool/up9-scalar.c b/src/f32-argmaxpool/up9-scalar.c
new file mode 100644
index 0000000..0cb8a49
--- /dev/null
+++ b/src/f32-argmaxpool/up9-scalar.c
@@ -0,0 +1,134 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/argmaxpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_argmaxpool_ukernel_up9__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* output,
+ uint32_t* index,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const float voutput_max = params->scalar.max;
+ const float voutput_min = params->scalar.min;
+ do {
+ float* o = output;
+ uint32_t* i = index;
+
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks < 4) {
+ i3 = i0;
+ }
+ if (ks <= 4) {
+ i4 = i0;
+ }
+ if (ks < 6) {
+ i5 = i0;
+ }
+ if (ks <= 6) {
+ i6 = i0;
+ }
+ if (ks < 8) {
+ i7 = i0;
+ }
+ if (ks <= 8) {
+ i8 = i0;
+ }
+
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vi8 = *i8++;
+
+ float vmax = vi0;
+ uint32_t vidx = 0;
+
+ if (vi1 > vmax) {
+ vmax = vi1;
+ vidx = 1;
+ }
+
+ if (vi2 > vmax) {
+ vmax = vi2;
+ vidx = 2;
+ }
+
+ if (vi3 > vmax) {
+ vmax = vi3;
+ vidx = 3;
+ }
+
+ if (vi4 > vmax) {
+ vmax = vi4;
+ vidx = 4;
+ }
+
+ if (vi5 > vmax) {
+ vmax = vi5;
+ vidx = 5;
+ }
+
+ if (vi6 > vmax) {
+ vmax = vi6;
+ vidx = 6;
+ }
+
+ if (vi7 > vmax) {
+ vmax = vi7;
+ vidx = 7;
+ }
+
+ if (vi8 > vmax) {
+ vmax = vi8;
+ vidx = 8;
+ }
+
+ const float vout = math_max_f32(math_min_f32(vmax, voutput_max), voutput_min);
+
+ *o++ = vout;
+ *i++ = vidx;
+ } while (--k != 0);
+ input = (const float**) ((uintptr_t) input + input_increment);
+ output = (float*) ((uintptr_t) o + output_increment);
+ index = (uint32_t*) i;
+ } while (--n != 0);
+}
diff --git a/src/f32-argmaxpool/up9-sse2.c b/src/f32-argmaxpool/up9-sse2.c
new file mode 100644
index 0000000..f2df769
--- /dev/null
+++ b/src/f32-argmaxpool/up9-sse2.c
@@ -0,0 +1,201 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/argmaxpool.h>
+
+
+void xnn_f32_argmaxpool_ukernel_up9__sse2(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* output,
+ uint32_t* index,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const __m128 voutput_max = _mm_load_ps(params->sse.max);
+ const __m128 voutput_min = _mm_load_ps(params->sse.min);
+ do {
+ float* o = output;
+ uint32_t* i = index;
+
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks < 4) {
+ i3 = i0;
+ }
+ if (ks <= 4) {
+ i4 = i0;
+ }
+ if (ks < 6) {
+ i5 = i0;
+ }
+ if (ks <= 6) {
+ i6 = i0;
+ }
+ if (ks < 8) {
+ i7 = i0;
+ }
+ if (ks <= 8) {
+ i8 = i0;
+ }
+
+ size_t k = kc;
+ for (; k >= 4; k -= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ i8 += 4;
+
+ __m128 vmax = vi0;
+ __m128i vidx = _mm_setzero_si128();
+
+ const __m128i vm1 = _mm_castps_si128(_mm_cmpgt_ps(vi1, vmax));
+ vmax = _mm_max_ps(vi1, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm1, vidx), _mm_and_si128(vm1, _mm_set1_epi32(1)));
+
+ const __m128i vm2 = _mm_castps_si128(_mm_cmpgt_ps(vi2, vmax));
+ vmax = _mm_max_ps(vi2, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm2, vidx), _mm_and_si128(vm2, _mm_set1_epi32(2)));
+
+ const __m128i vm3 = _mm_castps_si128(_mm_cmpgt_ps(vi3, vmax));
+ vmax = _mm_max_ps(vi3, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm3, vidx), _mm_and_si128(vm3, _mm_set1_epi32(3)));
+
+ const __m128i vm4 = _mm_castps_si128(_mm_cmpgt_ps(vi4, vmax));
+ vmax = _mm_max_ps(vi4, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm4, vidx), _mm_and_si128(vm4, _mm_set1_epi32(4)));
+
+ const __m128i vm5 = _mm_castps_si128(_mm_cmpgt_ps(vi5, vmax));
+ vmax = _mm_max_ps(vi5, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm5, vidx), _mm_and_si128(vm5, _mm_set1_epi32(5)));
+
+ const __m128i vm6 = _mm_castps_si128(_mm_cmpgt_ps(vi6, vmax));
+ vmax = _mm_max_ps(vi6, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm6, vidx), _mm_and_si128(vm6, _mm_set1_epi32(6)));
+
+ const __m128i vm7 = _mm_castps_si128(_mm_cmpgt_ps(vi7, vmax));
+ vmax = _mm_max_ps(vi7, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm7, vidx), _mm_and_si128(vm7, _mm_set1_epi32(7)));
+
+ const __m128i vm8 = _mm_castps_si128(_mm_cmpgt_ps(vi8, vmax));
+ vmax = _mm_max_ps(vi8, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm8, vidx), _mm_and_si128(vm8, _mm_set1_epi32(8)));
+
+ const __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ _mm_storeu_ps(o, vout);
+ o += 4;
+ _mm_storeu_si128((__m128i*) i, vidx);
+ i += 4;
+ }
+ if (k != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vi8 = _mm_loadu_ps(i8);
+
+ __m128 vmax = vi0;
+ __m128i vidx = _mm_setzero_si128();
+
+ const __m128i vm1 = _mm_castps_si128(_mm_cmpgt_ps(vi1, vmax));
+ vmax = _mm_max_ps(vi1, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm1, vidx), _mm_and_si128(vm1, _mm_set1_epi32(1)));
+
+ const __m128i vm2 = _mm_castps_si128(_mm_cmpgt_ps(vi2, vmax));
+ vmax = _mm_max_ps(vi2, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm2, vidx), _mm_and_si128(vm2, _mm_set1_epi32(2)));
+
+ const __m128i vm3 = _mm_castps_si128(_mm_cmpgt_ps(vi3, vmax));
+ vmax = _mm_max_ps(vi3, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm3, vidx), _mm_and_si128(vm3, _mm_set1_epi32(3)));
+
+ const __m128i vm4 = _mm_castps_si128(_mm_cmpgt_ps(vi4, vmax));
+ vmax = _mm_max_ps(vi4, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm4, vidx), _mm_and_si128(vm4, _mm_set1_epi32(4)));
+
+ const __m128i vm5 = _mm_castps_si128(_mm_cmpgt_ps(vi5, vmax));
+ vmax = _mm_max_ps(vi5, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm5, vidx), _mm_and_si128(vm5, _mm_set1_epi32(5)));
+
+ const __m128i vm6 = _mm_castps_si128(_mm_cmpgt_ps(vi6, vmax));
+ vmax = _mm_max_ps(vi6, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm6, vidx), _mm_and_si128(vm6, _mm_set1_epi32(6)));
+
+ const __m128i vm7 = _mm_castps_si128(_mm_cmpgt_ps(vi7, vmax));
+ vmax = _mm_max_ps(vi7, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm7, vidx), _mm_and_si128(vm7, _mm_set1_epi32(7)));
+
+ const __m128i vm8 = _mm_castps_si128(_mm_cmpgt_ps(vi8, vmax));
+ vmax = _mm_max_ps(vi8, vmax);
+ vidx = _mm_or_si128(_mm_andnot_si128(vm8, vidx), _mm_and_si128(vm8, _mm_set1_epi32(8)));
+
+ __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ _mm_store_sd((double*) o, _mm_castps_pd(vout));
+ _mm_storel_epi64((__m128i*) i, vidx);
+ vout = _mm_movehl_ps(vout, vout);
+ vidx = _mm_unpackhi_epi64(vidx, vidx);
+ o += 2;
+ i += 2;
+ }
+ if (k & 1) {
+ _mm_store_ss(o, vout);
+ *i = (uint32_t) _mm_cvtsi128_si32(vidx);
+ o += 1;
+ i += 1;
+ }
+ }
+ input = (const float**) ((uintptr_t) input + input_increment);
+ output = (float*) ((uintptr_t) o + output_increment);
+ index = (uint32_t*) i;
+ } while (--n != 0);
+}
diff --git a/src/f32-avgpool/mp9p8q-neon.c b/src/f32-avgpool/mp9p8q-neon.c
new file mode 100644
index 0000000..78d9456
--- /dev/null
+++ b/src/f32-avgpool/mp9p8q-neon.c
@@ -0,0 +1,206 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_f32_avgpool_ukernel_mp9p8q__neon(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ float* buffer,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const float32x4_t vmultiplier = vld1q_dup_f32(¶ms->scalar.multiplier);
+ const float32x4_t voutput_min = vld1q_dup_f32(¶ms->scalar.output_min);
+ const float32x4_t voutput_max = vld1q_dup_f32(¶ms->scalar.output_max);
+
+ do {
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vi8 = vld1q_f32(i8); i8 += 4;
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum018 = vaddq_f32(vsum01, vi8);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum01678 = vaddq_f32(vsum018, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum01678);
+
+ vst1q_f32(b, vsum); b += 4;
+ }
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vacc = vld1q_f32(b);
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum01a = vaddq_f32(vsum01, vacc);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum0167a = vaddq_f32(vsum01a, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum0167a);
+
+ vst1q_f32(b, vsum); b += 4;
+ }
+ }
+
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ size_t k = kc;
+ float* b = buffer;
+ while (k >= 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vacc = vld1q_f32(b); b += 4;
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum01a = vaddq_f32(vsum01, vacc);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum0167a = vaddq_f32(vsum01a, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum0167a);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ vst1q_f32(output, vout); output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const float32x4_t vi0 = vld1q_f32(i0);
+ const float32x4_t vi1 = vld1q_f32(i1);
+ const float32x4_t vi2 = vld1q_f32(i2);
+ const float32x4_t vi3 = vld1q_f32(i3);
+ const float32x4_t vi4 = vld1q_f32(i4);
+ const float32x4_t vi5 = vld1q_f32(i5);
+ const float32x4_t vi6 = vld1q_f32(i6);
+ const float32x4_t vi7 = vld1q_f32(i7);
+ const float32x4_t vacc = vld1q_f32(b);
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum01a = vaddq_f32(vsum01, vacc);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum0167a = vaddq_f32(vsum01a, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum0167a);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ float32x2_t vout_lo = vget_low_f32(vout);
+ if (k & 2) {
+ vst1_f32(output, vout_lo); output += 2;
+ vout_lo = vget_high_f32(vout);
+ }
+ if (k & 1) {
+ vst1_lane_f32(output, vout_lo, 0); output += 1;
+ }
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-avgpool/mp9p8q-psimd.c b/src/f32-avgpool/mp9p8q-psimd.c
new file mode 100644
index 0000000..96e4c54
--- /dev/null
+++ b/src/f32-avgpool/mp9p8q-psimd.c
@@ -0,0 +1,236 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_f32_avgpool_ukernel_mp9p8q__psimd(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ float* buffer,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const psimd_f32 vmultiplier = psimd_load_splat_f32(¶ms->scalar.multiplier);
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.output_min);
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.output_max);
+
+ do {
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ i8 += 4;
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum018 = psimd_add_f32(vsum01, vi8);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum01678 = psimd_add_f32(vsum018, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum01678);
+
+ psimd_store_f32(b, vsum);
+ b += 4;
+ }
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vacc = psimd_load_f32(b);
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum01a = psimd_add_f32(vsum01, vacc);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum0167a = psimd_add_f32(vsum01a, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum0167a);
+
+ psimd_store_f32(b, vsum);
+ b += 4;
+ }
+ }
+
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ size_t k = kc;
+ float* b = buffer;
+ while (k >= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vacc = psimd_load_f32(b);
+ b += 4;
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum01a = psimd_add_f32(vsum01, vacc);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum0167a = psimd_add_f32(vsum01a, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum0167a);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ psimd_store_f32(output, vout);
+ output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vacc = psimd_load_f32(b);
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum01a = psimd_add_f32(vsum01, vacc);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum0167a = psimd_add_f32(vsum01a, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum0167a);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ if (k & 2) {
+ psimd_store2_f32(output, vout);
+ output += 2;
+ vout = psimd_concat_hi_f32(vout, vout);
+ }
+ if (k & 1) {
+ psimd_store1_f32(output, vout);
+ output += 1;
+ }
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-avgpool/mp9p8q-scalar.c b/src/f32-avgpool/mp9p8q-scalar.c
new file mode 100644
index 0000000..c1c5af8
--- /dev/null
+++ b/src/f32-avgpool/mp9p8q-scalar.c
@@ -0,0 +1,172 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/avgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_avgpool_ukernel_mp9p8q__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ float* buffer,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const float vmultiplier = params->scalar.multiplier;
+ const float voutput_min = params->scalar.output_min;
+ const float voutput_max = params->scalar.output_max;
+
+ do {
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ float* b = buffer;
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vi8 = *i8++;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum67 = vi6 + vi7;
+ const float vsum018 = vsum01 + vi8;
+ const float vsum2345 = vsum23 + vsum45;
+ const float vsum01678 = vsum018 + vsum67;
+ const float vsum = vsum2345 + vsum01678;
+
+ *b++ = vsum;
+ } while (--k != 0);
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* b = buffer;
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vacc = *b;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum67 = vi6 + vi7;
+ const float vsum01a = vsum01 + vacc;
+ const float vsum2345 = vsum23 + vsum45;
+ const float vsum0167a = vsum01a + vsum67;
+ const float vsum = vsum2345 + vsum0167a;
+
+ *b++ = vsum;
+ } while (--k != 0);
+ }
+
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ size_t k = kc;
+ float* b = buffer;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vacc = *b++;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum67 = vi6 + vi7;
+ const float vsum01a = vsum01 + vacc;
+ const float vsum2345 = vsum23 + vsum45;
+ const float vsum0167a = vsum01a + vsum67;
+ const float vsum = vsum2345 + vsum0167a;
+
+ float vout = vsum * vmultiplier;
+ vout = math_max_f32(vout, voutput_min);
+ vout = math_min_f32(vout, voutput_max);
+
+ *output++ = vout;
+ } while (--k != 0);
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-avgpool/mp9p8q-sse.c b/src/f32-avgpool/mp9p8q-sse.c
new file mode 100644
index 0000000..7efa78d
--- /dev/null
+++ b/src/f32-avgpool/mp9p8q-sse.c
@@ -0,0 +1,234 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_f32_avgpool_ukernel_mp9p8q__sse(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ float* buffer,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const __m128 vmultiplier = _mm_load_ps(params->sse2.multiplier);
+ const __m128 voutput_min = _mm_load_ps(params->sse2.output_min);
+ const __m128 voutput_max = _mm_load_ps(params->sse2.output_max);
+
+ do {
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ i8 += 4;
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum018 = _mm_add_ps(vsum01, vi8);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum01678 = _mm_add_ps(vsum018, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum01678);
+
+ _mm_store_ps(b, vsum); b += 4;
+ }
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vacc = _mm_load_ps(b);
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum01a = _mm_add_ps(vsum01, vacc);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum0167a = _mm_add_ps(vsum01a, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum0167a);
+
+ _mm_store_ps(b, vsum); b += 4;
+ }
+ }
+
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ size_t k = kc;
+ float* b = buffer;
+ while (k >= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vacc = _mm_load_ps(b);
+ b += 4;
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum01a = _mm_add_ps(vsum01, vacc);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum0167a = _mm_add_ps(vsum01a, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum0167a);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ _mm_storeu_ps(output, vout);
+ output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vacc = _mm_load_ps(b);
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum01a = _mm_add_ps(vsum01, vacc);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum0167a = _mm_add_ps(vsum01a, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum0167a);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ if (k & 2) {
+ _mm_storel_pi((__m64*) output, vout);
+ vout = _mm_movehl_ps(vout, vout);
+ output += 2;
+ }
+ if (k & 1) {
+ _mm_store_ss(output, vout);
+ output += 1;
+ }
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-avgpool/up9-neon.c b/src/f32-avgpool/up9-neon.c
new file mode 100644
index 0000000..5ae8f80
--- /dev/null
+++ b/src/f32-avgpool/up9-neon.c
@@ -0,0 +1,135 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_f32_avgpool_ukernel_up9__neon(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const float32x4_t vmultiplier = vld1q_dup_f32(¶ms->scalar.multiplier);
+ const float32x4_t voutput_min = vld1q_dup_f32(¶ms->scalar.output_min);
+ const float32x4_t voutput_max = vld1q_dup_f32(¶ms->scalar.output_max);
+
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ size_t k = kc;
+ while (k >= 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vi8 = vld1q_f32(i8); i8 += 4;
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum018 = vaddq_f32(vsum01, vi8);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum01678 = vaddq_f32(vsum018, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum01678);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ vst1q_f32(output, vout); output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const float32x4_t vi0 = vld1q_f32(i0);
+ const float32x4_t vi1 = vld1q_f32(i1);
+ const float32x4_t vi2 = vld1q_f32(i2);
+ const float32x4_t vi3 = vld1q_f32(i3);
+ const float32x4_t vi4 = vld1q_f32(i4);
+ const float32x4_t vi5 = vld1q_f32(i5);
+ const float32x4_t vi6 = vld1q_f32(i6);
+ const float32x4_t vi7 = vld1q_f32(i7);
+ const float32x4_t vi8 = vld1q_f32(i8);
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum018 = vaddq_f32(vsum01, vi8);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum01678 = vaddq_f32(vsum018, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum01678);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ float32x2_t vout_lo = vget_low_f32(vout);
+ if (k & 2) {
+ vst1_f32(output, vout_lo); output += 2;
+ vout_lo = vget_high_f32(vout);
+ }
+ if (k & 1) {
+ vst1_lane_f32(output, vout_lo, 0); output += 1;
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-avgpool/up9-psimd.c b/src/f32-avgpool/up9-psimd.c
new file mode 100644
index 0000000..0579158
--- /dev/null
+++ b/src/f32-avgpool/up9-psimd.c
@@ -0,0 +1,146 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_f32_avgpool_ukernel_up9__psimd(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const psimd_f32 vmultiplier = psimd_load_splat_f32(¶ms->scalar.multiplier);
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.output_min);
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.output_max);
+
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ size_t k = kc;
+ while (k >= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ i8 += 4;
+
+ const psimd_f32 vsum018 = psimd_add_f32(psimd_add_f32(vi0, vi1), vi8);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum01678 = psimd_add_f32(vsum018, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum01678);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ psimd_store_f32(output, vout);
+ output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum018 = psimd_add_f32(vsum01, vi8);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum01678 = psimd_add_f32(vsum018, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum01678);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ if (k & 2) {
+ psimd_store2_f32(output, vout);
+ output += 2;
+ vout = psimd_concat_hi_f32(vout, vout);
+ }
+ if (k & 1) {
+ psimd_store1_f32(output, vout);
+ output += 1;
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-avgpool/up9-scalar.c b/src/f32-avgpool/up9-scalar.c
new file mode 100644
index 0000000..0b4b253
--- /dev/null
+++ b/src/f32-avgpool/up9-scalar.c
@@ -0,0 +1,99 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/avgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_avgpool_ukernel_up9__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const float vmultiplier = params->scalar.multiplier;
+ const float voutput_min = params->scalar.output_min;
+ const float voutput_max = params->scalar.output_max;
+
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vi8 = *i8++;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum67 = vi6 + vi7;
+ const float vsum018 = vsum01 + vi8;
+ const float vsum2345 = vsum23 + vsum45;
+ const float vsum01678 = vsum018 + vsum67;
+ const float vsum = vsum2345 + vsum01678;
+
+ float vout = vsum * vmultiplier;
+ vout = math_max_f32(vout, voutput_min);
+ vout = math_min_f32(vout, voutput_max);
+
+ *output++ = vout;
+ } while (--k != 0);
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-avgpool/up9-sse.c b/src/f32-avgpool/up9-sse.c
new file mode 100644
index 0000000..de3685a
--- /dev/null
+++ b/src/f32-avgpool/up9-sse.c
@@ -0,0 +1,145 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_f32_avgpool_ukernel_up9__sse(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const __m128 vmultiplier = _mm_load_ps(params->sse2.multiplier);
+ const __m128 voutput_min = _mm_load_ps(params->sse2.output_min);
+ const __m128 voutput_max = _mm_load_ps(params->sse2.output_max);
+
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ size_t k = kc;
+ while (k >= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ i8 += 4;
+
+ const __m128 vsum018 = _mm_add_ps(_mm_add_ps(vi0, vi1), vi8);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum01678 = _mm_add_ps(vsum018, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum01678);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ _mm_storeu_ps(output, vout); output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vi8 = _mm_loadu_ps(i8);
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum018 = _mm_add_ps(vsum01, vi8);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum01678 = _mm_add_ps(vsum018, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum01678);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ if (k & 2) {
+ _mm_storel_pi((__m64*) output, vout);
+ vout = _mm_movehl_ps(vout, vout);
+ output += 2;
+ }
+ if (k & 1) {
+ _mm_store_ss(output, vout);
+ output += 1;
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-clamp/neon.c b/src/f32-clamp/neon.c
new file mode 100644
index 0000000..e131e3e
--- /dev/null
+++ b/src/f32-clamp/neon.c
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/clamp.h>
+
+
+void xnn_f32_clamp_ukernel__neon(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float32x4x2_t voutput_clamp = vld2q_dup_f32(¶ms->scalar.max);
+
+ for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
+ const float32x4_t vx = vld1q_f32(x); x += 4;
+
+ const float32x4_t vy = vminq_f32(vmaxq_f32(vx, voutput_clamp.val[1]), voutput_clamp.val[0]);
+
+ vst1q_f32(y, vy); y += 4;
+ }
+ if (n != 0) {
+ const float32x4_t vx = vld1q_f32(x);
+
+ const float32x4_t vy = vminq_f32(vmaxq_f32(vx, voutput_clamp.val[1]), voutput_clamp.val[0]);
+
+ float32x2_t vy_lo = vget_low_f32(vy);
+ if (n & 2 * sizeof(float)) {
+ vst1_f32(y, vy_lo); y += 2;
+ vy_lo = vget_high_f32(vy);
+ }
+ if (n & 1 * sizeof(float)) {
+ vst1_lane_f32(y, vy_lo, 0);
+ }
+ }
+}
diff --git a/src/f32-clamp/psimd.c b/src/f32-clamp/psimd.c
new file mode 100644
index 0000000..53c253c
--- /dev/null
+++ b/src/f32-clamp/psimd.c
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/clamp.h>
+
+
+void xnn_f32_clamp_ukernel__psimd(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.max);
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.min);
+
+ for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
+ const psimd_f32 vx = psimd_load_f32(x);
+ x += 4;
+
+ const psimd_f32 vy = psimd_min_f32(psimd_max_f32(vx, voutput_min), voutput_max);
+
+ psimd_store_f32(y, vy);
+ y += 4;
+ }
+ if (n != 0) {
+ const psimd_f32 vx = psimd_load_f32(x);
+
+ psimd_f32 vy = psimd_min_f32(psimd_max_f32(vx, voutput_min), voutput_max);
+
+ if (n & 2 * sizeof(float)) {
+ psimd_store2_f32(y, vy);
+ vy = psimd_concat_hi_f32(vy, vy);
+ y += 2;
+ }
+ if (n & 1 * sizeof(float)) {
+ psimd_store1_f32(y, vy);
+ }
+ }
+}
diff --git a/src/f32-clamp/scalar.c b/src/f32-clamp/scalar.c
new file mode 100644
index 0000000..4fd6ae2
--- /dev/null
+++ b/src/f32-clamp/scalar.c
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/clamp.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_clamp_ukernel__scalar(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float vy_max = params->scalar.max;
+ const float vy_min = params->scalar.min;
+
+ for (; n >= 2 * sizeof(float); n -= 2 * sizeof(float)) {
+ const float vx0 = x[0];
+ const float vx1 = x[1];
+ x += 2;
+
+ float vy0 = math_max_f32(vx0, vy_min);
+ float vy1 = math_max_f32(vx1, vy_min);
+ vy0 = math_min_f32(vy0, vy_max);
+ vy1 = math_min_f32(vy1, vy_max);
+
+ y[0] = vy0;
+ y[1] = vy1;
+ y += 2;
+ }
+ if (n != 0) {
+ const float vx = *x;
+ float vy = math_max_f32(vx, vy_min);
+ vy = math_min_f32(vy, vy_max);
+ *y = vy;
+ }
+}
diff --git a/src/f32-clamp/sse.c b/src/f32-clamp/sse.c
new file mode 100644
index 0000000..21e2976
--- /dev/null
+++ b/src/f32-clamp/sse.c
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/clamp.h>
+
+
+void xnn_f32_clamp_ukernel__sse(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const __m128 voutput_max = _mm_load_ps(params->sse.max);
+ const __m128 voutput_min = _mm_load_ps(params->sse.min);
+
+ for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
+ const __m128 vx = _mm_loadu_ps(x);
+ x += 4;
+
+ const __m128 vy = _mm_min_ps(_mm_max_ps(vx, voutput_min), voutput_max);
+
+ _mm_storeu_ps(y, vy);
+ y += 4;
+ }
+ if (n != 0) {
+ const __m128 vx = _mm_loadu_ps(x);
+
+ __m128 vy = _mm_min_ps(_mm_max_ps(vx, voutput_min), voutput_max);
+
+ if (n & 2 * sizeof(float)) {
+ _mm_storel_pi((__m64*) y, vy);
+ vy = _mm_movehl_ps(vy, vy);
+ y += 2;
+ }
+ if (n & 1 * sizeof(float)) {
+ _mm_store_ss(y, vy);
+ }
+ }
+}
diff --git a/src/f32-conv-hwc/3x3s2p1c3x4-neonfma-2x2.c b/src/f32-conv-hwc/3x3s2p1c3x4-neonfma-2x2.c
new file mode 100644
index 0000000..298beec
--- /dev/null
+++ b/src/f32-conv-hwc/3x3s2p1c3x4-neonfma-2x2.c
@@ -0,0 +1,668 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/conv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_conv_hwc_ukernel_3x3s2p1c3x4__neonfma_2x2(
+ size_t input_height,
+ size_t input_width,
+ size_t output_y_start,
+ size_t output_y_end,
+ const float* input,
+ const float* zero,
+ const float* weights,
+ float* output,
+ size_t input_padding_top,
+ size_t output_channels,
+ size_t output_height_stride,
+ size_t output_width_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(input_width != 0);
+ assert(output_y_end > output_y_start);
+ assert(input_padding_top <= 1);
+ assert(output_channels != 0);
+
+ const size_t input_height_stride = input_width * 3 /* channels */ * sizeof(float);
+ const size_t input_width_increment = round_down_po2(input_width, 4) * 3 /* channels */ * sizeof(float);
+ const size_t output_width = (input_width + 1) / 2;
+ const size_t output_channel_increment = 4 * sizeof(float) - output_width * output_width_stride;
+
+ // Adjustment for padding processed below
+ const float* i0 = (const float*) ((uintptr_t) input + input_height_stride * (output_y_start * 2 - input_padding_top));
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
+ float* output0 = (float*) ((uintptr_t) output + output_height_stride * output_y_start);
+ float* output1 = (float*) ((uintptr_t) output0 + output_height_stride);
+
+ if XNN_UNPREDICTABLE(output_y_start < input_padding_top) {
+ i0 = zero;
+ }
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+
+ for (size_t output_y = output_y_start; output_y < output_y_end; output_y += 2) {
+ const size_t input_y2 = output_y * 2 + 2 - input_padding_top;
+ const size_t input_y4 = input_y2 + 2;
+ if XNN_UNPREDICTABLE(input_y2 >= input_height) {
+ i2 = zero;
+ }
+ if XNN_UNPREDICTABLE(input_y4 > input_height) {
+ i3 = zero;
+ }
+ if XNN_UNPREDICTABLE(input_y4 >= input_height) {
+ i4 = zero;
+ }
+ if XNN_UNPREDICTABLE(output_y + 2 > output_y_end) {
+ output1 = output0;
+ }
+
+ const float* w = weights;
+ size_t c = output_channels;
+ float* o0 = output0;
+ float* o1 = output1;
+ do {
+ // viMx0 = ( iM0c2, iM0c1, iM0c0, --- )
+ float32x4_t vi0x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x0 = vmovq_n_f32(0.0f);
+
+ size_t iw = input_width;
+ for (; iw >= 4; iw -= 4) {
+ float32x4_t vo0x0 = vld1q_f32(w);
+ float32x4_t vo1x0 = vo0x0;
+ float32x4_t vo0x1 = vo0x0;
+ float32x4_t vo1x1 = vo0x0;
+
+ const float32x4_t vk00c0 = vld1q_f32(w + 4);
+
+ // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
+ const float32x4_t vi0x1 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1x1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2x1 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3x1 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4x1 = vld1q_f32(i4); i4 += 4;
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c0, vi0x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c0, vi2x0, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c0, vi0x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c0, vi2x1, 3);
+
+ const float32x4_t vk10c0 = vld1q_f32(w + 8);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c0, vi1x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c0, vi3x0, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c0, vi1x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c0, vi3x1, 3);
+
+ const float32x4_t vk20c0 = vld1q_f32(w + 12);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c0, vi2x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c0, vi4x0, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c0, vi2x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c0, vi4x1, 3);
+
+ const float32x4_t vk00c1 = vld1q_f32(w + 16);
+
+ // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
+ const float32x4_t vi0x2 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1x2 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2x2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3x2 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4x2 = vld1q_f32(i4); i4 += 4;
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c1, vi0x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c1, vi2x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c1, vi0x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c1, vi2x2, 0);
+
+ const float32x4_t vk10c1 = vld1q_f32(w + 20);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c1, vi1x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c1, vi3x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c1, vi1x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c1, vi3x2, 0);
+
+ const float32x4_t vk20c1 = vld1q_f32(w + 24);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c1, vi2x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c1, vi4x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c1, vi2x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c1, vi4x2, 0);
+
+ const float32x4_t vk00c2 = vld1q_f32(w + 28);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c2, vi0x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c2, vi2x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c2, vi0x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c2, vi2x2, 1);
+
+ const float32x4_t vk10c2 = vld1q_f32(w + 32);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c2, vi1x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c2, vi3x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c2, vi1x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c2, vi3x2, 1);
+
+ const float32x4_t vk20c2 = vld1q_f32(w + 36);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c2, vi2x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c2, vi4x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c2, vi2x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c2, vi4x2, 1);
+
+ const float32x4_t vk01c0 = vld1q_f32(w + 40);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c0, vi0x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c0, vi2x1, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c0, vi0x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c0, vi2x2, 2);
+
+ const float32x4_t vk11c0 = vld1q_f32(w + 44);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c0, vi1x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c0, vi3x1, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c0, vi1x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c0, vi3x2, 2);
+
+ const float32x4_t vk21c0 = vld1q_f32(w + 48);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c0, vi2x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c0, vi4x1, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c0, vi2x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c0, vi4x2, 2);
+
+ const float32x4_t vk01c1 = vld1q_f32(w + 52);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c1, vi0x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c1, vi2x1, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c1, vi0x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c1, vi2x2, 3);
+
+ const float32x4_t vk11c1 = vld1q_f32(w + 56);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c1, vi1x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c1, vi3x1, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c1, vi1x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c1, vi3x2, 3);
+
+ const float32x4_t vk21c1 = vld1q_f32(w + 60);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c1, vi2x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c1, vi4x1, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c1, vi2x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c1, vi4x2, 3);
+
+ const float32x4_t vk01c2 = vld1q_f32(w + 64);
+
+ // viMx3 = ( iM4c2, iM4c1, iM4c0, iM3c2 )
+ const float32x4_t vi0x3 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1x3 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2x3 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3x3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4x3 = vld1q_f32(i4); i4 += 4;
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c2, vi0x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c2, vi2x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c2, vi0x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c2, vi2x3, 0);
+
+ const float32x4_t vk11c2 = vld1q_f32(w + 68);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c2, vi1x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c2, vi3x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c2, vi1x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c2, vi3x3, 0);
+
+ const float32x4_t vk21c2 = vld1q_f32(w + 72);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c2, vi2x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c2, vi4x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c2, vi2x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c2, vi4x3, 0);
+
+ const float32x4_t vk02c0 = vld1q_f32(w + 76);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c0, vi0x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c0, vi2x1, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c0, vi0x3, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c0, vi2x3, 1);
+
+ const float32x4_t vk12c0 = vld1q_f32(w + 80);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c0, vi1x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c0, vi3x1, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c0, vi1x3, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c0, vi3x3, 1);
+
+ const float32x4_t vk22c0 = vld1q_f32(w + 84);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c0, vi2x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c0, vi4x1, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c0, vi2x3, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c0, vi4x3, 1);
+
+ const float32x4_t vk02c1 = vld1q_f32(w + 88);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c1, vi0x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c1, vi2x2, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c1, vi0x3, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c1, vi2x3, 2);
+
+ const float32x4_t vk12c1 = vld1q_f32(w + 92);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c1, vi1x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c1, vi3x2, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c1, vi1x3, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c1, vi3x3, 2);
+
+ const float32x4_t vk22c1 = vld1q_f32(w + 96);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c1, vi2x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c1, vi4x2, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c1, vi2x3, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c1, vi4x3, 2);
+
+ const float32x4_t vk02c2 = vld1q_f32(w + 100);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c2, vi0x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c2, vi2x2, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c2, vi0x3, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c2, vi2x3, 3);
+
+ const float32x4_t vk12c2 = vld1q_f32(w + 104);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c2, vi1x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c2, vi3x2, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c2, vi1x3, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c2, vi3x3, 3);
+
+ const float32x4_t vk22c2 = vld1q_f32(w + 108);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c2, vi2x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c2, vi4x2, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c2, vi2x3, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c2, vi4x3, 3);
+
+ vi0x0 = vi0x3;
+ vi1x0 = vi1x3;
+ vi2x0 = vi2x3;
+ vi3x0 = vi3x3;
+ vi4x0 = vi4x3;
+
+ vo0x0 = vmaxq_f32(vo0x0, vmin);
+ vo1x0 = vmaxq_f32(vo1x0, vmin);
+ vo0x1 = vmaxq_f32(vo0x1, vmin);
+ vo1x1 = vmaxq_f32(vo1x1, vmin);
+
+ vo0x0 = vminq_f32(vo0x0, vmax);
+ vo1x0 = vminq_f32(vo1x0, vmax);
+ vo0x1 = vminq_f32(vo0x1, vmax);
+ vo1x1 = vminq_f32(vo1x1, vmax);
+
+ if XNN_LIKELY(c >= 4) {
+ vst1q_f32(o1, vo1x0);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride);
+ vst1q_f32(o0, vo0x0);
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride);
+
+ vst1q_f32(o1, vo1x1);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride);
+ vst1q_f32(o0, vo0x1);
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride);
+ } else {
+ float32x2_t vo0x0_lo = vget_low_f32(vo0x0);
+ float32x2_t vo1x0_lo = vget_low_f32(vo1x0);
+ float32x2_t vo0x1_lo = vget_low_f32(vo0x1);
+ float32x2_t vo1x1_lo = vget_low_f32(vo1x1);
+ float* o0_tmp = o0;
+ float* o1_tmp = o1;
+ if (c & 2) {
+ vst1_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1_lo);
+ vo1x1_lo = vget_high_f32(vo1x1);
+ vst1_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1_lo);
+ vo0x1_lo = vget_high_f32(vo0x1);
+
+ vst1_f32(o1_tmp, vo1x0_lo); o1_tmp += 2;
+ vo1x0_lo = vget_high_f32(vo1x0);
+ vst1_f32(o0_tmp, vo0x0_lo); o0_tmp += 2;
+ vo0x0_lo = vget_high_f32(vo0x0);
+ }
+ if (c & 1) {
+ vst1_lane_f32(o1_tmp, vo1x0_lo, 0);
+ vst1_lane_f32(o0_tmp, vo0x0_lo, 0);
+
+ vst1_lane_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1_lo, 0);
+ vst1_lane_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1_lo, 0);
+ }
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride * 2);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride * 2);
+ }
+ }
+ assert(iw < 4);
+ if XNN_UNLIKELY(iw != 0) {
+ float32x4_t vo0x0 = vld1q_f32(w);
+ float32x4_t vo1x0 = vo0x0;
+ float32x4_t vo0x1 = vo0x0;
+ float32x4_t vo1x1 = vo0x0;
+
+ const float32x4_t vk00c0 = vld1q_f32(w + 4);
+
+ // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
+ float32x4_t vi0x1 = vld1q_f32(i0);
+ float32x4_t vi1x1 = vld1q_f32(i1);
+ float32x4_t vi2x1 = vld1q_f32(i2);
+ float32x4_t vi3x1 = vld1q_f32(i3);
+ float32x4_t vi4x1 = vld1q_f32(i4);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c0, vi0x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c0, vi2x0, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c0, vi0x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c0, vi2x1, 3);
+ }
+
+ const float32x4_t vk10c0 = vld1q_f32(w + 8);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c0, vi1x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c0, vi3x0, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c0, vi1x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c0, vi3x1, 3);
+ }
+
+ const float32x4_t vk20c0 = vld1q_f32(w + 12);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c0, vi2x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c0, vi4x0, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c0, vi2x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c0, vi4x1, 3);
+ }
+
+ const float32x4_t vk00c1 = vld1q_f32(w + 16);
+
+ float32x4_t vi0x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x2 = vmovq_n_f32(0.0f);
+ if (iw >= 2) {
+ // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
+ vi0x2 = vld1q_f32(i0 + 4);
+ vi1x2 = vld1q_f32(i1 + 4);
+ vi2x2 = vld1q_f32(i2 + 4);
+ vi3x2 = vld1q_f32(i3 + 4);
+ vi4x2 = vld1q_f32(i4 + 4);
+ }
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c1, vi0x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c1, vi2x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c1, vi0x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c1, vi2x2, 0);
+
+ const float32x4_t vk10c1 = vld1q_f32(w + 20);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c1, vi1x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c1, vi3x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c1, vi1x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c1, vi3x2, 0);
+
+ const float32x4_t vk20c1 = vld1q_f32(w + 24);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c1, vi2x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c1, vi4x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c1, vi2x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c1, vi4x2, 0);
+
+ const float32x4_t vk00c2 = vld1q_f32(w + 28);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c2, vi0x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c2, vi2x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c2, vi0x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c2, vi2x2, 1);
+
+ const float32x4_t vk10c2 = vld1q_f32(w + 32);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c2, vi1x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c2, vi3x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c2, vi1x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c2, vi3x2, 1);
+
+ const float32x4_t vk20c2 = vld1q_f32(w + 36);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c2, vi2x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c2, vi4x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c2, vi2x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c2, vi4x2, 1);
+
+ const float32x4_t vk01c0 = vld1q_f32(w + 40);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c0, vi0x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c0, vi2x1, 0);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c0, vi0x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c0, vi2x2, 2);
+ }
+
+ const float32x4_t vk11c0 = vld1q_f32(w + 44);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c0, vi1x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c0, vi3x1, 0);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c0, vi1x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c0, vi3x2, 2);
+ }
+
+ const float32x4_t vk21c0 = vld1q_f32(w + 48);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c0, vi2x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c0, vi4x1, 0);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c0, vi2x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c0, vi4x2, 2);
+ }
+
+ const float32x4_t vk01c1 = vld1q_f32(w + 52);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c1, vi0x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c1, vi2x1, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c1, vi0x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c1, vi2x2, 3);
+ }
+
+ const float32x4_t vk11c1 = vld1q_f32(w + 56);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c1, vi1x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c1, vi3x1, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c1, vi1x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c1, vi3x2, 3);
+ }
+
+ const float32x4_t vk21c1 = vld1q_f32(w + 60);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c1, vi2x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c1, vi4x1, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c1, vi2x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c1, vi4x2, 3);
+ }
+
+ const float32x4_t vk01c2 = vld1q_f32(w + 64);
+
+ float32x4_t vi0x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x3 = vmovq_n_f32(0.0f);
+ if (iw > 2) {
+ // viMx3 = ( 0.0, 0.0, 0.0, iM3c2 )
+ vi0x3 = vld1q_lane_f32(i0 + 8, vi0x3, 0);
+ vi1x3 = vld1q_lane_f32(i1 + 8, vi1x3, 0);
+ vi2x3 = vld1q_lane_f32(i2 + 8, vi2x3, 0);
+ vi3x3 = vld1q_lane_f32(i3 + 8, vi3x3, 0);
+ vi4x3 = vld1q_lane_f32(i4 + 8, vi4x3, 0);
+ }
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c2, vi0x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c2, vi2x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c2, vi0x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c2, vi2x3, 0);
+
+ const float32x4_t vk11c2 = vld1q_f32(w + 68);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c2, vi1x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c2, vi3x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c2, vi1x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c2, vi3x3, 0);
+
+ const float32x4_t vk21c2 = vld1q_f32(w + 72);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c2, vi2x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c2, vi4x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c2, vi2x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c2, vi4x3, 0);
+
+ if (iw >= 2) {
+ const float32x4_t vk02c0 = vld1q_f32(w + 76);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c0, vi0x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c0, vi2x1, 3);
+
+ const float32x4_t vk12c0 = vld1q_f32(w + 80);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c0, vi1x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c0, vi3x1, 3);
+
+ const float32x4_t vk22c0 = vld1q_f32(w + 84);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c0, vi2x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c0, vi4x1, 3);
+
+ const float32x4_t vk02c1 = vld1q_f32(w + 88);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c1, vi0x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c1, vi2x2, 0);
+
+ const float32x4_t vk12c1 = vld1q_f32(w + 92);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c1, vi1x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c1, vi3x2, 0);
+
+ const float32x4_t vk22c1 = vld1q_f32(w + 96);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c1, vi2x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c1, vi4x2, 0);
+
+ const float32x4_t vk02c2 = vld1q_f32(w + 100);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c2, vi0x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c2, vi2x2, 1);
+
+ const float32x4_t vk12c2 = vld1q_f32(w + 104);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c2, vi1x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c2, vi3x2, 1);
+
+ const float32x4_t vk22c2 = vld1q_f32(w + 108);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c2, vi2x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c2, vi4x2, 1);
+ }
+
+ vo0x0 = vmaxq_f32(vo0x0, vmin);
+ vo1x0 = vmaxq_f32(vo1x0, vmin);
+ vo0x1 = vmaxq_f32(vo0x1, vmin);
+ vo1x1 = vmaxq_f32(vo1x1, vmin);
+
+ vo0x0 = vminq_f32(vo0x0, vmax);
+ vo1x0 = vminq_f32(vo1x0, vmax);
+ vo0x1 = vminq_f32(vo0x1, vmax);
+ vo1x1 = vminq_f32(vo1x1, vmax);
+
+ iw += 1;
+ if XNN_LIKELY(c >= 4) {
+ vst1q_f32(o1, vo1x0);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride);
+ vst1q_f32(o0, vo0x0);
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride);
+
+ if (iw & 4) {
+ vst1q_f32(o1, vo1x1);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride);
+ vst1q_f32(o0, vo0x1);
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride);
+ }
+ } else {
+ float* o0_tmp = o0;
+ float* o1_tmp = o1;
+ float32x2_t vo0x0_lo = vget_low_f32(vo0x0);
+ float32x2_t vo1x0_lo = vget_low_f32(vo1x0);
+ float32x2_t vo0x1_lo = vget_low_f32(vo0x1);
+ float32x2_t vo1x1_lo = vget_low_f32(vo1x1);
+ if (c & 2) {
+ if (iw & 4) {
+ vst1_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1_lo);
+ vo1x1_lo = vget_high_f32(vo1x1);
+ vst1_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1_lo);
+ vo0x1_lo = vget_high_f32(vo0x1);
+ }
+
+ vst1_f32(o1_tmp, vo1x0_lo); o1_tmp += 2;
+ vo1x0_lo = vget_high_f32(vo1x0);
+ vst1_f32(o0_tmp, vo0x0_lo); o0_tmp += 2;
+ vo0x0_lo = vget_high_f32(vo0x0);
+ }
+ if (c & 1) {
+ vst1_lane_f32(o1_tmp, vo1x0_lo, 0);
+ vst1_lane_f32(o0_tmp, vo0x0_lo, 0);
+
+ if (iw & 4) {
+ vst1_lane_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1_lo, 0);
+ vst1_lane_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1_lo, 0);
+ }
+ }
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride * 2);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride * 2);
+ }
+ }
+ // Move output pointers back to the position of the first pixel in a row,
+ // and forward to the next block of output channels
+ o0 = (float*) ((uintptr_t) o0 + output_channel_increment);
+ o1 = (float*) ((uintptr_t) o1 + output_channel_increment);
+ // Revert input pointers to the position of the first pixel in a row
+ i0 = (const float*) ((uintptr_t) i0 - input_width_increment);
+ i1 = (const float*) ((uintptr_t) i1 - input_width_increment);
+ i2 = (const float*) ((uintptr_t) i2 - input_width_increment);
+ i3 = (const float*) ((uintptr_t) i3 - input_width_increment);
+ i4 = (const float*) ((uintptr_t) i4 - input_width_increment);
+ // Move to the block of weights for the next 4 output channels
+ w += 112;
+ c = doz(c, 4);
+ } while (c != 0);
+ // Move output pointers forward to the next two rows
+ output0 = (float*) ((uintptr_t) output1 + output_height_stride);
+ output1 = (float*) ((uintptr_t) output0 + output_height_stride);
+ // Move input pointers forward to the next four rows
+ i0 = i4;
+ i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
+ i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
+ i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
+ i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
+ }
+}
diff --git a/src/f32-conv-hwc/3x3s2p1c3x8-neonfma-2x2.c b/src/f32-conv-hwc/3x3s2p1c3x8-neonfma-2x2.c
new file mode 100644
index 0000000..98fa7d8
--- /dev/null
+++ b/src/f32-conv-hwc/3x3s2p1c3x8-neonfma-2x2.c
@@ -0,0 +1,1025 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/conv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_conv_hwc_ukernel_3x3s2p1c3x8__neonfma_2x2(
+ size_t input_height,
+ size_t input_width,
+ size_t output_y_start,
+ size_t output_y_end,
+ const float* input,
+ const float* zero,
+ const float* weights,
+ float* output,
+ size_t input_padding_top,
+ size_t output_channels,
+ size_t output_height_stride,
+ size_t output_width_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(input_width != 0);
+ assert(output_y_end > output_y_start);
+ assert(input_padding_top <= 1);
+ assert(output_channels != 0);
+
+ const size_t input_height_stride = input_width * 3 /* channels */ * sizeof(float);
+ const size_t input_width_increment = round_down_po2(input_width, 4) * 3 /* channels */ * sizeof(float);
+ const size_t output_width = (input_width + 1) / 2;
+ const size_t output_channel_increment = 8 * sizeof(float) - output_width * output_width_stride;
+
+ // Adjustment for padding processed below
+ const float* i0 = (const float*) ((uintptr_t) input + input_height_stride * (output_y_start * 2 - input_padding_top));
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
+ float* output0 = (float*) ((uintptr_t) output + output_height_stride * output_y_start);
+ float* output1 = (float*) ((uintptr_t) output0 + output_height_stride);
+
+ if XNN_UNPREDICTABLE(output_y_start < input_padding_top) {
+ i0 = zero;
+ }
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+
+ for (size_t output_y = output_y_start; output_y < output_y_end; output_y += 2) {
+ const size_t input_y2 = output_y * 2 + 2 - input_padding_top;
+ const size_t input_y4 = input_y2 + 2;
+ if XNN_UNPREDICTABLE(input_y2 >= input_height) {
+ i2 = zero;
+ }
+ if XNN_UNPREDICTABLE(input_y4 > input_height) {
+ i3 = zero;
+ }
+ if XNN_UNPREDICTABLE(input_y4 >= input_height) {
+ i4 = zero;
+ }
+ if XNN_UNPREDICTABLE(output_y + 2 > output_y_end) {
+ output1 = output0;
+ }
+
+ const float* w = weights;
+ size_t c = output_channels;
+ float* o0 = output0;
+ float* o1 = output1;
+ do {
+ // viMx0 = ( iM0c2, iM0c1, iM0c0, --- )
+ float32x4_t vi0x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x0 = vmovq_n_f32(0.0f);
+
+ size_t iw = input_width;
+ for (; iw >= 4; iw -= 4) {
+ float32x4_t vo0x0c0123 = vld1q_f32(w);
+ float32x4_t vo0x0c4567 = vld1q_f32(w + 4);
+ float32x4_t vo1x0c0123 = vo0x0c0123;
+ float32x4_t vo1x0c4567 = vo0x0c4567;
+ float32x4_t vo0x1c0123 = vo0x0c0123;
+ float32x4_t vo0x1c4567 = vo0x0c4567;
+ float32x4_t vo1x1c0123 = vo0x0c0123;
+ float32x4_t vo1x1c4567 = vo0x0c4567;
+
+ const float32x4_t vk00c0x0123 = vld1q_f32(w + 8);
+ const float32x4_t vk00c0x4567 = vld1q_f32(w + 12);
+
+ // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
+ const float32x4_t vi0x1 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1x1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2x1 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3x1 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4x1 = vld1q_f32(i4); i4 += 4;
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk00c0x0123, vi0x0, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk00c0x0123, vi2x0, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk00c0x4567, vi0x0, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk00c0x4567, vi2x0, 1);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk00c0x0123, vi0x1, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk00c0x0123, vi2x1, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk00c0x4567, vi0x1, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk00c0x4567, vi2x1, 3);
+
+ const float32x4_t vk10c0x0123 = vld1q_f32(w + 16);
+ const float32x4_t vk10c0x4567 = vld1q_f32(w + 20);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk10c0x0123, vi1x0, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk10c0x0123, vi3x0, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk10c0x4567, vi1x0, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk10c0x4567, vi3x0, 1);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk10c0x0123, vi1x1, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk10c0x0123, vi3x1, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk10c0x4567, vi1x1, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk10c0x4567, vi3x1, 3);
+
+ const float32x4_t vk20c0x0123 = vld1q_f32(w + 24);
+ const float32x4_t vk20c0x4567 = vld1q_f32(w + 28);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk20c0x0123, vi2x0, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk20c0x0123, vi4x0, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk20c0x4567, vi2x0, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk20c0x4567, vi4x0, 1);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk20c0x0123, vi2x1, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk20c0x0123, vi4x1, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk20c0x4567, vi2x1, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk20c0x4567, vi4x1, 3);
+
+ const float32x4_t vk00c1x0123 = vld1q_f32(w + 32);
+ const float32x4_t vk00c1x4567 = vld1q_f32(w + 36);
+
+ // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
+ const float32x4_t vi0x2 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1x2 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2x2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3x2 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4x2 = vld1q_f32(i4); i4 += 4;
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk00c1x0123, vi0x0, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk00c1x0123, vi2x0, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk00c1x4567, vi0x0, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk00c1x4567, vi2x0, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk00c1x0123, vi0x2, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk00c1x0123, vi2x2, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk00c1x4567, vi0x2, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk00c1x4567, vi2x2, 0);
+
+ const float32x4_t vk10c1x0123 = vld1q_f32(w + 40);
+ const float32x4_t vk10c1x4567 = vld1q_f32(w + 44);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk10c1x0123, vi1x0, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk10c1x0123, vi3x0, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk10c1x4567, vi1x0, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk10c1x4567, vi3x0, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk10c1x0123, vi1x2, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk10c1x0123, vi3x2, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk10c1x4567, vi1x2, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk10c1x4567, vi3x2, 0);
+
+ const float32x4_t vk20c1x0123 = vld1q_f32(w + 48);
+ const float32x4_t vk20c1x4567 = vld1q_f32(w + 52);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk20c1x0123, vi2x0, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk20c1x0123, vi4x0, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk20c1x4567, vi2x0, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk20c1x4567, vi4x0, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk20c1x0123, vi2x2, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk20c1x0123, vi4x2, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk20c1x4567, vi2x2, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk20c1x4567, vi4x2, 0);
+
+ const float32x4_t vk00c2x0123 = vld1q_f32(w + 56);
+ const float32x4_t vk00c2x4567 = vld1q_f32(w + 60);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk00c2x0123, vi0x0, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk00c2x0123, vi2x0, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk00c2x4567, vi0x0, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk00c2x4567, vi2x0, 3);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk00c2x0123, vi0x2, 1);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk00c2x0123, vi2x2, 1);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk00c2x4567, vi0x2, 1);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk00c2x4567, vi2x2, 1);
+
+ const float32x4_t vk10c2x0123 = vld1q_f32(w + 64);
+ const float32x4_t vk10c2x4567 = vld1q_f32(w + 68);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk10c2x0123, vi1x0, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk10c2x0123, vi3x0, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk10c2x4567, vi1x0, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk10c2x4567, vi3x0, 3);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk10c2x0123, vi1x2, 1);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk10c2x0123, vi3x2, 1);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk10c2x4567, vi1x2, 1);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk10c2x4567, vi3x2, 1);
+
+ const float32x4_t vk20c2x0123 = vld1q_f32(w + 72);
+ const float32x4_t vk20c2x4567 = vld1q_f32(w + 76);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk20c2x0123, vi2x0, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk20c2x0123, vi4x0, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk20c2x4567, vi2x0, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk20c2x4567, vi4x0, 3);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk20c2x0123, vi2x2, 1);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk20c2x0123, vi4x2, 1);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk20c2x4567, vi2x2, 1);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk20c2x4567, vi4x2, 1);
+
+ const float32x4_t vk01c0x0123 = vld1q_f32(w + 80);
+ const float32x4_t vk01c0x4567 = vld1q_f32(w + 84);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk01c0x0123, vi0x1, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk01c0x0123, vi2x1, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk01c0x4567, vi0x1, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk01c0x4567, vi2x1, 0);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk01c0x0123, vi0x2, 2);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk01c0x0123, vi2x2, 2);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk01c0x4567, vi0x2, 2);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk01c0x4567, vi2x2, 2);
+
+ const float32x4_t vk11c0x0123 = vld1q_f32(w + 88);
+ const float32x4_t vk11c0x4567 = vld1q_f32(w + 92);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk11c0x0123, vi1x1, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk11c0x0123, vi3x1, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk11c0x4567, vi1x1, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk11c0x4567, vi3x1, 0);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk11c0x0123, vi1x2, 2);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk11c0x0123, vi3x2, 2);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk11c0x4567, vi1x2, 2);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk11c0x4567, vi3x2, 2);
+
+ const float32x4_t vk21c0x0123 = vld1q_f32(w + 96);
+ const float32x4_t vk21c0x4567 = vld1q_f32(w + 100);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk21c0x0123, vi2x1, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk21c0x0123, vi4x1, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk21c0x4567, vi2x1, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk21c0x4567, vi4x1, 0);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk21c0x0123, vi2x2, 2);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk21c0x0123, vi4x2, 2);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk21c0x4567, vi2x2, 2);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk21c0x4567, vi4x2, 2);
+
+ const float32x4_t vk01c1x0123 = vld1q_f32(w + 104);
+ const float32x4_t vk01c1x4567 = vld1q_f32(w + 108);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk01c1x0123, vi0x1, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk01c1x0123, vi2x1, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk01c1x4567, vi0x1, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk01c1x4567, vi2x1, 1);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk01c1x0123, vi0x2, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk01c1x0123, vi2x2, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk01c1x4567, vi0x2, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk01c1x4567, vi2x2, 3);
+
+ const float32x4_t vk11c1x0123 = vld1q_f32(w + 112);
+ const float32x4_t vk11c1x4567 = vld1q_f32(w + 116);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk11c1x0123, vi1x1, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk11c1x0123, vi3x1, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk11c1x4567, vi1x1, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk11c1x4567, vi3x1, 1);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk11c1x0123, vi1x2, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk11c1x0123, vi3x2, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk11c1x4567, vi1x2, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk11c1x4567, vi3x2, 3);
+
+ const float32x4_t vk21c1x0123 = vld1q_f32(w + 120);
+ const float32x4_t vk21c1x4567 = vld1q_f32(w + 124);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk21c1x0123, vi2x1, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk21c1x0123, vi4x1, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk21c1x4567, vi2x1, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk21c1x4567, vi4x1, 1);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk21c1x0123, vi2x2, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk21c1x0123, vi4x2, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk21c1x4567, vi2x2, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk21c1x4567, vi4x2, 3);
+
+ const float32x4_t vk01c2x0123 = vld1q_f32(w + 128);
+ const float32x4_t vk01c2x4567 = vld1q_f32(w + 132);
+
+ // viMx3 = ( iM4c2, iM4c1, iM4c0, iM3c2 )
+ const float32x4_t vi0x3 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1x3 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2x3 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3x3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4x3 = vld1q_f32(i4); i4 += 4;
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk01c2x0123, vi0x1, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk01c2x0123, vi2x1, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk01c2x4567, vi0x1, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk01c2x4567, vi2x1, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk01c2x0123, vi0x3, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk01c2x0123, vi2x3, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk01c2x4567, vi0x3, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk01c2x4567, vi2x3, 0);
+
+ const float32x4_t vk11c2x0123 = vld1q_f32(w + 136);
+ const float32x4_t vk11c2x4567 = vld1q_f32(w + 140);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk11c2x0123, vi1x1, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk11c2x0123, vi3x1, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk11c2x4567, vi1x1, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk11c2x4567, vi3x1, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk11c2x0123, vi1x3, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk11c2x0123, vi3x3, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk11c2x4567, vi1x3, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk11c2x4567, vi3x3, 0);
+
+ const float32x4_t vk21c2x0123 = vld1q_f32(w + 144);
+ const float32x4_t vk21c2x4567 = vld1q_f32(w + 148);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk21c2x0123, vi2x1, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk21c2x0123, vi4x1, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk21c2x4567, vi2x1, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk21c2x4567, vi4x1, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk21c2x0123, vi2x3, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk21c2x0123, vi4x3, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk21c2x4567, vi2x3, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk21c2x4567, vi4x3, 0);
+
+ const float32x4_t vk02c0x0123 = vld1q_f32(w + 152);
+ const float32x4_t vk02c0x4567 = vld1q_f32(w + 156);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk02c0x0123, vi0x1, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk02c0x0123, vi2x1, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk02c0x4567, vi0x1, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk02c0x4567, vi2x1, 3);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk02c0x0123, vi0x3, 1);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk02c0x0123, vi2x3, 1);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk02c0x4567, vi0x3, 1);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk02c0x4567, vi2x3, 1);
+
+ const float32x4_t vk12c0x0123 = vld1q_f32(w + 160);
+ const float32x4_t vk12c0x4567 = vld1q_f32(w + 164);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk12c0x0123, vi1x1, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk12c0x0123, vi3x1, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk12c0x4567, vi1x1, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk12c0x4567, vi3x1, 3);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk12c0x0123, vi1x3, 1);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk12c0x0123, vi3x3, 1);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk12c0x4567, vi1x3, 1);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk12c0x4567, vi3x3, 1);
+
+ const float32x4_t vk22c0x0123 = vld1q_f32(w + 168);
+ const float32x4_t vk22c0x4567 = vld1q_f32(w + 172);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk22c0x0123, vi2x1, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk22c0x0123, vi4x1, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk22c0x4567, vi2x1, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk22c0x4567, vi4x1, 3);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk22c0x0123, vi2x3, 1);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk22c0x0123, vi4x3, 1);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk22c0x4567, vi2x3, 1);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk22c0x4567, vi4x3, 1);
+
+ const float32x4_t vk02c1x0123 = vld1q_f32(w + 176);
+ const float32x4_t vk02c1x4567 = vld1q_f32(w + 180);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk02c1x0123, vi0x2, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk02c1x0123, vi2x2, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk02c1x4567, vi0x2, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk02c1x4567, vi2x2, 0);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk02c1x0123, vi0x3, 2);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk02c1x0123, vi2x3, 2);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk02c1x4567, vi0x3, 2);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk02c1x4567, vi2x3, 2);
+
+ const float32x4_t vk12c1x0123 = vld1q_f32(w + 184);
+ const float32x4_t vk12c1x4567 = vld1q_f32(w + 188);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk12c1x0123, vi1x2, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk12c1x0123, vi3x2, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk12c1x4567, vi1x2, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk12c1x4567, vi3x2, 0);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk12c1x0123, vi1x3, 2);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk12c1x0123, vi3x3, 2);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk12c1x4567, vi1x3, 2);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk12c1x4567, vi3x3, 2);
+
+ const float32x4_t vk22c1x0123 = vld1q_f32(w + 192);
+ const float32x4_t vk22c1x4567 = vld1q_f32(w + 196);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk22c1x0123, vi2x2, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk22c1x0123, vi4x2, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk22c1x4567, vi2x2, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk22c1x4567, vi4x2, 0);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk22c1x0123, vi2x3, 2);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk22c1x0123, vi4x3, 2);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk22c1x4567, vi2x3, 2);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk22c1x4567, vi4x3, 2);
+
+ const float32x4_t vk02c2x0123 = vld1q_f32(w + 200);
+ const float32x4_t vk02c2x4567 = vld1q_f32(w + 204);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk02c2x0123, vi0x2, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk02c2x0123, vi2x2, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk02c2x4567, vi0x2, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk02c2x4567, vi2x2, 1);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk02c2x0123, vi0x3, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk02c2x0123, vi2x3, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk02c2x4567, vi0x3, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk02c2x4567, vi2x3, 3);
+
+ const float32x4_t vk12c2x0123 = vld1q_f32(w + 208);
+ const float32x4_t vk12c2x4567 = vld1q_f32(w + 212);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk12c2x0123, vi1x2, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk12c2x0123, vi3x2, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk12c2x4567, vi1x2, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk12c2x4567, vi3x2, 1);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk12c2x0123, vi1x3, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk12c2x0123, vi3x3, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk12c2x4567, vi1x3, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk12c2x4567, vi3x3, 3);
+
+ const float32x4_t vk22c2x0123 = vld1q_f32(w + 216);
+ const float32x4_t vk22c2x4567 = vld1q_f32(w + 220);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk22c2x0123, vi2x2, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk22c2x0123, vi4x2, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk22c2x4567, vi2x2, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk22c2x4567, vi4x2, 1);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk22c2x0123, vi2x3, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk22c2x0123, vi4x3, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk22c2x4567, vi2x3, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk22c2x4567, vi4x3, 3);
+
+ vi0x0 = vi0x3;
+ vi1x0 = vi1x3;
+ vi2x0 = vi2x3;
+ vi3x0 = vi3x3;
+ vi4x0 = vi4x3;
+
+ vo0x0c0123 = vmaxq_f32(vo0x0c0123, vmin);
+ vo1x0c0123 = vmaxq_f32(vo1x0c0123, vmin);
+ vo0x0c4567 = vmaxq_f32(vo0x0c4567, vmin);
+ vo1x0c4567 = vmaxq_f32(vo1x0c4567, vmin);
+
+ vo0x1c0123 = vmaxq_f32(vo0x1c0123, vmin);
+ vo1x1c0123 = vmaxq_f32(vo1x1c0123, vmin);
+ vo0x1c4567 = vmaxq_f32(vo0x1c4567, vmin);
+ vo1x1c4567 = vmaxq_f32(vo1x1c4567, vmin);
+
+ vo0x0c0123 = vminq_f32(vo0x0c0123, vmax);
+ vo1x0c0123 = vminq_f32(vo1x0c0123, vmax);
+ vo0x0c4567 = vminq_f32(vo0x0c4567, vmax);
+ vo1x0c4567 = vminq_f32(vo1x0c4567, vmax);
+
+ vo0x1c0123 = vminq_f32(vo0x1c0123, vmax);
+ vo1x1c0123 = vminq_f32(vo1x1c0123, vmax);
+ vo0x1c4567 = vminq_f32(vo0x1c4567, vmax);
+ vo1x1c4567 = vminq_f32(vo1x1c4567, vmax);
+
+ if XNN_LIKELY(c >= 8) {
+ vst1q_f32(o1, vo1x0c0123);
+ vst1q_f32(o1 + 4, vo1x0c4567);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride);
+ vst1q_f32(o0, vo0x0c0123);
+ vst1q_f32(o0 + 4, vo0x0c4567);
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride);
+
+ vst1q_f32(o1, vo1x1c0123);
+ vst1q_f32(o1 + 4, vo1x1c4567);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride);
+ vst1q_f32(o0, vo0x1c0123);
+ vst1q_f32(o0 + 4, vo0x1c4567);
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride);
+ } else {
+ float* o0_tmp = o0;
+ float* o1_tmp = o1;
+ if (c & 4) {
+ vst1q_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1c0123);
+ vo1x1c0123 = vo1x1c4567;
+ vst1q_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1c0123);
+ vo0x1c0123 = vo0x1c4567;
+
+ vst1q_f32(o1_tmp, vo1x0c0123); o1_tmp += 4;
+ vo1x0c0123 = vo1x0c4567;
+ vst1q_f32(o0_tmp, vo0x0c0123); o0_tmp += 4;
+ vo0x0c0123 = vo0x0c4567;
+ }
+ float32x2_t vo0x0c01 = vget_low_f32(vo0x0c0123);
+ float32x2_t vo1x0c01 = vget_low_f32(vo1x0c0123);
+ float32x2_t vo0x1c01 = vget_low_f32(vo0x1c0123);
+ float32x2_t vo1x1c01 = vget_low_f32(vo1x1c0123);
+ if (c & 2) {
+ vst1_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1c01);
+ vo1x1c01 = vget_high_f32(vo1x1c0123);
+ vst1_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1c01);
+ vo0x1c01 = vget_high_f32(vo0x1c0123);
+
+ vst1_f32(o1_tmp, vo1x0c01); o1_tmp += 2;
+ vo1x0c01 = vget_high_f32(vo1x0c0123);
+ vst1_f32(o0_tmp, vo0x0c01); o0_tmp += 2;
+ vo0x0c01 = vget_high_f32(vo0x0c0123);
+ }
+ if (c & 1) {
+ vst1_lane_f32(o1_tmp, vo1x0c01, 0);
+ vst1_lane_f32(o0_tmp, vo0x0c01, 0);
+
+ vst1_lane_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1c01, 0);
+ vst1_lane_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1c01, 0);
+ }
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride * 2);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride * 2);
+ }
+ }
+ assert(iw < 4);
+ if XNN_UNLIKELY(iw != 0) {
+ float32x4_t vo0x0c0123 = vld1q_f32(w);
+ float32x4_t vo0x0c4567 = vld1q_f32(w + 4);
+ float32x4_t vo1x0c0123 = vo0x0c0123;
+ float32x4_t vo1x0c4567 = vo0x0c4567;
+ float32x4_t vo0x1c0123 = vo0x0c0123;
+ float32x4_t vo0x1c4567 = vo0x0c4567;
+ float32x4_t vo1x1c0123 = vo0x0c0123;
+ float32x4_t vo1x1c4567 = vo0x0c4567;
+
+ const float32x4_t vk00c0x0123 = vld1q_f32(w + 8);
+ const float32x4_t vk00c0x4567 = vld1q_f32(w + 12);
+
+ // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
+ float32x4_t vi0x1 = vld1q_f32(i0);
+ float32x4_t vi1x1 = vld1q_f32(i1);
+ float32x4_t vi2x1 = vld1q_f32(i2);
+ float32x4_t vi3x1 = vld1q_f32(i3);
+ float32x4_t vi4x1 = vld1q_f32(i4);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk00c0x0123, vi0x0, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk00c0x0123, vi2x0, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk00c0x4567, vi0x0, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk00c0x4567, vi2x0, 1);
+
+ if (iw > 2) {
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk00c0x0123, vi0x1, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk00c0x0123, vi2x1, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk00c0x4567, vi0x1, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk00c0x4567, vi2x1, 3);
+ }
+
+ const float32x4_t vk10c0x0123 = vld1q_f32(w + 16);
+ const float32x4_t vk10c0x4567 = vld1q_f32(w + 20);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk10c0x0123, vi1x0, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk10c0x0123, vi3x0, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk10c0x4567, vi1x0, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk10c0x4567, vi3x0, 1);
+
+ if (iw > 2) {
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk10c0x0123, vi1x1, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk10c0x0123, vi3x1, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk10c0x4567, vi1x1, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk10c0x4567, vi3x1, 3);
+ }
+
+ const float32x4_t vk20c0x0123 = vld1q_f32(w + 24);
+ const float32x4_t vk20c0x4567 = vld1q_f32(w + 28);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk20c0x0123, vi2x0, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk20c0x0123, vi4x0, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk20c0x4567, vi2x0, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk20c0x4567, vi4x0, 1);
+
+ if (iw > 2) {
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk20c0x0123, vi2x1, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk20c0x0123, vi4x1, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk20c0x4567, vi2x1, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk20c0x4567, vi4x1, 3);
+ }
+
+ const float32x4_t vk00c1x0123 = vld1q_f32(w + 32);
+ const float32x4_t vk00c1x4567 = vld1q_f32(w + 36);
+
+ float32x4_t vi0x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x2 = vmovq_n_f32(0.0f);
+ if (iw >= 2) {
+ // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
+ vi0x2 = vld1q_f32(i0 + 4);
+ vi1x2 = vld1q_f32(i1 + 4);
+ vi2x2 = vld1q_f32(i2 + 4);
+ vi3x2 = vld1q_f32(i3 + 4);
+ vi4x2 = vld1q_f32(i4 + 4);
+ }
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk00c1x0123, vi0x0, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk00c1x0123, vi2x0, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk00c1x4567, vi0x0, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk00c1x4567, vi2x0, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk00c1x0123, vi0x2, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk00c1x0123, vi2x2, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk00c1x4567, vi0x2, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk00c1x4567, vi2x2, 0);
+
+ const float32x4_t vk10c1x0123 = vld1q_f32(w + 40);
+ const float32x4_t vk10c1x4567 = vld1q_f32(w + 44);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk10c1x0123, vi1x0, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk10c1x0123, vi3x0, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk10c1x4567, vi1x0, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk10c1x4567, vi3x0, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk10c1x0123, vi1x2, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk10c1x0123, vi3x2, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk10c1x4567, vi1x2, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk10c1x4567, vi3x2, 0);
+
+ const float32x4_t vk20c1x0123 = vld1q_f32(w + 48);
+ const float32x4_t vk20c1x4567 = vld1q_f32(w + 52);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk20c1x0123, vi2x0, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk20c1x0123, vi4x0, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk20c1x4567, vi2x0, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk20c1x4567, vi4x0, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk20c1x0123, vi2x2, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk20c1x0123, vi4x2, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk20c1x4567, vi2x2, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk20c1x4567, vi4x2, 0);
+
+ const float32x4_t vk00c2x0123 = vld1q_f32(w + 56);
+ const float32x4_t vk00c2x4567 = vld1q_f32(w + 60);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk00c2x0123, vi0x0, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk00c2x0123, vi2x0, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk00c2x4567, vi0x0, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk00c2x4567, vi2x0, 3);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk00c2x0123, vi0x2, 1);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk00c2x0123, vi2x2, 1);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk00c2x4567, vi0x2, 1);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk00c2x4567, vi2x2, 1);
+
+ const float32x4_t vk10c2x0123 = vld1q_f32(w + 64);
+ const float32x4_t vk10c2x4567 = vld1q_f32(w + 68);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk10c2x0123, vi1x0, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk10c2x0123, vi3x0, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk10c2x4567, vi1x0, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk10c2x4567, vi3x0, 3);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk10c2x0123, vi1x2, 1);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk10c2x0123, vi3x2, 1);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk10c2x4567, vi1x2, 1);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk10c2x4567, vi3x2, 1);
+
+ const float32x4_t vk20c2x0123 = vld1q_f32(w + 72);
+ const float32x4_t vk20c2x4567 = vld1q_f32(w + 76);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk20c2x0123, vi2x0, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk20c2x0123, vi4x0, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk20c2x4567, vi2x0, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk20c2x4567, vi4x0, 3);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk20c2x0123, vi2x2, 1);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk20c2x0123, vi4x2, 1);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk20c2x4567, vi2x2, 1);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk20c2x4567, vi4x2, 1);
+
+ const float32x4_t vk01c0x0123 = vld1q_f32(w + 80);
+ const float32x4_t vk01c0x4567 = vld1q_f32(w + 84);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk01c0x0123, vi0x1, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk01c0x0123, vi2x1, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk01c0x4567, vi0x1, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk01c0x4567, vi2x1, 0);
+
+ if (iw > 2) {
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk01c0x0123, vi0x2, 2);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk01c0x0123, vi2x2, 2);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk01c0x4567, vi0x2, 2);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk01c0x4567, vi2x2, 2);
+ }
+
+ const float32x4_t vk11c0x0123 = vld1q_f32(w + 88);
+ const float32x4_t vk11c0x4567 = vld1q_f32(w + 92);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk11c0x0123, vi1x1, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk11c0x0123, vi3x1, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk11c0x4567, vi1x1, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk11c0x4567, vi3x1, 0);
+
+ if (iw > 2) {
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk11c0x0123, vi1x2, 2);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk11c0x0123, vi3x2, 2);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk11c0x4567, vi1x2, 2);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk11c0x4567, vi3x2, 2);
+ }
+
+ const float32x4_t vk21c0x0123 = vld1q_f32(w + 96);
+ const float32x4_t vk21c0x4567 = vld1q_f32(w + 100);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk21c0x0123, vi2x1, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk21c0x0123, vi4x1, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk21c0x4567, vi2x1, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk21c0x4567, vi4x1, 0);
+
+ if (iw > 2) {
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk21c0x0123, vi2x2, 2);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk21c0x0123, vi4x2, 2);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk21c0x4567, vi2x2, 2);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk21c0x4567, vi4x2, 2);
+ }
+
+ const float32x4_t vk01c1x0123 = vld1q_f32(w + 104);
+ const float32x4_t vk01c1x4567 = vld1q_f32(w + 108);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk01c1x0123, vi0x1, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk01c1x0123, vi2x1, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk01c1x4567, vi0x1, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk01c1x4567, vi2x1, 1);
+
+ if (iw > 2) {
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk01c1x0123, vi0x2, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk01c1x0123, vi2x2, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk01c1x4567, vi0x2, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk01c1x4567, vi2x2, 3);
+ }
+
+ const float32x4_t vk11c1x0123 = vld1q_f32(w + 112);
+ const float32x4_t vk11c1x4567 = vld1q_f32(w + 116);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk11c1x0123, vi1x1, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk11c1x0123, vi3x1, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk11c1x4567, vi1x1, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk11c1x4567, vi3x1, 1);
+
+ if (iw > 2) {
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk11c1x0123, vi1x2, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk11c1x0123, vi3x2, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk11c1x4567, vi1x2, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk11c1x4567, vi3x2, 3);
+ }
+
+ const float32x4_t vk21c1x0123 = vld1q_f32(w + 120);
+ const float32x4_t vk21c1x4567 = vld1q_f32(w + 124);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk21c1x0123, vi2x1, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk21c1x0123, vi4x1, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk21c1x4567, vi2x1, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk21c1x4567, vi4x1, 1);
+
+ if (iw > 2) {
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk21c1x0123, vi2x2, 3);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk21c1x0123, vi4x2, 3);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk21c1x4567, vi2x2, 3);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk21c1x4567, vi4x2, 3);
+ }
+
+ const float32x4_t vk01c2x0123 = vld1q_f32(w + 128);
+ const float32x4_t vk01c2x4567 = vld1q_f32(w + 132);
+
+ float32x4_t vi0x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x3 = vmovq_n_f32(0.0f);
+ if (iw > 2) {
+ // viMx3 = ( 0.0, 0.0, 0.0, iM3c2 )
+ vi0x3 = vld1q_lane_f32(i0 + 8, vi0x3, 0);
+ vi1x3 = vld1q_lane_f32(i1 + 8, vi1x3, 0);
+ vi2x3 = vld1q_lane_f32(i2 + 8, vi2x3, 0);
+ vi3x3 = vld1q_lane_f32(i3 + 8, vi3x3, 0);
+ vi4x3 = vld1q_lane_f32(i4 + 8, vi4x3, 0);
+ }
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk01c2x0123, vi0x1, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk01c2x0123, vi2x1, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk01c2x4567, vi0x1, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk01c2x4567, vi2x1, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk01c2x0123, vi0x3, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk01c2x0123, vi2x3, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk01c2x4567, vi0x3, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk01c2x4567, vi2x3, 0);
+
+ const float32x4_t vk11c2x0123 = vld1q_f32(w + 136);
+ const float32x4_t vk11c2x4567 = vld1q_f32(w + 140);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk11c2x0123, vi1x1, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk11c2x0123, vi3x1, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk11c2x4567, vi1x1, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk11c2x4567, vi3x1, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk11c2x0123, vi1x3, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk11c2x0123, vi3x3, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk11c2x4567, vi1x3, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk11c2x4567, vi3x3, 0);
+
+ const float32x4_t vk21c2x0123 = vld1q_f32(w + 144);
+ const float32x4_t vk21c2x4567 = vld1q_f32(w + 148);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk21c2x0123, vi2x1, 2);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk21c2x0123, vi4x1, 2);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk21c2x4567, vi2x1, 2);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk21c2x4567, vi4x1, 2);
+
+ vo0x1c0123 = vfmaq_laneq_f32(vo0x1c0123, vk21c2x0123, vi2x3, 0);
+ vo1x1c0123 = vfmaq_laneq_f32(vo1x1c0123, vk21c2x0123, vi4x3, 0);
+ vo0x1c4567 = vfmaq_laneq_f32(vo0x1c4567, vk21c2x4567, vi2x3, 0);
+ vo1x1c4567 = vfmaq_laneq_f32(vo1x1c4567, vk21c2x4567, vi4x3, 0);
+
+ if (iw >= 2) {
+ const float32x4_t vk02c0x0123 = vld1q_f32(w + 152);
+ const float32x4_t vk02c0x4567 = vld1q_f32(w + 156);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk02c0x0123, vi0x1, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk02c0x0123, vi2x1, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk02c0x4567, vi0x1, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk02c0x4567, vi2x1, 3);
+
+ const float32x4_t vk12c0x0123 = vld1q_f32(w + 160);
+ const float32x4_t vk12c0x4567 = vld1q_f32(w + 164);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk12c0x0123, vi1x1, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk12c0x0123, vi3x1, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk12c0x4567, vi1x1, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk12c0x4567, vi3x1, 3);
+
+ const float32x4_t vk22c0x0123 = vld1q_f32(w + 168);
+ const float32x4_t vk22c0x4567 = vld1q_f32(w + 172);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk22c0x0123, vi2x1, 3);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk22c0x0123, vi4x1, 3);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk22c0x4567, vi2x1, 3);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk22c0x4567, vi4x1, 3);
+
+ const float32x4_t vk02c1x0123 = vld1q_f32(w + 176);
+ const float32x4_t vk02c1x4567 = vld1q_f32(w + 180);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk02c1x0123, vi0x2, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk02c1x0123, vi2x2, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk02c1x4567, vi0x2, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk02c1x4567, vi2x2, 0);
+
+ const float32x4_t vk12c1x0123 = vld1q_f32(w + 184);
+ const float32x4_t vk12c1x4567 = vld1q_f32(w + 188);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk12c1x0123, vi1x2, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk12c1x0123, vi3x2, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk12c1x4567, vi1x2, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk12c1x4567, vi3x2, 0);
+
+ const float32x4_t vk22c1x0123 = vld1q_f32(w + 192);
+ const float32x4_t vk22c1x4567 = vld1q_f32(w + 196);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk22c1x0123, vi2x2, 0);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk22c1x0123, vi4x2, 0);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk22c1x4567, vi2x2, 0);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk22c1x4567, vi4x2, 0);
+
+ const float32x4_t vk02c2x0123 = vld1q_f32(w + 200);
+ const float32x4_t vk02c2x4567 = vld1q_f32(w + 204);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk02c2x0123, vi0x2, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk02c2x0123, vi2x2, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk02c2x4567, vi0x2, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk02c2x4567, vi2x2, 1);
+
+ const float32x4_t vk12c2x0123 = vld1q_f32(w + 208);
+ const float32x4_t vk12c2x4567 = vld1q_f32(w + 212);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk12c2x0123, vi1x2, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk12c2x0123, vi3x2, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk12c2x4567, vi1x2, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk12c2x4567, vi3x2, 1);
+
+ const float32x4_t vk22c2x0123 = vld1q_f32(w + 216);
+ const float32x4_t vk22c2x4567 = vld1q_f32(w + 220);
+
+ vo0x0c0123 = vfmaq_laneq_f32(vo0x0c0123, vk22c2x0123, vi2x2, 1);
+ vo1x0c0123 = vfmaq_laneq_f32(vo1x0c0123, vk22c2x0123, vi4x2, 1);
+ vo0x0c4567 = vfmaq_laneq_f32(vo0x0c4567, vk22c2x4567, vi2x2, 1);
+ vo1x0c4567 = vfmaq_laneq_f32(vo1x0c4567, vk22c2x4567, vi4x2, 1);
+ }
+
+ vo0x0c0123 = vmaxq_f32(vo0x0c0123, vmin);
+ vo1x0c0123 = vmaxq_f32(vo1x0c0123, vmin);
+ vo0x0c4567 = vmaxq_f32(vo0x0c4567, vmin);
+ vo1x0c4567 = vmaxq_f32(vo1x0c4567, vmin);
+
+ vo0x1c0123 = vmaxq_f32(vo0x1c0123, vmin);
+ vo1x1c0123 = vmaxq_f32(vo1x1c0123, vmin);
+ vo0x1c4567 = vmaxq_f32(vo0x1c4567, vmin);
+ vo1x1c4567 = vmaxq_f32(vo1x1c4567, vmin);
+
+ vo0x0c0123 = vminq_f32(vo0x0c0123, vmax);
+ vo1x0c0123 = vminq_f32(vo1x0c0123, vmax);
+ vo0x0c4567 = vminq_f32(vo0x0c4567, vmax);
+ vo1x0c4567 = vminq_f32(vo1x0c4567, vmax);
+
+ vo0x1c0123 = vminq_f32(vo0x1c0123, vmax);
+ vo1x1c0123 = vminq_f32(vo1x1c0123, vmax);
+ vo0x1c4567 = vminq_f32(vo0x1c4567, vmax);
+ vo1x1c4567 = vminq_f32(vo1x1c4567, vmax);
+
+ iw += 1;
+ if XNN_LIKELY(c >= 8) {
+ vst1q_f32(o1, vo1x0c0123);
+ vst1q_f32(o1 + 4, vo1x0c4567);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride);
+ vst1q_f32(o0, vo0x0c0123);
+ vst1q_f32(o0 + 4, vo0x0c4567);
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride);
+
+ if (iw & 4) {
+ vst1q_f32(o1, vo1x1c0123);
+ vst1q_f32(o1 + 4, vo1x1c4567);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride);
+ vst1q_f32(o0, vo0x1c0123);
+ vst1q_f32(o0 + 4, vo0x1c4567);
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride);
+ }
+ } else {
+ float* o0_tmp = o0;
+ float* o1_tmp = o1;
+ if (c & 4) {
+ if (iw & 4) {
+ vst1q_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1c0123);
+ vo1x1c0123 = vo1x1c4567;
+ vst1q_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1c0123);
+ vo0x1c0123 = vo0x1c4567;
+ }
+
+ vst1q_f32(o1_tmp, vo1x0c0123); o1_tmp += 4;
+ vo1x0c0123 = vo1x0c4567;
+ vst1q_f32(o0_tmp, vo0x0c0123); o0_tmp += 4;
+ vo0x0c0123 = vo0x0c4567;
+ }
+ float32x2_t vo0x0c01 = vget_low_f32(vo0x0c0123);
+ float32x2_t vo1x0c01 = vget_low_f32(vo1x0c0123);
+ float32x2_t vo0x1c01 = vget_low_f32(vo0x1c0123);
+ float32x2_t vo1x1c01 = vget_low_f32(vo1x1c0123);
+ if (c & 2) {
+ if (iw & 4) {
+ vst1_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1c01);
+ vo1x1c01 = vget_high_f32(vo1x1c0123);
+ vst1_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1c01);
+ vo0x1c01 = vget_high_f32(vo0x1c0123);
+ }
+
+ vst1_f32(o1_tmp, vo1x0c01); o1_tmp += 2;
+ vo1x0c01 = vget_high_f32(vo1x0c0123);
+ vst1_f32(o0_tmp, vo0x0c01); o0_tmp += 2;
+ vo0x0c01 = vget_high_f32(vo0x0c0123);
+ }
+ if (c & 1) {
+ vst1_lane_f32(o1_tmp, vo1x0c01, 0);
+ vst1_lane_f32(o0_tmp, vo0x0c01, 0);
+
+ if (iw & 4) {
+ vst1_lane_f32((float*) ((uintptr_t) o1_tmp + output_width_stride), vo1x1c01, 0);
+ vst1_lane_f32((float*) ((uintptr_t) o0_tmp + output_width_stride), vo0x1c01, 0);
+ }
+ }
+ o0 = (float*) ((uintptr_t) o0 + output_width_stride * 2);
+ o1 = (float*) ((uintptr_t) o1 + output_width_stride * 2);
+ }
+ }
+ // Move output pointers back to the position of the first pixel in a row,
+ // and forward to the next block of output channels
+ o0 = (float*) ((uintptr_t) o0 + output_channel_increment);
+ o1 = (float*) ((uintptr_t) o1 + output_channel_increment);
+ // Revert input pointers to the position of the first pixel in a row
+ i0 = (const float*) ((uintptr_t) i0 - input_width_increment);
+ i1 = (const float*) ((uintptr_t) i1 - input_width_increment);
+ i2 = (const float*) ((uintptr_t) i2 - input_width_increment);
+ i3 = (const float*) ((uintptr_t) i3 - input_width_increment);
+ i4 = (const float*) ((uintptr_t) i4 - input_width_increment);
+ // Move to the block of weights for the next 8 output channels
+ w += 224;
+ c = doz(c, 8);
+ } while (c != 0);
+ // Move output pointers forward to the next two rows
+ output0 = (float*) ((uintptr_t) output1 + output_height_stride);
+ output1 = (float*) ((uintptr_t) output0 + output_height_stride);
+ // Move input pointers forward to the next four rows
+ i0 = i4;
+ i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
+ i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
+ i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
+ i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
+ }
+}
diff --git a/src/f32-conv-hwc2spchw/3x3s2p1c3x4-neonfma-2x2.c b/src/f32-conv-hwc2spchw/3x3s2p1c3x4-neonfma-2x2.c
new file mode 100644
index 0000000..d9eeafc
--- /dev/null
+++ b/src/f32-conv-hwc2spchw/3x3s2p1c3x4-neonfma-2x2.c
@@ -0,0 +1,654 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/conv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_conv_hwc2spchw_ukernel_3x3s2p1c3x4__neonfma_2x2(
+ size_t input_height,
+ size_t input_width,
+ size_t output_y_start,
+ size_t output_y_end,
+ const float* input,
+ const float* zero,
+ const float* weights,
+ float* output,
+ size_t input_padding_top,
+ size_t output_channels,
+ size_t output_height_stride,
+ size_t output_channel_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(input_width != 0);
+ assert(output_y_end > output_y_start);
+ assert(input_padding_top <= 1);
+ assert(output_channels != 0);
+
+ const size_t input_height_stride = input_width * 3 /* channels */ * sizeof(float);
+ const size_t input_width_increment = round_down_po2(input_width, 4) * 3 /* channels */ * sizeof(float);
+ const size_t output_width = (input_width + 1) / 2;
+ const size_t output_channel_increment = output_channel_stride * 4 - output_width * sizeof(float);
+
+ // Adjustment for padding processed below
+ const float* i0 = (const float*) ((uintptr_t) input + input_height_stride * (output_y_start * 2 - input_padding_top));
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
+ float* output0 = (float*) ((uintptr_t) output + output_height_stride * output_y_start);
+ float* output1 = (float*) ((uintptr_t) output0 + output_height_stride);
+
+ if XNN_UNPREDICTABLE(output_y_start < input_padding_top) {
+ i0 = zero;
+ }
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+
+ for (size_t output_y = output_y_start; output_y < output_y_end; output_y += 2) {
+ const size_t input_y2 = output_y * 2 + 2 - input_padding_top;
+ const size_t input_y4 = input_y2 + 2;
+ if XNN_UNPREDICTABLE(input_y2 >= input_height) {
+ i2 = zero;
+ }
+ if XNN_UNPREDICTABLE(input_y4 > input_height) {
+ i3 = zero;
+ }
+ if XNN_UNPREDICTABLE(input_y4 >= input_height) {
+ i4 = zero;
+ }
+ if XNN_UNPREDICTABLE(output_y + 2 > output_y_end) {
+ output1 = output0;
+ }
+
+ const float* w = weights;
+ size_t c = output_channels;
+ float* o0c0 = output0;
+ float* o1c0 = output1;
+ float* o0c1 = (float*) ((uintptr_t) o0c0 + output_channel_stride);
+ float* o1c1 = (float*) ((uintptr_t) o1c0 + output_channel_stride);
+ float* o0c2 = (float*) ((uintptr_t) o0c1 + output_channel_stride);
+ float* o1c2 = (float*) ((uintptr_t) o1c1 + output_channel_stride);
+ float* o0c3 = (float*) ((uintptr_t) o0c2 + output_channel_stride);
+ float* o1c3 = (float*) ((uintptr_t) o1c2 + output_channel_stride);
+ do {
+ if XNN_UNPREDICTABLE(c < 2) {
+ o0c1 = o0c0;
+ o1c1 = o1c0;
+ }
+ if XNN_UNPREDICTABLE(c <= 2) {
+ o0c2 = o0c1;
+ o1c2 = o1c1;
+ }
+ if XNN_UNPREDICTABLE(c < 4) {
+ o0c3 = o0c2;
+ o1c3 = o1c2;
+ }
+
+ // viMx0 = ( iM0c2, iM0c1, iM0c0, --- )
+ float32x4_t vi0x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x0 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x0 = vmovq_n_f32(0.0f);
+
+ size_t iw = input_width;
+ for (; iw >= 4; iw -= 4) {
+ float32x4_t vo0x0 = vld1q_f32(w);
+ float32x4_t vo1x0 = vo0x0;
+ float32x4_t vo0x1 = vo0x0;
+ float32x4_t vo1x1 = vo0x0;
+
+ const float32x4_t vk00c0 = vld1q_f32(w + 4);
+
+ // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
+ const float32x4_t vi0x1 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1x1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2x1 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3x1 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4x1 = vld1q_f32(i4); i4 += 4;
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c0, vi0x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c0, vi2x0, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c0, vi0x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c0, vi2x1, 3);
+
+ const float32x4_t vk10c0 = vld1q_f32(w + 8);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c0, vi1x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c0, vi3x0, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c0, vi1x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c0, vi3x1, 3);
+
+ const float32x4_t vk20c0 = vld1q_f32(w + 12);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c0, vi2x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c0, vi4x0, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c0, vi2x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c0, vi4x1, 3);
+
+ const float32x4_t vk00c1 = vld1q_f32(w + 16);
+
+ // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
+ const float32x4_t vi0x2 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1x2 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2x2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3x2 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4x2 = vld1q_f32(i4); i4 += 4;
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c1, vi0x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c1, vi2x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c1, vi0x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c1, vi2x2, 0);
+
+ const float32x4_t vk10c1 = vld1q_f32(w + 20);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c1, vi1x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c1, vi3x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c1, vi1x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c1, vi3x2, 0);
+
+ const float32x4_t vk20c1 = vld1q_f32(w + 24);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c1, vi2x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c1, vi4x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c1, vi2x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c1, vi4x2, 0);
+
+ const float32x4_t vk00c2 = vld1q_f32(w + 28);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c2, vi0x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c2, vi2x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c2, vi0x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c2, vi2x2, 1);
+
+ const float32x4_t vk10c2 = vld1q_f32(w + 32);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c2, vi1x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c2, vi3x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c2, vi1x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c2, vi3x2, 1);
+
+ const float32x4_t vk20c2 = vld1q_f32(w + 36);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c2, vi2x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c2, vi4x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c2, vi2x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c2, vi4x2, 1);
+
+ const float32x4_t vk01c0 = vld1q_f32(w + 40);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c0, vi0x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c0, vi2x1, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c0, vi0x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c0, vi2x2, 2);
+
+ const float32x4_t vk11c0 = vld1q_f32(w + 44);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c0, vi1x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c0, vi3x1, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c0, vi1x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c0, vi3x2, 2);
+
+ const float32x4_t vk21c0 = vld1q_f32(w + 48);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c0, vi2x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c0, vi4x1, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c0, vi2x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c0, vi4x2, 2);
+
+ const float32x4_t vk01c1 = vld1q_f32(w + 52);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c1, vi0x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c1, vi2x1, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c1, vi0x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c1, vi2x2, 3);
+
+ const float32x4_t vk11c1 = vld1q_f32(w + 56);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c1, vi1x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c1, vi3x1, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c1, vi1x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c1, vi3x2, 3);
+
+ const float32x4_t vk21c1 = vld1q_f32(w + 60);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c1, vi2x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c1, vi4x1, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c1, vi2x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c1, vi4x2, 3);
+
+ const float32x4_t vk01c2 = vld1q_f32(w + 64);
+
+ // viMx3 = ( iM4c2, iM4c1, iM4c0, iM3c2 )
+ const float32x4_t vi0x3 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1x3 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2x3 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3x3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4x3 = vld1q_f32(i4); i4 += 4;
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c2, vi0x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c2, vi2x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c2, vi0x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c2, vi2x3, 0);
+
+ const float32x4_t vk11c2 = vld1q_f32(w + 68);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c2, vi1x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c2, vi3x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c2, vi1x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c2, vi3x3, 0);
+
+ const float32x4_t vk21c2 = vld1q_f32(w + 72);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c2, vi2x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c2, vi4x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c2, vi2x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c2, vi4x3, 0);
+
+ const float32x4_t vk02c0 = vld1q_f32(w + 76);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c0, vi0x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c0, vi2x1, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c0, vi0x3, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c0, vi2x3, 1);
+
+ const float32x4_t vk12c0 = vld1q_f32(w + 80);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c0, vi1x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c0, vi3x1, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c0, vi1x3, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c0, vi3x3, 1);
+
+ const float32x4_t vk22c0 = vld1q_f32(w + 84);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c0, vi2x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c0, vi4x1, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c0, vi2x3, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c0, vi4x3, 1);
+
+ const float32x4_t vk02c1 = vld1q_f32(w + 88);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c1, vi0x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c1, vi2x2, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c1, vi0x3, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c1, vi2x3, 2);
+
+ const float32x4_t vk12c1 = vld1q_f32(w + 92);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c1, vi1x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c1, vi3x2, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c1, vi1x3, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c1, vi3x3, 2);
+
+ const float32x4_t vk22c1 = vld1q_f32(w + 96);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c1, vi2x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c1, vi4x2, 0);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c1, vi2x3, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c1, vi4x3, 2);
+
+ const float32x4_t vk02c2 = vld1q_f32(w + 100);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c2, vi0x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c2, vi2x2, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk02c2, vi0x3, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk02c2, vi2x3, 3);
+
+ const float32x4_t vk12c2 = vld1q_f32(w + 104);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c2, vi1x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c2, vi3x2, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk12c2, vi1x3, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk12c2, vi3x3, 3);
+
+ const float32x4_t vk22c2 = vld1q_f32(w + 108);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c2, vi2x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c2, vi4x2, 1);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk22c2, vi2x3, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk22c2, vi4x3, 3);
+
+ vi0x0 = vi0x3;
+ vi1x0 = vi1x3;
+ vi2x0 = vi2x3;
+ vi3x0 = vi3x3;
+ vi4x0 = vi4x3;
+
+ vo0x0 = vmaxq_f32(vo0x0, vmin);
+ vo1x0 = vmaxq_f32(vo1x0, vmin);
+ vo0x1 = vmaxq_f32(vo0x1, vmin);
+ vo1x1 = vmaxq_f32(vo1x1, vmin);
+
+ vo0x0 = vminq_f32(vo0x0, vmax);
+ vo1x0 = vminq_f32(vo1x0, vmax);
+ vo0x1 = vminq_f32(vo0x1, vmax);
+ vo1x1 = vminq_f32(vo1x1, vmax);
+
+ const float32x4_t vo0c01 = vzip1q_f32(vo0x0, vo0x1);
+ const float32x4_t vo0c23 = vzip2q_f32(vo0x0, vo0x1);
+ const float32x4_t vo1c01 = vzip1q_f32(vo1x0, vo1x1);
+ const float32x4_t vo1c23 = vzip2q_f32(vo1x0, vo1x1);
+
+ // Always 2+ output width elements remaining
+ vst1_f32(o1c0, vget_low_f32(vo1c01)); o1c0 += 2;
+ vst1_f32(o1c1, vget_high_f32(vo1c01)); o1c1 += 2;
+ vst1_f32(o1c2, vget_low_f32(vo1c23)); o1c2 += 2;
+ vst1_f32(o1c3, vget_high_f32(vo1c23)); o1c3 += 2;
+
+ vst1_f32(o0c0, vget_low_f32(vo0c01)); o0c0 += 2;
+ vst1_f32(o0c1, vget_high_f32(vo0c01)); o0c1 += 2;
+ vst1_f32(o0c2, vget_low_f32(vo0c23)); o0c2 += 2;
+ vst1_f32(o0c3, vget_high_f32(vo0c23)); o0c3 += 2;
+ }
+ assert(iw < 4);
+ if XNN_UNLIKELY(iw != 0) {
+ float32x4_t vo0x0 = vld1q_f32(w);
+ float32x4_t vo1x0 = vo0x0;
+ float32x4_t vo0x1 = vo0x0;
+ float32x4_t vo1x1 = vo0x0;
+
+ const float32x4_t vk00c0 = vld1q_f32(w + 4);
+
+ // viMx1 = ( iM2c0, iM1c2, iM1c1, iM1c0 )
+ float32x4_t vi0x1 = vld1q_f32(i0);
+ float32x4_t vi1x1 = vld1q_f32(i1);
+ float32x4_t vi2x1 = vld1q_f32(i2);
+ float32x4_t vi3x1 = vld1q_f32(i3);
+ float32x4_t vi4x1 = vld1q_f32(i4);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c0, vi0x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c0, vi2x0, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c0, vi0x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c0, vi2x1, 3);
+ }
+
+ const float32x4_t vk10c0 = vld1q_f32(w + 8);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c0, vi1x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c0, vi3x0, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c0, vi1x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c0, vi3x1, 3);
+ }
+
+ const float32x4_t vk20c0 = vld1q_f32(w + 12);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c0, vi2x0, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c0, vi4x0, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c0, vi2x1, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c0, vi4x1, 3);
+ }
+
+ const float32x4_t vk00c1 = vld1q_f32(w + 16);
+
+ float32x4_t vi0x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x2 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x2 = vmovq_n_f32(0.0f);
+ if (iw >= 2) {
+ // viMx2 = ( iM3c1, iM3c0, iM2c2, iM2c1 )
+ vi0x2 = vld1q_f32(i0 + 4);
+ vi1x2 = vld1q_f32(i1 + 4);
+ vi2x2 = vld1q_f32(i2 + 4);
+ vi3x2 = vld1q_f32(i3 + 4);
+ vi4x2 = vld1q_f32(i4 + 4);
+ }
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c1, vi0x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c1, vi2x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c1, vi0x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c1, vi2x2, 0);
+
+ const float32x4_t vk10c1 = vld1q_f32(w + 20);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c1, vi1x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c1, vi3x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c1, vi1x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c1, vi3x2, 0);
+
+ const float32x4_t vk20c1 = vld1q_f32(w + 24);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c1, vi2x0, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c1, vi4x0, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c1, vi2x2, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c1, vi4x2, 0);
+
+ const float32x4_t vk00c2 = vld1q_f32(w + 28);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk00c2, vi0x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk00c2, vi2x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk00c2, vi0x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk00c2, vi2x2, 1);
+
+ const float32x4_t vk10c2 = vld1q_f32(w + 32);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk10c2, vi1x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk10c2, vi3x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk10c2, vi1x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk10c2, vi3x2, 1);
+
+ const float32x4_t vk20c2 = vld1q_f32(w + 36);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk20c2, vi2x0, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk20c2, vi4x0, 3);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk20c2, vi2x2, 1);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk20c2, vi4x2, 1);
+
+ const float32x4_t vk01c0 = vld1q_f32(w + 40);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c0, vi0x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c0, vi2x1, 0);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c0, vi0x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c0, vi2x2, 2);
+ }
+
+ const float32x4_t vk11c0 = vld1q_f32(w + 44);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c0, vi1x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c0, vi3x1, 0);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c0, vi1x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c0, vi3x2, 2);
+ }
+
+ const float32x4_t vk21c0 = vld1q_f32(w + 48);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c0, vi2x1, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c0, vi4x1, 0);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c0, vi2x2, 2);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c0, vi4x2, 2);
+ }
+
+ const float32x4_t vk01c1 = vld1q_f32(w + 52);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c1, vi0x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c1, vi2x1, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c1, vi0x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c1, vi2x2, 3);
+ }
+
+ const float32x4_t vk11c1 = vld1q_f32(w + 56);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c1, vi1x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c1, vi3x1, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c1, vi1x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c1, vi3x2, 3);
+ }
+
+ const float32x4_t vk21c1 = vld1q_f32(w + 60);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c1, vi2x1, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c1, vi4x1, 1);
+ if (iw > 2) {
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c1, vi2x2, 3);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c1, vi4x2, 3);
+ }
+
+ const float32x4_t vk01c2 = vld1q_f32(w + 64);
+
+ float32x4_t vi0x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x3 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x3 = vmovq_n_f32(0.0f);
+ if (iw > 2) {
+ // viMx3 = ( 0.0, 0.0, 0.0, iM3c2 )
+ vi0x3 = vld1q_lane_f32(i0 + 8, vi0x3, 0);
+ vi1x3 = vld1q_lane_f32(i1 + 8, vi1x3, 0);
+ vi2x3 = vld1q_lane_f32(i2 + 8, vi2x3, 0);
+ vi3x3 = vld1q_lane_f32(i3 + 8, vi3x3, 0);
+ vi4x3 = vld1q_lane_f32(i4 + 8, vi4x3, 0);
+ }
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk01c2, vi0x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk01c2, vi2x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk01c2, vi0x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk01c2, vi2x3, 0);
+
+ const float32x4_t vk11c2 = vld1q_f32(w + 68);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk11c2, vi1x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk11c2, vi3x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk11c2, vi1x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk11c2, vi3x3, 0);
+
+ const float32x4_t vk21c2 = vld1q_f32(w + 72);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk21c2, vi2x1, 2);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk21c2, vi4x1, 2);
+ vo0x1 = vfmaq_laneq_f32(vo0x1, vk21c2, vi2x3, 0);
+ vo1x1 = vfmaq_laneq_f32(vo1x1, vk21c2, vi4x3, 0);
+
+ if (iw >= 2) {
+ const float32x4_t vk02c0 = vld1q_f32(w + 76);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c0, vi0x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c0, vi2x1, 3);
+
+ const float32x4_t vk12c0 = vld1q_f32(w + 80);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c0, vi1x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c0, vi3x1, 3);
+
+ const float32x4_t vk22c0 = vld1q_f32(w + 84);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c0, vi2x1, 3);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c0, vi4x1, 3);
+
+ const float32x4_t vk02c1 = vld1q_f32(w + 88);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c1, vi0x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c1, vi2x2, 0);
+
+ const float32x4_t vk12c1 = vld1q_f32(w + 92);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c1, vi1x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c1, vi3x2, 0);
+
+ const float32x4_t vk22c1 = vld1q_f32(w + 96);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c1, vi2x2, 0);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c1, vi4x2, 0);
+
+ const float32x4_t vk02c2 = vld1q_f32(w + 100);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk02c2, vi0x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk02c2, vi2x2, 1);
+
+ const float32x4_t vk12c2 = vld1q_f32(w + 104);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk12c2, vi1x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk12c2, vi3x2, 1);
+
+ const float32x4_t vk22c2 = vld1q_f32(w + 108);
+
+ vo0x0 = vfmaq_laneq_f32(vo0x0, vk22c2, vi2x2, 1);
+ vo1x0 = vfmaq_laneq_f32(vo1x0, vk22c2, vi4x2, 1);
+ }
+
+ vo0x0 = vmaxq_f32(vo0x0, vmin);
+ vo1x0 = vmaxq_f32(vo1x0, vmin);
+ vo0x1 = vmaxq_f32(vo0x1, vmin);
+ vo1x1 = vmaxq_f32(vo1x1, vmin);
+
+ vo0x0 = vminq_f32(vo0x0, vmax);
+ vo1x0 = vminq_f32(vo1x0, vmax);
+ vo0x1 = vminq_f32(vo0x1, vmax);
+ vo1x1 = vminq_f32(vo1x1, vmax);
+
+ if (iw == 3) {
+ // Exactly 2 output width elements remaining
+ const float32x4_t vo0c01 = vzip1q_f32(vo0x0, vo0x1);
+ const float32x4_t vo0c23 = vzip2q_f32(vo0x0, vo0x1);
+ const float32x4_t vo1c01 = vzip1q_f32(vo1x0, vo1x1);
+ const float32x4_t vo1c23 = vzip2q_f32(vo1x0, vo1x1);
+
+ vst1_f32(o1c0, vget_low_f32(vo1c01)); o1c0 += 2;
+ vst1_f32(o1c1, vget_high_f32(vo1c01)); o1c1 += 2;
+ vst1_f32(o1c2, vget_low_f32(vo1c23)); o1c2 += 2;
+ vst1_f32(o1c3, vget_high_f32(vo1c23)); o1c3 += 2;
+
+ vst1_f32(o0c0, vget_low_f32(vo0c01)); o0c0 += 2;
+ vst1_f32(o0c1, vget_high_f32(vo0c01)); o0c1 += 2;
+ vst1_f32(o0c2, vget_low_f32(vo0c23)); o0c2 += 2;
+ vst1_f32(o0c3, vget_high_f32(vo0c23)); o0c3 += 2;
+ } else {
+ // Exactly 1 output width element remaining
+
+ vst1q_lane_f32(o1c0, vo1x0, 0); o1c0 += 1;
+ vst1q_lane_f32(o1c1, vo1x0, 1); o1c1 += 1;
+ vst1q_lane_f32(o1c2, vo1x0, 2); o1c2 += 1;
+ vst1q_lane_f32(o1c3, vo1x0, 3); o1c3 += 1;
+
+ vst1q_lane_f32(o0c0, vo0x0, 0); o0c0 += 1;
+ vst1q_lane_f32(o0c1, vo0x0, 1); o0c1 += 1;
+ vst1q_lane_f32(o0c2, vo0x0, 2); o0c2 += 1;
+ vst1q_lane_f32(o0c3, vo0x0, 3); o0c3 += 1;
+ }
+ }
+ // Move output pointers back to the position of the first pixel in a row,
+ // and forward to the next block of output channels.
+ o0c0 = (float*) ((uintptr_t) o0c0 + output_channel_increment);
+ o0c1 = (float*) ((uintptr_t) o0c1 + output_channel_increment);
+ o0c2 = (float*) ((uintptr_t) o0c2 + output_channel_increment);
+ o0c3 = (float*) ((uintptr_t) o0c3 + output_channel_increment);
+ o1c0 = (float*) ((uintptr_t) o1c0 + output_channel_increment);
+ o1c1 = (float*) ((uintptr_t) o1c1 + output_channel_increment);
+ o1c2 = (float*) ((uintptr_t) o1c2 + output_channel_increment);
+ o1c3 = (float*) ((uintptr_t) o1c3 + output_channel_increment);
+ // Revert input pointers to the position of the first pixel in a row
+ i0 = (const float*) ((uintptr_t) i0 - input_width_increment);
+ i1 = (const float*) ((uintptr_t) i1 - input_width_increment);
+ i2 = (const float*) ((uintptr_t) i2 - input_width_increment);
+ i3 = (const float*) ((uintptr_t) i3 - input_width_increment);
+ i4 = (const float*) ((uintptr_t) i4 - input_width_increment);
+ // Move to the block of weights for the next 4 output channels
+ w += 112;
+ c = doz(c, 4);
+ } while (c != 0);
+ // Move output pointers forward to the next two rows
+ output0 = (float*) ((uintptr_t) output1 + output_height_stride);
+ output1 = (float*) ((uintptr_t) output0 + output_height_stride);
+ // Move input pointers forward to the next four rows
+ i0 = i4;
+ i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
+ i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
+ i3 = (const float*) ((uintptr_t) i2 + input_height_stride);
+ i4 = (const float*) ((uintptr_t) i3 + input_height_stride);
+ }
+}
diff --git a/src/f32-dwconv-spchw/3x3p1-neonfma.c b/src/f32-dwconv-spchw/3x3p1-neonfma.c
new file mode 100644
index 0000000..3f2c0e7
--- /dev/null
+++ b/src/f32-dwconv-spchw/3x3p1-neonfma.c
@@ -0,0 +1,375 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_spchw_ukernel_3x3p1__neonfma(
+ size_t m,
+ size_t n,
+ const float* input,
+ const float* weights,
+ float* output,
+ size_t input_tuple_stride,
+ size_t output_tuple_stride,
+ size_t input_width_stride,
+ size_t output_width_stride,
+ const union xnn_f32_spchw_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const uint32x4_t vmask = vld1q_u32(params->neon.mask);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->neon.max);
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->neon.min);
+
+ const size_t input_width_increment = 3 * input_width_stride - round_up_po2(n, 4) / 4 * input_tuple_stride;
+ const size_t output_width_increment = 3 * output_width_stride - (n - 1) / 4 * output_tuple_stride;
+ const size_t input_width_increment_single = input_width_stride - round_up_po2(n, 4) / 4 * input_tuple_stride;
+ const size_t output_width_increment_single = output_width_stride - (n - 1) / 4 * output_tuple_stride;
+
+ /* No vertical padding */
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_width_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_width_stride);
+
+ float* output0 = output;
+ float* output1 = (float *)((uintptr_t)output0 + output_width_stride);
+ float* output2 = (float *)((uintptr_t)output1 + output_width_stride);
+
+ const float32x4_t vw0123 = vld1q_f32(weights);
+ const float32x4_t vw4567 = vld1q_f32(weights + 4);
+ const float32x2_t vw89 = vld1_f32(weights + 8);
+
+ while (m >= 3) {
+ float32x4_t vi0x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+ float32x4_t vi3x4567 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride);
+ float32x4_t vi4x4567 = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride);
+
+ size_t k = n;
+ for (; k > 4; k -= 4) {
+ float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0);
+ float32x4_t vo4567p01 = vdupq_laneq_f32(vw0123, 0);
+ float32x4_t vo4567p02 = vdupq_laneq_f32(vw0123, 0);
+
+ const float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ const float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ const float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+ const float32x4_t vi3x89AB = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride);
+ const float32x4_t vi4x89AB = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x4567, vw4567, 1);
+ vo4567p00 = vfmaq_lane_f32(vo4567p00, vi2x4567, vw89, 0);
+
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x4567, vw0123, 2);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x4567, vw4567, 1);
+ vo4567p01 = vfmaq_lane_f32(vo4567p01, vi3x4567, vw89, 0);
+
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x4567, vw0123, 2);
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x4567, vw4567, 1);
+ vo4567p02 = vfmaq_lane_f32(vo4567p02, vi4x4567, vw89, 0);
+
+ const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3);
+ const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3);
+ const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3);
+ const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3);
+ const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 1);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x3456, vw4567, 0);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vw4567, 3);
+
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw0123, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x3456, vw4567, 0);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vw4567, 3);
+
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x3456, vw0123, 1);
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x3456, vw4567, 0);
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi4x3456, vw4567, 3);
+
+ vi0x0123 = vi0x4567;
+ vi1x0123 = vi1x4567;
+ vi2x0123 = vi2x4567;
+ vi3x0123 = vi3x4567;
+ vi4x0123 = vi4x4567;
+
+ const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vi0x89AB, 1);
+ const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vi1x89AB, 1);
+ const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vi2x89AB, 1);
+ const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vi3x89AB, 1);
+ const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vi4x89AB, 1);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw0123, 3);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x5678, vw4567, 2);
+ vo4567p00 = vfmaq_lane_f32(vo4567p00, vi2x5678, vw89, 1);
+
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw0123, 3);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x5678, vw4567, 2);
+ vo4567p01 = vfmaq_lane_f32(vo4567p01, vi3x5678, vw89, 1);
+
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x5678, vw0123, 3);
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x5678, vw4567, 2);
+ vo4567p02 = vfmaq_lane_f32(vo4567p02, vi4x5678, vw89, 1);
+
+ vi0x4567 = vi0x89AB;
+ vi1x4567 = vi1x89AB;
+ vi2x4567 = vi2x89AB;
+ vi3x4567 = vi3x89AB;
+ vi4x4567 = vi4x89AB;
+
+ float32x4_t vo0 = vo4567p00;
+ float32x4_t vo1 = vo4567p01;
+ float32x4_t vo2 = vo4567p02;
+
+ vo0 = vmaxq_f32(vo0, vmin);
+ vo0 = vminq_f32(vo0, vmax);
+ vo1 = vmaxq_f32(vo1, vmin);
+ vo1 = vminq_f32(vo1, vmax);
+ vo2 = vmaxq_f32(vo2, vmin);
+ vo2 = vminq_f32(vo2, vmax);
+
+ vst1q_f32(output0, vo0); output0 = (float*) ((uintptr_t) output0 + output_tuple_stride);
+ vst1q_f32(output1, vo1); output1 = (float*) ((uintptr_t) output1 + output_tuple_stride);
+ vst1q_f32(output2, vo2); output2 = (float*) ((uintptr_t) output2 + output_tuple_stride);
+ }
+ /* Always process the last block of 1..4 pixels */
+ assert(k >= 1);
+ assert(k <= 4);
+ {
+ float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0);
+ float32x4_t vo4567p01 = vdupq_laneq_f32(vw0123, 0);
+ float32x4_t vo4567p02 = vdupq_laneq_f32(vw0123, 0);
+
+ vi0x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0x4567)));
+ vi1x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1x4567)));
+ vi2x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2x4567)));
+ vi3x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi3x4567)));
+ vi4x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi4x4567)));
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x4567, vw4567, 1);
+ vo4567p00 = vfmaq_lane_f32(vo4567p00, vi2x4567, vw89, 0);
+
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x4567, vw0123, 2);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x4567, vw4567, 1);
+ vo4567p01 = vfmaq_lane_f32(vo4567p01, vi3x4567, vw89, 0);
+
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x4567, vw0123, 2);
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x4567, vw4567, 1);
+ vo4567p02 = vfmaq_lane_f32(vo4567p02, vi4x4567, vw89, 0);
+
+ const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3);
+ const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3);
+ const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3);
+ const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3);
+ const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 1);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x3456, vw4567, 0);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vw4567, 3);
+
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw0123, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x3456, vw4567, 0);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vw4567, 3);
+
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x3456, vw0123, 1);
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x3456, vw4567, 0);
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi4x3456, vw4567, 3);
+
+ const float32x4_t vzero = vmovq_n_f32(0.0f);
+ const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vzero, 1);
+ const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vzero, 1);
+ const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vzero, 1);
+ const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vzero, 1);
+ const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vzero, 1);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw0123, 3);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi1x5678, vw4567, 2);
+ vo4567p00 = vfmaq_lane_f32(vo4567p00, vi2x5678, vw89, 1);
+
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw0123, 3);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi2x5678, vw4567, 2);
+ vo4567p01 = vfmaq_lane_f32(vo4567p01, vi3x5678, vw89, 1);
+
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi2x5678, vw0123, 3);
+ vo4567p02 = vfmaq_laneq_f32(vo4567p02, vi3x5678, vw4567, 2);
+ vo4567p02 = vfmaq_lane_f32(vo4567p02, vi4x5678, vw89, 1);
+
+ float32x4_t vo0 = vo4567p00;
+ float32x4_t vo1 = vo4567p01;
+ float32x4_t vo2 = vo4567p02;
+
+ vo0 = vmaxq_f32(vo0, vmin);
+ vo0 = vminq_f32(vo0, vmax);
+ vo1 = vmaxq_f32(vo1, vmin);
+ vo1 = vminq_f32(vo1, vmax);
+ vo2 = vmaxq_f32(vo2, vmin);
+ vo2 = vminq_f32(vo2, vmax);
+
+ if XNN_LIKELY(k & 4) {
+ vst1q_f32(output0, vo0);
+ vst1q_f32(output1, vo1);
+ vst1q_f32(output2, vo2);
+ } else {
+ float* output0_lo = output0;
+ float* output1_lo = output1;
+ float* output2_lo = output2;
+ float32x2_t vo0_lo = vget_low_f32(vo0);
+ float32x2_t vo1_lo = vget_low_f32(vo1);
+ float32x2_t vo2_lo = vget_low_f32(vo2);
+ if (k & 2) {
+ vst1_f32(output0_lo, vo0_lo); output0_lo += 2;
+ vst1_f32(output1_lo, vo1_lo); output1_lo += 2;
+ vst1_f32(output2_lo, vo2_lo); output2_lo += 2;
+ vo0_lo = vget_high_f32(vo0);
+ vo1_lo = vget_high_f32(vo1);
+ vo2_lo = vget_high_f32(vo2);
+ }
+ if (k & 1) {
+ vst1_lane_f32(output0_lo, vo0_lo, 0);
+ vst1_lane_f32(output1_lo, vo1_lo, 0);
+ vst1_lane_f32(output2_lo, vo2_lo, 0);
+ }
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_width_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_width_increment);
+ i2 = (const float*) ((uintptr_t) i2 + input_width_increment);
+ i3 = (const float*) ((uintptr_t) i3 + input_width_increment);
+ i4 = (const float*) ((uintptr_t) i4 + input_width_increment);
+ output0 = (float*) ((uintptr_t) output0 + output_width_increment);
+ output1 = (float*) ((uintptr_t) output1 + output_width_increment);
+ output2 = (float*) ((uintptr_t) output2 + output_width_increment);
+ m -= 3;
+ }
+
+ while (m != 0) {
+ float32x4_t vi0x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+
+ size_t k = n;
+ for (; k > 4; k -= 4) {
+ float32x4_t vo4567p0 = vdupq_laneq_f32(vw0123, 0);
+
+ const float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ const float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ const float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+
+ vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x4567, vw0123, 2);
+ float32x4_t vo4567p1 = vmulq_laneq_f32(vi1x4567, vw4567, 1);
+ float32x4_t vo4567p2 = vmulq_lane_f32(vi2x4567, vw89, 0);
+
+ const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3);
+ const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3);
+ const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3);
+
+ vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x3456, vw0123, 1);
+ vo4567p1 = vfmaq_laneq_f32(vo4567p1, vi1x3456, vw4567, 0);
+ vo4567p2 = vfmaq_laneq_f32(vo4567p2, vi2x3456, vw4567, 3);
+
+ vi0x0123 = vi0x4567;
+ vi1x0123 = vi1x4567;
+ vi2x0123 = vi2x4567;
+
+ const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vi0x89AB, 1);
+ const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vi1x89AB, 1);
+ const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vi2x89AB, 1);
+
+ vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x5678, vw0123, 3);
+ vo4567p1 = vfmaq_laneq_f32(vo4567p1, vi1x5678, vw4567, 2);
+ vo4567p2 = vfmaq_lane_f32(vo4567p2, vi2x5678, vw89, 1);
+
+ vi0x4567 = vi0x89AB;
+ vi1x4567 = vi1x89AB;
+ vi2x4567 = vi2x89AB;
+
+ float32x4_t vo = vaddq_f32(vo4567p0, vo4567p1);
+ vo = vaddq_f32(vo, vo4567p2);
+
+ vo = vmaxq_f32(vo, vmin);
+ vo = vminq_f32(vo, vmax);
+
+ vst1q_f32(output0, vo); output0 = (float*) ((uintptr_t) output0 + output_tuple_stride);
+ }
+ /* Always process the last block of 1..4 pixels */
+ assert(k >= 1);
+ assert(k <= 4);
+ {
+ float32x4_t vo4567p0 = vdupq_laneq_f32(vw0123, 0);
+
+ vi0x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0x4567)));
+ vi1x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1x4567)));
+ vi2x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2x4567)));
+
+ vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x4567, vw0123, 2);
+ float32x4_t vo4567p1 = vmulq_laneq_f32(vi1x4567, vw4567, 1);
+ float32x4_t vo4567p2 = vmulq_lane_f32(vi2x4567, vw89, 0);
+
+ const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3);
+ const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3);
+ const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3);
+
+ vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x3456, vw0123, 1);
+ vo4567p1 = vfmaq_laneq_f32(vo4567p1, vi1x3456, vw4567, 0);
+ vo4567p2 = vfmaq_laneq_f32(vo4567p2, vi2x3456, vw4567, 3);
+
+ const float32x4_t vzero = vmovq_n_f32(0.0f);
+ const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vzero, 1);
+ const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vzero, 1);
+ const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vzero, 1);
+
+ vo4567p0 = vfmaq_laneq_f32(vo4567p0, vi0x5678, vw0123, 3);
+ vo4567p1 = vfmaq_laneq_f32(vo4567p1, vi1x5678, vw4567, 2);
+ vo4567p2 = vfmaq_lane_f32(vo4567p2, vi2x5678, vw89, 1);
+
+ float32x4_t vo = vaddq_f32(vo4567p0, vo4567p1);
+ vo = vaddq_f32(vo, vo4567p2);
+
+ vo = vmaxq_f32(vo, vmin);
+ vo = vminq_f32(vo, vmax);
+
+ if XNN_LIKELY(k & 4) {
+ vst1q_f32(output0, vo);
+ } else {
+ float* output0_lo = output0;
+ float32x2_t vo_lo = vget_low_f32(vo);
+ if (k & 2) {
+ vst1_f32(output0_lo, vo_lo); output0_lo += 2;
+ vo_lo = vget_high_f32(vo);
+ }
+ if (k & 1) {
+ vst1_lane_f32(output0_lo, vo_lo, 0);
+ }
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_width_increment_single);
+ i1 = (const float*) ((uintptr_t) i1 + input_width_increment_single);
+ i2 = (const float*) ((uintptr_t) i2 + input_width_increment_single);
+ output0 = (float*) ((uintptr_t) output0 + output_width_increment_single);
+ m -= 1;
+ }
+}
diff --git a/src/f32-dwconv-spchw/3x3p1-sse.c b/src/f32-dwconv-spchw/3x3p1-sse.c
new file mode 100644
index 0000000..6507fce
--- /dev/null
+++ b/src/f32-dwconv-spchw/3x3p1-sse.c
@@ -0,0 +1,216 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_spchw_ukernel_3x3p1__sse(
+ size_t m,
+ size_t n,
+ const float* input,
+ const float* weights,
+ float* output,
+ size_t input_tuple_stride,
+ size_t output_tuple_stride,
+ size_t input_width_stride,
+ size_t output_width_stride,
+ const union xnn_f32_spchw_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const __m128 vmask = _mm_load_ps((const float*) params->sse.mask);
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+
+ const size_t input_width_increment = input_width_stride - round_up_po2(n, 4) / 4 * input_tuple_stride;
+ const size_t output_width_increment = output_width_stride - (n - 1) / 4 * output_tuple_stride;
+
+ /* No vertical padding */
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride);
+
+ const __m128 vbias = _mm_load1_ps(weights);
+ const __m128 vk00 = _mm_load1_ps(weights + 1);
+ const __m128 vk01 = _mm_load1_ps(weights + 2);
+ const __m128 vk02 = _mm_load1_ps(weights + 3);
+ const __m128 vk10 = _mm_load1_ps(weights + 4);
+ const __m128 vk11 = _mm_load1_ps(weights + 5);
+ const __m128 vk12 = _mm_load1_ps(weights + 6);
+ const __m128 vk20 = _mm_load1_ps(weights + 7);
+ const __m128 vk21 = _mm_load1_ps(weights + 8);
+ const __m128 vk22 = _mm_load1_ps(weights + 9);
+
+ do {
+ /* vi0x3012 = ( vi02, vi01, vi00, vi03 ) */
+ __m128 vi0x3012 = _mm_setzero_ps();
+ /* vi1x3012 = ( vi12, vi11, vi10, vi13 ) */
+ __m128 vi1x3012 = _mm_setzero_ps();
+ /* vi2x3012 = ( vi22, vi21, vi20, vi13 ) */
+ __m128 vi2x3012 = _mm_setzero_ps();
+ /* vi0x4567 = ( vi07, vi06, vi05, vi04 ) */
+ __m128 vi0x4567 = _mm_loadu_ps(i0);
+ i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ /* vi1x4567 = ( vi17, vi16, vi15, vi14 ) */
+ __m128 vi1x4567 = _mm_loadu_ps(i1);
+ i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ /* vi2x4567 = ( vi27, vi26, vi25, vi24 ) */
+ __m128 vi2x4567 = _mm_loadu_ps(i2);
+ i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+
+ size_t k = n;
+ for (; k > 4; k -= 4) {
+ __m128 vo4567p0 = vbias;
+
+ /* vi0x89AB = ( vi0B, vi0A, vi09, vi08 ) */
+ const __m128 vi0x89AB = _mm_loadu_ps(i0);
+ i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ /* vi1x89AB = ( vi1B, vi0A, vi09, vi08 ) */
+ const __m128 vi1x89AB = _mm_loadu_ps(i1);
+ i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ /* vi2x89AB = ( vi2B, vi0A, vi09, vi08 ) */
+ const __m128 vi2x89AB = _mm_loadu_ps(i2);
+ i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+
+ /* vi0x7456 = ( vi06, vi05, vi04, vi07 ) */
+ const __m128 vi0x7456 = _mm_shuffle_ps(vi0x4567, vi0x4567, _MM_SHUFFLE(2, 1, 0, 3));
+ /* vi1x7456 = ( vi16, vi15, vi14, vi17 ) */
+ const __m128 vi1x7456 = _mm_shuffle_ps(vi1x4567, vi1x4567, _MM_SHUFFLE(2, 1, 0, 3));
+ /* vi2x7456 = ( vi26, vi25, vi24, vi27 ) */
+ const __m128 vi2x7456 = _mm_shuffle_ps(vi2x4567, vi2x4567, _MM_SHUFFLE(2, 1, 0, 3));
+
+ vo4567p0 = _mm_add_ps(vo4567p0, _mm_mul_ps(vi0x4567, vk01));
+ __m128 vo4567p1 = _mm_mul_ps(vi1x4567, vk11);
+ __m128 vo4567p2 = _mm_mul_ps(vi2x4567, vk21);
+
+ /* vi0x3456 = ( vi06, vi05, vi04, vi03 ) */
+ const __m128 vi0x3456 = _mm_move_ss(vi0x7456, vi0x3012);
+ /* vi1x3456 = ( vi16, vi15, vi14, vi13 ) */
+ const __m128 vi1x3456 = _mm_move_ss(vi1x7456, vi1x3012);
+ /* vi2x3456 = ( vi26, vi25, vi24, vi23 ) */
+ const __m128 vi2x3456 = _mm_move_ss(vi2x7456, vi2x3012);
+
+ vo4567p0 = _mm_add_ps(vo4567p0, _mm_mul_ps(vi0x3456, vk00));
+ vo4567p1 = _mm_add_ps(vo4567p1, _mm_mul_ps(vi1x3456, vk10));
+ vo4567p2 = _mm_add_ps(vo4567p2, _mm_mul_ps(vi2x3456, vk20));
+
+ vi0x3012 = vi0x7456;
+ vi1x3012 = vi1x7456;
+ vi2x3012 = vi2x7456;
+
+ /* vi0x8567 = ( vi07, vi06, vi05, vi08 ) */
+ const __m128 vi0x8567 = _mm_move_ss(vi0x4567, vi0x89AB);
+ /* vi1x8567 = ( vi17, vi16, vi15, vi18 ) */
+ const __m128 vi1x8567 = _mm_move_ss(vi1x4567, vi1x89AB);
+ /* vi2x8567 = ( vi27, vi26, vi25, vi28 ) */
+ const __m128 vi2x8567 = _mm_move_ss(vi2x4567, vi2x89AB);
+
+ /* vi0x5678 = ( vi08, vi07, vi06, vi05 ) */
+ const __m128 vi0x5678 = _mm_shuffle_ps(vi0x8567, vi0x8567, _MM_SHUFFLE(0, 3, 2, 1));
+ /* vi1x5678 = ( vi18, vi17, vi16, vi15 ) */
+ const __m128 vi1x5678 = _mm_shuffle_ps(vi1x8567, vi1x8567, _MM_SHUFFLE(0, 3, 2, 1));
+ /* vi2x5678 = ( vi28, vi27, vi26, vi25 ) */
+ const __m128 vi2x5678 = _mm_shuffle_ps(vi2x8567, vi2x8567, _MM_SHUFFLE(0, 3, 2, 1));
+
+ vo4567p0 = _mm_add_ps(vo4567p0, _mm_mul_ps(vi0x5678, vk02));
+ vo4567p1 = _mm_add_ps(vo4567p1, _mm_mul_ps(vi1x5678, vk12));
+ vo4567p2 = _mm_add_ps(vo4567p2, _mm_mul_ps(vi2x5678, vk22));
+
+ vi0x4567 = vi0x89AB;
+ vi1x4567 = vi1x89AB;
+ vi2x4567 = vi2x89AB;
+
+ __m128 vo = _mm_add_ps(vo4567p0, vo4567p1);
+ vo = _mm_add_ps(vo, vo4567p2);
+
+ vo = _mm_max_ps(vo, vmin);
+ vo = _mm_min_ps(vo, vmax);
+
+ _mm_storeu_ps(output, vo);
+ output = (float*) ((uintptr_t) output + output_tuple_stride);
+ }
+ /* Always process the last block of 1..4 pixels */
+ assert(k >= 1);
+ assert(k <= 4);
+ {
+ __m128 vo4567p0 = vbias;
+
+ vi0x4567 = _mm_and_ps(vmask, vi0x4567);
+ vi1x4567 = _mm_and_ps(vmask, vi1x4567);
+ vi2x4567 = _mm_and_ps(vmask, vi2x4567);
+
+ /* vi0x7456 = ( vi06, vi05, vi04, vi07 ) */
+ const __m128 vi0x7456 = _mm_shuffle_ps(vi0x4567, vi0x4567, _MM_SHUFFLE(2, 1, 0, 3));
+ /* vi1x7456 = ( vi16, vi15, vi14, vi17 ) */
+ const __m128 vi1x7456 = _mm_shuffle_ps(vi1x4567, vi1x4567, _MM_SHUFFLE(2, 1, 0, 3));
+ /* vi2x7456 = ( vi26, vi25, vi24, vi27 ) */
+ const __m128 vi2x7456 = _mm_shuffle_ps(vi2x4567, vi2x4567, _MM_SHUFFLE(2, 1, 0, 3));
+
+ vo4567p0 = _mm_add_ps(vo4567p0, _mm_mul_ps(vi0x4567, vk01));
+ __m128 vo4567p1 = _mm_mul_ps(vi1x4567, vk11);
+ __m128 vo4567p2 = _mm_mul_ps(vi2x4567, vk21);
+
+ /* vi0x3456 = ( vi06, vi05, vi04, vi03 ) */
+ const __m128 vi0x3456 = _mm_move_ss(vi0x7456, vi0x3012);
+ /* vi1x3456 = ( vi16, vi15, vi14, vi13 ) */
+ const __m128 vi1x3456 = _mm_move_ss(vi1x7456, vi1x3012);
+ /* vi2x3456 = ( vi26, vi25, vi24, vi23 ) */
+ const __m128 vi2x3456 = _mm_move_ss(vi2x7456, vi2x3012);
+
+ vo4567p0 = _mm_add_ps(vo4567p0, _mm_mul_ps(vi0x3456, vk00));
+ vo4567p1 = _mm_add_ps(vo4567p1, _mm_mul_ps(vi1x3456, vk10));
+ vo4567p2 = _mm_add_ps(vo4567p2, _mm_mul_ps(vi2x3456, vk20));
+
+ const __m128 vzero = _mm_setzero_ps();
+ /* vi0x8567 = ( vi07, vi06, vi05, 0.0 ) */
+ const __m128 vi0x8567 = _mm_move_ss(vi0x4567, vzero);
+ /* vi1x8567 = ( vi17, vi16, vi15, 0.0 ) */
+ const __m128 vi1x8567 = _mm_move_ss(vi1x4567, vzero);
+ /* vi2x8567 = ( vi27, vi26, vi25, 0.0 ) */
+ const __m128 vi2x8567 = _mm_move_ss(vi2x4567, vzero);
+
+ /* vi0x5678 = ( vi08, vi07, vi06, vi05 ) */
+ const __m128 vi0x5678 = _mm_shuffle_ps(vi0x8567, vi0x8567, _MM_SHUFFLE(0, 3, 2, 1));
+ /* vi1x5678 = ( vi18, vi17, vi16, vi15 ) */
+ const __m128 vi1x5678 = _mm_shuffle_ps(vi1x8567, vi1x8567, _MM_SHUFFLE(0, 3, 2, 1));
+ /* vi2x5678 = ( vi28, vi27, vi26, vi25 ) */
+ const __m128 vi2x5678 = _mm_shuffle_ps(vi2x8567, vi2x8567, _MM_SHUFFLE(0, 3, 2, 1));
+
+ vo4567p0 = _mm_add_ps(vo4567p0, _mm_mul_ps(vi0x5678, vk02));
+ vo4567p1 = _mm_add_ps(vo4567p1, _mm_mul_ps(vi1x5678, vk12));
+ vo4567p2 = _mm_add_ps(vo4567p2, _mm_mul_ps(vi2x5678, vk22));
+
+ __m128 vo = _mm_add_ps(vo4567p0, vo4567p1);
+ vo = _mm_add_ps(vo, vo4567p2);
+
+ vo = _mm_max_ps(vo, vmin);
+ vo = _mm_min_ps(vo, vmax);
+
+ if XNN_LIKELY(k == 4) {
+ _mm_storeu_ps(output, vo);
+ } else {
+ float* output_lo = output;
+ if (k & 2) {
+ _mm_storel_pi((__m64*) output_lo, vo);
+ output_lo += 2;
+ vo = _mm_movehl_ps(vo, vo);
+ }
+ if (k & 1) {
+ _mm_store_ss(output_lo, vo);
+ }
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_width_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_width_increment);
+ i2 = (const float*) ((uintptr_t) i2 + input_width_increment);
+ output = (float*) ((uintptr_t) output + output_width_increment);
+ } while (--m != 0);
+}
diff --git a/src/f32-dwconv-spchw/3x3s2p1-neonfma.c b/src/f32-dwconv-spchw/3x3s2p1-neonfma.c
new file mode 100644
index 0000000..008b5fb
--- /dev/null
+++ b/src/f32-dwconv-spchw/3x3s2p1-neonfma.c
@@ -0,0 +1,172 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_spchw_ukernel_3x3s2p1__neonfma(
+ size_t m,
+ size_t n,
+ const float* input,
+ const float* weights,
+ float* output,
+ size_t input_tuple_stride,
+ size_t output_tuple_stride,
+ size_t input_width_stride,
+ size_t output_width_stride,
+ const union xnn_f32_spchw_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const uint32x4_t vmask_even = vld1q_u32(params->neon.mask_even);
+ const uint32x4_t vmask_odd = vld1q_u32(params->neon.mask_odd);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->neon.max);
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->neon.min);
+
+ const size_t input_width_increment = input_width_stride * 2 - n / 8 * input_tuple_stride * 2;
+ const size_t output_width_increment = output_width_stride - n / 8 * output_tuple_stride;
+
+ /* No vertical padding */
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride);
+
+ const float32x4_t vw0123 = vld1q_f32(weights);
+ const float32x4_t vw4567 = vld1q_f32(weights + 4);
+ const float32x2_t vw89 = vld1_f32(weights + 8);
+
+ do {
+ float32x4_t vi0x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x0123 = vmovq_n_f32(0.0f);
+
+ size_t k = n;
+ for (; k >= 8; k -= 8) {
+ // bias
+ float32x4_t vo468Ap0 = vdupq_laneq_f32(vw0123, 0);
+
+ const float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ const float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ const float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+
+ const float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ const float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ const float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+
+ const float32x4_t vi0x468A = vuzp1q_f32(vi0x4567, vi0x89AB);
+ const float32x4_t vi0x579B = vuzp2q_f32(vi0x4567, vi0x89AB);
+ const float32x4_t vi1x468A = vuzp1q_f32(vi1x4567, vi1x89AB);
+ const float32x4_t vi1x579B = vuzp2q_f32(vi1x4567, vi1x89AB);
+ const float32x4_t vi2x468A = vuzp1q_f32(vi2x4567, vi2x89AB);
+ const float32x4_t vi2x579B = vuzp2q_f32(vi2x4567, vi2x89AB);
+ // add bias only to first row, it will then get added
+ // to the final result
+ // multiply each row by corresponding row of center column of filter
+ vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x468A, vw0123, 2);
+ float32x4_t vo468Ap1 = vmulq_laneq_f32(vi1x468A, vw4567, 1);
+ float32x4_t vo468Ap2 = vmulq_lane_f32(vi2x468A, vw89, 0);
+
+ // grab the values corresponding the left filter tap
+ const float32x4_t vi0x3579 = vextq_f32(vi0x0123, vi0x579B, 3);
+ const float32x4_t vi1x3579 = vextq_f32(vi1x0123, vi1x579B, 3);
+ const float32x4_t vi2x3579 = vextq_f32(vi2x0123, vi2x579B, 3);
+
+ vi0x0123 = vi0x89AB;
+ vi1x0123 = vi1x89AB;
+ vi2x0123 = vi2x89AB;
+
+ vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x3579, vw0123, 1);
+ vo468Ap1 = vfmaq_laneq_f32(vo468Ap1, vi1x3579, vw4567, 0);
+ vo468Ap2 = vfmaq_laneq_f32(vo468Ap2, vi2x3579, vw4567, 3);
+
+ // do multiplication by right filter tap
+ vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x579B, vw0123, 3);
+ vo468Ap1 = vfmaq_laneq_f32(vo468Ap1, vi1x579B, vw4567, 2);
+ vo468Ap2 = vfmaq_lane_f32 (vo468Ap2, vi2x579B, vw89, 1);
+
+ // add up across rows to get the final outputs
+ float32x4_t vo = vaddq_f32(vo468Ap0, vo468Ap1);
+ vo = vaddq_f32(vo, vo468Ap2);
+
+ vo = vmaxq_f32(vo, vmin);
+ vo = vminq_f32(vo, vmax);
+
+ vst1q_f32(output, vo); output = (float*) ((uintptr_t) output + output_tuple_stride);
+ }
+ /* Last block has 0-7 pixels to process */
+ assert(k < 8);
+ if XNN_LIKELY(k != 0) {
+ // bias
+ float32x4_t vo468Ap0 = vdupq_laneq_f32(vw0123, 0);
+
+ const float32x4_t vi0x4567 = vld1q_f32(i0);
+ const float32x4_t vi1x4567 = vld1q_f32(i1);
+ const float32x4_t vi2x4567 = vld1q_f32(i2);
+
+ const float32x4_t vi0x89AB = vld1q_f32((const float*) ((uintptr_t) i0 + input_tuple_stride));
+ const float32x4_t vi1x89AB = vld1q_f32((const float*) ((uintptr_t) i1 + input_tuple_stride));
+ const float32x4_t vi2x89AB = vld1q_f32((const float*) ((uintptr_t) i2 + input_tuple_stride));
+
+ const float32x4_t vi0x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vuzp1q_f32(vi0x4567, vi0x89AB))));
+ const float32x4_t vi0x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vuzp2q_f32(vi0x4567, vi0x89AB))));
+ const float32x4_t vi1x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vuzp1q_f32(vi1x4567, vi1x89AB))));
+ const float32x4_t vi1x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vuzp2q_f32(vi1x4567, vi1x89AB))));
+ const float32x4_t vi2x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vuzp1q_f32(vi2x4567, vi2x89AB))));
+ const float32x4_t vi2x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vuzp2q_f32(vi2x4567, vi2x89AB))));
+ // add bias only to first row, it will then get added
+ // to the final result
+ // multiply each row by corresponding row of center column of filter
+ vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x468A, vw0123, 2);
+ float32x4_t vo468Ap1 = vmulq_laneq_f32(vi1x468A, vw4567, 1);
+ float32x4_t vo468Ap2 = vmulq_lane_f32(vi2x468A, vw89, 0);
+
+ // grab the values corresponding the left filter tap
+ const float32x4_t vi0x3579 = vextq_f32(vi0x0123, vi0x579B, 3);
+ const float32x4_t vi1x3579 = vextq_f32(vi1x0123, vi1x579B, 3);
+ const float32x4_t vi2x3579 = vextq_f32(vi2x0123, vi2x579B, 3);
+
+ vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x3579, vw0123, 1);
+ vo468Ap1 = vfmaq_laneq_f32(vo468Ap1, vi1x3579, vw4567, 0);
+ vo468Ap2 = vfmaq_laneq_f32(vo468Ap2, vi2x3579, vw4567, 3);
+
+ // do multiplication by right filter tap
+ vo468Ap0 = vfmaq_laneq_f32(vo468Ap0, vi0x579B, vw0123, 3);
+ vo468Ap1 = vfmaq_laneq_f32(vo468Ap1, vi1x579B, vw4567, 2);
+ vo468Ap2 = vfmaq_lane_f32 (vo468Ap2, vi2x579B, vw89, 1);
+
+ // add up across rows to get the final outputs
+ float32x4_t vo = vaddq_f32(vo468Ap0, vo468Ap1);
+ vo = vaddq_f32(vo, vo468Ap2);
+
+ vo = vmaxq_f32(vo, vmin);
+ vo = vminq_f32(vo, vmax);
+
+ k += 1;
+ if (k & 8) {
+ vst1q_f32(output, vo);
+ } else {
+ float* output_lo = output;
+ float32x2_t vo_lo = vget_low_f32(vo);
+ if (k & 4) {
+ vst1_f32(output_lo, vo_lo); output_lo += 2;
+ vo_lo = vget_high_f32(vo);
+ }
+ if (k & 2) {
+ vst1_lane_f32(output_lo, vo_lo, 0);
+ }
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_width_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_width_increment);
+ i2 = (const float*) ((uintptr_t) i2 + input_width_increment);
+ output = (float*) ((uintptr_t) output + output_width_increment);
+ } while (--m != 0);
+}
diff --git a/src/f32-dwconv-spchw/3x3s2p1-sse.c b/src/f32-dwconv-spchw/3x3s2p1-sse.c
new file mode 100644
index 0000000..204dc52
--- /dev/null
+++ b/src/f32-dwconv-spchw/3x3s2p1-sse.c
@@ -0,0 +1,182 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_spchw_ukernel_3x3s2p1__sse(
+ size_t m,
+ size_t n,
+ const float* input,
+ const float* weights,
+ float* output,
+ size_t input_tuple_stride,
+ size_t output_tuple_stride,
+ size_t input_width_stride,
+ size_t output_width_stride,
+ const union xnn_f32_spchw_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const __m128 vmask_even = _mm_load_ps((const float*) params->sse.mask_even);
+ const __m128 vmask_odd = _mm_load_ps((const float*) params->sse.mask_odd);
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+
+ const size_t input_width_increment = input_width_stride * 2 - n / 8 * input_tuple_stride * 2;
+ const size_t output_width_increment = output_width_stride - n / 8 * output_tuple_stride;
+
+ /* No vertical padding */
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride);
+
+ const __m128 vbias = _mm_load1_ps(weights);
+ const __m128 vk00 = _mm_load1_ps(weights + 1);
+ const __m128 vk01 = _mm_load1_ps(weights + 2);
+ const __m128 vk02 = _mm_load1_ps(weights + 3);
+ const __m128 vk10 = _mm_load1_ps(weights + 4);
+ const __m128 vk11 = _mm_load1_ps(weights + 5);
+ const __m128 vk12 = _mm_load1_ps(weights + 6);
+ const __m128 vk20 = _mm_load1_ps(weights + 7);
+ const __m128 vk21 = _mm_load1_ps(weights + 8);
+ const __m128 vk22 = _mm_load1_ps(weights + 9);
+
+ do {
+ __m128 vi0x7531 = _mm_setzero_ps();
+ __m128 vi1x7531 = _mm_setzero_ps();
+ __m128 vi2x7531 = _mm_setzero_ps();
+
+ size_t k = n;
+ for (; k >= 8; k -= 8) {
+ __m128 vo8ACEp0 = vbias;
+
+ const __m128 vi0x89AB = _mm_loadu_ps(i0);
+ i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ const __m128 vi1x89AB = _mm_loadu_ps(i1);
+ i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ const __m128 vi2x89AB = _mm_loadu_ps(i2);
+ i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+
+ const __m128 vi0xCDEF = _mm_loadu_ps(i0);
+ i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ const __m128 vi1xCDEF = _mm_loadu_ps(i1);
+ i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ const __m128 vi2xCDEF = _mm_loadu_ps(i2);
+ i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+
+ const __m128 vi0x8ACE = _mm_shuffle_ps(vi0x89AB, vi0xCDEF, _MM_SHUFFLE(2, 0, 2, 0));
+ const __m128 vi0x9BDF = _mm_shuffle_ps(vi0x89AB, vi0xCDEF, _MM_SHUFFLE(3, 1, 3, 1));
+ const __m128 vi1x8ACE = _mm_shuffle_ps(vi1x89AB, vi1xCDEF, _MM_SHUFFLE(2, 0, 2, 0));
+ const __m128 vi1x9BDF = _mm_shuffle_ps(vi1x89AB, vi1xCDEF, _MM_SHUFFLE(3, 1, 3, 1));
+ const __m128 vi2x8ACE = _mm_shuffle_ps(vi2x89AB, vi2xCDEF, _MM_SHUFFLE(2, 0, 2, 0));
+ const __m128 vi2x9BDF = _mm_shuffle_ps(vi2x89AB, vi2xCDEF, _MM_SHUFFLE(3, 1, 3, 1));
+
+ vo8ACEp0 = _mm_add_ps(vo8ACEp0, _mm_mul_ps(vi0x8ACE, vk01));
+ __m128 vo8ACEp1 = _mm_mul_ps(vi1x8ACE, vk11);
+ __m128 vo8ACEp2 = _mm_mul_ps(vi2x8ACE, vk21);
+
+ const __m128 vi0xF9BD = _mm_shuffle_ps(vi0x9BDF, vi0x9BDF, _MM_SHUFFLE(2, 1, 0, 3));
+ const __m128 vi1xF9BD = _mm_shuffle_ps(vi1x9BDF, vi1x9BDF, _MM_SHUFFLE(2, 1, 0, 3));
+ const __m128 vi2xF9BD = _mm_shuffle_ps(vi2x9BDF, vi2x9BDF, _MM_SHUFFLE(2, 1, 0, 3));
+
+ vo8ACEp0 = _mm_add_ps(vo8ACEp0, _mm_mul_ps(vi0x9BDF, vk02));
+ vo8ACEp1 = _mm_add_ps(vo8ACEp1, _mm_mul_ps(vi1x9BDF, vk12));
+ vo8ACEp2 = _mm_add_ps(vo8ACEp2, _mm_mul_ps(vi2x9BDF, vk22));
+
+ const __m128 vi0x7BDF = _mm_move_ss(vi0xF9BD, vi0x7531);
+ const __m128 vi1x7BDF = _mm_move_ss(vi1xF9BD, vi1x7531);
+ const __m128 vi2x7BDF = _mm_move_ss(vi2xF9BD, vi2x7531);
+
+ vi0x7531 = vi0xF9BD;
+ vi1x7531 = vi1xF9BD;
+ vi2x7531 = vi2xF9BD;
+
+ vo8ACEp0 = _mm_add_ps(vo8ACEp0, _mm_mul_ps(vi0x7BDF, vk00));
+ vo8ACEp1 = _mm_add_ps(vo8ACEp1, _mm_mul_ps(vi1x7BDF, vk10));
+ vo8ACEp2 = _mm_add_ps(vo8ACEp2, _mm_mul_ps(vi2x7BDF, vk20));
+
+ __m128 vo = _mm_add_ps(vo8ACEp0, vo8ACEp1);
+ vo = _mm_add_ps(vo, vo8ACEp2);
+
+ vo = _mm_max_ps(vo, vmin);
+ vo = _mm_min_ps(vo, vmax);
+
+ _mm_storeu_ps(output, vo);
+ output = (float*) ((uintptr_t) output + output_tuple_stride);
+ }
+ /* Last block has 0-7 pixels to process */
+ assert(k < 8);
+ if XNN_LIKELY(k != 0) {
+ __m128 vo8ACEp0 = vbias;
+
+ const __m128 vi0x89AB = _mm_loadu_ps(i0);
+ const __m128 vi1x89AB = _mm_loadu_ps(i1);
+ const __m128 vi2x89AB = _mm_loadu_ps(i2);
+
+ const __m128 vi0xCDEF = _mm_loadu_ps((const float*) ((uintptr_t) i0 + input_tuple_stride));
+ const __m128 vi1xCDEF = _mm_loadu_ps((const float*) ((uintptr_t) i1 + input_tuple_stride));
+ const __m128 vi2xCDEF = _mm_loadu_ps((const float*) ((uintptr_t) i2 + input_tuple_stride));
+
+ const __m128 vi0x8ACE = _mm_and_ps(vmask_even, _mm_shuffle_ps(vi0x89AB, vi0xCDEF, _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128 vi0x9BDF = _mm_and_ps(vmask_odd, _mm_shuffle_ps(vi0x89AB, vi0xCDEF, _MM_SHUFFLE(3, 1, 3, 1)));
+ const __m128 vi1x8ACE = _mm_and_ps(vmask_even, _mm_shuffle_ps(vi1x89AB, vi1xCDEF, _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128 vi1x9BDF = _mm_and_ps(vmask_odd, _mm_shuffle_ps(vi1x89AB, vi1xCDEF, _MM_SHUFFLE(3, 1, 3, 1)));
+ const __m128 vi2x8ACE = _mm_and_ps(vmask_even, _mm_shuffle_ps(vi2x89AB, vi2xCDEF, _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128 vi2x9BDF = _mm_and_ps(vmask_odd, _mm_shuffle_ps(vi2x89AB, vi2xCDEF, _MM_SHUFFLE(3, 1, 3, 1)));
+
+ vo8ACEp0 = _mm_add_ps(vo8ACEp0, _mm_mul_ps(vi0x8ACE, vk01));
+ __m128 vo8ACEp1 = _mm_mul_ps(vi1x8ACE, vk11);
+ __m128 vo8ACEp2 = _mm_mul_ps(vi2x8ACE, vk21);
+
+ const __m128 vi0xF9BD = _mm_shuffle_ps(vi0x9BDF, vi0x9BDF, _MM_SHUFFLE(2, 1, 0, 3));
+ const __m128 vi1xF9BD = _mm_shuffle_ps(vi1x9BDF, vi1x9BDF, _MM_SHUFFLE(2, 1, 0, 3));
+ const __m128 vi2xF9BD = _mm_shuffle_ps(vi2x9BDF, vi2x9BDF, _MM_SHUFFLE(2, 1, 0, 3));
+
+ vo8ACEp0 = _mm_add_ps(vo8ACEp0, _mm_mul_ps(vi0x9BDF, vk02));
+ vo8ACEp1 = _mm_add_ps(vo8ACEp1, _mm_mul_ps(vi1x9BDF, vk12));
+ vo8ACEp2 = _mm_add_ps(vo8ACEp2, _mm_mul_ps(vi2x9BDF, vk22));
+
+ const __m128 vi0x7BDF = _mm_move_ss(vi0xF9BD, vi0x7531);
+ const __m128 vi1x7BDF = _mm_move_ss(vi1xF9BD, vi1x7531);
+ const __m128 vi2x7BDF = _mm_move_ss(vi2xF9BD, vi2x7531);
+
+ vo8ACEp0 = _mm_add_ps(vo8ACEp0, _mm_mul_ps(vi0x7BDF, vk00));
+ vo8ACEp1 = _mm_add_ps(vo8ACEp1, _mm_mul_ps(vi1x7BDF, vk10));
+ vo8ACEp2 = _mm_add_ps(vo8ACEp2, _mm_mul_ps(vi2x7BDF, vk20));
+
+ __m128 vo = _mm_add_ps(vo8ACEp0, vo8ACEp1);
+ vo = _mm_add_ps(vo, vo8ACEp2);
+
+ vo = _mm_max_ps(vo, vmin);
+ vo = _mm_min_ps(vo, vmax);
+
+ if (k == 7) {
+ _mm_storeu_ps(output, vo);
+ } else {
+ float* output_lo = output;
+ k += 1;
+ if (k & 4) {
+ _mm_storel_pi((__m64*) output_lo, vo);
+ output_lo += 2;
+ vo = _mm_movehl_ps(vo, vo);
+ }
+ if (k & 2) {
+ _mm_store_ss(output_lo, vo);
+ }
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_width_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_width_increment);
+ i2 = (const float*) ((uintptr_t) i2 + input_width_increment);
+ output = (float*) ((uintptr_t) output + output_width_increment);
+ } while (--m != 0);
+}
diff --git a/src/f32-dwconv-spchw/5x5p2-neonfma.c b/src/f32-dwconv-spchw/5x5p2-neonfma.c
new file mode 100644
index 0000000..ed60827
--- /dev/null
+++ b/src/f32-dwconv-spchw/5x5p2-neonfma.c
@@ -0,0 +1,338 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_spchw_ukernel_5x5p2__neonfma(
+ size_t m,
+ size_t n,
+ const float* input,
+ const float* weights,
+ float* output,
+ size_t input_tuple_stride,
+ size_t output_tuple_stride,
+ size_t input_width_stride,
+ size_t output_width_stride,
+ const union xnn_f32_spchw_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const uint32x4_t vmask = vld1q_u32(params->neon.mask);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->neon.max);
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->neon.min);
+
+ const size_t input_width_increment_single = input_width_stride - round_up_po2(n, 4) / 4 * input_tuple_stride;
+ const size_t output_width_increment_single = output_width_stride - (n - 1) / 4 * output_tuple_stride;
+
+ /* No vertical padding */
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_width_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_width_stride);
+
+ float* output0 = output;
+
+ const float32x4_t vw0123 = vld1q_f32(weights);
+ const float32x4_t vw4567 = vld1q_f32(weights + 4);
+ const float32x4_t vw89AB = vld1q_f32(weights + 8);
+ const float32x4_t vwCDEF = vld1q_f32(weights + 12);
+ const float32x4_t vwGHIJ = vld1q_f32(weights + 16);
+ const float32x4_t vwKLMN = vld1q_f32(weights + 20);
+ const float32x2_t vwOP = vld1_f32( weights + 24);
+
+ do {
+ float32x4_t vi0x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+ float32x4_t vi3x4567 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride);
+ float32x4_t vi4x4567 = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride);
+
+ size_t k = n;
+ for (; k > 8; k -= 4) {
+ float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0);
+
+ const float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ const float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ const float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+ const float32x4_t vi3x89AB = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride);
+ const float32x4_t vi4x89AB = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 3);
+ float32x4_t vo4567p01 = vmulq_laneq_f32(vi1x4567, vw89AB, 0);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x4567, vwCDEF, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x4567, vwGHIJ, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x4567, vwKLMN, 3);
+
+ const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3);
+ const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3);
+ const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3);
+ const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3);
+ const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 2);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw4567, 3);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vwCDEF, 0);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vwGHIJ, 1);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x3456, vwKLMN, 2);
+
+ const float32x4_t vi0x2345 = vextq_f32(vi0x0123, vi0x4567, 2);
+ const float32x4_t vi1x2345 = vextq_f32(vi1x0123, vi1x4567, 2);
+ const float32x4_t vi2x2345 = vextq_f32(vi2x0123, vi2x4567, 2);
+ const float32x4_t vi3x2345 = vextq_f32(vi3x0123, vi3x4567, 2);
+ const float32x4_t vi4x2345 = vextq_f32(vi4x0123, vi4x4567, 2);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x2345, vw0123, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x2345, vw4567, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x2345, vw89AB, 3);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x2345, vwGHIJ, 0);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x2345, vwKLMN, 1);
+
+ vi0x0123 = vi0x4567;
+ vi1x0123 = vi1x4567;
+ vi2x0123 = vi2x4567;
+ vi3x0123 = vi3x4567;
+ vi4x0123 = vi4x4567;
+
+ const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vi0x89AB, 1);
+ const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vi1x89AB, 1);
+ const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vi2x89AB, 1);
+ const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vi3x89AB, 1);
+ const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vi4x89AB, 1);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw4567, 0);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw89AB, 1);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x5678, vwCDEF, 2);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x5678, vwGHIJ, 3);
+ vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x5678, vwOP, 0);
+
+ const float32x4_t vi0x6789 = vextq_f32(vi0x4567, vi0x89AB, 2);
+ const float32x4_t vi1x6789 = vextq_f32(vi1x4567, vi1x89AB, 2);
+ const float32x4_t vi2x6789 = vextq_f32(vi2x4567, vi2x89AB, 2);
+ const float32x4_t vi3x6789 = vextq_f32(vi3x4567, vi3x89AB, 2);
+ const float32x4_t vi4x6789 = vextq_f32(vi4x4567, vi4x89AB, 2);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x6789, vw4567, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x6789, vw89AB, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x6789, vwCDEF, 3);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x6789, vwKLMN, 0);
+ vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x6789, vwOP, 1);
+
+ vi0x4567 = vi0x89AB;
+ vi1x4567 = vi1x89AB;
+ vi2x4567 = vi2x89AB;
+ vi3x4567 = vi3x89AB;
+ vi4x4567 = vi4x89AB;
+
+ vo4567p00 = vaddq_f32(vo4567p00, vo4567p01);
+
+ float32x4_t vo0 = vo4567p00;
+
+ vo0 = vmaxq_f32(vo0, vmin);
+ vo0 = vminq_f32(vo0, vmax);
+
+ vst1q_f32(output0, vo0); output0 = (float*) ((uintptr_t) output0 + output_tuple_stride);
+ }
+ /* Always process the last block of 5..8 pixels */
+ if XNN_LIKELY(k > 4)
+ {
+ float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0);
+
+ float32x4_t vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ float32x4_t vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ float32x4_t vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+ float32x4_t vi3x89AB = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride);
+ float32x4_t vi4x89AB = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride);
+
+ vi0x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0x89AB)));
+ vi1x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1x89AB)));
+ vi2x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2x89AB)));
+ vi3x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi3x89AB)));
+ vi4x89AB = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi4x89AB)));
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 3);
+ float32x4_t vo4567p01 = vmulq_laneq_f32(vi1x4567, vw89AB, 0);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x4567, vwCDEF, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x4567, vwGHIJ, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x4567, vwKLMN, 3);
+
+ const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3);
+ const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3);
+ const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3);
+ const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3);
+ const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 2);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw4567, 3);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vwCDEF, 0);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vwGHIJ, 1);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x3456, vwKLMN, 2);
+
+ const float32x4_t vi0x2345 = vextq_f32(vi0x0123, vi0x4567, 2);
+ const float32x4_t vi1x2345 = vextq_f32(vi1x0123, vi1x4567, 2);
+ const float32x4_t vi2x2345 = vextq_f32(vi2x0123, vi2x4567, 2);
+ const float32x4_t vi3x2345 = vextq_f32(vi3x0123, vi3x4567, 2);
+ const float32x4_t vi4x2345 = vextq_f32(vi4x0123, vi4x4567, 2);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x2345, vw0123, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x2345, vw4567, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x2345, vw89AB, 3);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x2345, vwGHIJ, 0);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x2345, vwKLMN, 1);
+
+ vi0x0123 = vi0x4567;
+ vi1x0123 = vi1x4567;
+ vi2x0123 = vi2x4567;
+ vi3x0123 = vi3x4567;
+ vi4x0123 = vi4x4567;
+
+ const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vi0x89AB, 1);
+ const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vi1x89AB, 1);
+ const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vi2x89AB, 1);
+ const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vi3x89AB, 1);
+ const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vi4x89AB, 1);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw4567, 0);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw89AB, 1);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x5678, vwCDEF, 2);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x5678, vwGHIJ, 3);
+ vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x5678, vwOP, 0);
+
+ const float32x4_t vi0x6789 = vextq_f32(vi0x4567, vi0x89AB, 2);
+ const float32x4_t vi1x6789 = vextq_f32(vi1x4567, vi1x89AB, 2);
+ const float32x4_t vi2x6789 = vextq_f32(vi2x4567, vi2x89AB, 2);
+ const float32x4_t vi3x6789 = vextq_f32(vi3x4567, vi3x89AB, 2);
+ const float32x4_t vi4x6789 = vextq_f32(vi4x4567, vi4x89AB, 2);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x6789, vw4567, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x6789, vw89AB, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x6789, vwCDEF, 3);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x6789, vwKLMN, 0);
+ vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x6789, vwOP, 1);
+
+ vi0x4567 = vi0x89AB;
+ vi1x4567 = vi1x89AB;
+ vi2x4567 = vi2x89AB;
+ vi3x4567 = vi3x89AB;
+ vi4x4567 = vi4x89AB;
+
+ vo4567p00 = vaddq_f32(vo4567p00, vo4567p01);
+ float32x4_t vo0 = vo4567p00;
+
+ vo0 = vmaxq_f32(vo0, vmin);
+ vo0 = vminq_f32(vo0, vmax);
+
+ vst1q_f32(output0, vo0); output0 = (float*) ((uintptr_t) output0 + output_tuple_stride);
+ k -= 4;
+ }
+ assert(k >= 1);
+ assert(k <= 4);
+ {
+ float32x4_t vo4567p00 = vdupq_laneq_f32(vw0123, 0);
+
+ // This might have already happened if there are more than 4 pixels, but
+ // we can't count on it.
+ vi0x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0x4567)));
+ vi1x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1x4567)));
+ vi2x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2x4567)));
+ vi3x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi3x4567)));
+ vi4x4567 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi4x4567)));
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x4567, vw0123, 3);
+ float32x4_t vo4567p01 = vmulq_laneq_f32(vi1x4567, vw89AB, 0);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x4567, vwCDEF, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x4567, vwGHIJ, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x4567, vwKLMN, 3);
+
+ const float32x4_t vi0x3456 = vextq_f32(vi0x0123, vi0x4567, 3);
+ const float32x4_t vi1x3456 = vextq_f32(vi1x0123, vi1x4567, 3);
+ const float32x4_t vi2x3456 = vextq_f32(vi2x0123, vi2x4567, 3);
+ const float32x4_t vi3x3456 = vextq_f32(vi3x0123, vi3x4567, 3);
+ const float32x4_t vi4x3456 = vextq_f32(vi4x0123, vi4x4567, 3);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x3456, vw0123, 2);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x3456, vw4567, 3);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x3456, vwCDEF, 0);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x3456, vwGHIJ, 1);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x3456, vwKLMN, 2);
+
+ const float32x4_t vi0x2345 = vextq_f32(vi0x0123, vi0x4567, 2);
+ const float32x4_t vi1x2345 = vextq_f32(vi1x0123, vi1x4567, 2);
+ const float32x4_t vi2x2345 = vextq_f32(vi2x0123, vi2x4567, 2);
+ const float32x4_t vi3x2345 = vextq_f32(vi3x0123, vi3x4567, 2);
+ const float32x4_t vi4x2345 = vextq_f32(vi4x0123, vi4x4567, 2);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x2345, vw0123, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x2345, vw4567, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x2345, vw89AB, 3);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x2345, vwGHIJ, 0);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi4x2345, vwKLMN, 1);
+
+ const float32x4_t vzero = vmovq_n_f32(0.0f);
+ const float32x4_t vi0x5678 = vextq_f32(vi0x4567, vzero, 1);
+ const float32x4_t vi1x5678 = vextq_f32(vi1x4567, vzero, 1);
+ const float32x4_t vi2x5678 = vextq_f32(vi2x4567, vzero, 1);
+ const float32x4_t vi3x5678 = vextq_f32(vi3x4567, vzero, 1);
+ const float32x4_t vi4x5678 = vextq_f32(vi4x4567, vzero, 1);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x5678, vw4567, 0);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x5678, vw89AB, 1);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x5678, vwCDEF, 2);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x5678, vwGHIJ, 3);
+ vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x5678, vwOP, 0);
+
+ const float32x4_t vi0x6789 = vextq_f32(vi0x4567, vzero, 2);
+ const float32x4_t vi1x6789 = vextq_f32(vi1x4567, vzero, 2);
+ const float32x4_t vi2x6789 = vextq_f32(vi2x4567, vzero, 2);
+ const float32x4_t vi3x6789 = vextq_f32(vi3x4567, vzero, 2);
+ const float32x4_t vi4x6789 = vextq_f32(vi4x4567, vzero, 2);
+
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi0x6789, vw4567, 1);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi1x6789, vw89AB, 2);
+ vo4567p00 = vfmaq_laneq_f32(vo4567p00, vi2x6789, vwCDEF, 3);
+ vo4567p01 = vfmaq_laneq_f32(vo4567p01, vi3x6789, vwKLMN, 0);
+ vo4567p00 = vfmaq_lane_f32( vo4567p00, vi4x6789, vwOP, 1);
+
+ vo4567p00 = vaddq_f32(vo4567p00, vo4567p01);
+ float32x4_t vo0 = vo4567p00;
+
+ vo0 = vmaxq_f32(vo0, vmin);
+ vo0 = vminq_f32(vo0, vmax);
+
+ if XNN_LIKELY(k & 4) {
+ vst1q_f32(output0, vo0);
+ } else {
+ float* output0_lo = output0;
+ float32x2_t vo0_lo = vget_low_f32(vo0);
+ if (k & 2) {
+ vst1_f32(output0_lo, vo0_lo); output0_lo += 2;
+ vo0_lo = vget_high_f32(vo0);
+ }
+ if (k & 1) {
+ vst1_lane_f32(output0_lo, vo0_lo, 0);
+ }
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_width_increment_single);
+ i1 = (const float*) ((uintptr_t) i1 + input_width_increment_single);
+ i2 = (const float*) ((uintptr_t) i2 + input_width_increment_single);
+ i3 = (const float*) ((uintptr_t) i3 + input_width_increment_single);
+ i4 = (const float*) ((uintptr_t) i4 + input_width_increment_single);
+ output0 = (float*) ((uintptr_t) output0 + output_width_increment_single);
+ m -= 1;
+ } while (m > 0);
+}
diff --git a/src/f32-dwconv-spchw/5x5s2p2-neonfma.c b/src/f32-dwconv-spchw/5x5s2p2-neonfma.c
new file mode 100644
index 0000000..18fb0a2
--- /dev/null
+++ b/src/f32-dwconv-spchw/5x5s2p2-neonfma.c
@@ -0,0 +1,240 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_spchw_ukernel_5x5s2p2__neonfma(
+ size_t m,
+ size_t n,
+ const float* input,
+ const float* weights,
+ float* output,
+ size_t input_tuple_stride,
+ size_t output_tuple_stride,
+ size_t input_width_stride,
+ size_t output_width_stride,
+ const union xnn_f32_spchw_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const uint32x4_t vmask_even = vld1q_u32(params->neon.mask_even);
+ const uint32x4_t vmask_odd = vld1q_u32(params->neon.mask_odd);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->neon.max);
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->neon.min);
+
+ const size_t input_width_increment_single = input_width_stride * 2 - input_tuple_stride * ( (n - 1) / 4 + 1);
+ const size_t output_width_increment_single = output_width_stride - (n + 1) / 8 * output_tuple_stride;
+
+ /* No vertical padding */
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_width_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_width_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_width_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_width_stride);
+
+ float* output0 = output;
+
+ const float32x4_t vw0123 = vld1q_f32(weights);
+ const float32x4_t vw4567 = vld1q_f32(weights + 4);
+ const float32x4_t vw89AB = vld1q_f32(weights + 8);
+ const float32x4_t vwCDEF = vld1q_f32(weights + 12);
+ const float32x4_t vwGHIJ = vld1q_f32(weights + 16);
+ const float32x4_t vwKLMN = vld1q_f32(weights + 20);
+ const float32x2_t vwOP = vld1_f32( weights + 24);
+
+ do {
+ float32x4_t vi0x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi1x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi2x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi3x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi4x0123 = vmovq_n_f32(0.0f);
+ float32x4_t vi0x4567 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ float32x4_t vi1x4567 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ float32x4_t vi2x4567 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+ float32x4_t vi3x4567 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride);
+ float32x4_t vi4x4567 = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride);
+
+ long long k = n;
+ for (; k > 0; k -= 8) {
+ float32x4_t vo468Ap00 = vdupq_laneq_f32(vw0123, 0);
+
+ float32x4_t vi0x89AB;
+ float32x4_t vi1x89AB;
+ float32x4_t vi2x89AB;
+ float32x4_t vi3x89AB;
+ float32x4_t vi4x89AB;
+
+ if XNN_LIKELY(k > 4) {
+ vi0x89AB = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ vi1x89AB = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ vi2x89AB = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+ vi3x89AB = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride);
+ vi4x89AB = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride);
+ } else {
+ vi0x89AB = vmovq_n_f32(0.f);
+ vi1x89AB = vmovq_n_f32(0.f);
+ vi2x89AB = vmovq_n_f32(0.f);
+ vi3x89AB = vmovq_n_f32(0.f);
+ vi4x89AB = vmovq_n_f32(0.f);
+ }
+
+ float32x4_t vi0xCDEF;
+ float32x4_t vi1xCDEF;
+ float32x4_t vi2xCDEF;
+ float32x4_t vi3xCDEF;
+ float32x4_t vi4xCDEF;
+
+ if XNN_LIKELY(k > 8) {
+ vi0xCDEF = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + input_tuple_stride);
+ vi1xCDEF = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + input_tuple_stride);
+ vi2xCDEF = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + input_tuple_stride);
+ vi3xCDEF = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + input_tuple_stride);
+ vi4xCDEF = vld1q_f32(i4); i4 = (const float*) ((uintptr_t) i4 + input_tuple_stride);
+ } else {
+ vi0xCDEF = vmovq_n_f32(0.f);
+ vi1xCDEF = vmovq_n_f32(0.f);
+ vi2xCDEF = vmovq_n_f32(0.f);
+ vi3xCDEF = vmovq_n_f32(0.f);
+ vi4xCDEF = vmovq_n_f32(0.f);
+ }
+ float32x4_t vi0x468A = vuzp1q_f32(vi0x4567, vi0x89AB);
+ float32x4_t vi0x579B = vuzp2q_f32(vi0x4567, vi0x89AB);
+ float32x4_t vi1x468A = vuzp1q_f32(vi1x4567, vi1x89AB);
+ float32x4_t vi1x579B = vuzp2q_f32(vi1x4567, vi1x89AB);
+ float32x4_t vi2x468A = vuzp1q_f32(vi2x4567, vi2x89AB);
+ float32x4_t vi2x579B = vuzp2q_f32(vi2x4567, vi2x89AB);
+ float32x4_t vi3x468A = vuzp1q_f32(vi3x4567, vi3x89AB);
+ float32x4_t vi3x579B = vuzp2q_f32(vi3x4567, vi3x89AB);
+ float32x4_t vi4x468A = vuzp1q_f32(vi4x4567, vi4x89AB);
+ float32x4_t vi4x579B = vuzp2q_f32(vi4x4567, vi4x89AB);
+
+ if XNN_UNLIKELY(k <= 8) {
+ vi0x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi0x468A)));
+ vi1x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi1x468A)));
+ vi2x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi2x468A)));
+ vi3x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi3x468A)));
+ vi4x468A = vreinterpretq_u32_f32(vandq_u32(vmask_even, vreinterpretq_f32_u32(vi4x468A)));
+
+ vi0x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi0x579B)));
+ vi1x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi1x579B)));
+ vi2x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi2x579B)));
+ vi3x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi3x579B)));
+ vi4x579B = vreinterpretq_u32_f32(vandq_u32(vmask_odd, vreinterpretq_f32_u32(vi4x579B)));
+ }
+
+ // middle tap
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x468A, vw0123, 3);
+ float32x4_t vo468Ap01 = vmulq_laneq_f32(vi1x468A, vw89AB, 0);
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x468A, vwCDEF, 1);
+ vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x468A, vwGHIJ, 2);
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi4x468A, vwKLMN, 3);
+
+ // one left
+ const float32x4_t vi0x3579 = vextq_f32(vi0x0123, vi0x579B, 3);
+ const float32x4_t vi1x3579 = vextq_f32(vi1x0123, vi1x579B, 3);
+ const float32x4_t vi2x3579 = vextq_f32(vi2x0123, vi2x579B, 3);
+ const float32x4_t vi3x3579 = vextq_f32(vi3x0123, vi3x579B, 3);
+ const float32x4_t vi4x3579 = vextq_f32(vi4x0123, vi4x579B, 3);
+
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x3579, vw0123, 2);
+ vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi1x3579, vw4567, 3);
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x3579, vwCDEF, 0);
+ vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x3579, vwGHIJ, 1);
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi4x3579, vwKLMN, 2);
+
+ // two left
+ // getting the vector to use for the far left tap is annoying
+ // as we can't ext anything we currently have to get it.
+ // To do this, we get a bit ugly. Interpret the float 32x4
+ // vector as int 64x2. Then left shift by 32. Interpret
+ // again as float 32x4. Now the right most bits are what we
+ // want them to be for the following ext.
+ const float32x4_t vi0x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi0x0123), 32));
+ const float32x4_t vi1x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi1x0123), 32));
+ const float32x4_t vi2x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi2x0123), 32));
+ const float32x4_t vi3x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi3x0123), 32));
+ const float32x4_t vi4x0012 = vreinterpretq_f32_u64(vshlq_n_u64(vreinterpretq_u64_f32(vi4x0123), 32));
+
+ const float32x4_t vi0x2468 = vextq_f32(vi0x0012, vi0x468A, 3);
+ const float32x4_t vi1x2468 = vextq_f32(vi1x0012, vi1x468A, 3);
+ const float32x4_t vi2x2468 = vextq_f32(vi2x0012, vi2x468A, 3);
+ const float32x4_t vi3x2468 = vextq_f32(vi3x0012, vi3x468A, 3);
+ const float32x4_t vi4x2468 = vextq_f32(vi4x0012, vi4x468A, 3);
+
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x2468, vw0123, 1);
+ vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi1x2468, vw4567, 2);
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x2468, vw89AB, 3);
+ vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x2468, vwGHIJ, 0);
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi4x2468, vwKLMN, 1);
+
+ vi0x0123 = vi0x89AB;
+ vi1x0123 = vi1x89AB;
+ vi2x0123 = vi2x89AB;
+ vi3x0123 = vi3x89AB;
+ vi4x0123 = vi4x89AB;
+
+ // one right
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x579B, vw4567, 0);
+ vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi1x579B, vw89AB, 1);
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x579B, vwCDEF, 2);
+ vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x579B, vwGHIJ, 3);
+ vo468Ap00 = vfmaq_lane_f32( vo468Ap00, vi4x579B, vwOP, 0);
+
+ // two right
+ const float32x4_t vi0x68AC = vextq_f32(vi0x468A, vi0xCDEF, 1);
+ const float32x4_t vi1x68AC = vextq_f32(vi1x468A, vi1xCDEF, 1);
+ const float32x4_t vi2x68AC = vextq_f32(vi2x468A, vi2xCDEF, 1);
+ const float32x4_t vi3x68AC = vextq_f32(vi3x468A, vi3xCDEF, 1);
+ const float32x4_t vi4x68AC = vextq_f32(vi4x468A, vi4xCDEF, 1);
+
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi0x68AC, vw4567, 1);
+ vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi1x68AC, vw89AB, 2);
+ vo468Ap00 = vfmaq_laneq_f32(vo468Ap00, vi2x68AC, vwCDEF, 3);
+ vo468Ap01 = vfmaq_laneq_f32(vo468Ap01, vi3x68AC, vwKLMN, 0);
+ vo468Ap00 = vfmaq_lane_f32( vo468Ap00, vi4x68AC, vwOP, 1);
+
+ vi0x4567 = vi0xCDEF;
+ vi1x4567 = vi1xCDEF;
+ vi2x4567 = vi2xCDEF;
+ vi3x4567 = vi3xCDEF;
+ vi4x4567 = vi4xCDEF;
+
+ float32x4_t vo0 = vaddq_f32(vo468Ap00, vo468Ap01);
+
+ vo0 = vmaxq_f32(vo0, vmin);
+ vo0 = vminq_f32(vo0, vmax);
+
+ size_t k_tmp = (k + 1) / 2;
+ if XNN_LIKELY(k_tmp >= 4) {
+ vst1q_f32(output0, vo0);
+ output0 = (float*) ((uintptr_t) output0 + output_tuple_stride);
+ } else {
+ float* output0_lo = output0;
+ float32x2_t vo0_lo = vget_low_f32(vo0);
+ if (k_tmp & 2) {
+ vst1_f32(output0_lo, vo0_lo); output0_lo += 2;
+ vo0_lo = vget_high_f32(vo0);
+ }
+ if (k_tmp & 1) {
+ vst1_lane_f32(output0_lo, vo0_lo, 0);
+ }
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_width_increment_single);
+ i1 = (const float*) ((uintptr_t) i1 + input_width_increment_single);
+ i2 = (const float*) ((uintptr_t) i2 + input_width_increment_single);
+ i3 = (const float*) ((uintptr_t) i3 + input_width_increment_single);
+ i4 = (const float*) ((uintptr_t) i4 + input_width_increment_single);
+ output0 = (float*) ((uintptr_t) output0 + output_width_increment_single);
+ m -= 1;
+ } while (m > 0);
+}
diff --git a/src/f32-dwconv/up-neon.c.in b/src/f32-dwconv/up-neon.c.in
new file mode 100644
index 0000000..df1ecb7
--- /dev/null
+++ b/src/f32-dwconv/up-neon.c.in
@@ -0,0 +1,110 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert CR % 4 == 0
+$assert MR >= 2
+$assert AR >= 1
+$ABC = "0123456789ABCDEF"
+$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up${CR}x${MR}__${"neonfma" if FMA else "neon"}(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ do {
+ $for M in range(MR):
+ const float* i${M} = input[${M}];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= ${CR}; c -= ${CR}) {
+ $for C in range(0, CR, 4):
+ float32x4_t vacc${ABC[C:C+4]}p0 = vld1q_f32(w); w += 4;
+
+ $for M in range(MR):
+
+ $for C in range(0, CR, 4):
+ const float32x4_t vi${M}x${ABC[C:C+4]} = vld1q_f32(i${M}); i${M} += 4;
+ $for C in range(0, CR, 4):
+ const float32x4_t vk${M}x${ABC[C:C+4]} = vld1q_f32(w); w += 4;
+ $for C in range(0, CR, 4):
+ $if 1 <= M < AR:
+ float32x4_t vacc${ABC[C:C+4]}p${M} = vmulq_f32(vi${M}x${ABC[C:C+4]}, vk${M}x${ABC[C:C+4]});
+ $else:
+ vacc${ABC[C:C+4]}p${M % AR} = ${VMULADDQ_F32}(vacc${ABC[C:C+4]}p${M % AR}, vi${M}x${ABC[C:C+4]}, vk${M}x${ABC[C:C+4]});
+
+ $STEPA = 1
+ $while STEPA < AR:
+ $for A in range(0, AR, STEPA * 2):
+ $if A + STEPA < AR:
+ for C in range(0, CR, 4):
+ vacc${ABC[C:C+4]}p${A} = vaddq_f32(vacc${ABC[C:C+4]}p${A}, vacc${ABC[C:C+4]}p${A + STEPA});
+ $STEPA *= 2
+
+ $for C in range(0, CR, 4):
+ float32x4_t vacc${ABC[C:C+4]} = vmaxq_f32(vacc${ABC[C:C+4]}p0, vmin);
+ $for C in range(0, CR, 4):
+ vacc${ABC[C:C+4]} = vminq_f32(vacc${ABC[C:C+4]}, vmax);
+
+ $for C in range(0, CR, 4):
+ vst1q_f32(output, vacc${ABC[C:C+4]}); output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ $for C in range(0, CR, 4):
+ float32x4_t vacc${ABC[C:C+4]} = vld1q_f32(w); w += 4;
+
+ $for M in range(MR):
+
+ $for C in range(0, CR, 4):
+ const float32x4_t vi${M}x${ABC[C:C+4]} = vld1q_f32(i${M}); i${M} += 4;
+ $for C in range(0, CR, 4):
+ const float32x4_t vk${M}x${ABC[C:C+4]} = vld1q_f32(w); w += 4;
+ $for C in range(0, CR, 4):
+ vacc${ABC[C:C+4]} = ${VMULADDQ_F32}(vacc${ABC[C:C+4]}, vi${M}x${ABC[C:C+4]}, vk${M}x${ABC[C:C+4]});
+
+ $for C in range(0, CR, 4):
+ vacc${ABC[C:C+4]} = vmaxq_f32(vacc${ABC[C:C+4]}, vmin);
+ $for C in range(0, CR, 4):
+ vacc${ABC[C:C+4]} = vminq_f32(vacc${ABC[C:C+4]}, vmax);
+
+ $for LOG2C in reversed(range(CR.bit_length())):
+ $if CR != 1 << LOG2C:
+ if (c & ${1 << LOG2C}) {
+ $if LOG2C >= 2:
+ $for C in range(0, 1 << LOG2C, 4):
+ vst1q_f32(output, vacc${ABC[C:C+4]}); output += 4;
+ $for C in range(0, 1 << (LOG2C - 1), 4):
+ vacc${ABC[C:C+4]} = vacc${ABC[C + (1 << LOG2C):C + (1 << LOG2C)+4]};
+ $elif LOG2C == 1:
+ vst1_f32(output, vacc${ABC[0:2]}); output += 2;
+ vacc${ABC[0:2]} = vget_high_f32(vacc${ABC[0:4]});
+ $elif LOG2C == 0:
+ vst1_lane_f32(output, vacc${ABC[0:2]}, 0); output += 1;
+ }
+ $if LOG2C == 2:
+ float32x2_t vacc${ABC[0:2]} = vget_low_f32(vacc${ABC[0:4]});
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up-psimd.c.in b/src/f32-dwconv/up-psimd.c.in
new file mode 100644
index 0000000..ed4b8ff
--- /dev/null
+++ b/src/f32-dwconv/up-psimd.c.in
@@ -0,0 +1,91 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert CR == 4
+$assert MR >= 2
+$assert AR >= 1
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up${CR}x${MR}__psimd(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ do {
+ $for M in range(MR):
+ const float* i${M} = input[${M}];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ psimd_f32 vacc0 = psimd_load_f32(w);
+ $for M in range(MR):
+
+ const psimd_f32 vi${M} = psimd_load_f32(i${M});
+ const psimd_f32 vk${M} = psimd_load_f32(w + ${(M+1) * CR});
+ $if 1 <= M < AR:
+ psimd_f32 vacc${M} = psimd_mul_f32(vi${M}, vk${M});
+ $else:
+ vacc${M % AR} = psimd_qfma_f32(vacc${M % AR}, vi${M}, vk${M});
+ i${M} += ${CR};
+
+ w += ${(MR + 1) * CR};
+
+ $STEPA = 1
+ $while STEPA < AR:
+ $for A in range(0, AR, STEPA * 2):
+ $if A + STEPA < AR:
+ vacc${A} = psimd_add_f32(vacc${A}, vacc${A + STEPA});
+ $STEPA *= 2
+
+ vacc0 = psimd_max_f32(vacc0, vmin);
+ vacc0 = psimd_min_f32(vacc0, vmax);
+
+ psimd_store_f32(output, vacc0);
+ output += ${CR};
+ }
+ if XNN_UNLIKELY(c != 0) {
+ psimd_f32 vacc = psimd_load_f32(w);
+ $for M in range(MR):
+
+ const psimd_f32 vi${M} = psimd_load_f32(i${M});
+ const psimd_f32 vk${M} = psimd_load_f32(w + ${(M+1) * CR});
+ vacc = psimd_qfma_f32(vacc, vi${M}, vk${M});
+
+ w += ${(MR + 1) * CR};
+
+ vacc = psimd_max_f32(vacc, vmin);
+ vacc = psimd_min_f32(vacc, vmax);
+
+ if (c & 2) {
+ psimd_store2_f32(output, vacc);
+ vacc = psimd_concat_hi_f32(vacc, vacc);
+ output += 2;
+ }
+ if (c & 1) {
+ psimd_store1_f32(output, vacc);
+ output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up-scalar.c.in b/src/f32-dwconv/up-scalar.c.in
new file mode 100644
index 0000000..b05e545
--- /dev/null
+++ b/src/f32-dwconv/up-scalar.c.in
@@ -0,0 +1,65 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert CR == 1
+$assert MR >= 2
+$assert AR >= 1
+#include <assert.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_ukernel_up${CR}x${MR}__scalar(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ do {
+ $for M in range(MR):
+ const float* i${M} = input[${M}];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ do {
+ float vacc0 = w[0];
+ $for M in range(MR):
+
+ const float vi${M} = *i${M}++;
+ const float vk${M} = w[${M+1}];
+ $if 1 <= M < AR:
+ float vacc${M} = vi${M} * vk${M};
+ $else:
+ vacc${M % AR} += vi${M} * vk${M};
+
+ w += ${MR + 1};
+
+ $STEPA = 1
+ $while STEPA < AR:
+ $for A in range(0, AR, STEPA * 2):
+ $if A + STEPA < AR:
+ vacc${A} += vacc${A + STEPA};
+ $STEPA *= 2
+
+ vacc0 = math_max_f32(vacc0, vmin);
+ vacc0 = math_min_f32(vacc0, vmax);
+
+ *output++ = vacc0;
+ } while (--c != 0);
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up-sse.c.in b/src/f32-dwconv/up-sse.c.in
new file mode 100644
index 0000000..37a1257
--- /dev/null
+++ b/src/f32-dwconv/up-sse.c.in
@@ -0,0 +1,91 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert CR == 4
+$assert MR >= 2
+$assert AR >= 1
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up${CR}x${MR}__sse(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ do {
+ $for M in range(MR):
+ const float* i${M} = input[${M}];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ __m128 vacc0 = _mm_load_ps(w);
+ $for M in range(MR):
+
+ const __m128 vi${M} = _mm_loadu_ps(i${M});
+ const __m128 vk${M} = _mm_load_ps(w + ${(M+1) * CR});
+ $if 1 <= M < AR:
+ __m128 vacc${M} = _mm_mul_ps(vi${M}, vk${M});
+ $else:
+ vacc${M % AR} = _mm_add_ps(vacc${M % AR}, _mm_mul_ps(vi${M}, vk${M}));
+ i${M} += ${CR};
+
+ w += ${(MR + 1) * CR};
+
+ $STEPA = 1
+ $while STEPA < AR:
+ $for A in range(0, AR, STEPA * 2):
+ $if A + STEPA < AR:
+ vacc${A} = _mm_add_ps(vacc${A}, vacc${A + STEPA});
+ $STEPA *= 2
+
+ vacc0 = _mm_max_ps(vacc0, vmin);
+ vacc0 = _mm_min_ps(vacc0, vmax);
+
+ _mm_storeu_ps(output, vacc0);
+ output += ${CR};
+ }
+ if XNN_UNLIKELY(c != 0) {
+ __m128 vacc = _mm_load_ps(w);
+ $for M in range(MR):
+
+ const __m128 vi${M} = _mm_loadu_ps(i${M});
+ const __m128 vk${M} = _mm_load_ps(w + ${(M+1) * CR});
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi${M}, vk${M}));
+
+ w += ${(MR + 1) * CR};
+
+ vacc = _mm_max_ps(vacc, vmin);
+ vacc = _mm_min_ps(vacc, vmax);
+
+ if (c & 2) {
+ _mm_storel_pi((__m64*) output, vacc);
+ vacc = _mm_movehl_ps(vacc, vacc);
+ output += 2;
+ }
+ if (c & 1) {
+ _mm_store_ss(output, vacc);
+ output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up1x25-scalar.c b/src/f32-dwconv/up1x25-scalar.c
new file mode 100644
index 0000000..c0912af
--- /dev/null
+++ b/src/f32-dwconv/up1x25-scalar.c
@@ -0,0 +1,176 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_ukernel_up1x25__scalar(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ const float* i9 = input[9];
+ const float* i10 = input[10];
+ const float* i11 = input[11];
+ const float* i12 = input[12];
+ const float* i13 = input[13];
+ const float* i14 = input[14];
+ const float* i15 = input[15];
+ const float* i16 = input[16];
+ const float* i17 = input[17];
+ const float* i18 = input[18];
+ const float* i19 = input[19];
+ const float* i20 = input[20];
+ const float* i21 = input[21];
+ const float* i22 = input[22];
+ const float* i23 = input[23];
+ const float* i24 = input[24];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ do {
+ float vacc0 = w[0];
+
+ const float vi0 = *i0++;
+ const float vk0 = w[1];
+ vacc0 += vi0 * vk0;
+
+ const float vi1 = *i1++;
+ const float vk1 = w[2];
+ float vacc1 = vi1 * vk1;
+
+ const float vi2 = *i2++;
+ const float vk2 = w[3];
+ vacc0 += vi2 * vk2;
+
+ const float vi3 = *i3++;
+ const float vk3 = w[4];
+ vacc1 += vi3 * vk3;
+
+ const float vi4 = *i4++;
+ const float vk4 = w[5];
+ vacc0 += vi4 * vk4;
+
+ const float vi5 = *i5++;
+ const float vk5 = w[6];
+ vacc1 += vi5 * vk5;
+
+ const float vi6 = *i6++;
+ const float vk6 = w[7];
+ vacc0 += vi6 * vk6;
+
+ const float vi7 = *i7++;
+ const float vk7 = w[8];
+ vacc1 += vi7 * vk7;
+
+ const float vi8 = *i8++;
+ const float vk8 = w[9];
+ vacc0 += vi8 * vk8;
+
+ const float vi9 = *i9++;
+ const float vk9 = w[10];
+ vacc1 += vi9 * vk9;
+
+ const float vi10 = *i10++;
+ const float vk10 = w[11];
+ vacc0 += vi10 * vk10;
+
+ const float vi11 = *i11++;
+ const float vk11 = w[12];
+ vacc1 += vi11 * vk11;
+
+ const float vi12 = *i12++;
+ const float vk12 = w[13];
+ vacc0 += vi12 * vk12;
+
+ const float vi13 = *i13++;
+ const float vk13 = w[14];
+ vacc1 += vi13 * vk13;
+
+ const float vi14 = *i14++;
+ const float vk14 = w[15];
+ vacc0 += vi14 * vk14;
+
+ const float vi15 = *i15++;
+ const float vk15 = w[16];
+ vacc1 += vi15 * vk15;
+
+ const float vi16 = *i16++;
+ const float vk16 = w[17];
+ vacc0 += vi16 * vk16;
+
+ const float vi17 = *i17++;
+ const float vk17 = w[18];
+ vacc1 += vi17 * vk17;
+
+ const float vi18 = *i18++;
+ const float vk18 = w[19];
+ vacc0 += vi18 * vk18;
+
+ const float vi19 = *i19++;
+ const float vk19 = w[20];
+ vacc1 += vi19 * vk19;
+
+ const float vi20 = *i20++;
+ const float vk20 = w[21];
+ vacc0 += vi20 * vk20;
+
+ const float vi21 = *i21++;
+ const float vk21 = w[22];
+ vacc1 += vi21 * vk21;
+
+ const float vi22 = *i22++;
+ const float vk22 = w[23];
+ vacc0 += vi22 * vk22;
+
+ const float vi23 = *i23++;
+ const float vk23 = w[24];
+ vacc1 += vi23 * vk23;
+
+ const float vi24 = *i24++;
+ const float vk24 = w[25];
+ vacc0 += vi24 * vk24;
+
+ w += 26;
+
+ vacc0 += vacc1;
+
+ vacc0 = math_max_f32(vacc0, vmin);
+ vacc0 = math_min_f32(vacc0, vmax);
+
+ *output++ = vacc0;
+ } while (--c != 0);
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up1x4-scalar.c b/src/f32-dwconv/up1x4-scalar.c
new file mode 100644
index 0000000..6f89579
--- /dev/null
+++ b/src/f32-dwconv/up1x4-scalar.c
@@ -0,0 +1,71 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_ukernel_up1x4__scalar(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ do {
+ float vacc0 = w[0];
+
+ const float vi0 = *i0++;
+ const float vk0 = w[1];
+ vacc0 += vi0 * vk0;
+
+ const float vi1 = *i1++;
+ const float vk1 = w[2];
+ float vacc1 = vi1 * vk1;
+
+ const float vi2 = *i2++;
+ const float vk2 = w[3];
+ vacc0 += vi2 * vk2;
+
+ const float vi3 = *i3++;
+ const float vk3 = w[4];
+ vacc1 += vi3 * vk3;
+
+ w += 5;
+
+ vacc0 += vacc1;
+
+ vacc0 = math_max_f32(vacc0, vmin);
+ vacc0 = math_min_f32(vacc0, vmax);
+
+ *output++ = vacc0;
+ } while (--c != 0);
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up1x9-scalar.c b/src/f32-dwconv/up1x9-scalar.c
new file mode 100644
index 0000000..69b10c1
--- /dev/null
+++ b/src/f32-dwconv/up1x9-scalar.c
@@ -0,0 +1,96 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/dwconv.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_dwconv_ukernel_up1x9__scalar(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ do {
+ float vacc0 = w[0];
+
+ const float vi0 = *i0++;
+ const float vk0 = w[1];
+ vacc0 += vi0 * vk0;
+
+ const float vi1 = *i1++;
+ const float vk1 = w[2];
+ float vacc1 = vi1 * vk1;
+
+ const float vi2 = *i2++;
+ const float vk2 = w[3];
+ vacc0 += vi2 * vk2;
+
+ const float vi3 = *i3++;
+ const float vk3 = w[4];
+ vacc1 += vi3 * vk3;
+
+ const float vi4 = *i4++;
+ const float vk4 = w[5];
+ vacc0 += vi4 * vk4;
+
+ const float vi5 = *i5++;
+ const float vk5 = w[6];
+ vacc1 += vi5 * vk5;
+
+ const float vi6 = *i6++;
+ const float vk6 = w[7];
+ vacc0 += vi6 * vk6;
+
+ const float vi7 = *i7++;
+ const float vk7 = w[8];
+ vacc1 += vi7 * vk7;
+
+ const float vi8 = *i8++;
+ const float vk8 = w[9];
+ vacc0 += vi8 * vk8;
+
+ w += 10;
+
+ vacc0 += vacc1;
+
+ vacc0 = math_max_f32(vacc0, vmin);
+ vacc0 = math_min_f32(vacc0, vmax);
+
+ *output++ = vacc0;
+ } while (--c != 0);
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up4x25-psimd.c b/src/f32-dwconv/up4x25-psimd.c
new file mode 100644
index 0000000..27ea94f
--- /dev/null
+++ b/src/f32-dwconv/up4x25-psimd.c
@@ -0,0 +1,321 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-psimd.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up4x25__psimd(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ const float* i9 = input[9];
+ const float* i10 = input[10];
+ const float* i11 = input[11];
+ const float* i12 = input[12];
+ const float* i13 = input[13];
+ const float* i14 = input[14];
+ const float* i15 = input[15];
+ const float* i16 = input[16];
+ const float* i17 = input[17];
+ const float* i18 = input[18];
+ const float* i19 = input[19];
+ const float* i20 = input[20];
+ const float* i21 = input[21];
+ const float* i22 = input[22];
+ const float* i23 = input[23];
+ const float* i24 = input[24];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ psimd_f32 vacc0 = psimd_load_f32(w);
+
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vk0 = psimd_load_f32(w + 4);
+ vacc0 = psimd_qfma_f32(vacc0, vi0, vk0);
+ i0 += 4;
+
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vk1 = psimd_load_f32(w + 8);
+ psimd_f32 vacc1 = psimd_mul_f32(vi1, vk1);
+ i1 += 4;
+
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vk2 = psimd_load_f32(w + 12);
+ vacc0 = psimd_qfma_f32(vacc0, vi2, vk2);
+ i2 += 4;
+
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vk3 = psimd_load_f32(w + 16);
+ vacc1 = psimd_qfma_f32(vacc1, vi3, vk3);
+ i3 += 4;
+
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vk4 = psimd_load_f32(w + 20);
+ vacc0 = psimd_qfma_f32(vacc0, vi4, vk4);
+ i4 += 4;
+
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vk5 = psimd_load_f32(w + 24);
+ vacc1 = psimd_qfma_f32(vacc1, vi5, vk5);
+ i5 += 4;
+
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vk6 = psimd_load_f32(w + 28);
+ vacc0 = psimd_qfma_f32(vacc0, vi6, vk6);
+ i6 += 4;
+
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vk7 = psimd_load_f32(w + 32);
+ vacc1 = psimd_qfma_f32(vacc1, vi7, vk7);
+ i7 += 4;
+
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ const psimd_f32 vk8 = psimd_load_f32(w + 36);
+ vacc0 = psimd_qfma_f32(vacc0, vi8, vk8);
+ i8 += 4;
+
+ const psimd_f32 vi9 = psimd_load_f32(i9);
+ const psimd_f32 vk9 = psimd_load_f32(w + 40);
+ vacc1 = psimd_qfma_f32(vacc1, vi9, vk9);
+ i9 += 4;
+
+ const psimd_f32 vi10 = psimd_load_f32(i10);
+ const psimd_f32 vk10 = psimd_load_f32(w + 44);
+ vacc0 = psimd_qfma_f32(vacc0, vi10, vk10);
+ i10 += 4;
+
+ const psimd_f32 vi11 = psimd_load_f32(i11);
+ const psimd_f32 vk11 = psimd_load_f32(w + 48);
+ vacc1 = psimd_qfma_f32(vacc1, vi11, vk11);
+ i11 += 4;
+
+ const psimd_f32 vi12 = psimd_load_f32(i12);
+ const psimd_f32 vk12 = psimd_load_f32(w + 52);
+ vacc0 = psimd_qfma_f32(vacc0, vi12, vk12);
+ i12 += 4;
+
+ const psimd_f32 vi13 = psimd_load_f32(i13);
+ const psimd_f32 vk13 = psimd_load_f32(w + 56);
+ vacc1 = psimd_qfma_f32(vacc1, vi13, vk13);
+ i13 += 4;
+
+ const psimd_f32 vi14 = psimd_load_f32(i14);
+ const psimd_f32 vk14 = psimd_load_f32(w + 60);
+ vacc0 = psimd_qfma_f32(vacc0, vi14, vk14);
+ i14 += 4;
+
+ const psimd_f32 vi15 = psimd_load_f32(i15);
+ const psimd_f32 vk15 = psimd_load_f32(w + 64);
+ vacc1 = psimd_qfma_f32(vacc1, vi15, vk15);
+ i15 += 4;
+
+ const psimd_f32 vi16 = psimd_load_f32(i16);
+ const psimd_f32 vk16 = psimd_load_f32(w + 68);
+ vacc0 = psimd_qfma_f32(vacc0, vi16, vk16);
+ i16 += 4;
+
+ const psimd_f32 vi17 = psimd_load_f32(i17);
+ const psimd_f32 vk17 = psimd_load_f32(w + 72);
+ vacc1 = psimd_qfma_f32(vacc1, vi17, vk17);
+ i17 += 4;
+
+ const psimd_f32 vi18 = psimd_load_f32(i18);
+ const psimd_f32 vk18 = psimd_load_f32(w + 76);
+ vacc0 = psimd_qfma_f32(vacc0, vi18, vk18);
+ i18 += 4;
+
+ const psimd_f32 vi19 = psimd_load_f32(i19);
+ const psimd_f32 vk19 = psimd_load_f32(w + 80);
+ vacc1 = psimd_qfma_f32(vacc1, vi19, vk19);
+ i19 += 4;
+
+ const psimd_f32 vi20 = psimd_load_f32(i20);
+ const psimd_f32 vk20 = psimd_load_f32(w + 84);
+ vacc0 = psimd_qfma_f32(vacc0, vi20, vk20);
+ i20 += 4;
+
+ const psimd_f32 vi21 = psimd_load_f32(i21);
+ const psimd_f32 vk21 = psimd_load_f32(w + 88);
+ vacc1 = psimd_qfma_f32(vacc1, vi21, vk21);
+ i21 += 4;
+
+ const psimd_f32 vi22 = psimd_load_f32(i22);
+ const psimd_f32 vk22 = psimd_load_f32(w + 92);
+ vacc0 = psimd_qfma_f32(vacc0, vi22, vk22);
+ i22 += 4;
+
+ const psimd_f32 vi23 = psimd_load_f32(i23);
+ const psimd_f32 vk23 = psimd_load_f32(w + 96);
+ vacc1 = psimd_qfma_f32(vacc1, vi23, vk23);
+ i23 += 4;
+
+ const psimd_f32 vi24 = psimd_load_f32(i24);
+ const psimd_f32 vk24 = psimd_load_f32(w + 100);
+ vacc0 = psimd_qfma_f32(vacc0, vi24, vk24);
+ i24 += 4;
+
+ w += 104;
+
+ vacc0 = psimd_add_f32(vacc0, vacc1);
+
+ vacc0 = psimd_max_f32(vacc0, vmin);
+ vacc0 = psimd_min_f32(vacc0, vmax);
+
+ psimd_store_f32(output, vacc0);
+ output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ psimd_f32 vacc = psimd_load_f32(w);
+
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vk0 = psimd_load_f32(w + 4);
+ vacc = psimd_qfma_f32(vacc, vi0, vk0);
+
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vk1 = psimd_load_f32(w + 8);
+ vacc = psimd_qfma_f32(vacc, vi1, vk1);
+
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vk2 = psimd_load_f32(w + 12);
+ vacc = psimd_qfma_f32(vacc, vi2, vk2);
+
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vk3 = psimd_load_f32(w + 16);
+ vacc = psimd_qfma_f32(vacc, vi3, vk3);
+
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vk4 = psimd_load_f32(w + 20);
+ vacc = psimd_qfma_f32(vacc, vi4, vk4);
+
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vk5 = psimd_load_f32(w + 24);
+ vacc = psimd_qfma_f32(vacc, vi5, vk5);
+
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vk6 = psimd_load_f32(w + 28);
+ vacc = psimd_qfma_f32(vacc, vi6, vk6);
+
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vk7 = psimd_load_f32(w + 32);
+ vacc = psimd_qfma_f32(vacc, vi7, vk7);
+
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ const psimd_f32 vk8 = psimd_load_f32(w + 36);
+ vacc = psimd_qfma_f32(vacc, vi8, vk8);
+
+ const psimd_f32 vi9 = psimd_load_f32(i9);
+ const psimd_f32 vk9 = psimd_load_f32(w + 40);
+ vacc = psimd_qfma_f32(vacc, vi9, vk9);
+
+ const psimd_f32 vi10 = psimd_load_f32(i10);
+ const psimd_f32 vk10 = psimd_load_f32(w + 44);
+ vacc = psimd_qfma_f32(vacc, vi10, vk10);
+
+ const psimd_f32 vi11 = psimd_load_f32(i11);
+ const psimd_f32 vk11 = psimd_load_f32(w + 48);
+ vacc = psimd_qfma_f32(vacc, vi11, vk11);
+
+ const psimd_f32 vi12 = psimd_load_f32(i12);
+ const psimd_f32 vk12 = psimd_load_f32(w + 52);
+ vacc = psimd_qfma_f32(vacc, vi12, vk12);
+
+ const psimd_f32 vi13 = psimd_load_f32(i13);
+ const psimd_f32 vk13 = psimd_load_f32(w + 56);
+ vacc = psimd_qfma_f32(vacc, vi13, vk13);
+
+ const psimd_f32 vi14 = psimd_load_f32(i14);
+ const psimd_f32 vk14 = psimd_load_f32(w + 60);
+ vacc = psimd_qfma_f32(vacc, vi14, vk14);
+
+ const psimd_f32 vi15 = psimd_load_f32(i15);
+ const psimd_f32 vk15 = psimd_load_f32(w + 64);
+ vacc = psimd_qfma_f32(vacc, vi15, vk15);
+
+ const psimd_f32 vi16 = psimd_load_f32(i16);
+ const psimd_f32 vk16 = psimd_load_f32(w + 68);
+ vacc = psimd_qfma_f32(vacc, vi16, vk16);
+
+ const psimd_f32 vi17 = psimd_load_f32(i17);
+ const psimd_f32 vk17 = psimd_load_f32(w + 72);
+ vacc = psimd_qfma_f32(vacc, vi17, vk17);
+
+ const psimd_f32 vi18 = psimd_load_f32(i18);
+ const psimd_f32 vk18 = psimd_load_f32(w + 76);
+ vacc = psimd_qfma_f32(vacc, vi18, vk18);
+
+ const psimd_f32 vi19 = psimd_load_f32(i19);
+ const psimd_f32 vk19 = psimd_load_f32(w + 80);
+ vacc = psimd_qfma_f32(vacc, vi19, vk19);
+
+ const psimd_f32 vi20 = psimd_load_f32(i20);
+ const psimd_f32 vk20 = psimd_load_f32(w + 84);
+ vacc = psimd_qfma_f32(vacc, vi20, vk20);
+
+ const psimd_f32 vi21 = psimd_load_f32(i21);
+ const psimd_f32 vk21 = psimd_load_f32(w + 88);
+ vacc = psimd_qfma_f32(vacc, vi21, vk21);
+
+ const psimd_f32 vi22 = psimd_load_f32(i22);
+ const psimd_f32 vk22 = psimd_load_f32(w + 92);
+ vacc = psimd_qfma_f32(vacc, vi22, vk22);
+
+ const psimd_f32 vi23 = psimd_load_f32(i23);
+ const psimd_f32 vk23 = psimd_load_f32(w + 96);
+ vacc = psimd_qfma_f32(vacc, vi23, vk23);
+
+ const psimd_f32 vi24 = psimd_load_f32(i24);
+ const psimd_f32 vk24 = psimd_load_f32(w + 100);
+ vacc = psimd_qfma_f32(vacc, vi24, vk24);
+
+ w += 104;
+
+ vacc = psimd_max_f32(vacc, vmin);
+ vacc = psimd_min_f32(vacc, vmax);
+
+ if (c & 2) {
+ psimd_store2_f32(output, vacc);
+ vacc = psimd_concat_hi_f32(vacc, vacc);
+ output += 2;
+ }
+ if (c & 1) {
+ psimd_store1_f32(output, vacc);
+ output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up4x25-sse.c b/src/f32-dwconv/up4x25-sse.c
new file mode 100644
index 0000000..be8b1e9
--- /dev/null
+++ b/src/f32-dwconv/up4x25-sse.c
@@ -0,0 +1,321 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-sse.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up4x25__sse(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ const float* i9 = input[9];
+ const float* i10 = input[10];
+ const float* i11 = input[11];
+ const float* i12 = input[12];
+ const float* i13 = input[13];
+ const float* i14 = input[14];
+ const float* i15 = input[15];
+ const float* i16 = input[16];
+ const float* i17 = input[17];
+ const float* i18 = input[18];
+ const float* i19 = input[19];
+ const float* i20 = input[20];
+ const float* i21 = input[21];
+ const float* i22 = input[22];
+ const float* i23 = input[23];
+ const float* i24 = input[24];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ __m128 vacc0 = _mm_load_ps(w);
+
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vk0 = _mm_load_ps(w + 4);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi0, vk0));
+ i0 += 4;
+
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vk1 = _mm_load_ps(w + 8);
+ __m128 vacc1 = _mm_mul_ps(vi1, vk1);
+ i1 += 4;
+
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vk2 = _mm_load_ps(w + 12);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi2, vk2));
+ i2 += 4;
+
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vk3 = _mm_load_ps(w + 16);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi3, vk3));
+ i3 += 4;
+
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vk4 = _mm_load_ps(w + 20);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi4, vk4));
+ i4 += 4;
+
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vk5 = _mm_load_ps(w + 24);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi5, vk5));
+ i5 += 4;
+
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vk6 = _mm_load_ps(w + 28);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi6, vk6));
+ i6 += 4;
+
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vk7 = _mm_load_ps(w + 32);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi7, vk7));
+ i7 += 4;
+
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ const __m128 vk8 = _mm_load_ps(w + 36);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi8, vk8));
+ i8 += 4;
+
+ const __m128 vi9 = _mm_loadu_ps(i9);
+ const __m128 vk9 = _mm_load_ps(w + 40);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi9, vk9));
+ i9 += 4;
+
+ const __m128 vi10 = _mm_loadu_ps(i10);
+ const __m128 vk10 = _mm_load_ps(w + 44);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi10, vk10));
+ i10 += 4;
+
+ const __m128 vi11 = _mm_loadu_ps(i11);
+ const __m128 vk11 = _mm_load_ps(w + 48);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi11, vk11));
+ i11 += 4;
+
+ const __m128 vi12 = _mm_loadu_ps(i12);
+ const __m128 vk12 = _mm_load_ps(w + 52);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi12, vk12));
+ i12 += 4;
+
+ const __m128 vi13 = _mm_loadu_ps(i13);
+ const __m128 vk13 = _mm_load_ps(w + 56);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi13, vk13));
+ i13 += 4;
+
+ const __m128 vi14 = _mm_loadu_ps(i14);
+ const __m128 vk14 = _mm_load_ps(w + 60);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi14, vk14));
+ i14 += 4;
+
+ const __m128 vi15 = _mm_loadu_ps(i15);
+ const __m128 vk15 = _mm_load_ps(w + 64);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi15, vk15));
+ i15 += 4;
+
+ const __m128 vi16 = _mm_loadu_ps(i16);
+ const __m128 vk16 = _mm_load_ps(w + 68);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi16, vk16));
+ i16 += 4;
+
+ const __m128 vi17 = _mm_loadu_ps(i17);
+ const __m128 vk17 = _mm_load_ps(w + 72);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi17, vk17));
+ i17 += 4;
+
+ const __m128 vi18 = _mm_loadu_ps(i18);
+ const __m128 vk18 = _mm_load_ps(w + 76);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi18, vk18));
+ i18 += 4;
+
+ const __m128 vi19 = _mm_loadu_ps(i19);
+ const __m128 vk19 = _mm_load_ps(w + 80);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi19, vk19));
+ i19 += 4;
+
+ const __m128 vi20 = _mm_loadu_ps(i20);
+ const __m128 vk20 = _mm_load_ps(w + 84);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi20, vk20));
+ i20 += 4;
+
+ const __m128 vi21 = _mm_loadu_ps(i21);
+ const __m128 vk21 = _mm_load_ps(w + 88);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi21, vk21));
+ i21 += 4;
+
+ const __m128 vi22 = _mm_loadu_ps(i22);
+ const __m128 vk22 = _mm_load_ps(w + 92);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi22, vk22));
+ i22 += 4;
+
+ const __m128 vi23 = _mm_loadu_ps(i23);
+ const __m128 vk23 = _mm_load_ps(w + 96);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi23, vk23));
+ i23 += 4;
+
+ const __m128 vi24 = _mm_loadu_ps(i24);
+ const __m128 vk24 = _mm_load_ps(w + 100);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi24, vk24));
+ i24 += 4;
+
+ w += 104;
+
+ vacc0 = _mm_add_ps(vacc0, vacc1);
+
+ vacc0 = _mm_max_ps(vacc0, vmin);
+ vacc0 = _mm_min_ps(vacc0, vmax);
+
+ _mm_storeu_ps(output, vacc0);
+ output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ __m128 vacc = _mm_load_ps(w);
+
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vk0 = _mm_load_ps(w + 4);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi0, vk0));
+
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vk1 = _mm_load_ps(w + 8);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi1, vk1));
+
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vk2 = _mm_load_ps(w + 12);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi2, vk2));
+
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vk3 = _mm_load_ps(w + 16);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi3, vk3));
+
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vk4 = _mm_load_ps(w + 20);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi4, vk4));
+
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vk5 = _mm_load_ps(w + 24);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi5, vk5));
+
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vk6 = _mm_load_ps(w + 28);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi6, vk6));
+
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vk7 = _mm_load_ps(w + 32);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi7, vk7));
+
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ const __m128 vk8 = _mm_load_ps(w + 36);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi8, vk8));
+
+ const __m128 vi9 = _mm_loadu_ps(i9);
+ const __m128 vk9 = _mm_load_ps(w + 40);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi9, vk9));
+
+ const __m128 vi10 = _mm_loadu_ps(i10);
+ const __m128 vk10 = _mm_load_ps(w + 44);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi10, vk10));
+
+ const __m128 vi11 = _mm_loadu_ps(i11);
+ const __m128 vk11 = _mm_load_ps(w + 48);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi11, vk11));
+
+ const __m128 vi12 = _mm_loadu_ps(i12);
+ const __m128 vk12 = _mm_load_ps(w + 52);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi12, vk12));
+
+ const __m128 vi13 = _mm_loadu_ps(i13);
+ const __m128 vk13 = _mm_load_ps(w + 56);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi13, vk13));
+
+ const __m128 vi14 = _mm_loadu_ps(i14);
+ const __m128 vk14 = _mm_load_ps(w + 60);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi14, vk14));
+
+ const __m128 vi15 = _mm_loadu_ps(i15);
+ const __m128 vk15 = _mm_load_ps(w + 64);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi15, vk15));
+
+ const __m128 vi16 = _mm_loadu_ps(i16);
+ const __m128 vk16 = _mm_load_ps(w + 68);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi16, vk16));
+
+ const __m128 vi17 = _mm_loadu_ps(i17);
+ const __m128 vk17 = _mm_load_ps(w + 72);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi17, vk17));
+
+ const __m128 vi18 = _mm_loadu_ps(i18);
+ const __m128 vk18 = _mm_load_ps(w + 76);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi18, vk18));
+
+ const __m128 vi19 = _mm_loadu_ps(i19);
+ const __m128 vk19 = _mm_load_ps(w + 80);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi19, vk19));
+
+ const __m128 vi20 = _mm_loadu_ps(i20);
+ const __m128 vk20 = _mm_load_ps(w + 84);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi20, vk20));
+
+ const __m128 vi21 = _mm_loadu_ps(i21);
+ const __m128 vk21 = _mm_load_ps(w + 88);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi21, vk21));
+
+ const __m128 vi22 = _mm_loadu_ps(i22);
+ const __m128 vk22 = _mm_load_ps(w + 92);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi22, vk22));
+
+ const __m128 vi23 = _mm_loadu_ps(i23);
+ const __m128 vk23 = _mm_load_ps(w + 96);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi23, vk23));
+
+ const __m128 vi24 = _mm_loadu_ps(i24);
+ const __m128 vk24 = _mm_load_ps(w + 100);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi24, vk24));
+
+ w += 104;
+
+ vacc = _mm_max_ps(vacc, vmin);
+ vacc = _mm_min_ps(vacc, vmax);
+
+ if (c & 2) {
+ _mm_storel_pi((__m64*) output, vacc);
+ vacc = _mm_movehl_ps(vacc, vacc);
+ output += 2;
+ }
+ if (c & 1) {
+ _mm_store_ss(output, vacc);
+ output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up4x4-psimd.c b/src/f32-dwconv/up4x4-psimd.c
new file mode 100644
index 0000000..204c00f
--- /dev/null
+++ b/src/f32-dwconv/up4x4-psimd.c
@@ -0,0 +1,111 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-psimd.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up4x4__psimd(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ psimd_f32 vacc0 = psimd_load_f32(w);
+
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vk0 = psimd_load_f32(w + 4);
+ vacc0 = psimd_qfma_f32(vacc0, vi0, vk0);
+ i0 += 4;
+
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vk1 = psimd_load_f32(w + 8);
+ psimd_f32 vacc1 = psimd_mul_f32(vi1, vk1);
+ i1 += 4;
+
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vk2 = psimd_load_f32(w + 12);
+ vacc0 = psimd_qfma_f32(vacc0, vi2, vk2);
+ i2 += 4;
+
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vk3 = psimd_load_f32(w + 16);
+ vacc1 = psimd_qfma_f32(vacc1, vi3, vk3);
+ i3 += 4;
+
+ w += 20;
+
+ vacc0 = psimd_add_f32(vacc0, vacc1);
+
+ vacc0 = psimd_max_f32(vacc0, vmin);
+ vacc0 = psimd_min_f32(vacc0, vmax);
+
+ psimd_store_f32(output, vacc0);
+ output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ psimd_f32 vacc = psimd_load_f32(w);
+
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vk0 = psimd_load_f32(w + 4);
+ vacc = psimd_qfma_f32(vacc, vi0, vk0);
+
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vk1 = psimd_load_f32(w + 8);
+ vacc = psimd_qfma_f32(vacc, vi1, vk1);
+
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vk2 = psimd_load_f32(w + 12);
+ vacc = psimd_qfma_f32(vacc, vi2, vk2);
+
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vk3 = psimd_load_f32(w + 16);
+ vacc = psimd_qfma_f32(vacc, vi3, vk3);
+
+ w += 20;
+
+ vacc = psimd_max_f32(vacc, vmin);
+ vacc = psimd_min_f32(vacc, vmax);
+
+ if (c & 2) {
+ psimd_store2_f32(output, vacc);
+ vacc = psimd_concat_hi_f32(vacc, vacc);
+ output += 2;
+ }
+ if (c & 1) {
+ psimd_store1_f32(output, vacc);
+ output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up4x4-sse.c b/src/f32-dwconv/up4x4-sse.c
new file mode 100644
index 0000000..e2353b1
--- /dev/null
+++ b/src/f32-dwconv/up4x4-sse.c
@@ -0,0 +1,111 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-sse.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up4x4__sse(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ __m128 vacc0 = _mm_load_ps(w);
+
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vk0 = _mm_load_ps(w + 4);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi0, vk0));
+ i0 += 4;
+
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vk1 = _mm_load_ps(w + 8);
+ __m128 vacc1 = _mm_mul_ps(vi1, vk1);
+ i1 += 4;
+
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vk2 = _mm_load_ps(w + 12);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi2, vk2));
+ i2 += 4;
+
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vk3 = _mm_load_ps(w + 16);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi3, vk3));
+ i3 += 4;
+
+ w += 20;
+
+ vacc0 = _mm_add_ps(vacc0, vacc1);
+
+ vacc0 = _mm_max_ps(vacc0, vmin);
+ vacc0 = _mm_min_ps(vacc0, vmax);
+
+ _mm_storeu_ps(output, vacc0);
+ output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ __m128 vacc = _mm_load_ps(w);
+
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vk0 = _mm_load_ps(w + 4);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi0, vk0));
+
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vk1 = _mm_load_ps(w + 8);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi1, vk1));
+
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vk2 = _mm_load_ps(w + 12);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi2, vk2));
+
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vk3 = _mm_load_ps(w + 16);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi3, vk3));
+
+ w += 20;
+
+ vacc = _mm_max_ps(vacc, vmin);
+ vacc = _mm_min_ps(vacc, vmax);
+
+ if (c & 2) {
+ _mm_storel_pi((__m64*) output, vacc);
+ vacc = _mm_movehl_ps(vacc, vacc);
+ output += 2;
+ }
+ if (c & 1) {
+ _mm_store_ss(output, vacc);
+ output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up4x9-aarch64-neonfma-cortex-a55.S b/src/f32-dwconv/up4x9-aarch64-neonfma-cortex-a55.S
new file mode 100644
index 0000000..01cd48d
--- /dev/null
+++ b/src/f32-dwconv/up4x9-aarch64-neonfma-cortex-a55.S
@@ -0,0 +1,832 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_dwconv_ukernel_up4x9__aarch64_neonfma_cortex_a55(
+# size_t channels,
+# size_t output_width,
+# const float** input,
+# const float* weights,
+# float* output,
+# size_t input_stride,
+# size_t output_increment,
+# const union xnn_f32_output_params params[restrict static 1])
+BEGIN_FUNCTION xnn_f32_dwconv_ukernel_up4x9__aarch64_neonfma_cortex_a55
+
+ # Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ STP d10, d11, [sp, 16]
+ STP d12, d13, [sp, 32]
+ STP d14, d15, [sp, 48]
+
+ # v30.4S = vmax
+ LD1R {v30.4S}, [x7], 4
+ # v31.4S = vmin
+ LD1R {v31.4S}, [x7]
+
+0:
+ # x7 := i0
+ # x8 := i1
+ LDP x7, x8, [x2]
+ # x9 := i2
+ # x10 := i3
+ LDP x9, x10, [x2, 16]
+ # x11 := i4
+ # x12 := i5
+ LDP x11, x12, [x2, 32]
+ # x13 := i6
+ # x14 := i7
+ LDP x13, x14, [x2, 48]
+ # x15 := i8
+ LDR x15, [x2, 64]
+ # input += input_stride
+ ADD x2, x2, x5
+
+ # x16 := c = channels
+ # c -= 8
+ SUBS x16, x0, 8
+ # x17 := w = weights
+ MOV x17, x3
+
+ # skip main loop if c < 8
+ B.LO 3f
+
+ # SWP prologue
+
+ # Load vbias.lo
+ LD1 {v0.2S}, [x17], 8
+
+ # Load vbias.hi
+ LD1 {v1.2S}, [x17], 8
+
+ # Load vi0.lo
+ LD1 {v4.2S}, [x7], 8
+
+ # Load vk0.lo
+ LD1 {v5.2S}, [x17], 8
+
+ # Load vi0.hi
+ LD1 {v6.2S}, [x7], 8
+
+ # Load vk0.hi
+ LD1 {v7.2S}, [x17], 8
+
+ # Load vi1.lo
+ LD1 {v8.2S}, [x8], 8
+
+ # Load vk1.lo
+ LD1 {v9.2S}, [x17], 8
+
+ # Load vi1.hi
+ LD1 {v10.2S}, [x8], 8
+
+ # Load vk1.hi
+ LD1 {v11.2S}, [x17], 8
+
+ # Load vi2.lo
+ LD1 {v12.2S}, [x9], 8
+
+ # Load vk2.lo
+ LD1 {v13.2S}, [x17], 8
+
+ # Load vi2.hi
+ LD1 {v14.2S}, [x9], 8
+
+ # Load vk2.hi
+ LD1 {v15.2S}, [x17], 8
+
+ # Load vi3.lo
+ LD1 {v16.2S}, [x10], 8
+
+ # Load vk3.lo
+ LD1 {v17.2S}, [x17], 8
+
+ # Load vi3.hi
+ LD1 {v18.2S}, [x10], 8
+
+ # Load vk3.hi
+ LD1 {v19.2S}, [x17], 8
+
+ # Load vi4.lo
+ LD1 {v20.2S}, [x11], 8
+
+ # Load vk4.lo
+ LD1 {v21.2S}, [x17], 8
+
+ # Load vi4.hi
+ LD1 {v22.2S}, [x11], 8
+
+ # Load vk4.hi
+ LD1 {v23.2S}, [x17], 8
+
+ # Load vi5.lo
+ LD1 {v24.2S}, [x12], 8
+
+ # Load vk5.lo
+ LD1 {v25.2S}, [x17], 8
+
+ # Load vi5.hi
+ LD1 {v26.2S}, [x12], 8
+
+ # Load vk5.hi
+ LD1 {v27.2S}, [x17], 8
+
+ # vacc.lo += vi0.lo * vk0.lo
+ FMLA v0.2S, v4.2S, v5.2S
+ # Load vi6.lo
+ LD1 {v4.2S}, [x13], 8
+
+ # Load vk6.lo
+ LD1 {v5.2S}, [x17], 8
+
+ # vacc.hi += vi0.hi * vk0.hi
+ FMLA v1.2S, v6.2S, v7.2S
+ # Load vi6.hi
+ LD1 {v6.2S}, [x13], 8
+
+ # Load vk6.hi
+ LD1 {v7.2S}, [x17], 8
+
+ # vacc.lo += vi1.lo * vk0.lo
+ FMLA v0.2S, v8.2S, v9.2S
+ # Load vi7.lo
+ LD1 {v8.2S}, [x14], 8
+
+ # Load vk7.lo
+ LD1 {v9.2S}, [x17], 8
+
+ # vacc.hi += vi1.hi * vk0.hi
+ FMLA v1.2S, v10.2S, v11.2S
+ # Load vi7.hi
+ LD1 {v10.2S}, [x14], 8
+
+ # Load vk7.hi
+ LD1 {v11.2S}, [x17], 8
+
+ # vacc.lo += vi2.lo * vk2.lo
+ FMLA v0.2S, v12.2S, v13.2S
+ # Load vi8.lo
+ LD1 {v12.2S}, [x15], 8
+
+ # Load vk8.lo
+ LD1 {v13.2S}, [x17], 8
+
+ # vacc.hi += vi2.hi * vk2.hi
+ FMLA v1.2S, v14.2S, v15.2S
+ # Load vi8.hi
+ LD1 {v14.2S}, [x15], 8
+
+ # Load vk8.hi
+ LD1 {v15.2S}, [x17], 8
+
+ # Load vbias_next.lo
+ LD1 {v2.2S}, [x17], 8
+
+ # Load vbias_next.hi
+ LD1 {v3.2S}, [x17], 8
+
+ # vacc.lo += vi3.lo * vk3.lo
+ FMLA v0.2S, v16.2S, v17.2S
+ # Load vi0_next.lo
+ LD1 {v16.2S}, [x7], 8
+
+ # Load vk0_next.lo
+ LD1 {v17.2S}, [x17], 8
+
+ # vacc.hi += vi3.hi * vk3.hi
+ FMLA v1.2S, v18.2S, v19.2S
+ # Load vi0_next.hi
+ LD1 {v18.2S}, [x7], 8
+
+ # Load vk0_next.hi
+ LD1 {v19.2S}, [x17], 8
+
+ # vacc.lo += vi4.lo * vk4.lo
+ FMLA v0.2S, v20.2S, v21.2S
+ # Load vi1_next.lo
+ LD1 {v20.2S}, [x8], 8
+
+ # Load vk1_next.lo
+ LD1 {v21.2S}, [x17], 8
+
+ # vacc.hi += vi4.hi * vk4.hi
+ FMLA v1.2S, v22.2S, v23.2S
+ # Load vi1_next.hi
+ LD1 {v22.2S}, [x8], 8
+
+ # Load vk1_next.hi
+ LD1 {v23.2S}, [x17], 8
+
+ # vacc.lo += vi5.lo * vk5.lo
+ FMLA v0.2S, v24.2S, v25.2S
+ # Load vi2_next.lo
+ LD1 {v24.2S}, [x9], 8
+
+ # Load vk2_next.lo
+ LD1 {v25.2S}, [x17], 8
+
+ # vacc.hi += vi5.hi * vk5.hi
+ FMLA v1.2S, v26.2S, v27.2S
+ # Load vi2_next.hi
+ LD1 {v26.2S}, [x9], 8
+
+ # Load vk2_next.hi
+ LD1 {v27.2S}, [x17], 8
+
+ # vacc.lo += vi6.lo * vk6.lo
+ FMLA v0.2S, v4.2S, v5.2S
+ # Load vi3_next.lo
+ LD1 {v4.2S}, [x10], 8
+
+ # Load vk3_next.lo
+ LD1 {v5.2S}, [x17], 8
+
+ # vacc.hi += vi6.hi * vk6.hi
+ FMLA v1.2S, v6.2S, v7.2S
+ # Load vi3_next.hi
+ LD1 {v6.2S}, [x10], 8
+
+ # Load vk3_next.hi
+ LD1 {v7.2S}, [x17], 8
+
+ # vacc.lo += vi7.lo * vk7.lo
+ FMLA v0.2S, v8.2S, v9.2S
+ # Load vi4_next.lo
+ LD1 {v8.2S}, [x11], 8
+
+ # Load vk4_next.lo
+ LD1 {v9.2S}, [x17], 8
+
+ # vacc.hi += vi7.hi * vk7.hi
+ FMLA v1.2S, v10.2S, v11.2S
+ # Load vi4_next.hi
+ LD1 {v10.2S}, [x11], 8
+
+ # Load vk4_next.hi
+ LD1 {v11.2S}, [x17], 8
+
+ # vacc.lo += vi8.lo * vk8.lo
+ FMLA v0.2S, v12.2S, v13.2S
+ # Load vi5_next.lo
+ LD1 {v12.2S}, [x12], 8
+
+ # Load vk5_next.lo
+ LD1 {v13.2S}, [x17], 8
+
+ # vacc.hi += vi8.hi * vk8.hi
+ FMLA v1.2S, v14.2S, v15.2S
+ # Load vi5_next.hi
+ LD1 {v14.2S}, [x12], 8
+
+ # Load vk5_next.hi
+ LD1 {v15.2S}, [x17], 8
+
+ # vacc_next.lo += vi0_next.lo * vk0_next.lo
+ FMLA v2.2S, v16.2S, v17.2S
+ # Load vi6_next.lo
+ LD1 {v16.2S}, [x13], 8
+
+ # vacc.lo = min(vacc.lo, vmax)
+ FMIN v0.2S, v0.2S, v30.2S
+ # Load vk6_next.lo
+ LD1 {v17.2S}, [x17], 8
+
+ # vacc_next.hi += vi0_next.hi * vk0_next.hi
+ FMLA v3.2S, v18.2S, v19.2S
+ # Load vi6_next.hi
+ LD1 {v18.2S}, [x13], 8
+
+ # vacc.hi = min(vacc.hi, vmax)
+ FMIN v1.2S, v1.2S, v30.2S
+ # Load vk6_next.hi
+ LD1 {v19.2S}, [x17], 8
+
+ # vacc_next.lo += vi1_next.lo * vk1_next.lo
+ FMLA v2.2S, v20.2S, v21.2S
+ # Load vi7_next.lo
+ LD1 {v20.2S}, [x14], 8
+
+ # vacc.lo = max(vacc.lo, vmin)
+ FMAX v0.2S, v0.2S, v31.2S
+ # Load vk7_next.lo
+ LD1 {v21.2S}, [x17], 8
+
+ # vacc_next.hi += vi1_next.hi * vk1_next.hi
+ FMLA v3.2S, v22.2S, v23.2S
+ # Load vi7_next.hi
+ LD1 {v22.2S}, [x14], 8
+
+ # vacc.hi = max(vacc.hi, vmin)
+ FMAX v1.2S, v1.2S, v31.2S
+ # Load vk7_next.hi
+ LD1 {v23.2S}, [x17], 8
+
+ # vacc_next.lo += vi2_next.lo * vk2_next.lo
+ FMLA v2.2S, v24.2S, v25.2S
+ # Load vi8_next.lo
+ LD1 {v24.2S}, [x15], 8
+
+ # Load vk8_next.lo
+ LD1 {v25.2S}, [x17], 8
+
+ # vacc_next.hi += vi2_next.hi * vk2_next.hi
+ FMLA v3.2S, v26.2S, v27.2S
+ # Load vi8_next.hi
+ LD1 {v26.2S}, [x15], 8
+
+ # Store vacc
+ STP d0, d1, [x4], 16
+
+ # c -= 8
+ SUBS x16, x16, 8
+ # Load vk8_next.hi
+ LD1 {v27.2S}, [x17], 8
+
+ B.LO 2f
+
+1:
+ # SWP iteration
+
+ # Load vbias.lo
+ LD1 {v0.2S}, [x17], 8
+
+ # Load vbias.hi
+ LD1 {v1.2S}, [x17], 8
+
+ # vacc_prev.lo += vi3_prev.lo * vk3_prev.lo
+ FMLA v2.2S, v4.2S, v5.2S
+ # Load vi0.lo
+ LD1 {v4.2S}, [x7], 8
+
+ # Load vk0.lo
+ LD1 {v5.2S}, [x17], 8
+
+ # vacc_prev.hi += vi3_prev.hi * vk3_prev.hi
+ FMLA v3.2S, v6.2S, v7.2S
+ # Load vi0.hi
+ LD1 {v6.2S}, [x7], 8
+
+ # Load vk0.hi
+ LD1 {v7.2S}, [x17], 8
+
+ # vacc_prev.lo += vi4_prev.lo * vk4_prev.lo
+ FMLA v2.2S, v8.2S, v9.2S
+ # Load vi1.lo
+ LD1 {v8.2S}, [x8], 8
+
+ # Load vk1.lo
+ LD1 {v9.2S}, [x17], 8
+
+ # vacc_prev.hi += vi4_prev.hi * vk4_prev.hi
+ FMLA v3.2S, v10.2S, v11.2S
+ # Load vi1.hi
+ LD1 {v10.2S}, [x8], 8
+
+ # Load vk1.hi
+ LD1 {v11.2S}, [x17], 8
+
+ # vacc_prev.lo += vi5_prev.lo * vk5_prev.lo
+ FMLA v2.2S, v12.2S, v13.2S
+ # Load vi2.lo
+ LD1 {v12.2S}, [x9], 8
+
+ # Load vk2.lo
+ LD1 {v13.2S}, [x17], 8
+
+ # vacc_prev.hi += vi5_prev.hi * vk5_prev.hi
+ FMLA v3.2S, v14.2S, v15.2S
+ # Load vi2.hi
+ LD1 {v14.2S}, [x9], 8
+
+ # Load vk2.hi
+ LD1 {v15.2S}, [x17], 8
+
+ # vacc_prev.lo += vi6_prev.lo * vk6_prev.lo
+ FMLA v2.2S, v16.2S, v17.2S
+ # Load vi3.lo
+ LD1 {v16.2S}, [x10], 8
+
+ # Load vk3.lo
+ LD1 {v17.2S}, [x17], 8
+
+ # vacc_prev.hi += vi6_prev.hi * vk6_prev.hi
+ FMLA v3.2S, v18.2S, v19.2S
+ # Load vi3.hi
+ LD1 {v18.2S}, [x10], 8
+
+ # Load vk3.hi
+ LD1 {v19.2S}, [x17], 8
+
+ # vacc_prev.lo += vi7_prev.lo * vk7_prev.lo
+ FMLA v2.2S, v20.2S, v21.2S
+ # Load vi4.lo
+ LD1 {v20.2S}, [x11], 8
+
+ # Load vk4.lo
+ LD1 {v21.2S}, [x17], 8
+
+ # vacc_prev.hi += vi7_prev.hi * vk7_prev.hi
+ FMLA v3.2S, v22.2S, v23.2S
+ # Load vi4.hi
+ LD1 {v22.2S}, [x11], 8
+
+ # Load vk4.hi
+ LD1 {v23.2S}, [x17], 8
+
+ # vacc_prev.lo += vi8_prev.lo * vk8_prev.lo
+ FMLA v2.2S, v24.2S, v25.2S
+ # Load vi5.lo
+ LD1 {v24.2S}, [x12], 8
+
+ # Load vk5.lo
+ LD1 {v25.2S}, [x17], 8
+
+ # vacc_prev.hi += vi8_prev.hi * vk8_prev.hi
+ FMLA v3.2S, v26.2S, v27.2S
+ # Load vi5.hi
+ LD1 {v26.2S}, [x12], 8
+
+ # Load vk5.hi
+ LD1 {v27.2S}, [x17], 8
+
+ # vacc.lo += vi0.lo * vk0.lo
+ FMLA v0.2S, v4.2S, v5.2S
+ # Load vi6.lo
+ LD1 {v4.2S}, [x13], 8
+
+ # vacc_prev.lo = min(vacc_prev.lo, vmax)
+ FMIN v2.2S, v2.2S, v30.2S
+ # Load vk6.lo
+ LD1 {v5.2S}, [x17], 8
+
+ # vacc.hi += vi0.hi * vk0.hi
+ FMLA v1.2S, v6.2S, v7.2S
+ # Load vi6.hi
+ LD1 {v6.2S}, [x13], 8
+
+ # vacc_prev.hi = min(vacc_prev.hi, vmax)
+ FMIN v3.2S, v3.2S, v30.2S
+ # Load vk6.hi
+ LD1 {v7.2S}, [x17], 8
+
+ # vacc.lo += vi1.lo * vk0.lo
+ FMLA v0.2S, v8.2S, v9.2S
+ # Load vi7.lo
+ LD1 {v8.2S}, [x14], 8
+
+ # vacc_prev.lo = max(vacc_prev.lo, vmin)
+ FMAX v2.2S, v2.2S, v31.2S
+ # Load vk7.lo
+ LD1 {v9.2S}, [x17], 8
+
+ # vacc.hi += vi1.hi * vk0.hi
+ FMLA v1.2S, v10.2S, v11.2S
+ # Load vi7.hi
+ LD1 {v10.2S}, [x14], 8
+
+ # vacc_prev.lo = max(vacc_prev.lo, vmin)
+ FMAX v3.2S, v3.2S, v31.2S
+ # Load vk7.hi
+ LD1 {v11.2S}, [x17], 8
+
+ # vacc.lo += vi2.lo * vk2.lo
+ FMLA v0.2S, v12.2S, v13.2S
+ # Load vi8.lo
+ LD1 {v12.2S}, [x15], 8
+
+ # Load vk8.lo
+ LD1 {v13.2S}, [x17], 8
+
+ # vacc.hi += vi2.hi * vk2.hi
+ FMLA v1.2S, v14.2S, v15.2S
+ # Load vi8.hi
+ LD1 {v14.2S}, [x15], 8
+
+ # Store vacc_prev
+ STP d2, d3, [x4], 16
+
+ # Load vk8.hi
+ LD1 {v15.2S}, [x17], 8
+
+ # Load vbias_next.lo
+ LD1 {v2.2S}, [x17], 8
+
+ # Load vbias_next.hi
+ LD1 {v3.2S}, [x17], 8
+
+ # vacc.lo += vi3.lo * vk3.lo
+ FMLA v0.2S, v16.2S, v17.2S
+ # Load vi0_next.lo
+ LD1 {v16.2S}, [x7], 8
+
+ # Load vk0_next.lo
+ LD1 {v17.2S}, [x17], 8
+
+ # vacc.hi += vi3.hi * vk3.hi
+ FMLA v1.2S, v18.2S, v19.2S
+ # Load vi0_next.hi
+ LD1 {v18.2S}, [x7], 8
+
+ # Load vk0_next.hi
+ LD1 {v19.2S}, [x17], 8
+
+ # vacc.lo += vi4.lo * vk4.lo
+ FMLA v0.2S, v20.2S, v21.2S
+ # Load vi1_next.lo
+ LD1 {v20.2S}, [x8], 8
+
+ # Load vk1_next.lo
+ LD1 {v21.2S}, [x17], 8
+
+ # vacc.hi += vi4.hi * vk4.hi
+ FMLA v1.2S, v22.2S, v23.2S
+ # Load vi1_next.hi
+ LD1 {v22.2S}, [x8], 8
+
+ # Load vk1_next.hi
+ LD1 {v23.2S}, [x17], 8
+
+ # vacc.lo += vi5.lo * vk5.lo
+ FMLA v0.2S, v24.2S, v25.2S
+ # Load vi2_next.lo
+ LD1 {v24.2S}, [x9], 8
+
+ # Load vk2_next.lo
+ LD1 {v25.2S}, [x17], 8
+
+ # vacc.hi += vi5.hi * vk5.hi
+ FMLA v1.2S, v26.2S, v27.2S
+ # Load vi2_next.hi
+ LD1 {v26.2S}, [x9], 8
+
+ # Load vk2_next.hi
+ LD1 {v27.2S}, [x17], 8
+
+ # vacc.lo += vi6.lo * vk6.lo
+ FMLA v0.2S, v4.2S, v5.2S
+ # Load vi3_next.lo
+ LD1 {v4.2S}, [x10], 8
+
+ # Load vk3_next.lo
+ LD1 {v5.2S}, [x17], 8
+
+ # vacc.hi += vi6.hi * vk6.hi
+ FMLA v1.2S, v6.2S, v7.2S
+ # Load vi3_next.hi
+ LD1 {v6.2S}, [x10], 8
+
+ # Load vk3_next.hi
+ LD1 {v7.2S}, [x17], 8
+
+ # vacc.lo += vi7.lo * vk7.lo
+ FMLA v0.2S, v8.2S, v9.2S
+ # Load vi4_next.lo
+ LD1 {v8.2S}, [x11], 8
+
+ # Load vk4_next.lo
+ LD1 {v9.2S}, [x17], 8
+
+ # vacc.hi += vi7.hi * vk7.hi
+ FMLA v1.2S, v10.2S, v11.2S
+ # Load vi4_next.hi
+ LD1 {v10.2S}, [x11], 8
+
+ # Load vk4_next.hi
+ LD1 {v11.2S}, [x17], 8
+
+ # vacc.lo += vi8.lo * vk8.lo
+ FMLA v0.2S, v12.2S, v13.2S
+ # Load vi5_next.lo
+ LD1 {v12.2S}, [x12], 8
+
+ # Load vk5_next.lo
+ LD1 {v13.2S}, [x17], 8
+
+ # vacc.hi += vi8.hi * vk8.hi
+ FMLA v1.2S, v14.2S, v15.2S
+ # Load vi5_next.hi
+ LD1 {v14.2S}, [x12], 8
+
+ # Load vk5_next.hi
+ LD1 {v15.2S}, [x17], 8
+
+ # vacc_next.lo += vi0_next.lo * vk0_next.lo
+ FMLA v2.2S, v16.2S, v17.2S
+ # Load vi6_next.lo
+ LD1 {v16.2S}, [x13], 8
+
+ # vacc.lo = min(vacc.lo, vmax)
+ FMIN v0.2S, v0.2S, v30.2S
+ # Load vk6_next.lo
+ LD1 {v17.2S}, [x17], 8
+
+ # vacc_next.hi += vi0_next.hi * vk0_next.hi
+ FMLA v3.2S, v18.2S, v19.2S
+ # Load vi6_next.hi
+ LD1 {v18.2S}, [x13], 8
+
+ # vacc.hi = min(vacc.hi, vmax)
+ FMIN v1.2S, v1.2S, v30.2S
+ # Load vk6_next.hi
+ LD1 {v19.2S}, [x17], 8
+
+ # vacc_next.lo += vi1_next.lo * vk1_next.lo
+ FMLA v2.2S, v20.2S, v21.2S
+ # Load vi7_next.lo
+ LD1 {v20.2S}, [x14], 8
+
+ # vacc.lo = max(vacc.lo, vmin)
+ FMAX v0.2S, v0.2S, v31.2S
+ # Load vk7_next.lo
+ LD1 {v21.2S}, [x17], 8
+
+ # vacc_next.hi += vi1_next.hi * vk1_next.hi
+ FMLA v3.2S, v22.2S, v23.2S
+ # Load vi7_next.hi
+ LD1 {v22.2S}, [x14], 8
+
+ # vacc.hi = max(vacc.hi, vmin)
+ FMAX v1.2S, v1.2S, v31.2S
+ # Load vk7_next.hi
+ LD1 {v23.2S}, [x17], 8
+
+ # vacc_next.lo += vi2_next.lo * vk2_next.lo
+ FMLA v2.2S, v24.2S, v25.2S
+ # Load vi8_next.lo
+ LD1 {v24.2S}, [x15], 8
+
+ # Load vk8_next.lo
+ LD1 {v25.2S}, [x17], 8
+
+ # vacc_next.hi += vi2_next.hi * vk2_next.hi
+ FMLA v3.2S, v26.2S, v27.2S
+ # Load vi8_next.hi
+ LD1 {v26.2S}, [x15], 8
+
+ # Store vacc
+ STP d0, d1, [x4], 16
+
+ # c -= 8
+ SUBS x16, x16, 8
+ # Load vk8_next.hi
+ LD1 {v27.2S}, [x17], 8
+
+ B.HS 1b
+
+2:
+ # SWP epilogue
+
+ # vacc_prev.lo += vi3_prev.lo * vk3_prev.lo
+ FMLA v2.2S, v4.2S, v5.2S
+
+ # vacc_prev.hi += vi3_prev.hi * vk3_prev.hi
+ FMLA v3.2S, v6.2S, v7.2S
+
+ # vacc_prev.lo += vi4_prev.lo * vk4_prev.lo
+ FMLA v2.2S, v8.2S, v9.2S
+
+ # vacc_prev.hi += vi4_prev.hi * vk4_prev.hi
+ FMLA v3.2S, v10.2S, v11.2S
+
+ # vacc_prev.lo += vi5_prev.lo * vk5_prev.lo
+ FMLA v2.2S, v12.2S, v13.2S
+
+ # vacc_prev.hi += vi5_prev.hi * vk5_prev.hi
+ FMLA v3.2S, v14.2S, v15.2S
+
+ # vacc_prev.lo += vi6_prev.lo * vk6_prev.lo
+ FMLA v2.2S, v16.2S, v17.2S
+
+ # vacc_prev.hi += vi6_prev.hi * vk6_prev.hi
+ FMLA v3.2S, v18.2S, v19.2S
+
+ # vacc_prev.lo += vi7_prev.lo * vk7_prev.lo
+ FMLA v2.2S, v20.2S, v21.2S
+
+ # vacc_prev.hi += vi7_prev.hi * vk7_prev.hi
+ FMLA v3.2S, v22.2S, v23.2S
+
+ # vacc_prev.lo += vi8_prev.lo * vk8_prev.lo
+ FMLA v2.2S, v24.2S, v25.2S
+
+ # vacc_prev.hi += vi8_prev.hi * vk8_prev.hi
+ FMLA v3.2S, v26.2S, v27.2S
+
+ # vacc_prev.lo = min(vacc_prev.lo, vmax)
+ FMIN v2.2S, v2.2S, v30.2S
+
+ # vacc_prev.hi = min(vacc_prev.hi, vmax)
+ FMIN v3.2S, v3.2S, v30.2S
+
+ # vacc_prev.lo = max(vacc_prev.lo, vmin)
+ FMAX v2.2S, v2.2S, v31.2S
+
+ # vacc_prev.lo = max(vacc_prev.lo, vmin)
+ FMAX v3.2S, v3.2S, v31.2S
+
+ # Store vacc_prev
+ STP d2, d3, [x4], 16
+
+3:
+ # skip processing 4 channels if ((c - 8) & 4) = (c & 4) != 0
+ TBZ x16, 2, 4f
+
+ LDP q0, q1, [x17], 32
+ LDP q2, q3, [x17], 32
+ LDP q4, q5, [x17], 32
+ LDP q6, q7, [x17], 32
+ LDP q8, q9, [x17], 32
+ LDR q10, [x7], 16
+ LDR q11, [x8], 16
+ LDR q12, [x9], 16
+ LDR q13, [x10], 16
+ LDR q14, [x11], 16
+ LDR q15, [x12], 16
+ LDR q16, [x13], 16
+ LDR q17, [x14], 16
+ LDR q18, [x15], 16
+
+ FMLA v0.4S, v1.4S, v10.4S
+ FMLA v0.4S, v2.4S, v11.4S
+ FMLA v0.4S, v3.4S, v12.4S
+ FMLA v0.4S, v4.4S, v13.4S
+ FMLA v0.4S, v5.4S, v14.4S
+ FMLA v0.4S, v6.4S, v15.4S
+ FMLA v0.4S, v7.4S, v16.4S
+ FMLA v0.4S, v8.4S, v17.4S
+ FMLA v0.4S, v9.4S, v18.4S
+
+ FMIN v0.4S, v0.4S, v30.4S
+ FMAX v0.4S, v0.4S, v31.4S
+
+ STR q0, [x4], 16
+
+4:
+ # restore actual c value
+ ADD x16, x16, 8
+ # skip processing remainder channels unless c != 0
+ CBZ x16, 6f
+
+ LDP q0, q1, [x17], 32
+ LDP q2, q3, [x17], 32
+ LDP q4, q5, [x17], 32
+ LDP q6, q7, [x17], 32
+ LDP q8, q9, [x17], 32
+ LDR q10, [x7], 16
+ LDR q11, [x8], 16
+ LDR q12, [x9], 16
+ LDR q13, [x10], 16
+ LDR q14, [x11], 16
+ LDR q15, [x12], 16
+ LDR q16, [x13], 16
+ LDR q17, [x14], 16
+ LDR q18, [x15], 16
+
+ FMLA v0.4S, v1.4S, v10.4S
+ FMLA v0.4S, v2.4S, v11.4S
+ FMLA v0.4S, v3.4S, v12.4S
+ FMLA v0.4S, v4.4S, v13.4S
+ FMLA v0.4S, v5.4S, v14.4S
+ FMLA v0.4S, v6.4S, v15.4S
+ FMLA v0.4S, v7.4S, v16.4S
+ FMLA v0.4S, v8.4S, v17.4S
+ FMLA v0.4S, v9.4S, v18.4S
+
+ FMIN v0.4S, v0.4S, v30.4S
+ FMAX v0.4S, v0.4S, v31.4S
+
+ TBZ x16, 1, 5f
+
+ ST1 {v0.2S}, [x4], 8
+ DUP d0, v0.D[1]
+
+5:
+ TBZ x16, 0, 6f
+
+ ST1 {v0.S}[0], [x4], 4
+
+6:
+ # output_width -= 1
+ SUBS x1, x1, 1
+ # output += output_increment
+ ADD x4, x4, x6
+ # process next pixel if output_width != 0
+ B.NE 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_dwconv_ukernel_up4x9__aarch64_neonfma_cortex_a55
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-dwconv/up4x9-aarch64-neonfma.S b/src/f32-dwconv/up4x9-aarch64-neonfma.S
new file mode 100644
index 0000000..1ffc421
--- /dev/null
+++ b/src/f32-dwconv/up4x9-aarch64-neonfma.S
@@ -0,0 +1,152 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_dwconv_ukernel_up4x9__aarch64_neonfma(
+# size_t channels,
+# size_t output_width,
+# const float** input,
+# const float* weights,
+# float* output,
+# size_t input_stride,
+# size_t output_increment,
+# const union xnn_f32_output_params params[restrict static 1])
+BEGIN_FUNCTION xnn_f32_dwconv_ukernel_up4x9__aarch64_neonfma
+
+ # Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ STP d10, d11, [sp, 16]
+ STP d12, d13, [sp, 32]
+ STP d14, d15, [sp, 48]
+
+ # v30.4S = vmax
+ LD1R {v30.4S}, [x7], 4
+ # v31.4S = vmin
+ LD1R {v31.4S}, [x7]
+
+0:
+ # x7 := i0
+ # x8 := i1
+ LDP x7, x8, [x2]
+ # x9 := i2
+ # x10 := i3
+ LDP x9, x10, [x2, 16]
+ # x11 := i4
+ # x12 := i5
+ LDP x11, x12, [x2, 32]
+ # x13 := i6
+ # x14 := i7
+ LDP x13, x14, [x2, 48]
+ # x15 := i8
+ LDR x15, [x2, 64]
+ # input += input_stride
+ ADD x2, x2, x5
+
+ # x16 := c = channels
+ # c -= 4
+ SUBS x16, x0, 4
+ # x17 := w = weights
+ MOV x17, x3
+
+ # skip main loop if c <= 4
+ B.LO 2f
+1:
+ LDP q0, q1, [x17], 32
+ LDP q2, q3, [x17], 32
+ LDP q4, q5, [x17], 32
+ LDP q6, q7, [x17], 32
+ LDP q8, q9, [x17], 32
+ LDR q10, [x7], 16
+ LDR q11, [x8], 16
+ LDR q12, [x9], 16
+ LDR q13, [x10], 16
+ LDR q14, [x11], 16
+ LDR q15, [x12], 16
+ LDR q16, [x13], 16
+ LDR q17, [x14], 16
+ LDR q18, [x15], 16
+
+ FMLA v0.4S, v1.4S, v10.4S
+ FMLA v0.4S, v2.4S, v11.4S
+ FMLA v0.4S, v3.4S, v12.4S
+ FMLA v0.4S, v4.4S, v13.4S
+ FMLA v0.4S, v5.4S, v14.4S
+ FMLA v0.4S, v6.4S, v15.4S
+ FMLA v0.4S, v7.4S, v16.4S
+ FMLA v0.4S, v8.4S, v17.4S
+ FMLA v0.4S, v9.4S, v18.4S
+
+ FMIN v0.4S, v0.4S, v30.4S
+ FMAX v0.4S, v0.4S, v31.4S
+
+ STR q0, [x4], 16
+ SUBS x16, x16, 4
+ B.HS 1b
+
+2:
+ # restore actual c value
+ ADD x16, x16, 4
+ # skip processing remainder channels unless c != 0
+ CBZ x16, 4f
+
+ LDP q0, q1, [x17], 32
+ LDP q2, q3, [x17], 32
+ LDP q4, q5, [x17], 32
+ LDP q6, q7, [x17], 32
+ LDP q8, q9, [x17], 32
+ LDR q10, [x7], 16
+ LDR q11, [x8], 16
+ LDR q12, [x9], 16
+ LDR q13, [x10], 16
+ LDR q14, [x11], 16
+ LDR q15, [x12], 16
+ LDR q16, [x13], 16
+ LDR q17, [x14], 16
+ LDR q18, [x15], 16
+
+ FMLA v0.4S, v1.4S, v10.4S
+ FMLA v0.4S, v2.4S, v11.4S
+ FMLA v0.4S, v3.4S, v12.4S
+ FMLA v0.4S, v4.4S, v13.4S
+ FMLA v0.4S, v5.4S, v14.4S
+ FMLA v0.4S, v6.4S, v15.4S
+ FMLA v0.4S, v7.4S, v16.4S
+ FMLA v0.4S, v8.4S, v17.4S
+ FMLA v0.4S, v9.4S, v18.4S
+
+ FMIN v0.4S, v0.4S, v30.4S
+ FMAX v0.4S, v0.4S, v31.4S
+
+ TBZ x16, 1, 3f
+
+ ST1 {v0.2S}, [x4], 8
+ DUP d0, v0.D[1]
+
+3:
+ TBZ x16, 0, 4f
+
+ ST1 {v0.S}[0], [x4], 4
+
+4:
+ # output_width -= 1
+ SUBS x1, x1, 1
+ # output += output_increment
+ ADD x4, x4, x6
+ # process next pixel if output_width != 0
+ B.NE 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_dwconv_ukernel_up4x9__aarch64_neonfma
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-dwconv/up4x9-neon.c b/src/f32-dwconv/up4x9-neon.c
new file mode 100644
index 0000000..037d2bf
--- /dev/null
+++ b/src/f32-dwconv/up4x9-neon.c
@@ -0,0 +1,147 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up4x9__neon(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ float32x4_t vacc0123p0 = vld1q_f32(w); w += 4;
+
+
+ const float32x4_t vi0x0123 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vk0x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vmlaq_f32(vacc0123p0, vi0x0123, vk0x0123);
+
+ const float32x4_t vi1x0123 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vk1x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vmlaq_f32(vacc0123p0, vi1x0123, vk1x0123);
+
+ const float32x4_t vi2x0123 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vk2x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vmlaq_f32(vacc0123p0, vi2x0123, vk2x0123);
+
+ const float32x4_t vi3x0123 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vk3x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vmlaq_f32(vacc0123p0, vi3x0123, vk3x0123);
+
+ const float32x4_t vi4x0123 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vk4x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vmlaq_f32(vacc0123p0, vi4x0123, vk4x0123);
+
+ const float32x4_t vi5x0123 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vk5x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vmlaq_f32(vacc0123p0, vi5x0123, vk5x0123);
+
+ const float32x4_t vi6x0123 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vk6x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vmlaq_f32(vacc0123p0, vi6x0123, vk6x0123);
+
+ const float32x4_t vi7x0123 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vk7x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vmlaq_f32(vacc0123p0, vi7x0123, vk7x0123);
+
+ const float32x4_t vi8x0123 = vld1q_f32(i8); i8 += 4;
+ const float32x4_t vk8x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vmlaq_f32(vacc0123p0, vi8x0123, vk8x0123);
+
+
+ float32x4_t vacc0123 = vmaxq_f32(vacc0123p0, vmin);
+ vacc0123 = vminq_f32(vacc0123, vmax);
+
+ vst1q_f32(output, vacc0123); output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ float32x4_t vacc0123 = vld1q_f32(w); w += 4;
+
+
+ const float32x4_t vi0x0123 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vk0x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vmlaq_f32(vacc0123, vi0x0123, vk0x0123);
+
+ const float32x4_t vi1x0123 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vk1x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vmlaq_f32(vacc0123, vi1x0123, vk1x0123);
+
+ const float32x4_t vi2x0123 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vk2x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vmlaq_f32(vacc0123, vi2x0123, vk2x0123);
+
+ const float32x4_t vi3x0123 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vk3x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vmlaq_f32(vacc0123, vi3x0123, vk3x0123);
+
+ const float32x4_t vi4x0123 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vk4x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vmlaq_f32(vacc0123, vi4x0123, vk4x0123);
+
+ const float32x4_t vi5x0123 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vk5x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vmlaq_f32(vacc0123, vi5x0123, vk5x0123);
+
+ const float32x4_t vi6x0123 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vk6x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vmlaq_f32(vacc0123, vi6x0123, vk6x0123);
+
+ const float32x4_t vi7x0123 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vk7x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vmlaq_f32(vacc0123, vi7x0123, vk7x0123);
+
+ const float32x4_t vi8x0123 = vld1q_f32(i8); i8 += 4;
+ const float32x4_t vk8x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vmlaq_f32(vacc0123, vi8x0123, vk8x0123);
+
+ vacc0123 = vmaxq_f32(vacc0123, vmin);
+ vacc0123 = vminq_f32(vacc0123, vmax);
+
+ float32x2_t vacc01 = vget_low_f32(vacc0123);
+ if (c & 2) {
+ vst1_f32(output, vacc01); output += 2;
+ vacc01 = vget_high_f32(vacc0123);
+ }
+ if (c & 1) {
+ vst1_lane_f32(output, vacc01, 0); output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up4x9-neonfma.c b/src/f32-dwconv/up4x9-neonfma.c
new file mode 100644
index 0000000..ccbb077
--- /dev/null
+++ b/src/f32-dwconv/up4x9-neonfma.c
@@ -0,0 +1,147 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up4x9__neonfma(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ float32x4_t vacc0123p0 = vld1q_f32(w); w += 4;
+
+
+ const float32x4_t vi0x0123 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vk0x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi0x0123, vk0x0123);
+
+ const float32x4_t vi1x0123 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vk1x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi1x0123, vk1x0123);
+
+ const float32x4_t vi2x0123 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vk2x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi2x0123, vk2x0123);
+
+ const float32x4_t vi3x0123 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vk3x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi3x0123, vk3x0123);
+
+ const float32x4_t vi4x0123 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vk4x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi4x0123, vk4x0123);
+
+ const float32x4_t vi5x0123 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vk5x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi5x0123, vk5x0123);
+
+ const float32x4_t vi6x0123 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vk6x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi6x0123, vk6x0123);
+
+ const float32x4_t vi7x0123 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vk7x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi7x0123, vk7x0123);
+
+ const float32x4_t vi8x0123 = vld1q_f32(i8); i8 += 4;
+ const float32x4_t vk8x0123 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi8x0123, vk8x0123);
+
+
+ float32x4_t vacc0123 = vmaxq_f32(vacc0123p0, vmin);
+ vacc0123 = vminq_f32(vacc0123, vmax);
+
+ vst1q_f32(output, vacc0123); output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ float32x4_t vacc0123 = vld1q_f32(w); w += 4;
+
+
+ const float32x4_t vi0x0123 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vk0x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi0x0123, vk0x0123);
+
+ const float32x4_t vi1x0123 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vk1x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi1x0123, vk1x0123);
+
+ const float32x4_t vi2x0123 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vk2x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi2x0123, vk2x0123);
+
+ const float32x4_t vi3x0123 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vk3x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi3x0123, vk3x0123);
+
+ const float32x4_t vi4x0123 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vk4x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi4x0123, vk4x0123);
+
+ const float32x4_t vi5x0123 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vk5x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi5x0123, vk5x0123);
+
+ const float32x4_t vi6x0123 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vk6x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi6x0123, vk6x0123);
+
+ const float32x4_t vi7x0123 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vk7x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi7x0123, vk7x0123);
+
+ const float32x4_t vi8x0123 = vld1q_f32(i8); i8 += 4;
+ const float32x4_t vk8x0123 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi8x0123, vk8x0123);
+
+ vacc0123 = vmaxq_f32(vacc0123, vmin);
+ vacc0123 = vminq_f32(vacc0123, vmax);
+
+ float32x2_t vacc01 = vget_low_f32(vacc0123);
+ if (c & 2) {
+ vst1_f32(output, vacc01); output += 2;
+ vacc01 = vget_high_f32(vacc0123);
+ }
+ if (c & 1) {
+ vst1_lane_f32(output, vacc01, 0); output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up4x9-psimd.c b/src/f32-dwconv/up4x9-psimd.c
new file mode 100644
index 0000000..8f820b7
--- /dev/null
+++ b/src/f32-dwconv/up4x9-psimd.c
@@ -0,0 +1,161 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-psimd.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up4x9__psimd(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ psimd_f32 vacc0 = psimd_load_f32(w);
+
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vk0 = psimd_load_f32(w + 4);
+ vacc0 = psimd_qfma_f32(vacc0, vi0, vk0);
+ i0 += 4;
+
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vk1 = psimd_load_f32(w + 8);
+ psimd_f32 vacc1 = psimd_mul_f32(vi1, vk1);
+ i1 += 4;
+
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vk2 = psimd_load_f32(w + 12);
+ vacc0 = psimd_qfma_f32(vacc0, vi2, vk2);
+ i2 += 4;
+
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vk3 = psimd_load_f32(w + 16);
+ vacc1 = psimd_qfma_f32(vacc1, vi3, vk3);
+ i3 += 4;
+
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vk4 = psimd_load_f32(w + 20);
+ vacc0 = psimd_qfma_f32(vacc0, vi4, vk4);
+ i4 += 4;
+
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vk5 = psimd_load_f32(w + 24);
+ vacc1 = psimd_qfma_f32(vacc1, vi5, vk5);
+ i5 += 4;
+
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vk6 = psimd_load_f32(w + 28);
+ vacc0 = psimd_qfma_f32(vacc0, vi6, vk6);
+ i6 += 4;
+
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vk7 = psimd_load_f32(w + 32);
+ vacc1 = psimd_qfma_f32(vacc1, vi7, vk7);
+ i7 += 4;
+
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ const psimd_f32 vk8 = psimd_load_f32(w + 36);
+ vacc0 = psimd_qfma_f32(vacc0, vi8, vk8);
+ i8 += 4;
+
+ w += 40;
+
+ vacc0 = psimd_add_f32(vacc0, vacc1);
+
+ vacc0 = psimd_max_f32(vacc0, vmin);
+ vacc0 = psimd_min_f32(vacc0, vmax);
+
+ psimd_store_f32(output, vacc0);
+ output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ psimd_f32 vacc = psimd_load_f32(w);
+
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vk0 = psimd_load_f32(w + 4);
+ vacc = psimd_qfma_f32(vacc, vi0, vk0);
+
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vk1 = psimd_load_f32(w + 8);
+ vacc = psimd_qfma_f32(vacc, vi1, vk1);
+
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vk2 = psimd_load_f32(w + 12);
+ vacc = psimd_qfma_f32(vacc, vi2, vk2);
+
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vk3 = psimd_load_f32(w + 16);
+ vacc = psimd_qfma_f32(vacc, vi3, vk3);
+
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vk4 = psimd_load_f32(w + 20);
+ vacc = psimd_qfma_f32(vacc, vi4, vk4);
+
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vk5 = psimd_load_f32(w + 24);
+ vacc = psimd_qfma_f32(vacc, vi5, vk5);
+
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vk6 = psimd_load_f32(w + 28);
+ vacc = psimd_qfma_f32(vacc, vi6, vk6);
+
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vk7 = psimd_load_f32(w + 32);
+ vacc = psimd_qfma_f32(vacc, vi7, vk7);
+
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ const psimd_f32 vk8 = psimd_load_f32(w + 36);
+ vacc = psimd_qfma_f32(vacc, vi8, vk8);
+
+ w += 40;
+
+ vacc = psimd_max_f32(vacc, vmin);
+ vacc = psimd_min_f32(vacc, vmax);
+
+ if (c & 2) {
+ psimd_store2_f32(output, vacc);
+ vacc = psimd_concat_hi_f32(vacc, vacc);
+ output += 2;
+ }
+ if (c & 1) {
+ psimd_store1_f32(output, vacc);
+ output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up4x9-sse.c b/src/f32-dwconv/up4x9-sse.c
new file mode 100644
index 0000000..6779891
--- /dev/null
+++ b/src/f32-dwconv/up4x9-sse.c
@@ -0,0 +1,161 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-sse.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up4x9__sse(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 4; c -= 4) {
+ __m128 vacc0 = _mm_load_ps(w);
+
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vk0 = _mm_load_ps(w + 4);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi0, vk0));
+ i0 += 4;
+
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vk1 = _mm_load_ps(w + 8);
+ __m128 vacc1 = _mm_mul_ps(vi1, vk1);
+ i1 += 4;
+
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vk2 = _mm_load_ps(w + 12);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi2, vk2));
+ i2 += 4;
+
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vk3 = _mm_load_ps(w + 16);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi3, vk3));
+ i3 += 4;
+
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vk4 = _mm_load_ps(w + 20);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi4, vk4));
+ i4 += 4;
+
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vk5 = _mm_load_ps(w + 24);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi5, vk5));
+ i5 += 4;
+
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vk6 = _mm_load_ps(w + 28);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi6, vk6));
+ i6 += 4;
+
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vk7 = _mm_load_ps(w + 32);
+ vacc1 = _mm_add_ps(vacc1, _mm_mul_ps(vi7, vk7));
+ i7 += 4;
+
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ const __m128 vk8 = _mm_load_ps(w + 36);
+ vacc0 = _mm_add_ps(vacc0, _mm_mul_ps(vi8, vk8));
+ i8 += 4;
+
+ w += 40;
+
+ vacc0 = _mm_add_ps(vacc0, vacc1);
+
+ vacc0 = _mm_max_ps(vacc0, vmin);
+ vacc0 = _mm_min_ps(vacc0, vmax);
+
+ _mm_storeu_ps(output, vacc0);
+ output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ __m128 vacc = _mm_load_ps(w);
+
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vk0 = _mm_load_ps(w + 4);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi0, vk0));
+
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vk1 = _mm_load_ps(w + 8);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi1, vk1));
+
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vk2 = _mm_load_ps(w + 12);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi2, vk2));
+
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vk3 = _mm_load_ps(w + 16);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi3, vk3));
+
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vk4 = _mm_load_ps(w + 20);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi4, vk4));
+
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vk5 = _mm_load_ps(w + 24);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi5, vk5));
+
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vk6 = _mm_load_ps(w + 28);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi6, vk6));
+
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vk7 = _mm_load_ps(w + 32);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi7, vk7));
+
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ const __m128 vk8 = _mm_load_ps(w + 36);
+ vacc = _mm_add_ps(vacc, _mm_mul_ps(vi8, vk8));
+
+ w += 40;
+
+ vacc = _mm_max_ps(vacc, vmin);
+ vacc = _mm_min_ps(vacc, vmax);
+
+ if (c & 2) {
+ _mm_storel_pi((__m64*) output, vacc);
+ vacc = _mm_movehl_ps(vacc, vacc);
+ output += 2;
+ }
+ if (c & 1) {
+ _mm_store_ss(output, vacc);
+ output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-dwconv/up8x9-neonfma.c b/src/f32-dwconv/up8x9-neonfma.c
new file mode 100644
index 0000000..bf2bbbb
--- /dev/null
+++ b/src/f32-dwconv/up8x9-neonfma.c
@@ -0,0 +1,212 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-dwconv/up-neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_f32_dwconv_ukernel_up8x9__neonfma(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(channels != 0);
+ assert(output_width != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const float* w = weights;
+ for (; c >= 8; c -= 8) {
+ float32x4_t vacc0123p0 = vld1q_f32(w); w += 4;
+ float32x4_t vacc4567p0 = vld1q_f32(w); w += 4;
+
+
+ const float32x4_t vi0x0123 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi0x4567 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vk0x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk0x4567 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi0x0123, vk0x0123);
+ vacc4567p0 = vfmaq_f32(vacc4567p0, vi0x4567, vk0x4567);
+
+ const float32x4_t vi1x0123 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi1x4567 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vk1x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk1x4567 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi1x0123, vk1x0123);
+ vacc4567p0 = vfmaq_f32(vacc4567p0, vi1x4567, vk1x4567);
+
+ const float32x4_t vi2x0123 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi2x4567 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vk2x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk2x4567 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi2x0123, vk2x0123);
+ vacc4567p0 = vfmaq_f32(vacc4567p0, vi2x4567, vk2x4567);
+
+ const float32x4_t vi3x0123 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi3x4567 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vk3x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk3x4567 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi3x0123, vk3x0123);
+ vacc4567p0 = vfmaq_f32(vacc4567p0, vi3x4567, vk3x4567);
+
+ const float32x4_t vi4x0123 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi4x4567 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vk4x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk4x4567 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi4x0123, vk4x0123);
+ vacc4567p0 = vfmaq_f32(vacc4567p0, vi4x4567, vk4x4567);
+
+ const float32x4_t vi5x0123 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi5x4567 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vk5x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk5x4567 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi5x0123, vk5x0123);
+ vacc4567p0 = vfmaq_f32(vacc4567p0, vi5x4567, vk5x4567);
+
+ const float32x4_t vi6x0123 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi6x4567 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vk6x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk6x4567 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi6x0123, vk6x0123);
+ vacc4567p0 = vfmaq_f32(vacc4567p0, vi6x4567, vk6x4567);
+
+ const float32x4_t vi7x0123 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vi7x4567 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vk7x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk7x4567 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi7x0123, vk7x0123);
+ vacc4567p0 = vfmaq_f32(vacc4567p0, vi7x4567, vk7x4567);
+
+ const float32x4_t vi8x0123 = vld1q_f32(i8); i8 += 4;
+ const float32x4_t vi8x4567 = vld1q_f32(i8); i8 += 4;
+ const float32x4_t vk8x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk8x4567 = vld1q_f32(w); w += 4;
+ vacc0123p0 = vfmaq_f32(vacc0123p0, vi8x0123, vk8x0123);
+ vacc4567p0 = vfmaq_f32(vacc4567p0, vi8x4567, vk8x4567);
+
+
+ float32x4_t vacc0123 = vmaxq_f32(vacc0123p0, vmin);
+ float32x4_t vacc4567 = vmaxq_f32(vacc4567p0, vmin);
+ vacc0123 = vminq_f32(vacc0123, vmax);
+ vacc4567 = vminq_f32(vacc4567, vmax);
+
+ vst1q_f32(output, vacc0123); output += 4;
+ vst1q_f32(output, vacc4567); output += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ float32x4_t vacc0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc4567 = vld1q_f32(w); w += 4;
+
+
+ const float32x4_t vi0x0123 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi0x4567 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vk0x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk0x4567 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi0x0123, vk0x0123);
+ vacc4567 = vfmaq_f32(vacc4567, vi0x4567, vk0x4567);
+
+ const float32x4_t vi1x0123 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi1x4567 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vk1x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk1x4567 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi1x0123, vk1x0123);
+ vacc4567 = vfmaq_f32(vacc4567, vi1x4567, vk1x4567);
+
+ const float32x4_t vi2x0123 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi2x4567 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vk2x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk2x4567 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi2x0123, vk2x0123);
+ vacc4567 = vfmaq_f32(vacc4567, vi2x4567, vk2x4567);
+
+ const float32x4_t vi3x0123 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi3x4567 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vk3x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk3x4567 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi3x0123, vk3x0123);
+ vacc4567 = vfmaq_f32(vacc4567, vi3x4567, vk3x4567);
+
+ const float32x4_t vi4x0123 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi4x4567 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vk4x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk4x4567 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi4x0123, vk4x0123);
+ vacc4567 = vfmaq_f32(vacc4567, vi4x4567, vk4x4567);
+
+ const float32x4_t vi5x0123 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi5x4567 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vk5x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk5x4567 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi5x0123, vk5x0123);
+ vacc4567 = vfmaq_f32(vacc4567, vi5x4567, vk5x4567);
+
+ const float32x4_t vi6x0123 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi6x4567 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vk6x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk6x4567 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi6x0123, vk6x0123);
+ vacc4567 = vfmaq_f32(vacc4567, vi6x4567, vk6x4567);
+
+ const float32x4_t vi7x0123 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vi7x4567 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vk7x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk7x4567 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi7x0123, vk7x0123);
+ vacc4567 = vfmaq_f32(vacc4567, vi7x4567, vk7x4567);
+
+ const float32x4_t vi8x0123 = vld1q_f32(i8); i8 += 4;
+ const float32x4_t vi8x4567 = vld1q_f32(i8); i8 += 4;
+ const float32x4_t vk8x0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vk8x4567 = vld1q_f32(w); w += 4;
+ vacc0123 = vfmaq_f32(vacc0123, vi8x0123, vk8x0123);
+ vacc4567 = vfmaq_f32(vacc4567, vi8x4567, vk8x4567);
+
+ vacc0123 = vmaxq_f32(vacc0123, vmin);
+ vacc4567 = vmaxq_f32(vacc4567, vmin);
+ vacc0123 = vminq_f32(vacc0123, vmax);
+ vacc4567 = vminq_f32(vacc4567, vmax);
+
+ if (c & 4) {
+ vst1q_f32(output, vacc0123); output += 4;
+ vacc0123 = vacc4567;
+ }
+ float32x2_t vacc01 = vget_low_f32(vacc0123);
+ if (c & 2) {
+ vst1_f32(output, vacc01); output += 2;
+ vacc01 = vget_high_f32(vacc0123);
+ }
+ if (c & 1) {
+ vst1_lane_f32(output, vacc01, 0); output += 1;
+ }
+ }
+
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/f32-gavgpool-spchw/neon-x4.c b/src/f32-gavgpool-spchw/neon-x4.c
new file mode 100644
index 0000000..a9514e7
--- /dev/null
+++ b/src/f32-gavgpool-spchw/neon-x4.c
@@ -0,0 +1,125 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gavgpool_spchw_ukernel__neon_x4(
+ size_t elements,
+ size_t channels,
+ const float* input,
+ float* output,
+ const union xnn_f32_gavgpool_params params[restrict static 1])
+{
+ assert(elements != 0);
+ assert(elements % sizeof(float) == 0);
+ assert(channels != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + elements);
+ const float* i2 = (const float*) ((uintptr_t) i1 + elements);
+ const float* i3 = (const float*) ((uintptr_t) i2 + elements);
+
+ const uint32x4_t vmask = vld1q_u32(params->neon.mask);
+ const float32x4_t vmultiplier = vld1q_dup_f32(¶ms->neon.multiplier);
+ const float32x4_t voutput_min = vld1q_dup_f32(¶ms->neon.output_min);
+ const float32x4_t voutput_max = vld1q_dup_f32(¶ms->neon.output_max);
+
+ while (channels >= 4) {
+ float32x4_t vsum0 = vmovq_n_f32(0.0f);
+ float32x4_t vsum1 = vmovq_n_f32(0.0f);
+ float32x4_t vsum2 = vmovq_n_f32(0.0f);
+ float32x4_t vsum3 = vmovq_n_f32(0.0f);
+ size_t n = elements;
+ while (n >= 4 * sizeof(float)) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+
+ vsum0 = vaddq_f32(vsum0, vi0);
+ vsum1 = vaddq_f32(vsum1, vi1);
+ vsum2 = vaddq_f32(vsum2, vi2);
+ vsum3 = vaddq_f32(vsum3, vi3);
+ n -= 4 * sizeof(float);
+ }
+
+ if XNN_UNLIKELY(n != 0) {
+ float32x4_t vi0 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + n);
+ float32x4_t vi1 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + n);
+ float32x4_t vi2 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + n);
+ float32x4_t vi3 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + n);
+
+ vi0 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0)));
+ vi1 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi1)));
+ vi2 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi2)));
+ vi3 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi3)));
+
+ vsum0 = vaddq_f32(vsum0, vi0);
+ vsum1 = vaddq_f32(vsum1, vi1);
+ vsum2 = vaddq_f32(vsum2, vi2);
+ vsum3 = vaddq_f32(vsum3, vi3);
+ }
+
+ // Having exaclty 4 rows makes this work out nicely as we end up with
+ // the 4 totals in 4 different lanes of the same vector.
+#ifdef __aarch64__
+ const float32x4_t vsum01 = vpaddq_f32(vsum0, vsum1);
+ const float32x4_t vsum23 = vpaddq_f32(vsum2, vsum3);
+ const float32x4_t vsum = vpaddq_f32(vsum01, vsum23);
+#else
+ const float32x4_t vsum01 = vcombine_f32(vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0)),
+ vadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)));
+ const float32x4_t vsum23 = vcombine_f32(vadd_f32(vget_low_f32(vsum2), vget_high_f32(vsum2)),
+ vadd_f32(vget_low_f32(vsum3), vget_high_f32(vsum3)));
+ const float32x4_t vsum = vcombine_f32(vpadd_f32(vget_low_f32(vsum01), vget_high_f32(vsum01)),
+ vpadd_f32(vget_low_f32(vsum23), vget_high_f32(vsum23)));
+#endif
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ vst1q_f32(output, vout); output += 4;
+ i0 = i3;
+ i1 = (const float*) ((uintptr_t) i0 + elements);
+ i2 = (const float*) ((uintptr_t) i1 + elements);
+ i3 = (const float*) ((uintptr_t) i2 + elements);
+ channels -= 4;
+ }
+
+ while (channels != 0) {
+ float32x4_t vsum0 = vmovq_n_f32(0.0f);
+ size_t n = elements;
+ while (n >= 4 * sizeof(float)) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ vsum0 = vaddq_f32(vsum0, vi0);
+ n -= 4 * sizeof(float);
+ }
+
+ if XNN_UNLIKELY(n != 0) {
+ float32x4_t vi0 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + n);
+ vi0 = vreinterpretq_u32_f32(vandq_u32(vmask, vreinterpretq_f32_u32(vi0)));
+ vsum0 = vaddq_f32(vsum0, vi0);
+ }
+
+ float32x2_t vsum = vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0));
+ vsum = vpadd_f32(vsum, vsum);
+
+ float32x2_t vout = vmul_f32(vsum, vget_low_f32(vmultiplier));
+
+ vout = vmax_f32(vout, vget_low_f32(voutput_min));
+ vout = vmin_f32(vout, vget_low_f32(voutput_max));
+
+ vst1_lane_f32(output, vout, 0); output += 1;
+ channels -= 1;
+ }
+}
diff --git a/src/f32-gavgpool-spchw/sse-x4.c b/src/f32-gavgpool-spchw/sse-x4.c
new file mode 100644
index 0000000..e294cfb
--- /dev/null
+++ b/src/f32-gavgpool-spchw/sse-x4.c
@@ -0,0 +1,121 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gavgpool_spchw_ukernel__sse_x4(
+ size_t elements,
+ size_t channels,
+ const float* input,
+ float* output,
+ const union xnn_f32_gavgpool_params params[restrict static 1])
+{
+ assert(elements != 0);
+ assert(elements % sizeof(float) == 0);
+ assert(channels != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + elements);
+ const float* i2 = (const float*) ((uintptr_t) i1 + elements);
+ const float* i3 = (const float*) ((uintptr_t) i2 + elements);
+
+ const __m128 vmask = _mm_load_ps((const float*) params->sse.mask);
+ const __m128 vmultiplier = _mm_load_ps(params->sse.multiplier);
+ const __m128 voutput_min = _mm_load_ps(params->sse.output_min);
+ const __m128 voutput_max = _mm_load_ps(params->sse.output_max);
+
+ while (channels >= 4) {
+ __m128 vsum0 = _mm_setzero_ps();
+ __m128 vsum1 = _mm_setzero_ps();
+ __m128 vsum2 = _mm_setzero_ps();
+ __m128 vsum3 = _mm_setzero_ps();
+ size_t n = elements;
+ while (n >= 4 * sizeof(float)) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+
+ vsum0 = _mm_add_ps(vsum0, vi0);
+ vsum1 = _mm_add_ps(vsum1, vi1);
+ vsum2 = _mm_add_ps(vsum2, vi2);
+ vsum3 = _mm_add_ps(vsum3, vi3);
+ n -= 4 * sizeof(float);
+ }
+
+ if XNN_UNLIKELY(n != 0) {
+ const __m128 vi0 = _mm_and_ps(_mm_loadu_ps(i0), vmask);
+ i0 = (const float*) ((uintptr_t) i0 + n);
+ const __m128 vi1 = _mm_and_ps(_mm_loadu_ps(i1), vmask);
+ i1 = (const float*) ((uintptr_t) i1 + n);
+ const __m128 vi2 = _mm_and_ps(_mm_loadu_ps(i2), vmask);
+ i2 = (const float*) ((uintptr_t) i2 + n);
+ const __m128 vi3 = _mm_and_ps(_mm_loadu_ps(i3), vmask);
+ i3 = (const float*) ((uintptr_t) i3 + n);
+
+ vsum0 = _mm_add_ps(vsum0, vi0);
+ vsum1 = _mm_add_ps(vsum1, vi1);
+ vsum2 = _mm_add_ps(vsum2, vi2);
+ vsum3 = _mm_add_ps(vsum3, vi3);
+ }
+
+ // Having exaclty 4 rows makes this work out nicely as we end up with
+ // the 4 totals in 4 different lanes of the same vector.
+ const __m128 vsum01 = _mm_add_ps(_mm_unpacklo_ps(vsum0, vsum1), _mm_unpackhi_ps(vsum0, vsum1));
+ const __m128 vsum23 = _mm_add_ps(_mm_unpacklo_ps(vsum2, vsum3), _mm_unpackhi_ps(vsum2, vsum3));
+ const __m128 vsum = _mm_add_ps(_mm_movelh_ps(vsum01, vsum23), _mm_movehl_ps(vsum23, vsum01));
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ _mm_storeu_ps(output, vout);
+ output += 4;
+ i0 = i3;
+ i1 = (const float*) ((uintptr_t) i0 + elements);
+ i2 = (const float*) ((uintptr_t) i1 + elements);
+ i3 = (const float*) ((uintptr_t) i2 + elements);
+ channels -= 4;
+ }
+
+ while (channels != 0) {
+ __m128 vsum = _mm_setzero_ps();
+ size_t n = elements;
+ while (n >= 4 * sizeof(float)) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ vsum = _mm_add_ps(vsum, vi0);
+ n -= 4 * sizeof(float);
+ }
+
+ if XNN_UNLIKELY(n != 0) {
+ __m128 vi0 = _mm_and_ps(_mm_loadu_ps(i0), vmask);
+ i0 = (const float*) ((uintptr_t) i0 + n);
+ vsum = _mm_add_ps(vsum, vi0);
+ }
+
+ vsum = _mm_add_ps(vsum, _mm_movehl_ps(vsum, vsum));
+ vsum = _mm_add_ss(vsum, _mm_shuffle_ps(vsum, vsum, _MM_SHUFFLE(3, 2, 1, 1)));
+
+ __m128 vout = _mm_mul_ss(vsum, vmultiplier);
+
+ vout = _mm_max_ss(vout, voutput_min);
+ vout = _mm_min_ss(vout, voutput_max);
+
+ _mm_store_ss(output, vout);
+ output += 1;
+ channels -= 1;
+ }
+}
diff --git a/src/f32-gavgpool/mp7p7q-neon.c b/src/f32-gavgpool/mp7p7q-neon.c
new file mode 100644
index 0000000..ec7224f
--- /dev/null
+++ b/src/f32-gavgpool/mp7p7q-neon.c
@@ -0,0 +1,186 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gavgpool_ukernel_mp7p7q__neon(
+ size_t m,
+ size_t n,
+ const float* input,
+ size_t input_stride,
+ const float* zero,
+ float* buffer,
+ float* output,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(m > 7);
+ assert(n != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
+ const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
+ const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
+ const size_t packed_n = round_up_po2(n, 4);
+ const size_t input_increment = 7 * input_stride - packed_n * sizeof(float);
+
+ float* b = buffer;
+ for (size_t k = 0; k < n; k += 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+
+ const float32x4_t vsum016 = vaddq_f32(vsum01, vi6);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+
+ const float32x4_t vsum = vaddq_f32(vsum016, vsum2345);
+
+ vst1q_f32(b, vsum); b += 4;
+ }
+ for (m -= 7; m > 7; m -= 7) {
+ b = buffer;
+
+ i0 = (const float*) ((uintptr_t) i0 + input_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_increment);
+ i2 = (const float*) ((uintptr_t) i2 + input_increment);
+ i3 = (const float*) ((uintptr_t) i3 + input_increment);
+ i4 = (const float*) ((uintptr_t) i4 + input_increment);
+ i5 = (const float*) ((uintptr_t) i5 + input_increment);
+ i6 = (const float*) ((uintptr_t) i6 + input_increment);
+
+ for (size_t k = 0; k < n; k += 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vacc = vld1q_f32(b);
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum6a = vaddq_f32(vi6, vacc);
+
+ const float32x4_t vsum0123 = vaddq_f32(vsum01, vsum23);
+ const float32x4_t vsum456a = vaddq_f32(vsum45, vsum6a);
+
+ const float32x4_t vsum = vaddq_f32(vsum0123, vsum456a);
+
+ vst1q_f32(b, vsum); b += 4;
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ i2 = (const float*) ((uintptr_t) i2 + input_increment);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ i3 = (const float*) ((uintptr_t) i3 + input_increment);
+ if (m < 4) {
+ i3 = zero;
+ }
+ i4 = (const float*) ((uintptr_t) i4 + input_increment);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ i5 = (const float*) ((uintptr_t) i5 + input_increment);
+ if (m < 6) {
+ i5 = zero;
+ }
+ i6 = (const float*) ((uintptr_t) i6 + input_increment);
+ if (m <= 6) {
+ i6 = zero;
+ }
+ const float32x4_t vmultiplier = vld1q_dup_f32(¶ms->scalar.multiplier);
+ const float32x4_t voutput_min = vld1q_dup_f32(¶ms->scalar.output_min);
+ const float32x4_t voutput_max = vld1q_dup_f32(¶ms->scalar.output_max);
+
+ b = buffer;
+ while (n >= 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vacc = vld1q_f32(b); b += 4;
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum6a = vaddq_f32(vi6, vacc);
+
+ const float32x4_t vsum0123 = vaddq_f32(vsum01, vsum23);
+ const float32x4_t vsum456a = vaddq_f32(vsum45, vsum6a);
+
+ const float32x4_t vsum = vaddq_f32(vsum0123, vsum456a);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ vst1q_f32(output, vout); output += 4;
+
+ n -= 4;
+ }
+ if (n != 0) {
+ const float32x4_t vi0 = vld1q_f32(i0);
+ const float32x4_t vi1 = vld1q_f32(i1);
+ const float32x4_t vi2 = vld1q_f32(i2);
+ const float32x4_t vi3 = vld1q_f32(i3);
+ const float32x4_t vi4 = vld1q_f32(i4);
+ const float32x4_t vi5 = vld1q_f32(i5);
+ const float32x4_t vi6 = vld1q_f32(i6);
+ const float32x4_t vacc = vld1q_f32(b);
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum6a = vaddq_f32(vi6, vacc);
+
+ const float32x4_t vsum0123 = vaddq_f32(vsum01, vsum23);
+ const float32x4_t vsum456a = vaddq_f32(vsum45, vsum6a);
+
+ const float32x4_t vsum = vaddq_f32(vsum0123, vsum456a);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ float32x2_t vout_lo = vget_low_f32(vout);
+ if (n & 2) {
+ vst1_f32(output, vout_lo); output += 2;
+ vout_lo = vget_high_f32(vout);
+ }
+ if (n & 1) {
+ vst1_lane_f32(output, vout_lo, 0);
+ }
+ }
+}
diff --git a/src/f32-gavgpool/mp7p7q-psimd.c b/src/f32-gavgpool/mp7p7q-psimd.c
new file mode 100644
index 0000000..be1a8dc
--- /dev/null
+++ b/src/f32-gavgpool/mp7p7q-psimd.c
@@ -0,0 +1,209 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gavgpool_ukernel_mp7p7q__psimd(
+ size_t m,
+ size_t n,
+ const float* input,
+ size_t input_stride,
+ const float* zero,
+ float* buffer,
+ float* output,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(m > 7);
+ assert(n != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
+ const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
+ const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
+ const size_t packed_n = round_up_po2(n, 4);
+ const size_t input_increment = 7 * input_stride - packed_n * sizeof(float);
+
+ float* b = buffer;
+ for (size_t k = 0; k < n; k += 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+
+ const psimd_f32 vsum016 = psimd_add_f32(vsum01, vi6);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+
+ const psimd_f32 vsum = psimd_add_f32(vsum016, vsum2345);
+
+ psimd_store_f32(b, vsum); b += 4;
+ }
+ for (m -= 7; m > 7; m -= 7) {
+ b = buffer;
+
+ i0 = (const float*) ((uintptr_t) i0 + input_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_increment);
+ i2 = (const float*) ((uintptr_t) i2 + input_increment);
+ i3 = (const float*) ((uintptr_t) i3 + input_increment);
+ i4 = (const float*) ((uintptr_t) i4 + input_increment);
+ i5 = (const float*) ((uintptr_t) i5 + input_increment);
+ i6 = (const float*) ((uintptr_t) i6 + input_increment);
+
+ for (size_t k = 0; k < n; k += 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vacc = psimd_load_f32(b);
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum6a = psimd_add_f32(vi6, vacc);
+
+ const psimd_f32 vsum0123 = psimd_add_f32(vsum01, vsum23);
+ const psimd_f32 vsum456a = psimd_add_f32(vsum45, vsum6a);
+
+ const psimd_f32 vsum = psimd_add_f32(vsum0123, vsum456a);
+
+ psimd_store_f32(b, vsum); b += 4;
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ i2 = (const float*) ((uintptr_t) i2 + input_increment);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ i3 = (const float*) ((uintptr_t) i3 + input_increment);
+ if (m < 4) {
+ i3 = zero;
+ }
+ i4 = (const float*) ((uintptr_t) i4 + input_increment);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ i5 = (const float*) ((uintptr_t) i5 + input_increment);
+ if (m < 6) {
+ i5 = zero;
+ }
+ i6 = (const float*) ((uintptr_t) i6 + input_increment);
+ if (m <= 6) {
+ i6 = zero;
+ }
+ const psimd_f32 vmultiplier = psimd_load_splat_f32(¶ms->scalar.multiplier);
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.output_min);
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.output_max);
+
+ b = buffer;
+ while (n >= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vacc = psimd_load_f32(b);
+ b += 4;
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum6a = psimd_add_f32(vi6, vacc);
+
+ const psimd_f32 vsum0123 = psimd_add_f32(vsum01, vsum23);
+ const psimd_f32 vsum456a = psimd_add_f32(vsum45, vsum6a);
+
+ const psimd_f32 vsum = psimd_add_f32(vsum0123, vsum456a);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ psimd_store_f32(output, vout);
+ output += 4;
+
+ n -= 4;
+ }
+ if (n != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vacc = psimd_load_f32(b);
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum6a = psimd_add_f32(vi6, vacc);
+
+ const psimd_f32 vsum0123 = psimd_add_f32(vsum01, vsum23);
+ const psimd_f32 vsum456a = psimd_add_f32(vsum45, vsum6a);
+
+ const psimd_f32 vsum = psimd_add_f32(vsum0123, vsum456a);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ if (n & 2) {
+ psimd_store2_f32(output, vout);
+ output += 2;
+ vout = psimd_concat_hi_f32(vout, vout);
+ }
+ if (n & 1) {
+ psimd_store1_f32(output, vout);
+ }
+ }
+}
diff --git a/src/f32-gavgpool/mp7p7q-scalar.c b/src/f32-gavgpool/mp7p7q-scalar.c
new file mode 100644
index 0000000..ac721b8
--- /dev/null
+++ b/src/f32-gavgpool/mp7p7q-scalar.c
@@ -0,0 +1,150 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gavgpool_ukernel_mp7p7q__scalar(
+ size_t m,
+ size_t n,
+ const float* input,
+ size_t input_stride,
+ const float* zero,
+ float* buffer,
+ float* output,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(m > 7);
+ assert(n != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
+ const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
+ const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
+ const size_t input_increment = 7 * input_stride - n * sizeof(float);
+
+ float* b = buffer;
+ size_t k = n;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+
+ const float vsum016 = vsum01 + vi6;
+ const float vsum2345 = vsum23 + vsum45;
+
+ const float vsum = vsum016 + vsum2345;
+
+ *b++ = vsum;
+ } while (--k != 0);
+ for (m -= 7; m > 7; m -= 7) {
+ b = buffer;
+
+ i0 = (const float*) ((uintptr_t) i0 + input_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_increment);
+ i2 = (const float*) ((uintptr_t) i2 + input_increment);
+ i3 = (const float*) ((uintptr_t) i3 + input_increment);
+ i4 = (const float*) ((uintptr_t) i4 + input_increment);
+ i5 = (const float*) ((uintptr_t) i5 + input_increment);
+ i6 = (const float*) ((uintptr_t) i6 + input_increment);
+
+ size_t k = n;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vacc = *b;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum6a = vi6 + vacc;
+
+ const float vsum0123 = vsum01 + vsum23;
+ const float vsum456a = vsum45 + vsum6a;
+
+ const float vsum = vsum0123 + vsum456a;
+
+ *b++ = vsum;
+ } while (--k != 0);
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ i2 = (const float*) ((uintptr_t) i2 + input_increment);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ i3 = (const float*) ((uintptr_t) i3 + input_increment);
+ if (m < 4) {
+ i3 = zero;
+ }
+ i4 = (const float*) ((uintptr_t) i4 + input_increment);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ i5 = (const float*) ((uintptr_t) i5 + input_increment);
+ if (m < 6) {
+ i5 = zero;
+ }
+ i6 = (const float*) ((uintptr_t) i6 + input_increment);
+ if (m <= 6) {
+ i6 = zero;
+ }
+ const float vmultiplier = params->scalar.multiplier;
+ const float voutput_min = params->scalar.output_min;
+ const float voutput_max = params->scalar.output_max;
+
+ b = buffer;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vacc = *b++;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum6a = vi6 + vacc;
+
+ const float vsum0123 = vsum01 + vsum23;
+ const float vsum456a = vsum45 + vsum6a;
+
+ const float vsum = vsum0123 + vsum456a;
+
+ float vout = vsum * vmultiplier;
+ vout = math_max_f32(vout, voutput_min);
+ vout = math_min_f32(vout, voutput_max);
+
+ *output++ = vout;
+ } while (--n != 0);
+}
diff --git a/src/f32-gavgpool/mp7p7q-sse.c b/src/f32-gavgpool/mp7p7q-sse.c
new file mode 100644
index 0000000..a7a8891
--- /dev/null
+++ b/src/f32-gavgpool/mp7p7q-sse.c
@@ -0,0 +1,209 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gavgpool_ukernel_mp7p7q__sse(
+ size_t m,
+ size_t n,
+ const float* input,
+ size_t input_stride,
+ const float* zero,
+ float* buffer,
+ float* output,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(m > 7);
+ assert(n != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
+ const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
+ const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
+ const size_t packed_n = round_up_po2(n, 4);
+ const size_t input_increment = 7 * input_stride - packed_n * sizeof(float);
+
+ float* b = buffer;
+ for (size_t k = 0; k < n; k += 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+
+ const __m128 vsum016 = _mm_add_ps(vsum01, vi6);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+
+ const __m128 vsum = _mm_add_ps(vsum016, vsum2345);
+
+ _mm_store_ps(b, vsum); b += 4;
+ }
+ for (m -= 7; m > 7; m -= 7) {
+ b = buffer;
+
+ i0 = (const float*) ((uintptr_t) i0 + input_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_increment);
+ i2 = (const float*) ((uintptr_t) i2 + input_increment);
+ i3 = (const float*) ((uintptr_t) i3 + input_increment);
+ i4 = (const float*) ((uintptr_t) i4 + input_increment);
+ i5 = (const float*) ((uintptr_t) i5 + input_increment);
+ i6 = (const float*) ((uintptr_t) i6 + input_increment);
+
+ for (size_t k = 0; k < n; k += 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vacc = _mm_load_ps(b);
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum6a = _mm_add_ps(vi6, vacc);
+
+ const __m128 vsum0123 = _mm_add_ps(vsum01, vsum23);
+ const __m128 vsum456a = _mm_add_ps(vsum45, vsum6a);
+
+ const __m128 vsum = _mm_add_ps(vsum0123, vsum456a);
+
+ _mm_store_ps(b, vsum); b += 4;
+ }
+ }
+
+ i0 = (const float*) ((uintptr_t) i0 + input_increment);
+ i1 = (const float*) ((uintptr_t) i1 + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ i2 = (const float*) ((uintptr_t) i2 + input_increment);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ i3 = (const float*) ((uintptr_t) i3 + input_increment);
+ if (m < 4) {
+ i3 = zero;
+ }
+ i4 = (const float*) ((uintptr_t) i4 + input_increment);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ i5 = (const float*) ((uintptr_t) i5 + input_increment);
+ if (m < 6) {
+ i5 = zero;
+ }
+ i6 = (const float*) ((uintptr_t) i6 + input_increment);
+ if (m <= 6) {
+ i6 = zero;
+ }
+ const __m128 vmultiplier = _mm_load_ps(params->sse2.multiplier);
+ const __m128 voutput_min = _mm_load_ps(params->sse2.output_min);
+ const __m128 voutput_max = _mm_load_ps(params->sse2.output_max);
+
+ b = buffer;
+ while (n >= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vacc = _mm_load_ps(b);
+ b += 4;
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum6a = _mm_add_ps(vi6, vacc);
+
+ const __m128 vsum0123 = _mm_add_ps(vsum01, vsum23);
+ const __m128 vsum456a = _mm_add_ps(vsum45, vsum6a);
+
+ const __m128 vsum = _mm_add_ps(vsum0123, vsum456a);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ _mm_storeu_ps(output, vout);
+ output += 4;
+
+ n -= 4;
+ }
+ if (n != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vacc = _mm_loadu_ps(b);
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum6a = _mm_add_ps(vi6, vacc);
+
+ const __m128 vsum0123 = _mm_add_ps(vsum01, vsum23);
+ const __m128 vsum456a = _mm_add_ps(vsum45, vsum6a);
+
+ const __m128 vsum = _mm_add_ps(vsum0123, vsum456a);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ if (n & 2) {
+ _mm_storel_pi((__m64*) output, vout);
+ vout = _mm_movehl_ps(vout, vout);
+ output += 2;
+ }
+ if (n & 1) {
+ _mm_store_ss(output, vout);
+ }
+ }
+}
diff --git a/src/f32-gavgpool/up7-neon.c b/src/f32-gavgpool/up7-neon.c
new file mode 100644
index 0000000..e102996
--- /dev/null
+++ b/src/f32-gavgpool/up7-neon.c
@@ -0,0 +1,114 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gavgpool.h>
+
+
+void xnn_f32_gavgpool_ukernel_up7__neon(
+ size_t m,
+ size_t n,
+ const float* input,
+ size_t input_stride,
+ const float* zero,
+ float* output,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(m <= 7);
+ assert(n != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
+ if (m < 2) {
+ i1 = zero;
+ }
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
+ if (m < 4) {
+ i3 = zero;
+ }
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
+ if (m < 6) {
+ i5 = zero;
+ }
+ const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
+ if (m <= 6) {
+ i6 = zero;
+ }
+ const float32x4_t vmultiplier = vld1q_dup_f32(¶ms->scalar.multiplier);
+ const float32x4_t voutput_min = vld1q_dup_f32(¶ms->scalar.output_min);
+ const float32x4_t voutput_max = vld1q_dup_f32(¶ms->scalar.output_max);
+
+ while (n >= 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+
+ const float32x4_t vsum016 = vaddq_f32(vsum01, vi6);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+
+ const float32x4_t vsum = vaddq_f32(vsum016, vsum2345);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ vst1q_f32(output, vout); output += 4;
+
+ n -= 4;
+ }
+ if (n != 0) {
+ const float32x4_t vi0 = vld1q_f32(i0);
+ const float32x4_t vi1 = vld1q_f32(i1);
+ const float32x4_t vi2 = vld1q_f32(i2);
+ const float32x4_t vi3 = vld1q_f32(i3);
+ const float32x4_t vi4 = vld1q_f32(i4);
+ const float32x4_t vi5 = vld1q_f32(i5);
+ const float32x4_t vi6 = vld1q_f32(i6);
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+
+ const float32x4_t vsum016 = vaddq_f32(vsum01, vi6);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+
+ const float32x4_t vsum = vaddq_f32(vsum016, vsum2345);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ float32x2_t vout_lo = vget_low_f32(vout);
+ if (n & 2) {
+ vst1_f32(output, vout_lo); output += 2;
+ vout_lo = vget_high_f32(vout);
+ }
+ if (n & 1) {
+ vst1_lane_f32(output, vout_lo, 0);
+ }
+ }
+}
diff --git a/src/f32-gavgpool/up7-psimd.c b/src/f32-gavgpool/up7-psimd.c
new file mode 100644
index 0000000..3c69d53
--- /dev/null
+++ b/src/f32-gavgpool/up7-psimd.c
@@ -0,0 +1,122 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gavgpool.h>
+
+
+void xnn_f32_gavgpool_ukernel_up7__psimd(
+ size_t m,
+ size_t n,
+ const float* input,
+ size_t input_stride,
+ const float* zero,
+ float* output,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(m <= 7);
+ assert(n != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
+ if (m < 2) {
+ i1 = zero;
+ }
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
+ if (m < 4) {
+ i3 = zero;
+ }
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
+ if (m < 6) {
+ i5 = zero;
+ }
+ const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
+ if (m <= 6) {
+ i6 = zero;
+ }
+ const psimd_f32 vmultiplier = psimd_load_splat_f32(¶ms->scalar.multiplier);
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.output_min);
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.output_max);
+
+ while (n >= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+
+ const psimd_f32 vsum016 = psimd_add_f32(vsum01, vi6);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+
+ const psimd_f32 vsum = psimd_add_f32(vsum016, vsum2345);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ psimd_store_f32(output, vout);
+ output += 4;
+
+ n -= 4;
+ }
+ if (n != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+
+ const psimd_f32 vsum016 = psimd_add_f32(vsum01, vi6);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+
+ const psimd_f32 vsum = psimd_add_f32(vsum016, vsum2345);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ if (n & 2) {
+ psimd_store2_f32(output, vout);
+ output += 2;
+ vout = psimd_concat_hi_f32(vout, vout);
+ }
+ if (n & 1) {
+ psimd_store1_f32(output, vout);
+ }
+ }
+}
diff --git a/src/f32-gavgpool/up7-scalar.c b/src/f32-gavgpool/up7-scalar.c
new file mode 100644
index 0000000..34788d6
--- /dev/null
+++ b/src/f32-gavgpool/up7-scalar.c
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gavgpool_ukernel_up7__scalar(
+ size_t m,
+ size_t n,
+ const float* input,
+ size_t input_stride,
+ const float* zero,
+ float* output,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(m <= 7);
+ assert(n != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
+ if (m < 2) {
+ i1 = zero;
+ }
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
+ if (m < 4) {
+ i3 = zero;
+ }
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
+ if (m < 6) {
+ i5 = zero;
+ }
+ const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
+ if (m <= 6) {
+ i6 = zero;
+ }
+
+ const float vmultiplier = params->scalar.multiplier;
+ const float voutput_min = params->scalar.output_min;
+ const float voutput_max = params->scalar.output_max;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+
+ const float vsum016 = vsum01 + vi6;
+ const float vsum2345 = vsum23 + vsum45;
+
+ const float vsum = vsum016 + vsum2345;
+
+ float vout = vsum * vmultiplier;
+ vout = math_max_f32(vout, voutput_min);
+ vout = math_min_f32(vout, voutput_max);
+
+ *output++ = vout;
+ } while (--n != 0);
+}
diff --git a/src/f32-gavgpool/up7-sse.c b/src/f32-gavgpool/up7-sse.c
new file mode 100644
index 0000000..ec23f2e
--- /dev/null
+++ b/src/f32-gavgpool/up7-sse.c
@@ -0,0 +1,122 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gavgpool.h>
+
+
+void xnn_f32_gavgpool_ukernel_up7__sse(
+ size_t m,
+ size_t n,
+ const float* input,
+ size_t input_stride,
+ const float* zero,
+ float* output,
+ const union xnn_f32_avgpool_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(m <= 7);
+ assert(n != 0);
+
+ const float* i0 = input;
+ const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
+ if (m < 2) {
+ i1 = zero;
+ }
+ const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
+ if (m < 4) {
+ i3 = zero;
+ }
+ const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
+ if (m < 6) {
+ i5 = zero;
+ }
+ const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
+ if (m <= 6) {
+ i6 = zero;
+ }
+ const __m128 vmultiplier = _mm_load_ps(params->sse2.multiplier);
+ const __m128 voutput_min = _mm_load_ps(params->sse2.output_min);
+ const __m128 voutput_max = _mm_load_ps(params->sse2.output_max);
+
+ while (n >= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+
+ const __m128 vsum016 = _mm_add_ps(vsum01, vi6);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+
+ const __m128 vsum = _mm_add_ps(vsum016, vsum2345);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ _mm_storeu_ps(output, vout);
+ output += 4;
+
+ n -= 4;
+ }
+ if (n != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vi6 = _mm_loadu_ps(i6);
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+
+ const __m128 vsum016 = _mm_add_ps(vsum01, vi6);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+
+ const __m128 vsum = _mm_add_ps(vsum016, vsum2345);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ if (n & 2) {
+ _mm_storel_pi((__m64*) output, vout);
+ vout = _mm_movehl_ps(vout, vout);
+ output += 2;
+ }
+ if (n & 1) {
+ _mm_store_ss(output, vout);
+ }
+ }
+}
diff --git a/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S
new file mode 100644
index 0000000..3bca67f
--- /dev/null
+++ b/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S
@@ -0,0 +1,350 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, (x4) - unused
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+
+# C pointers
+# x6 c0
+
+# Vector register usage and GPR shadows
+# a0 v0 first set of A
+# a0 v1 second set of A
+# B v2 v3 v4 x7 x10 x16 first set of B
+# B v5 v6 v7 x17 x18 x9
+# B v23 v24 v25 x7 x10 x16 second set of B (same x as first set)
+# B v17 v18 v19 x17 x18 x9
+# C v20 v21 v22
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53
+
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+0:
+ # Load initial bias from w into accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x5], 48
+
+ PRFM PLDL1KEEP, [x5]
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x5, 256]
+ PRFM PLDL1KEEP, [x5, 320]
+
+ # Is there at least 4 floats (16 bytes) for prologue + epilogue?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 3f
+
+ # Prologue - loads for first group of 6 fma
+
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+
+ LDR d2, [x5] // vb0x0123
+ LDR x7, [x5, 8]
+
+ LDR d3, [x5, 16] // vb0x4567
+ LDR x10, [x5, 24]
+
+ LDR d4, [x5, 32] // vb0x89AB
+ LDR x16, [x5, 40]
+
+ LDR d5, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+
+ LDR d6, [x5, 64] // vb1x4567
+ LDR x18, [x5, 72]
+
+ LDR d7, [x5, 80] // vb1x89AB
+ LDR x9, [x5, 88]
+ INS v2.d[1], x7
+ ADD x5, x5, 96
+
+ # Is there at least 4 floats (16 bytes) for main loop?
+ SUBS x0, x0, 16
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ # First group of 6 fma.
+ # A is loaded for 2nd group into v1
+
+ # BLOCK 0
+ LDR d1, [x3], 8 // a0
+ INS v3.d[1], x10
+ FMLA v20.4s, v2.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 192]
+
+ # BLOCK 1
+ INS v4.d[1], x16
+ FMLA v21.4s, v3.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+
+ # BLOCK 2
+ LDR d23, [x5] // vb0x0123
+ INS v5.d[1], x17
+ LDR x7, [x5, 8]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d24, [x5, 16] // vb0x4567
+ INS v6.d[1], x18
+ LDR x10, [x5, 24]
+
+ # BLOCK 4
+ LDR d25, [x5, 32] // vb0x89AB
+ INS v7.d[1], x9
+ FMLA v20.4s, v5.4s, v0.s[1]
+ LDR x16, [x5, 40]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v6.4s, v0.s[1]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ LDR x18, [x5, 72]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v23.d[1], x7 // v23 was loaded in block 2
+ LDR x9, [x5, 88]
+
+ # Second group of 6 fma.
+ # A is loaded for 1st group into v0
+
+ # BLOCK 0
+ LDR d0, [x3], 8 // a0
+ INS v24.d[1], x10
+ FMLA v20.4s, v23.4s, v1.s[0]
+
+ # BLOCK 1
+ INS v25.d[1], x16
+ FMLA v21.4s, v24.4s, v1.s[0]
+
+ # BLOCK 2
+ LDR d2, [x5, 96] // vb0x0123
+ INS v17.d[1], x17
+ LDR x7, [x5, 104]
+ FMLA v22.4s, v25.4s, v1.s[0]
+
+ # BLOCK 3
+ LDR d3, [x5, 112] // vb0x4567
+ INS v18.d[1], x18
+ LDR x10, [x5, 120]
+
+ # BLOCK 4
+ LDR d4, [x5, 128] // vb0x89AB
+ INS v19.d[1], x9
+ FMLA v20.4s, v17.4s, v1.s[1]
+ LDR x16, [x5, 136]
+
+ # BLOCK 5
+ LDR d5, [x5, 144] // vb1x0123
+ LDR x17, [x5, 152]
+ FMLA v21.4s, v18.4s, v1.s[1]
+
+ # BLOCK 6
+ LDR d6, [x5, 160] // vb1x4567
+ LDR x18, [x5, 168]
+ SUBS x0, x0, 16
+ FMLA v22.4s, v19.4s, v1.s[1]
+
+ # BLOCK 7
+ LDR d7, [x5, 176] // vb1x89AB
+ INS v2.d[1], x7
+ LDR x9, [x5, 184]
+ ADD x5, x5, 192
+ B.HS 1b
+
+ # Epilogue
+ # First block same as main loop. Second block has no loads.
+2:
+ # BLOCK 0
+ LDR d1, [x3], 8 // a0
+ INS v3.d[1], x10
+ FMLA v20.4s, v2.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 192]
+
+ # BLOCK 1
+ INS v4.d[1], x16
+ FMLA v21.4s, v3.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+
+ # BLOCK 2
+ LDR d23, [x5] // vb0x0123
+ INS v5.d[1], x17
+ LDR x7, [x5, 8]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d24, [x5, 16] // vb0x4567
+ INS v6.d[1], x18
+ LDR x10, [x5, 24]
+
+ # BLOCK 4
+ LDR d25, [x5, 32] // vb0x89AB
+ INS v7.d[1], x9
+ FMLA v20.4s, v5.4s, v0.s[1]
+ LDR x16, [x5, 40]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v6.4s, v0.s[1]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ LDR x18, [x5, 72]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v23.d[1], x7 // v23 was loaded in block 2
+ LDR x9, [x5, 88]
+ ADD x5, x5, 96
+
+ # Second group of 6 fma. 8 blocks of 4 cycles.
+ # Epilogue version does no loads
+
+ # BLOCK 0
+ INS v24.d[1], x10
+ FMLA v20.4s, v23.4s, v1.s[0]
+
+ # BLOCK 1
+ INS v25.d[1], x16
+ FMLA v21.4s, v24.4s, v1.s[0]
+
+ # BLOCK 2
+ INS v17.d[1], x17
+ FMLA v22.4s, v25.4s, v1.s[0]
+
+ # BLOCK 3
+ INS v18.d[1], x18
+
+ # BLOCK 4
+ INS v19.d[1], x9
+ FMLA v20.4s, v17.4s, v1.s[1]
+
+ # BLOCK 5
+ FMLA v21.4s, v18.4s, v1.s[1]
+
+ # BLOCK 6
+ FMLA v22.4s, v19.4s, v1.s[1]
+
+ # BLOCK 7
+3:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 5f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 6f
+
+4:
+ # Clamp
+ FMIN v20.4s, v20.4s, v30.4s
+ FMIN v21.4s, v21.4s, v30.4s
+ FMIN v22.4s, v22.4s, v30.4s
+ FMAX v20.4s, v20.4s, v31.4s
+ FMAX v21.4s, v21.4s, v31.4s
+ FMAX v22.4s, v22.4s, v31.4s
+
+ # Store full 1 x 12
+ CMP x1, 12
+ B.LO 7f
+
+ ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14
+
+ SUB x3, x3, x2 // a0 -= kc
+
+ SUBS x1, x1, 12
+ B.HI 0b
+
+ RET
+
+5:
+ # Remainder - 2 floats of A (8 bytes)
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+ LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
+ LD1 {v5.16b, v6.16b, v7.16b}, [x5], 48
+
+ # First block of 3 B
+ FMLA v20.4s, v2.4s, v0.s[0]
+ FMLA v21.4s, v3.4s, v0.s[0]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # Second block of 3 B
+ FMLA v20.4s, v5.4s, v0.s[1]
+ FMLA v21.4s, v6.4s, v0.s[1]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ TBZ x0, 2, 4b
+6:
+ # Remainder - 1 float of A (4 bytes)
+ LDR s0, [x3], 4 // a0
+ LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
+
+ FMLA v20.4s, v2.4s, v0.s[0]
+ FMLA v21.4s, v3.4s, v0.s[0]
+ FMLA v22.4s, v4.4s, v0.s[0]
+ B 4b
+
+7:
+ # Store odd channels
+ TBZ x1, 3, 8f
+ STP q20, q21, [x6]
+ ADD x6, x6, 32
+ MOV v20.16b, v22.16b
+
+8:
+ TBZ x1, 2, 9f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+9:
+ TBZ x1, 1, 10f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s20, [x6]
+11:
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in
new file mode 100644
index 0000000..c79b02e
--- /dev/null
+++ b/src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in
@@ -0,0 +1,360 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma_cortex_a53(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, (x4) - unused
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+
+# C pointers
+# x6 c0
+
+# Vector register usage and GPR shadows
+# a0 v0 first set of A
+# a0 v1 second set of A
+# B v2 v3 v4 x7 x10 x16 first set of B
+# B v5 v6 v7 x17 x18 x9
+# B v23 v24 v25 x7 x10 x16 second set of B (same x as first set)
+# B v17 v18 v19 x17 x18 x9
+# C v20 v21 v22
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma_cortex_a53
+
+ $if INC:
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+ $else:
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x15], 48
+ $else:
+ # Load initial bias from w into accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x5], 48
+
+ PRFM PLDL1KEEP, [x5]
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x5, 256]
+ PRFM PLDL1KEEP, [x5, 320]
+
+ # Is there at least 4 floats (16 bytes) for prologue + epilogue?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 3f
+
+ # Prologue - loads for first group of 6 fma
+
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+
+ LDR d2, [x5] // vb0x0123
+ LDR x7, [x5, 8]
+
+ LDR d3, [x5, 16] // vb0x4567
+ LDR x10, [x5, 24]
+
+ LDR d4, [x5, 32] // vb0x89AB
+ LDR x16, [x5, 40]
+
+ LDR d5, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+
+ LDR d6, [x5, 64] // vb1x4567
+ LDR x18, [x5, 72]
+
+ LDR d7, [x5, 80] // vb1x89AB
+ LDR x9, [x5, 88]
+ INS v2.d[1], x7
+ ADD x5, x5, 96
+
+ # Is there at least 4 floats (16 bytes) for main loop?
+ SUBS x0, x0, 16
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ # First group of 6 fma.
+ # A is loaded for 2nd group into v1
+
+ # BLOCK 0
+ LDR d1, [x3], 8 // a0
+ INS v3.d[1], x10
+ FMLA v20.4s, v2.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 192]
+
+ # BLOCK 1
+ INS v4.d[1], x16
+ FMLA v21.4s, v3.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+
+ # BLOCK 2
+ LDR d23, [x5] // vb0x0123
+ INS v5.d[1], x17
+ LDR x7, [x5, 8]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d24, [x5, 16] // vb0x4567
+ INS v6.d[1], x18
+ LDR x10, [x5, 24]
+
+ # BLOCK 4
+ LDR d25, [x5, 32] // vb0x89AB
+ INS v7.d[1], x9
+ FMLA v20.4s, v5.4s, v0.s[1]
+ LDR x16, [x5, 40]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v6.4s, v0.s[1]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ LDR x18, [x5, 72]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v23.d[1], x7 // v23 was loaded in block 2
+ LDR x9, [x5, 88]
+
+ # Second group of 6 fma.
+ # A is loaded for 1st group into v0
+
+ # BLOCK 0
+ LDR d0, [x3], 8 // a0
+ INS v24.d[1], x10
+ FMLA v20.4s, v23.4s, v1.s[0]
+
+ # BLOCK 1
+ INS v25.d[1], x16
+ FMLA v21.4s, v24.4s, v1.s[0]
+
+ # BLOCK 2
+ LDR d2, [x5, 96] // vb0x0123
+ INS v17.d[1], x17
+ LDR x7, [x5, 104]
+ FMLA v22.4s, v25.4s, v1.s[0]
+
+ # BLOCK 3
+ LDR d3, [x5, 112] // vb0x4567
+ INS v18.d[1], x18
+ LDR x10, [x5, 120]
+
+ # BLOCK 4
+ LDR d4, [x5, 128] // vb0x89AB
+ INS v19.d[1], x9
+ FMLA v20.4s, v17.4s, v1.s[1]
+ LDR x16, [x5, 136]
+
+ # BLOCK 5
+ LDR d5, [x5, 144] // vb1x0123
+ LDR x17, [x5, 152]
+ FMLA v21.4s, v18.4s, v1.s[1]
+
+ # BLOCK 6
+ LDR d6, [x5, 160] // vb1x4567
+ LDR x18, [x5, 168]
+ SUBS x0, x0, 16
+ FMLA v22.4s, v19.4s, v1.s[1]
+
+ # BLOCK 7
+ LDR d7, [x5, 176] // vb1x89AB
+ INS v2.d[1], x7
+ LDR x9, [x5, 184]
+ ADD x5, x5, 192
+ B.HS 1b
+
+ # Epilogue
+ # First block same as main loop. Second block has no loads.
+2:
+ # BLOCK 0
+ LDR d1, [x3], 8 // a0
+ INS v3.d[1], x10
+ FMLA v20.4s, v2.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 192]
+
+ # BLOCK 1
+ INS v4.d[1], x16
+ FMLA v21.4s, v3.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+
+ # BLOCK 2
+ LDR d23, [x5] // vb0x0123
+ INS v5.d[1], x17
+ LDR x7, [x5, 8]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d24, [x5, 16] // vb0x4567
+ INS v6.d[1], x18
+ LDR x10, [x5, 24]
+
+ # BLOCK 4
+ LDR d25, [x5, 32] // vb0x89AB
+ INS v7.d[1], x9
+ FMLA v20.4s, v5.4s, v0.s[1]
+ LDR x16, [x5, 40]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v6.4s, v0.s[1]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ LDR x18, [x5, 72]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v23.d[1], x7 // v23 was loaded in block 2
+ LDR x9, [x5, 88]
+ ADD x5, x5, 96
+
+ # Second group of 6 fma. 8 blocks of 4 cycles.
+ # Epilogue version does no loads
+
+ # BLOCK 0
+ INS v24.d[1], x10
+ FMLA v20.4s, v23.4s, v1.s[0]
+
+ # BLOCK 1
+ INS v25.d[1], x16
+ FMLA v21.4s, v24.4s, v1.s[0]
+
+ # BLOCK 2
+ INS v17.d[1], x17
+ FMLA v22.4s, v25.4s, v1.s[0]
+
+ # BLOCK 3
+ INS v18.d[1], x18
+
+ # BLOCK 4
+ INS v19.d[1], x9
+ FMLA v20.4s, v17.4s, v1.s[1]
+
+ # BLOCK 5
+ FMLA v21.4s, v18.4s, v1.s[1]
+
+ # BLOCK 6
+ FMLA v22.4s, v19.4s, v1.s[1]
+
+ # BLOCK 7
+3:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 5f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 6f
+
+4:
+ # Clamp
+ FMIN v20.4s, v20.4s, v30.4s
+ FMIN v21.4s, v21.4s, v30.4s
+ FMIN v22.4s, v22.4s, v30.4s
+ FMAX v20.4s, v20.4s, v31.4s
+ FMAX v21.4s, v21.4s, v31.4s
+ FMAX v22.4s, v22.4s, v31.4s
+
+ # Store full 1 x 12
+ CMP x1, 12
+ B.LO 7f
+
+ ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14
+
+ SUB x3, x3, x2 // a0 -= kc
+
+ SUBS x1, x1, 12
+ B.HI 0b
+
+ RET
+
+5:
+ # Remainder - 2 floats of A (8 bytes)
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+ LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
+ LD1 {v5.16b, v6.16b, v7.16b}, [x5], 48
+
+ # First block of 3 B
+ FMLA v20.4s, v2.4s, v0.s[0]
+ FMLA v21.4s, v3.4s, v0.s[0]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # Second block of 3 B
+ FMLA v20.4s, v5.4s, v0.s[1]
+ FMLA v21.4s, v6.4s, v0.s[1]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ TBZ x0, 2, 4b
+6:
+ # Remainder - 1 float of A (4 bytes)
+ LDR s0, [x3], 4 // a0
+ LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
+
+ FMLA v20.4s, v2.4s, v0.s[0]
+ FMLA v21.4s, v3.4s, v0.s[0]
+ FMLA v22.4s, v4.4s, v0.s[0]
+ B 4b
+
+7:
+ # Store odd channels
+ TBZ x1, 3, 8f
+ STP q20, q21, [x6]
+ ADD x6, x6, 32
+ MOV v20.16b, v22.16b
+
+8:
+ TBZ x1, 2, 9f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+9:
+ TBZ x1, 1, 10f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s20, [x6]
+11:
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x12__aarch64_neonfma_cortex_a53
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/1x4-scalar.c b/src/f32-gemm/1x4-scalar.c
new file mode 100644
index 0000000..db5a3d2
--- /dev/null
+++ b/src/f32-gemm/1x4-scalar.c
@@ -0,0 +1,101 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gemm_ukernel_1x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ float vacc00 = w[0];
+ float vacc01 = w[1];
+ float vacc02 = w[2];
+ float vacc03 = w[3];
+ w += 4;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc02 += va0 * vb2;
+ vacc03 += va0 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc02 = math_max_f32(vacc02, vmin);
+ vacc03 = math_max_f32(vacc03, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc02 = math_min_f32(vacc02, vmax);
+ vacc03 = math_min_f32(vacc03, vmax);
+
+ if XNN_LIKELY(nc >= 4) {
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0[2] = vacc02;
+ c0[3] = vacc03;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const void*) ((uintptr_t) a0 - kc);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ vacc00 = vacc02;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/1x8-aarch64-neonfma-cortex-a57.S b/src/f32-gemm/1x8-aarch64-neonfma-cortex-a57.S
new file mode 100644
index 0000000..8073c03
--- /dev/null
+++ b/src/f32-gemm/1x8-aarch64-neonfma-cortex-a57.S
@@ -0,0 +1,219 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/1x8-aarch64-neonfma-cortex-a57.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, (x4) - unused
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+
+# C pointers
+# x6 c0
+
+# Clamp v4 v5
+
+# A57 based on A75 but with PRFM removed from main loop
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57
+
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+0:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+
+ MOVI v18.4s, 0 // second set of C for pipelining FMLA
+ MOVI v19.4s, 0
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDP q20, q21, [x5], 32
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ LDR q0, [x3], 16
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. FMA for second 4, loads for 1st block of 4.
+ FMLA v16.4s, v20.4s, v1.s[0]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v21.4s, v1.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ SUBS x0, x0, 32
+ LDP q26, q27, [x5], 32
+ B.HS 1b
+
+2:
+ # Epilogue
+
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. no loads
+ FMLA v16.4s, v20.4s, v1.s[0]
+ FMLA v17.4s, v21.4s, v1.s[0]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+
+3:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBNZ x0, 4, 5f
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 8f
+
+4:
+ FADD v16.4s, v16.4s, v18.4s
+ FADD v17.4s, v17.4s, v19.4s
+
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+
+ # Store full 1 x 8
+ CMP x1, 8
+ B.LO 9f
+
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+5:
+ # Remainder- 4 floats of A (16 bytes)
+ LDP q20, q21, [x5], 32
+ LDR q0, [x3], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+
+ TBZ x0, 3, 7f
+6:
+ # Remainder- 2 floats of A (8 bytes)
+ LDP q20, q21, [x5], 32
+ LDR d0, [x3], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+7:
+ TBZ x0, 2, 4b
+8:
+ # Remainder- 1 float of A (4 bytes)
+ LDP q20, q21, [x5], 32
+ LDR s0, [x3], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ B 4b
+
+ # Store odd channels
+9:
+ TBZ x1, 2, 10f
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+10:
+ TBZ x1, 1, 11f
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+11:
+ TBZ x1, 0, 12f
+ STR s16, [x6]
+12:
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/1x8-aarch64-neonfma-cortex-a57.S.in b/src/f32-gemm/1x8-aarch64-neonfma-cortex-a57.S.in
new file mode 100644
index 0000000..ca50a32
--- /dev/null
+++ b/src/f32-gemm/1x8-aarch64-neonfma-cortex-a57.S.in
@@ -0,0 +1,229 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x8__aarch64_neonfma_cortex_a57(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, (x4) - unused
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+
+# C pointers
+# x6 c0
+
+# Clamp v4 v5
+
+# A57 based on A75 but with PRFM removed from main loop
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x8__aarch64_neonfma_cortex_a57
+
+ $if INC:
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+ $else:
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+
+ MOVI v18.4s, 0 // second set of C for pipelining FMLA
+ MOVI v19.4s, 0
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDP q20, q21, [x5], 32
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ LDR q0, [x3], 16
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. FMA for second 4, loads for 1st block of 4.
+ FMLA v16.4s, v20.4s, v1.s[0]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v21.4s, v1.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ SUBS x0, x0, 32
+ LDP q26, q27, [x5], 32
+ B.HS 1b
+
+2:
+ # Epilogue
+
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. no loads
+ FMLA v16.4s, v20.4s, v1.s[0]
+ FMLA v17.4s, v21.4s, v1.s[0]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+
+3:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBNZ x0, 4, 5f
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 8f
+
+4:
+ FADD v16.4s, v16.4s, v18.4s
+ FADD v17.4s, v17.4s, v19.4s
+
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+
+ # Store full 1 x 8
+ CMP x1, 8
+ B.LO 9f
+
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+5:
+ # Remainder- 4 floats of A (16 bytes)
+ LDP q20, q21, [x5], 32
+ LDR q0, [x3], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+
+ TBZ x0, 3, 7f
+6:
+ # Remainder- 2 floats of A (8 bytes)
+ LDP q20, q21, [x5], 32
+ LDR d0, [x3], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+7:
+ TBZ x0, 2, 4b
+8:
+ # Remainder- 1 float of A (4 bytes)
+ LDP q20, q21, [x5], 32
+ LDR s0, [x3], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ B 4b
+
+ # Store odd channels
+9:
+ TBZ x1, 2, 10f
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+10:
+ TBZ x1, 1, 11f
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+11:
+ TBZ x1, 0, 12f
+ STR s16, [x6]
+12:
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.S b/src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..8e0f557
--- /dev/null
+++ b/src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,223 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, (x4) - unused
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+
+# C pointers
+# x6 c0
+
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75
+
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+0:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+
+ MOVI v18.4s, 0 // second set of C for pipelining FMLA
+ MOVI v19.4s, 0
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDP q20, q21, [x5], 32
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ LDR q0, [x3], 16
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. FMA for second 4, loads for 1st block of 4.
+ FMLA v16.4s, v20.4s, v1.s[0]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v21.4s, v1.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ SUBS x0, x0, 32
+ LDP q26, q27, [x5], 32
+ B.HS 1b
+
+2:
+ # Epilogue
+
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. no loads
+ FMLA v16.4s, v20.4s, v1.s[0]
+ FMLA v17.4s, v21.4s, v1.s[0]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+
+3:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBNZ x0, 4, 5f
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 8f
+
+4:
+ FADD v16.4s, v16.4s, v18.4s
+ FADD v17.4s, v17.4s, v19.4s
+
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+
+ # Store full 1 x 8
+ CMP x1, 8
+ B.LO 9f
+
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+5:
+ # Remainder- 4 floats of A (16 bytes)
+ LDP q20, q21, [x5], 32
+ LDR q0, [x3], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+
+ TBZ x0, 3, 7f
+6:
+ # Remainder- 2 floats of A (8 bytes)
+ LDP q20, q21, [x5], 32
+ LDR d0, [x3], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+7:
+ TBZ x0, 2, 4b
+8:
+ # Remainder- 1 float of A (4 bytes)
+ LDP q20, q21, [x5], 32
+ LDR s0, [x3], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ B 4b
+
+ # Store odd channels
+9:
+ TBZ x1, 2, 10f
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+10:
+ TBZ x1, 1, 11f
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+11:
+ TBZ x1, 0, 12f
+ STR s16, [x6]
+12:
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.S.in b/src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.S.in
new file mode 100644
index 0000000..9c3751e
--- /dev/null
+++ b/src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.S.in
@@ -0,0 +1,233 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x8__aarch64_neonfma_cortex_a75(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, (x4) - unused
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+
+# C pointers
+# x6 c0
+
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x8__aarch64_neonfma_cortex_a75
+
+ $if INC:
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+ $else:
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+
+ MOVI v18.4s, 0 // second set of C for pipelining FMLA
+ MOVI v19.4s, 0
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDP q20, q21, [x5], 32
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ LDR q0, [x3], 16
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. FMA for second 4, loads for 1st block of 4.
+ FMLA v16.4s, v20.4s, v1.s[0]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v21.4s, v1.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ SUBS x0, x0, 32
+ LDP q26, q27, [x5], 32
+ B.HS 1b
+
+2:
+ # Epilogue
+
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. no loads
+ FMLA v16.4s, v20.4s, v1.s[0]
+ FMLA v17.4s, v21.4s, v1.s[0]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+
+3:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBNZ x0, 4, 5f
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 8f
+
+4:
+ FADD v16.4s, v16.4s, v18.4s
+ FADD v17.4s, v17.4s, v19.4s
+
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+
+ # Store full 1 x 8
+ CMP x1, 8
+ B.LO 9f
+
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+5:
+ # Remainder- 4 floats of A (16 bytes)
+ LDP q20, q21, [x5], 32
+ LDR q0, [x3], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+
+ TBZ x0, 3, 7f
+6:
+ # Remainder- 2 floats of A (8 bytes)
+ LDP q20, q21, [x5], 32
+ LDR d0, [x3], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+7:
+ TBZ x0, 2, 4b
+8:
+ # Remainder- 1 float of A (4 bytes)
+ LDP q20, q21, [x5], 32
+ LDR s0, [x3], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ B 4b
+
+ # Store odd channels
+9:
+ TBZ x1, 2, 10f
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+10:
+ TBZ x1, 1, 11f
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+11:
+ TBZ x1, 0, 12f
+ STR s16, [x6]
+12:
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_1x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/1x8-neon-ld64.c b/src/f32-gemm/1x8-neon-ld64.c
new file mode 100644
index 0000000..9c6f8b6
--- /dev/null
+++ b/src/f32-gemm/1x8-neon-ld64.c
@@ -0,0 +1,105 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_1x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/1x8-neonfma-ld64.c b/src/f32-gemm/1x8-neonfma-ld64.c
new file mode 100644
index 0000000..4df05f9
--- /dev/null
+++ b/src/f32-gemm/1x8-neonfma-ld64.c
@@ -0,0 +1,117 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_1x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/1x8-psimd-loadsplat.c b/src/f32-gemm/1x8-psimd-loadsplat.c
new file mode 100644
index 0000000..6df6dd3
--- /dev/null
+++ b/src/f32-gemm/1x8-psimd-loadsplat.c
@@ -0,0 +1,99 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-loadsplat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_1x8__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/1x8-psimd-splat.c b/src/f32-gemm/1x8-psimd-splat.c
new file mode 100644
index 0000000..485bac7
--- /dev/null
+++ b/src/f32-gemm/1x8-psimd-splat.c
@@ -0,0 +1,137 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-splat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_1x8__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+
+ const psimd_f32 va0c0 = psimd_splat0_f32(va0);
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c0, vb4567c0);
+ const psimd_f32 va0c1 = psimd_splat1_f32(va0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c1, vb4567c1);
+ const psimd_f32 va0c2 = psimd_splat2_f32(va0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c2, vb4567c2);
+ const psimd_f32 va0c3 = psimd_splat3_f32(va0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c3, vb4567c3);
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/1x8-sse-dup.c b/src/f32-gemm/1x8-sse-dup.c
new file mode 100644
index 0000000..4ac9205
--- /dev/null
+++ b/src/f32-gemm/1x8-sse-dup.c
@@ -0,0 +1,141 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-dup.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_1x8__sse_dup(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w + 0);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+
+
+ const __m128 va0c0000 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 0, 0, 0));
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c0000, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c0000, vb4567c0));
+
+ const __m128 va0c1111 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(1, 1, 1, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c1111, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c1111, vb4567c1));
+
+ const __m128 va0c2222 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(2, 2, 2, 2));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c2222, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c2222, vb4567c2));
+
+ const __m128 va0c3333 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(3, 3, 3, 3));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c3333, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c3333, vb4567c3));
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/1x8-sse-load1.c b/src/f32-gemm/1x8-sse-load1.c
new file mode 100644
index 0000000..7842261
--- /dev/null
+++ b/src/f32-gemm/1x8-sse-load1.c
@@ -0,0 +1,99 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-load1.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_1x8__sse_load1(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w + 0);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ size_t k = kc;
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/1x8s4-psimd.c b/src/f32-gemm/1x8s4-psimd.c
new file mode 100644
index 0000000..acd2964
--- /dev/null
+++ b/src/f32-gemm/1x8s4-psimd.c
@@ -0,0 +1,140 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-s4.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_1x8s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c0);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c1);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c2);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c3);
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/1x8s4-sse.c b/src/f32-gemm/1x8s4-sse.c
new file mode 100644
index 0000000..dff0537
--- /dev/null
+++ b/src/f32-gemm/1x8s4-sse.c
@@ -0,0 +1,140 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-shuffle.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_1x8s4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w + 0);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c0));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c1));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c2));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c3));
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/2x4-scalar.c b/src/f32-gemm/2x4-scalar.c
new file mode 100644
index 0000000..dacfd5f
--- /dev/null
+++ b/src/f32-gemm/2x4-scalar.c
@@ -0,0 +1,135 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gemm_ukernel_2x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+
+ do {
+ float vacc00 = w[0];
+ float vacc01 = w[1];
+ float vacc02 = w[2];
+ float vacc03 = w[3];
+ w += 4;
+ float vacc10 = vacc00;
+ float vacc11 = vacc01;
+ float vacc12 = vacc02;
+ float vacc13 = vacc03;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+ const float va1 = *a1++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc02 += va0 * vb2;
+ vacc03 += va0 * vb3;
+ vacc10 += va1 * vb0;
+ vacc11 += va1 * vb1;
+ vacc12 += va1 * vb2;
+ vacc13 += va1 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc02 = math_max_f32(vacc02, vmin);
+ vacc03 = math_max_f32(vacc03, vmin);
+ vacc10 = math_max_f32(vacc10, vmin);
+ vacc11 = math_max_f32(vacc11, vmin);
+ vacc12 = math_max_f32(vacc12, vmin);
+ vacc13 = math_max_f32(vacc13, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc02 = math_min_f32(vacc02, vmax);
+ vacc03 = math_min_f32(vacc03, vmax);
+ vacc10 = math_min_f32(vacc10, vmax);
+ vacc11 = math_min_f32(vacc11, vmax);
+ vacc12 = math_min_f32(vacc12, vmax);
+ vacc13 = math_min_f32(vacc13, vmax);
+
+ if XNN_LIKELY(nc >= 4) {
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ c1[2] = vacc12;
+ c1[3] = vacc13;
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0[2] = vacc02;
+ c0[3] = vacc03;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a1 = (const void*) ((uintptr_t) a1 - kc);
+ a0 = (const void*) ((uintptr_t) a0 - kc);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ vacc10 = vacc12;
+ c1 += 2;
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ vacc00 = vacc02;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ c1[0] = vacc10;
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S
new file mode 100644
index 0000000..6e36926
--- /dev/null
+++ b/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S
@@ -0,0 +1,594 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+# Vector register usage and GPR shadows
+# a0 v0 first set of A
+# a1 v0[1] x13
+# a2 v1
+# a3 v1[1] x8
+# a0 v2 second set of A
+# a1 v2[1] x13
+# a2 v3
+# a3 v3[1] x8
+# B v6 v7 v8 x20 x21 x16 first set of B
+# B v9 v10 v11 x17 x18 x19
+# B v14 v15 v16 x20 x21 x16 second set of B (same x as first set)
+# B v17 v18 v19 x17 x18 x19
+# C v20 v21 v22
+# C v23 v24 v25
+# C v26 v27 v28
+# C v29 v30 v31
+# Clamp v4 v5
+# v12 to v13 unused.
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53
+
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Save x19-21 on stack
+ STR x21, [sp, -80]!
+ STP x19, x20, [sp, 16]
+
+ # Save d8-d11,d14,d15 on stack
+ STP d8, d9, [sp, 32]
+ STP d10, d11, [sp, 48]
+ STP d14, d15, [sp, 64]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial bias from w into accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x5], 48
+ MOV v23.16b, v20.16b
+ MOV v24.16b, v21.16b
+ MOV v25.16b, v22.16b
+ MOV v26.16b, v20.16b
+ MOV v27.16b, v21.16b
+ MOV v28.16b, v22.16b
+ MOV v29.16b, v20.16b
+ MOV v30.16b, v21.16b
+ MOV v31.16b, v22.16b
+
+ PRFM PLDL1KEEP, [x5]
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x5, 256]
+ PRFM PLDL1KEEP, [x5, 320]
+
+ # Is there at least 4 floats (16 bytes)?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 3f
+
+ SUBS x0, x0, 16
+
+ # Prologue - loads for first group of 24 FMA
+
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+ LDR x13, [x11], 8 // a1
+ LDR d1, [x12], 8 // a2
+ LDR x8, [x4], 8 // a3
+
+ LDR d6, [x5] // vb0x0123
+ LDR x20, [x5, 8]
+
+ LDR d7, [x5, 16] // vb0x4567
+ LDR x21, [x5, 24]
+
+ LDR d8, [x5, 32] // vb0x89AB
+ LDR x16, [x5, 40]
+
+ LDR d9, [x5, 48] // vb1x0123
+ INS v0.d[1], x13
+ LDR x17, [x5, 56]
+
+ LDR d10, [x5, 64] // vb1x4567
+ INS v1.d[1], x8
+ LDR x18, [x5, 72]
+
+ LDR d11, [x5, 80] // vb1x89AB
+ LDR x19, [x5, 88]
+ INS v6.d[1], x20
+ ADD x5, x5, 96
+
+ # Is there at least 4 floats (16 bytes) for main loop?
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ # First group of 24 fma. 8 blocks of 4 cycles. LDR + 3 FMA
+ # A is loaded for 2nd group into v2/v3
+ # INS is 4 blocks (16 cycles) after load
+
+ # BLOCK 0
+ LDR d2, [x3], 8 // a0
+ INS v7.d[1], x21
+ FMLA v20.4s, v6.4s, v0.s[0]
+ LDR x13, [x11], 8 // a1
+ FMLA v23.4s, v6.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v26.4s, v6.4s, v1.s[0]
+
+ # BLOCK 1
+ LDR d3, [x12], 8 // a2
+ INS v8.d[1], x16
+ FMLA v29.4s, v6.4s, v1.s[2]
+ LDR x8, [x4], 8 // a3
+ FMLA v21.4s, v7.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v24.4s, v7.4s, v0.s[2]
+
+ # BLOCK 2
+ LDR d14, [x5] // vb0x0123
+ INS v9.d[1], x17
+ FMLA v27.4s, v7.4s, v1.s[0]
+ LDR x20, [x5, 8]
+ FMLA v30.4s, v7.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v8.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d15, [x5, 16] // vb0x4567
+ INS v10.d[1], x18
+ FMLA v25.4s, v8.4s, v0.s[2]
+ LDR x21, [x5, 24]
+ FMLA v28.4s, v8.4s, v1.s[0]
+ FMLA v31.4s, v8.4s, v1.s[2]
+
+ # BLOCK 4
+ LDR d16, [x5, 32] // vb0x89AB
+ INS v11.d[1], x19
+ FMLA v20.4s, v9.4s, v0.s[1]
+ LDR x16, [x5, 40]
+ FMLA v23.4s, v9.4s, v0.s[3]
+ FMLA v26.4s, v9.4s, v1.s[1]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ INS v2.d[1], x13 // a1 was loaded in block 0
+ FMLA v29.4s, v9.4s, v1.s[3]
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v10.4s, v0.s[1]
+ FMLA v24.4s, v10.4s, v0.s[3]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ INS v3.d[1], x8 // a3 was loaded in block 1
+ FMLA v27.4s, v10.4s, v1.s[1]
+ LDR x18, [x5, 72]
+ FMLA v30.4s, v10.4s, v1.s[3]
+ FMLA v22.4s, v11.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v14.d[1], x20 // v14 was loaded in block 2
+ FMLA v25.4s, v11.4s, v0.s[3]
+ LDR x19, [x5, 88]
+ FMLA v28.4s, v11.4s, v1.s[1]
+ FMLA v31.4s, v11.4s, v1.s[3]
+
+ # Second group of 24 fma. 8 blocks of 4 cycles. LDR + 3 FMA
+ # A is loaded for 1st group into v0/v1
+
+ # BLOCK 0
+ LDR d0, [x3], 8 // a0
+ INS v15.d[1], x21
+ FMLA v20.4s, v14.4s, v2.s[0]
+ LDR x13, [x11], 8 // a1
+ FMLA v23.4s, v14.4s, v2.s[2]
+ FMLA v26.4s, v14.4s, v3.s[0]
+
+ # BLOCK 1
+ LDR d1, [x12], 8 // a2
+ INS v16.d[1], x16
+ FMLA v29.4s, v14.4s, v3.s[2]
+ LDR x8, [x4], 8 // a3
+ FMLA v21.4s, v15.4s, v2.s[0]
+ FMLA v24.4s, v15.4s, v2.s[2]
+
+ # BLOCK 2
+ LDR d6, [x5, 96] // vb0x0123
+ INS v17.d[1], x17
+ FMLA v27.4s, v15.4s, v3.s[0]
+ LDR x20, [x5, 104]
+ FMLA v30.4s, v15.4s, v3.s[2]
+ FMLA v22.4s, v16.4s, v2.s[0]
+
+ # BLOCK 3
+ LDR d7, [x5, 112] // vb0x4567
+ INS v18.d[1], x18
+ FMLA v25.4s, v16.4s, v2.s[2]
+ LDR x21, [x5, 120]
+ FMLA v28.4s, v16.4s, v3.s[0]
+ FMLA v31.4s, v16.4s, v3.s[2]
+
+ # BLOCK 4
+ LDR d8, [x5, 128] // vb0x89AB
+ INS v19.d[1], x19
+ FMLA v20.4s, v17.4s, v2.s[1]
+ LDR x16, [x5, 136]
+ FMLA v23.4s, v17.4s, v2.s[3]
+ FMLA v26.4s, v17.4s, v3.s[1]
+
+ # BLOCK 5
+ LDR d9, [x5, 144] // vb1x0123
+ INS v0.d[1], x13 // a1
+ FMLA v29.4s, v17.4s, v3.s[3]
+ LDR x17, [x5, 152]
+ FMLA v21.4s, v18.4s, v2.s[1]
+ FMLA v24.4s, v18.4s, v2.s[3]
+
+ # BLOCK 6
+ LDR d10, [x5, 160] // vb1x4567
+ INS v1.d[1], x8 // a3
+ FMLA v27.4s, v18.4s, v3.s[1]
+ LDR x18, [x5, 168]
+ FMLA v30.4s, v18.4s, v3.s[3]
+ SUBS x0, x0, 16
+ FMLA v22.4s, v19.4s, v2.s[1]
+
+ # BLOCK 7
+ LDR d11, [x5, 176] // vb1x89AB
+ INS v6.d[1], x20
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDR x19, [x5, 184]
+ FMLA v28.4s, v19.4s, v3.s[1]
+ ADD x5, x5, 192
+ FMLA v31.4s, v19.4s, v3.s[3]
+ B.HS 1b
+
+ # Epilogue
+ # First block same as main loop. Second block has no loads.
+2:
+ # BLOCK 0
+ LDR d2, [x3], 8 // a0
+ INS v7.d[1], x21
+ FMLA v20.4s, v6.4s, v0.s[0]
+ LDR x13, [x11], 8 // a1
+ FMLA v23.4s, v6.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v26.4s, v6.4s, v1.s[0]
+
+ # BLOCK 1
+ LDR d3, [x12], 8 // a2
+ INS v8.d[1], x16
+ FMLA v29.4s, v6.4s, v1.s[2]
+ LDR x8, [x4], 8 // a3
+ FMLA v21.4s, v7.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v24.4s, v7.4s, v0.s[2]
+
+ # BLOCK 2
+ LDR d14, [x5] // vb0x0123
+ INS v9.d[1], x17
+ FMLA v27.4s, v7.4s, v1.s[0]
+ LDR x20, [x5, 8]
+ FMLA v30.4s, v7.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v8.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d15, [x5, 16] // vb0x4567
+ INS v10.d[1], x18
+ FMLA v25.4s, v8.4s, v0.s[2]
+ LDR x21, [x5, 24]
+ FMLA v28.4s, v8.4s, v1.s[0]
+ FMLA v31.4s, v8.4s, v1.s[2]
+
+ # BLOCK 4
+ LDR d16, [x5, 32] // vb0x89AB
+ INS v11.d[1], x19
+ FMLA v20.4s, v9.4s, v0.s[1]
+ LDR x16, [x5, 40]
+ FMLA v23.4s, v9.4s, v0.s[3]
+ FMLA v26.4s, v9.4s, v1.s[1]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ INS v2.d[1], x13 // a1 was loaded in block 0
+ FMLA v29.4s, v9.4s, v1.s[3]
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v10.4s, v0.s[1]
+ FMLA v24.4s, v10.4s, v0.s[3]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ INS v3.d[1], x8 // a3 was loaded in block 1
+ FMLA v27.4s, v10.4s, v1.s[1]
+ LDR x18, [x5, 72]
+ FMLA v30.4s, v10.4s, v1.s[3]
+ FMLA v22.4s, v11.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v14.d[1], x20 // v14 was loaded in block 2
+ FMLA v25.4s, v11.4s, v0.s[3]
+ LDR x19, [x5, 88]
+ FMLA v28.4s, v11.4s, v1.s[1]
+ ADD x5, x5, 96
+ FMLA v31.4s, v11.4s, v1.s[3]
+
+ # Second group of 24 fma. 8 blocks of 4 cycles.
+ # Epilogue version does no loads
+
+ # BLOCK 0
+ INS v15.d[1], x21
+ FMLA v20.4s, v14.4s, v2.s[0]
+ FMLA v23.4s, v14.4s, v2.s[2]
+ FMLA v26.4s, v14.4s, v3.s[0]
+
+ # BLOCK 1
+ INS v16.d[1], x16
+ FMLA v29.4s, v14.4s, v3.s[2]
+ FMLA v21.4s, v15.4s, v2.s[0]
+ FMLA v24.4s, v15.4s, v2.s[2]
+
+ # BLOCK 2
+ INS v17.d[1], x17
+ FMLA v27.4s, v15.4s, v3.s[0]
+ FMLA v30.4s, v15.4s, v3.s[2]
+ FMLA v22.4s, v16.4s, v2.s[0]
+
+ # BLOCK 3
+ INS v18.d[1], x18
+ FMLA v25.4s, v16.4s, v2.s[2]
+ FMLA v28.4s, v16.4s, v3.s[0]
+ FMLA v31.4s, v16.4s, v3.s[2]
+
+ # BLOCK 4
+ INS v19.d[1], x19
+ FMLA v20.4s, v17.4s, v2.s[1]
+ FMLA v23.4s, v17.4s, v2.s[3]
+ FMLA v26.4s, v17.4s, v3.s[1]
+
+ # BLOCK 5
+ FMLA v29.4s, v17.4s, v3.s[3]
+ FMLA v21.4s, v18.4s, v2.s[1]
+ FMLA v24.4s, v18.4s, v2.s[3]
+
+ # BLOCK 6
+ FMLA v27.4s, v18.4s, v3.s[1]
+ FMLA v30.4s, v18.4s, v3.s[3]
+ FMLA v22.4s, v19.4s, v2.s[1]
+
+ # BLOCK 7
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v28.4s, v19.4s, v3.s[1]
+ FMLA v31.4s, v19.4s, v3.s[3]
+
+3:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 5f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 6f
+
+4:
+ # Clamp
+ FMIN v20.4s, v20.4s, v4.4s
+ FMIN v21.4s, v21.4s, v4.4s
+ FMIN v22.4s, v22.4s, v4.4s
+ FMIN v23.4s, v23.4s, v4.4s
+ FMIN v24.4s, v24.4s, v4.4s
+ FMIN v25.4s, v25.4s, v4.4s
+ FMIN v26.4s, v26.4s, v4.4s
+ FMIN v27.4s, v27.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v20.4s, v20.4s, v5.4s
+ FMAX v21.4s, v21.4s, v5.4s
+ FMAX v22.4s, v22.4s, v5.4s
+ FMAX v23.4s, v23.4s, v5.4s
+ FMAX v24.4s, v24.4s, v5.4s
+ FMAX v25.4s, v25.4s, v5.4s
+ FMAX v26.4s, v26.4s, v5.4s
+ FMAX v27.4s, v27.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 12
+ CMP x1, 12
+ B.LO 7f
+
+ ST1 {v29.16b, v30.16b, v31.16b}, [x7], x14
+ ST1 {v26.16b, v27.16b, v28.16b}, [x10], x14
+ ST1 {v23.16b, v24.16b, v25.16b}, [x9], x14
+ ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 12
+ B.HI 0b
+
+ # Restore d8-d11,d14,d15 from stack
+ LDP d14, d15, [sp, 64]
+ LDP d10, d11, [sp, 48]
+ LDP d8, d9, [sp, 32]
+
+ # Restore x19-21 from stack
+ LDP x19, x20, [sp, 16]
+ LDR x21, [sp], 80
+ RET
+
+5:
+ # Remainder - 2 floats of A (8 bytes)
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+ LDR d1, [x11], 8 // a1
+ LDR d2, [x12], 8 // a2
+ LDR d3, [x4], 8 // a3
+ LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48
+ LD1 {v9.16b, v10.16b, v11.16b}, [x5], 48
+
+ # First block of 3 B
+ FMLA v20.4s, v6.4s, v0.s[0]
+ FMLA v23.4s, v6.4s, v1.s[0]
+ FMLA v26.4s, v6.4s, v2.s[0]
+ FMLA v29.4s, v6.4s, v3.s[0]
+ FMLA v21.4s, v7.4s, v0.s[0]
+ FMLA v24.4s, v7.4s, v1.s[0]
+ FMLA v27.4s, v7.4s, v2.s[0]
+ FMLA v30.4s, v7.4s, v3.s[0]
+ FMLA v22.4s, v8.4s, v0.s[0]
+ FMLA v25.4s, v8.4s, v1.s[0]
+ FMLA v28.4s, v8.4s, v2.s[0]
+ FMLA v31.4s, v8.4s, v3.s[0]
+
+ # Second block of 3 B
+ FMLA v20.4s, v9.4s, v0.s[1]
+ FMLA v23.4s, v9.4s, v1.s[1]
+ FMLA v26.4s, v9.4s, v2.s[1]
+ FMLA v29.4s, v9.4s, v3.s[1]
+ FMLA v21.4s, v10.4s, v0.s[1]
+ FMLA v24.4s, v10.4s, v1.s[1]
+ FMLA v27.4s, v10.4s, v2.s[1]
+ FMLA v30.4s, v10.4s, v3.s[1]
+ FMLA v22.4s, v11.4s, v0.s[1]
+ FMLA v25.4s, v11.4s, v1.s[1]
+ FMLA v28.4s, v11.4s, v2.s[1]
+ FMLA v31.4s, v11.4s, v3.s[1]
+
+ TBZ x0, 2, 4b
+6:
+ # Remainder - 1 float of A (4 bytes)
+ LDR s0, [x3], 4 // a0
+ LDR s1, [x11], 4 // a1
+ LDR s2, [x12], 4 // a2
+ LDR s3, [x4], 4 // a3
+ LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48
+
+ FMLA v20.4s, v6.4s, v0.s[0]
+ FMLA v23.4s, v6.4s, v1.s[0]
+ FMLA v26.4s, v6.4s, v2.s[0]
+ FMLA v29.4s, v6.4s, v3.s[0]
+ FMLA v21.4s, v7.4s, v0.s[0]
+ FMLA v24.4s, v7.4s, v1.s[0]
+ FMLA v27.4s, v7.4s, v2.s[0]
+ FMLA v30.4s, v7.4s, v3.s[0]
+ FMLA v22.4s, v8.4s, v0.s[0]
+ FMLA v25.4s, v8.4s, v1.s[0]
+ FMLA v28.4s, v8.4s, v2.s[0]
+ FMLA v31.4s, v8.4s, v3.s[0]
+ B 4b
+
+7:
+ # Store odd channels
+ TBZ x1, 3, 8f
+ STP q29, q30, [x7]
+ ADD x7, x7, 32
+ MOV v29.16b, v31.16b
+ STP q26, q27, [x10]
+ ADD x10, x10, 32
+ MOV v26.16b, v28.16b
+ STP q23, q24, [x9]
+ ADD x9, x9, 32
+ MOV v23.16b, v25.16b
+ STP q20, q21, [x6]
+ ADD x6, x6, 32
+ MOV v20.16b, v22.16b
+
+8:
+ TBZ x1, 2, 9f
+ STR q29, [x7], 16
+ MOV v29.16b, v30.16b
+ STR q26, [x10], 16
+ MOV v26.16b, v27.16b
+ STR q23, [x9], 16
+ MOV v23.16b, v24.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+9:
+ TBZ x1, 1, 10f
+ STR d29, [x7], 8
+ DUP d29, v29.d[1]
+ STR d26, [x10], 8
+ DUP d26, v26.d[1]
+ STR d23, [x9], 8
+ DUP d23, v23.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s29, [x7]
+ STR s26, [x10]
+ STR s23, [x9]
+ STR s20, [x6]
+11:
+ # Restore d8-d11,d14,d15 from stack
+ LDP d14, d15, [sp, 64]
+ LDP d10, d11, [sp, 48]
+ LDP d8, d9, [sp, 32]
+
+ # Restore x19-21 from stack
+ LDP x19, x20, [sp, 16]
+ LDR x21, [sp], 80
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in b/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in
new file mode 100644
index 0000000..fe07b51
--- /dev/null
+++ b/src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in
@@ -0,0 +1,607 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma_cortex_a53(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+# Vector register usage and GPR shadows
+# a0 v0 first set of A
+# a1 v0[1] x13
+# a2 v1
+# a3 v1[1] x8
+# a0 v2 second set of A
+# a1 v2[1] x13
+# a2 v3
+# a3 v3[1] x8
+# B v6 v7 v8 x20 x21 x16 first set of B
+# B v9 v10 v11 x17 x18 x19
+# B v14 v15 v16 x20 x21 x16 second set of B (same x as first set)
+# B v17 v18 v19 x17 x18 x19
+# C v20 v21 v22
+# C v23 v24 v25
+# C v26 v27 v28
+# C v29 v30 v31
+# Clamp v4 v5
+# v12 to v13 unused.
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma_cortex_a53
+
+ $if INC:
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+ $else:
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Save x19-21 on stack
+ STR x21, [sp, -80]!
+ STP x19, x20, [sp, 16]
+
+ # Save d8-d11,d14,d15 on stack
+ STP d8, d9, [sp, 32]
+ STP d10, d11, [sp, 48]
+ STP d14, d15, [sp, 64]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x15], 48
+ LD1 {v23.16b, v24.16b, v25.16b}, [x15], 48
+ LD1 {v26.16b, v27.16b, v28.16b}, [x15], 48
+ LD1 {v29.16b, v30.16b, v31.16b}, [x15], 48
+ $else:
+ # Load initial bias from w into accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x5], 48
+ MOV v23.16b, v20.16b
+ MOV v24.16b, v21.16b
+ MOV v25.16b, v22.16b
+ MOV v26.16b, v20.16b
+ MOV v27.16b, v21.16b
+ MOV v28.16b, v22.16b
+ MOV v29.16b, v20.16b
+ MOV v30.16b, v21.16b
+ MOV v31.16b, v22.16b
+
+ PRFM PLDL1KEEP, [x5]
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x5, 256]
+ PRFM PLDL1KEEP, [x5, 320]
+
+ # Is there at least 4 floats (16 bytes)?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 3f
+
+ SUBS x0, x0, 16
+
+ # Prologue - loads for first group of 24 FMA
+
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+ LDR x13, [x11], 8 // a1
+ LDR d1, [x12], 8 // a2
+ LDR x8, [x4], 8 // a3
+
+ LDR d6, [x5] // vb0x0123
+ LDR x20, [x5, 8]
+
+ LDR d7, [x5, 16] // vb0x4567
+ LDR x21, [x5, 24]
+
+ LDR d8, [x5, 32] // vb0x89AB
+ LDR x16, [x5, 40]
+
+ LDR d9, [x5, 48] // vb1x0123
+ INS v0.d[1], x13
+ LDR x17, [x5, 56]
+
+ LDR d10, [x5, 64] // vb1x4567
+ INS v1.d[1], x8
+ LDR x18, [x5, 72]
+
+ LDR d11, [x5, 80] // vb1x89AB
+ LDR x19, [x5, 88]
+ INS v6.d[1], x20
+ ADD x5, x5, 96
+
+ # Is there at least 4 floats (16 bytes) for main loop?
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ # First group of 24 fma. 8 blocks of 4 cycles. LDR + 3 FMA
+ # A is loaded for 2nd group into v2/v3
+ # INS is 4 blocks (16 cycles) after load
+
+ # BLOCK 0
+ LDR d2, [x3], 8 // a0
+ INS v7.d[1], x21
+ FMLA v20.4s, v6.4s, v0.s[0]
+ LDR x13, [x11], 8 // a1
+ FMLA v23.4s, v6.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v26.4s, v6.4s, v1.s[0]
+
+ # BLOCK 1
+ LDR d3, [x12], 8 // a2
+ INS v8.d[1], x16
+ FMLA v29.4s, v6.4s, v1.s[2]
+ LDR x8, [x4], 8 // a3
+ FMLA v21.4s, v7.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v24.4s, v7.4s, v0.s[2]
+
+ # BLOCK 2
+ LDR d14, [x5] // vb0x0123
+ INS v9.d[1], x17
+ FMLA v27.4s, v7.4s, v1.s[0]
+ LDR x20, [x5, 8]
+ FMLA v30.4s, v7.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v8.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d15, [x5, 16] // vb0x4567
+ INS v10.d[1], x18
+ FMLA v25.4s, v8.4s, v0.s[2]
+ LDR x21, [x5, 24]
+ FMLA v28.4s, v8.4s, v1.s[0]
+ FMLA v31.4s, v8.4s, v1.s[2]
+
+ # BLOCK 4
+ LDR d16, [x5, 32] // vb0x89AB
+ INS v11.d[1], x19
+ FMLA v20.4s, v9.4s, v0.s[1]
+ LDR x16, [x5, 40]
+ FMLA v23.4s, v9.4s, v0.s[3]
+ FMLA v26.4s, v9.4s, v1.s[1]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ INS v2.d[1], x13 // a1 was loaded in block 0
+ FMLA v29.4s, v9.4s, v1.s[3]
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v10.4s, v0.s[1]
+ FMLA v24.4s, v10.4s, v0.s[3]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ INS v3.d[1], x8 // a3 was loaded in block 1
+ FMLA v27.4s, v10.4s, v1.s[1]
+ LDR x18, [x5, 72]
+ FMLA v30.4s, v10.4s, v1.s[3]
+ FMLA v22.4s, v11.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v14.d[1], x20 // v14 was loaded in block 2
+ FMLA v25.4s, v11.4s, v0.s[3]
+ LDR x19, [x5, 88]
+ FMLA v28.4s, v11.4s, v1.s[1]
+ FMLA v31.4s, v11.4s, v1.s[3]
+
+ # Second group of 24 fma. 8 blocks of 4 cycles. LDR + 3 FMA
+ # A is loaded for 1st group into v0/v1
+
+ # BLOCK 0
+ LDR d0, [x3], 8 // a0
+ INS v15.d[1], x21
+ FMLA v20.4s, v14.4s, v2.s[0]
+ LDR x13, [x11], 8 // a1
+ FMLA v23.4s, v14.4s, v2.s[2]
+ FMLA v26.4s, v14.4s, v3.s[0]
+
+ # BLOCK 1
+ LDR d1, [x12], 8 // a2
+ INS v16.d[1], x16
+ FMLA v29.4s, v14.4s, v3.s[2]
+ LDR x8, [x4], 8 // a3
+ FMLA v21.4s, v15.4s, v2.s[0]
+ FMLA v24.4s, v15.4s, v2.s[2]
+
+ # BLOCK 2
+ LDR d6, [x5, 96] // vb0x0123
+ INS v17.d[1], x17
+ FMLA v27.4s, v15.4s, v3.s[0]
+ LDR x20, [x5, 104]
+ FMLA v30.4s, v15.4s, v3.s[2]
+ FMLA v22.4s, v16.4s, v2.s[0]
+
+ # BLOCK 3
+ LDR d7, [x5, 112] // vb0x4567
+ INS v18.d[1], x18
+ FMLA v25.4s, v16.4s, v2.s[2]
+ LDR x21, [x5, 120]
+ FMLA v28.4s, v16.4s, v3.s[0]
+ FMLA v31.4s, v16.4s, v3.s[2]
+
+ # BLOCK 4
+ LDR d8, [x5, 128] // vb0x89AB
+ INS v19.d[1], x19
+ FMLA v20.4s, v17.4s, v2.s[1]
+ LDR x16, [x5, 136]
+ FMLA v23.4s, v17.4s, v2.s[3]
+ FMLA v26.4s, v17.4s, v3.s[1]
+
+ # BLOCK 5
+ LDR d9, [x5, 144] // vb1x0123
+ INS v0.d[1], x13 // a1
+ FMLA v29.4s, v17.4s, v3.s[3]
+ LDR x17, [x5, 152]
+ FMLA v21.4s, v18.4s, v2.s[1]
+ FMLA v24.4s, v18.4s, v2.s[3]
+
+ # BLOCK 6
+ LDR d10, [x5, 160] // vb1x4567
+ INS v1.d[1], x8 // a3
+ FMLA v27.4s, v18.4s, v3.s[1]
+ LDR x18, [x5, 168]
+ FMLA v30.4s, v18.4s, v3.s[3]
+ SUBS x0, x0, 16
+ FMLA v22.4s, v19.4s, v2.s[1]
+
+ # BLOCK 7
+ LDR d11, [x5, 176] // vb1x89AB
+ INS v6.d[1], x20
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDR x19, [x5, 184]
+ FMLA v28.4s, v19.4s, v3.s[1]
+ ADD x5, x5, 192
+ FMLA v31.4s, v19.4s, v3.s[3]
+ B.HS 1b
+
+ # Epilogue
+ # First block same as main loop. Second block has no loads.
+2:
+ # BLOCK 0
+ LDR d2, [x3], 8 // a0
+ INS v7.d[1], x21
+ FMLA v20.4s, v6.4s, v0.s[0]
+ LDR x13, [x11], 8 // a1
+ FMLA v23.4s, v6.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v26.4s, v6.4s, v1.s[0]
+
+ # BLOCK 1
+ LDR d3, [x12], 8 // a2
+ INS v8.d[1], x16
+ FMLA v29.4s, v6.4s, v1.s[2]
+ LDR x8, [x4], 8 // a3
+ FMLA v21.4s, v7.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v24.4s, v7.4s, v0.s[2]
+
+ # BLOCK 2
+ LDR d14, [x5] // vb0x0123
+ INS v9.d[1], x17
+ FMLA v27.4s, v7.4s, v1.s[0]
+ LDR x20, [x5, 8]
+ FMLA v30.4s, v7.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v8.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d15, [x5, 16] // vb0x4567
+ INS v10.d[1], x18
+ FMLA v25.4s, v8.4s, v0.s[2]
+ LDR x21, [x5, 24]
+ FMLA v28.4s, v8.4s, v1.s[0]
+ FMLA v31.4s, v8.4s, v1.s[2]
+
+ # BLOCK 4
+ LDR d16, [x5, 32] // vb0x89AB
+ INS v11.d[1], x19
+ FMLA v20.4s, v9.4s, v0.s[1]
+ LDR x16, [x5, 40]
+ FMLA v23.4s, v9.4s, v0.s[3]
+ FMLA v26.4s, v9.4s, v1.s[1]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ INS v2.d[1], x13 // a1 was loaded in block 0
+ FMLA v29.4s, v9.4s, v1.s[3]
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v10.4s, v0.s[1]
+ FMLA v24.4s, v10.4s, v0.s[3]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ INS v3.d[1], x8 // a3 was loaded in block 1
+ FMLA v27.4s, v10.4s, v1.s[1]
+ LDR x18, [x5, 72]
+ FMLA v30.4s, v10.4s, v1.s[3]
+ FMLA v22.4s, v11.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v14.d[1], x20 // v14 was loaded in block 2
+ FMLA v25.4s, v11.4s, v0.s[3]
+ LDR x19, [x5, 88]
+ FMLA v28.4s, v11.4s, v1.s[1]
+ ADD x5, x5, 96
+ FMLA v31.4s, v11.4s, v1.s[3]
+
+ # Second group of 24 fma. 8 blocks of 4 cycles.
+ # Epilogue version does no loads
+
+ # BLOCK 0
+ INS v15.d[1], x21
+ FMLA v20.4s, v14.4s, v2.s[0]
+ FMLA v23.4s, v14.4s, v2.s[2]
+ FMLA v26.4s, v14.4s, v3.s[0]
+
+ # BLOCK 1
+ INS v16.d[1], x16
+ FMLA v29.4s, v14.4s, v3.s[2]
+ FMLA v21.4s, v15.4s, v2.s[0]
+ FMLA v24.4s, v15.4s, v2.s[2]
+
+ # BLOCK 2
+ INS v17.d[1], x17
+ FMLA v27.4s, v15.4s, v3.s[0]
+ FMLA v30.4s, v15.4s, v3.s[2]
+ FMLA v22.4s, v16.4s, v2.s[0]
+
+ # BLOCK 3
+ INS v18.d[1], x18
+ FMLA v25.4s, v16.4s, v2.s[2]
+ FMLA v28.4s, v16.4s, v3.s[0]
+ FMLA v31.4s, v16.4s, v3.s[2]
+
+ # BLOCK 4
+ INS v19.d[1], x19
+ FMLA v20.4s, v17.4s, v2.s[1]
+ FMLA v23.4s, v17.4s, v2.s[3]
+ FMLA v26.4s, v17.4s, v3.s[1]
+
+ # BLOCK 5
+ FMLA v29.4s, v17.4s, v3.s[3]
+ FMLA v21.4s, v18.4s, v2.s[1]
+ FMLA v24.4s, v18.4s, v2.s[3]
+
+ # BLOCK 6
+ FMLA v27.4s, v18.4s, v3.s[1]
+ FMLA v30.4s, v18.4s, v3.s[3]
+ FMLA v22.4s, v19.4s, v2.s[1]
+
+ # BLOCK 7
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v28.4s, v19.4s, v3.s[1]
+ FMLA v31.4s, v19.4s, v3.s[3]
+
+3:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 5f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 6f
+
+4:
+ # Clamp
+ FMIN v20.4s, v20.4s, v4.4s
+ FMIN v21.4s, v21.4s, v4.4s
+ FMIN v22.4s, v22.4s, v4.4s
+ FMIN v23.4s, v23.4s, v4.4s
+ FMIN v24.4s, v24.4s, v4.4s
+ FMIN v25.4s, v25.4s, v4.4s
+ FMIN v26.4s, v26.4s, v4.4s
+ FMIN v27.4s, v27.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v20.4s, v20.4s, v5.4s
+ FMAX v21.4s, v21.4s, v5.4s
+ FMAX v22.4s, v22.4s, v5.4s
+ FMAX v23.4s, v23.4s, v5.4s
+ FMAX v24.4s, v24.4s, v5.4s
+ FMAX v25.4s, v25.4s, v5.4s
+ FMAX v26.4s, v26.4s, v5.4s
+ FMAX v27.4s, v27.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 12
+ CMP x1, 12
+ B.LO 7f
+
+ ST1 {v29.16b, v30.16b, v31.16b}, [x7], x14
+ ST1 {v26.16b, v27.16b, v28.16b}, [x10], x14
+ ST1 {v23.16b, v24.16b, v25.16b}, [x9], x14
+ ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 12
+ B.HI 0b
+
+ # Restore d8-d11,d14,d15 from stack
+ LDP d14, d15, [sp, 64]
+ LDP d10, d11, [sp, 48]
+ LDP d8, d9, [sp, 32]
+
+ # Restore x19-21 from stack
+ LDP x19, x20, [sp, 16]
+ LDR x21, [sp], 80
+ RET
+
+5:
+ # Remainder - 2 floats of A (8 bytes)
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+ LDR d1, [x11], 8 // a1
+ LDR d2, [x12], 8 // a2
+ LDR d3, [x4], 8 // a3
+ LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48
+ LD1 {v9.16b, v10.16b, v11.16b}, [x5], 48
+
+ # First block of 3 B
+ FMLA v20.4s, v6.4s, v0.s[0]
+ FMLA v23.4s, v6.4s, v1.s[0]
+ FMLA v26.4s, v6.4s, v2.s[0]
+ FMLA v29.4s, v6.4s, v3.s[0]
+ FMLA v21.4s, v7.4s, v0.s[0]
+ FMLA v24.4s, v7.4s, v1.s[0]
+ FMLA v27.4s, v7.4s, v2.s[0]
+ FMLA v30.4s, v7.4s, v3.s[0]
+ FMLA v22.4s, v8.4s, v0.s[0]
+ FMLA v25.4s, v8.4s, v1.s[0]
+ FMLA v28.4s, v8.4s, v2.s[0]
+ FMLA v31.4s, v8.4s, v3.s[0]
+
+ # Second block of 3 B
+ FMLA v20.4s, v9.4s, v0.s[1]
+ FMLA v23.4s, v9.4s, v1.s[1]
+ FMLA v26.4s, v9.4s, v2.s[1]
+ FMLA v29.4s, v9.4s, v3.s[1]
+ FMLA v21.4s, v10.4s, v0.s[1]
+ FMLA v24.4s, v10.4s, v1.s[1]
+ FMLA v27.4s, v10.4s, v2.s[1]
+ FMLA v30.4s, v10.4s, v3.s[1]
+ FMLA v22.4s, v11.4s, v0.s[1]
+ FMLA v25.4s, v11.4s, v1.s[1]
+ FMLA v28.4s, v11.4s, v2.s[1]
+ FMLA v31.4s, v11.4s, v3.s[1]
+
+ TBZ x0, 2, 4b
+6:
+ # Remainder - 1 float of A (4 bytes)
+ LDR s0, [x3], 4 // a0
+ LDR s1, [x11], 4 // a1
+ LDR s2, [x12], 4 // a2
+ LDR s3, [x4], 4 // a3
+ LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48
+
+ FMLA v20.4s, v6.4s, v0.s[0]
+ FMLA v23.4s, v6.4s, v1.s[0]
+ FMLA v26.4s, v6.4s, v2.s[0]
+ FMLA v29.4s, v6.4s, v3.s[0]
+ FMLA v21.4s, v7.4s, v0.s[0]
+ FMLA v24.4s, v7.4s, v1.s[0]
+ FMLA v27.4s, v7.4s, v2.s[0]
+ FMLA v30.4s, v7.4s, v3.s[0]
+ FMLA v22.4s, v8.4s, v0.s[0]
+ FMLA v25.4s, v8.4s, v1.s[0]
+ FMLA v28.4s, v8.4s, v2.s[0]
+ FMLA v31.4s, v8.4s, v3.s[0]
+ B 4b
+
+7:
+ # Store odd channels
+ TBZ x1, 3, 8f
+ STP q29, q30, [x7]
+ ADD x7, x7, 32
+ MOV v29.16b, v31.16b
+ STP q26, q27, [x10]
+ ADD x10, x10, 32
+ MOV v26.16b, v28.16b
+ STP q23, q24, [x9]
+ ADD x9, x9, 32
+ MOV v23.16b, v25.16b
+ STP q20, q21, [x6]
+ ADD x6, x6, 32
+ MOV v20.16b, v22.16b
+
+8:
+ TBZ x1, 2, 9f
+ STR q29, [x7], 16
+ MOV v29.16b, v30.16b
+ STR q26, [x10], 16
+ MOV v26.16b, v27.16b
+ STR q23, [x9], 16
+ MOV v23.16b, v24.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+9:
+ TBZ x1, 1, 10f
+ STR d29, [x7], 8
+ DUP d29, v29.d[1]
+ STR d26, [x10], 8
+ DUP d26, v26.d[1]
+ STR d23, [x9], 8
+ DUP d23, v23.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s29, [x7]
+ STR s26, [x10]
+ STR s23, [x9]
+ STR s20, [x6]
+11:
+ # Restore d8-d11,d14,d15 from stack
+ LDP d14, d15, [sp, 64]
+ LDP d10, d11, [sp, 48]
+ LDP d8, d9, [sp, 32]
+
+ # Restore x19-21 from stack
+ LDP x19, x20, [sp, 16]
+ LDR x21, [sp], 80
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x12__aarch64_neonfma_cortex_a53
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x12-neon-ld64.c b/src/f32-gemm/4x12-neon-ld64.c
new file mode 100644
index 0000000..9fae962
--- /dev/null
+++ b/src/f32-gemm/4x12-neon-ld64.c
@@ -0,0 +1,241 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x12__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x89AB = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc1x89AB = vacc0x89AB;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc2x89AB = vacc0x89AB;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc3x89AB = vacc0x89AB;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc0x89AB = vmlaq_lane_f32(vacc0x89AB, vb89ABc0, va0, 0);
+ vacc1x89AB = vmlaq_lane_f32(vacc1x89AB, vb89ABc0, va1, 0);
+ vacc2x89AB = vmlaq_lane_f32(vacc2x89AB, vb89ABc0, va2, 0);
+ vacc3x89AB = vmlaq_lane_f32(vacc3x89AB, vb89ABc0, va3, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc0x89AB = vmlaq_lane_f32(vacc0x89AB, vb89ABc1, va0, 1);
+ vacc1x89AB = vmlaq_lane_f32(vacc1x89AB, vb89ABc1, va1, 1);
+ vacc2x89AB = vmlaq_lane_f32(vacc2x89AB, vb89ABc1, va2, 1);
+ vacc3x89AB = vmlaq_lane_f32(vacc3x89AB, vb89ABc1, va3, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89AB = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ vacc0x89AB = vmlaq_f32(vacc0x89AB, va0, vb89AB);
+ vacc1x89AB = vmlaq_f32(vacc1x89AB, va1, vb89AB);
+ vacc2x89AB = vmlaq_f32(vacc2x89AB, va2, vb89AB);
+ vacc3x89AB = vmlaq_f32(vacc3x89AB, va3, vb89AB);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc0x89AB = vminq_f32(vacc0x89AB, vmax);
+ vacc1x89AB = vminq_f32(vacc1x89AB, vmax);
+ vacc2x89AB = vminq_f32(vacc2x89AB, vmax);
+ vacc3x89AB = vminq_f32(vacc3x89AB, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc0x89AB = vmaxq_f32(vacc0x89AB, vmin);
+ vacc1x89AB = vmaxq_f32(vacc1x89AB, vmin);
+ vacc2x89AB = vmaxq_f32(vacc2x89AB, vmin);
+ vacc3x89AB = vmaxq_f32(vacc3x89AB, vmin);
+
+ if XNN_LIKELY(nc >= 12) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ vst1q_f32(c3 + 8, vacc3x89AB);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ vst1q_f32(c2 + 8, vacc2x89AB);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ vst1q_f32(c1 + 8, vacc1x89AB);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ vst1q_f32(c0 + 8, vacc0x89AB);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 12;
+
+ } else {
+ if (nc & 8) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+ vst1q_f32(c3, vacc3x4567); c3 += 4;
+ vst1q_f32(c2, vacc2x4567); c2 += 4;
+ vst1q_f32(c1, vacc1x4567); c1 += 4;
+ vst1q_f32(c0, vacc0x4567); c0 += 4;
+
+ vacc3x0123 = vacc3x89AB;
+ vacc2x0123 = vacc2x89AB;
+ vacc1x0123 = vacc1x89AB;
+ vacc0x0123 = vacc0x89AB;
+ }
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x12-neonfma-ld64.c b/src/f32-gemm/4x12-neonfma-ld64.c
new file mode 100644
index 0000000..ba414ce
--- /dev/null
+++ b/src/f32-gemm/4x12-neonfma-ld64.c
@@ -0,0 +1,279 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x12__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x89AB = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc1x89AB = vacc0x89AB;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc2x89AB = vacc0x89AB;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc3x89AB = vacc0x89AB;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc0x89AB = vfmaq_lane_f32(vacc0x89AB, vb89ABc0, va0, 0);
+ vacc1x89AB = vfmaq_lane_f32(vacc1x89AB, vb89ABc0, va1, 0);
+ vacc2x89AB = vfmaq_lane_f32(vacc2x89AB, vb89ABc0, va2, 0);
+ vacc3x89AB = vfmaq_lane_f32(vacc3x89AB, vb89ABc0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc0x89AB = vfmaq_f32(vacc0x89AB, va0c0, vb89ABc0);
+ vacc1x89AB = vfmaq_f32(vacc1x89AB, va1c0, vb89ABc0);
+ vacc2x89AB = vfmaq_f32(vacc2x89AB, va2c0, vb89ABc0);
+ vacc3x89AB = vfmaq_f32(vacc3x89AB, va3c0, vb89ABc0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc0x89AB = vfmaq_lane_f32(vacc0x89AB, vb89ABc1, va0, 1);
+ vacc1x89AB = vfmaq_lane_f32(vacc1x89AB, vb89ABc1, va1, 1);
+ vacc2x89AB = vfmaq_lane_f32(vacc2x89AB, vb89ABc1, va2, 1);
+ vacc3x89AB = vfmaq_lane_f32(vacc3x89AB, vb89ABc1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc0x89AB = vfmaq_f32(vacc0x89AB, va0c1, vb89ABc1);
+ vacc1x89AB = vfmaq_f32(vacc1x89AB, va1c1, vb89ABc1);
+ vacc2x89AB = vfmaq_f32(vacc2x89AB, va2c1, vb89ABc1);
+ vacc3x89AB = vfmaq_f32(vacc3x89AB, va3c1, vb89ABc1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89AB = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ vacc0x89AB = vfmaq_f32(vacc0x89AB, va0, vb89AB);
+ vacc1x89AB = vfmaq_f32(vacc1x89AB, va1, vb89AB);
+ vacc2x89AB = vfmaq_f32(vacc2x89AB, va2, vb89AB);
+ vacc3x89AB = vfmaq_f32(vacc3x89AB, va3, vb89AB);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc0x89AB = vminq_f32(vacc0x89AB, vmax);
+ vacc1x89AB = vminq_f32(vacc1x89AB, vmax);
+ vacc2x89AB = vminq_f32(vacc2x89AB, vmax);
+ vacc3x89AB = vminq_f32(vacc3x89AB, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc0x89AB = vmaxq_f32(vacc0x89AB, vmin);
+ vacc1x89AB = vmaxq_f32(vacc1x89AB, vmin);
+ vacc2x89AB = vmaxq_f32(vacc2x89AB, vmin);
+ vacc3x89AB = vmaxq_f32(vacc3x89AB, vmin);
+
+ if XNN_LIKELY(nc >= 12) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ vst1q_f32(c3 + 8, vacc3x89AB);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ vst1q_f32(c2 + 8, vacc2x89AB);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ vst1q_f32(c1 + 8, vacc1x89AB);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ vst1q_f32(c0 + 8, vacc0x89AB);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 12;
+
+ } else {
+ if (nc & 8) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+ vst1q_f32(c3, vacc3x4567); c3 += 4;
+ vst1q_f32(c2, vacc2x4567); c2 += 4;
+ vst1q_f32(c1, vacc1x4567); c1 += 4;
+ vst1q_f32(c0, vacc0x4567); c0 += 4;
+
+ vacc3x0123 = vacc3x89AB;
+ vacc2x0123 = vacc2x89AB;
+ vacc1x0123 = vacc1x89AB;
+ vacc0x0123 = vacc0x89AB;
+ }
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x2-neon-ld64.c b/src/f32-gemm/4x2-neon-ld64.c
new file mode 100644
index 0000000..872a842
--- /dev/null
+++ b/src/f32-gemm/4x2-neon-ld64.c
@@ -0,0 +1,137 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/MRx2-neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x2__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x2_t vacc0x01 = vld1_f32(w); w += 2;
+ float32x2_t vacc1x01 = vacc0x01;
+ float32x2_t vacc2x01 = vacc0x01;
+ float32x2_t vacc3x01 = vacc0x01;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x2_t vb01c0 = vld1_f32(w); w += 2;
+
+ vacc0x01 = vmla_lane_f32(vacc0x01, vb01c0, va0, 0);
+ vacc1x01 = vmla_lane_f32(vacc1x01, vb01c0, va1, 0);
+ vacc2x01 = vmla_lane_f32(vacc2x01, vb01c0, va2, 0);
+ vacc3x01 = vmla_lane_f32(vacc3x01, vb01c0, va3, 0);
+ const float32x2_t vb01c1 = vld1_f32(w); w += 2;
+
+ vacc0x01 = vmla_lane_f32(vacc0x01, vb01c1, va0, 1);
+ vacc1x01 = vmla_lane_f32(vacc1x01, vb01c1, va1, 1);
+ vacc2x01 = vmla_lane_f32(vacc2x01, vb01c1, va2, 1);
+ vacc3x01 = vmla_lane_f32(vacc3x01, vb01c1, va3, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x2_t va0 = vld1_dup_f32(a0); a0 += 1;
+ const float32x2_t va1 = vld1_dup_f32(a1); a1 += 1;
+ const float32x2_t va2 = vld1_dup_f32(a2); a2 += 1;
+ const float32x2_t va3 = vld1_dup_f32(a3); a3 += 1;
+
+ const float32x2_t vb01 = vld1_f32(w); w += 2;
+
+ vacc0x01 = vmla_f32(vacc0x01, va0, vb01);
+ vacc1x01 = vmla_f32(vacc1x01, va1, vb01);
+ vacc2x01 = vmla_f32(vacc2x01, va2, vb01);
+ vacc3x01 = vmla_f32(vacc3x01, va3, vb01);
+ }
+
+ const float32x2_t vmax = vld1_dup_f32(¶ms->scalar.max);
+ vacc0x01 = vmin_f32(vacc0x01, vmax);
+ vacc1x01 = vmin_f32(vacc1x01, vmax);
+ vacc2x01 = vmin_f32(vacc2x01, vmax);
+ vacc3x01 = vmin_f32(vacc3x01, vmax);
+
+ const float32x2_t vmin = vld1_dup_f32(¶ms->scalar.min);
+ vacc0x01 = vmax_f32(vacc0x01, vmin);
+ vacc1x01 = vmax_f32(vacc1x01, vmin);
+ vacc2x01 = vmax_f32(vacc2x01, vmin);
+ vacc3x01 = vmax_f32(vacc3x01, vmin);
+
+ if XNN_LIKELY(nc >= 2) {
+ vst1_f32(c0, vacc0x01);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+ vst1_f32(c1, vacc1x01);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1_f32(c2, vacc2x01);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1_f32(c3, vacc3x01);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+
+ nc -= 2;
+ } else {
+ assert(nc == 1);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x2-neonfma-ld64.c b/src/f32-gemm/4x2-neonfma-ld64.c
new file mode 100644
index 0000000..4014c0e
--- /dev/null
+++ b/src/f32-gemm/4x2-neonfma-ld64.c
@@ -0,0 +1,159 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/MRx2-neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x2__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x2_t vacc0x01 = vld1_f32(w); w += 2;
+ float32x2_t vacc1x01 = vacc0x01;
+ float32x2_t vacc2x01 = vacc0x01;
+ float32x2_t vacc3x01 = vacc0x01;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x2_t vb01c0 = vld1_f32(w); w += 2;
+
+ #if defined(__aarch64__)
+ vacc0x01 = vfma_lane_f32(vacc0x01, vb01c0, va0, 0);
+ vacc1x01 = vfma_lane_f32(vacc1x01, vb01c0, va1, 0);
+ vacc2x01 = vfma_lane_f32(vacc2x01, vb01c0, va2, 0);
+ vacc3x01 = vfma_lane_f32(vacc3x01, vb01c0, va3, 0);
+ #else
+ const float32x2_t va0c0 = vdup_lane_f32(va0, 0);
+ const float32x2_t va1c0 = vdup_lane_f32(va1, 0);
+ const float32x2_t va2c0 = vdup_lane_f32(va2, 0);
+ const float32x2_t va3c0 = vdup_lane_f32(va3, 0);
+ vacc0x01 = vfma_f32(vacc0x01, va0c0, vb01c0);
+ vacc1x01 = vfma_f32(vacc1x01, va1c0, vb01c0);
+ vacc2x01 = vfma_f32(vacc2x01, va2c0, vb01c0);
+ vacc3x01 = vfma_f32(vacc3x01, va3c0, vb01c0);
+ #endif
+ const float32x2_t vb01c1 = vld1_f32(w); w += 2;
+
+ #if defined(__aarch64__)
+ vacc0x01 = vfma_lane_f32(vacc0x01, vb01c1, va0, 1);
+ vacc1x01 = vfma_lane_f32(vacc1x01, vb01c1, va1, 1);
+ vacc2x01 = vfma_lane_f32(vacc2x01, vb01c1, va2, 1);
+ vacc3x01 = vfma_lane_f32(vacc3x01, vb01c1, va3, 1);
+ #else
+ const float32x2_t va0c1 = vdup_lane_f32(va0, 1);
+ const float32x2_t va1c1 = vdup_lane_f32(va1, 1);
+ const float32x2_t va2c1 = vdup_lane_f32(va2, 1);
+ const float32x2_t va3c1 = vdup_lane_f32(va3, 1);
+ vacc0x01 = vfma_f32(vacc0x01, va0c1, vb01c1);
+ vacc1x01 = vfma_f32(vacc1x01, va1c1, vb01c1);
+ vacc2x01 = vfma_f32(vacc2x01, va2c1, vb01c1);
+ vacc3x01 = vfma_f32(vacc3x01, va3c1, vb01c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x2_t va0 = vld1_dup_f32(a0); a0 += 1;
+ const float32x2_t va1 = vld1_dup_f32(a1); a1 += 1;
+ const float32x2_t va2 = vld1_dup_f32(a2); a2 += 1;
+ const float32x2_t va3 = vld1_dup_f32(a3); a3 += 1;
+
+ const float32x2_t vb01 = vld1_f32(w); w += 2;
+
+ vacc0x01 = vfma_f32(vacc0x01, va0, vb01);
+ vacc1x01 = vfma_f32(vacc1x01, va1, vb01);
+ vacc2x01 = vfma_f32(vacc2x01, va2, vb01);
+ vacc3x01 = vfma_f32(vacc3x01, va3, vb01);
+ }
+
+ const float32x2_t vmax = vld1_dup_f32(¶ms->scalar.max);
+ vacc0x01 = vmin_f32(vacc0x01, vmax);
+ vacc1x01 = vmin_f32(vacc1x01, vmax);
+ vacc2x01 = vmin_f32(vacc2x01, vmax);
+ vacc3x01 = vmin_f32(vacc3x01, vmax);
+
+ const float32x2_t vmin = vld1_dup_f32(¶ms->scalar.min);
+ vacc0x01 = vmax_f32(vacc0x01, vmin);
+ vacc1x01 = vmax_f32(vacc1x01, vmin);
+ vacc2x01 = vmax_f32(vacc2x01, vmin);
+ vacc3x01 = vmax_f32(vacc3x01, vmin);
+
+ if XNN_LIKELY(nc >= 2) {
+ vst1_f32(c0, vacc0x01);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+ vst1_f32(c1, vacc1x01);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1_f32(c2, vacc2x01);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1_f32(c3, vacc3x01);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+
+ nc -= 2;
+ } else {
+ assert(nc == 1);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x2-scalar.c b/src/f32-gemm/4x2-scalar.c
new file mode 100644
index 0000000..cd2f650
--- /dev/null
+++ b/src/f32-gemm/4x2-scalar.c
@@ -0,0 +1,143 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gemm_ukernel_4x2__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float vacc00 = w[0];
+ float vacc01 = w[1];
+ w += 2;
+ float vacc10 = vacc00;
+ float vacc11 = vacc01;
+ float vacc20 = vacc00;
+ float vacc21 = vacc01;
+ float vacc30 = vacc00;
+ float vacc31 = vacc01;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+ const float va1 = *a1++;
+ const float va2 = *a2++;
+ const float va3 = *a3++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ w += 2;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc10 += va1 * vb0;
+ vacc11 += va1 * vb1;
+ vacc20 += va2 * vb0;
+ vacc21 += va2 * vb1;
+ vacc30 += va3 * vb0;
+ vacc31 += va3 * vb1;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc10 = math_max_f32(vacc10, vmin);
+ vacc11 = math_max_f32(vacc11, vmin);
+ vacc20 = math_max_f32(vacc20, vmin);
+ vacc21 = math_max_f32(vacc21, vmin);
+ vacc30 = math_max_f32(vacc30, vmin);
+ vacc31 = math_max_f32(vacc31, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc10 = math_min_f32(vacc10, vmax);
+ vacc11 = math_min_f32(vacc11, vmax);
+ vacc20 = math_min_f32(vacc20, vmax);
+ vacc21 = math_min_f32(vacc21, vmax);
+ vacc30 = math_min_f32(vacc30, vmax);
+ vacc31 = math_min_f32(vacc31, vmax);
+
+ if XNN_LIKELY(nc >= 2) {
+ c3[0] = vacc30;
+ c3[1] = vacc31;
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ c2[0] = vacc20;
+ c2[1] = vacc21;
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const void*) ((uintptr_t) a3 - kc);
+ a2 = (const void*) ((uintptr_t) a2 - kc);
+ a1 = (const void*) ((uintptr_t) a1 - kc);
+ a0 = (const void*) ((uintptr_t) a0 - kc);
+
+ nc -= 2;
+ } else {
+ if (nc & 1) {
+ c3[0] = vacc30;
+ c2[0] = vacc20;
+ c1[0] = vacc10;
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x4-scalar.c b/src/f32-gemm/4x4-scalar.c
new file mode 100644
index 0000000..350d5ba
--- /dev/null
+++ b/src/f32-gemm/4x4-scalar.c
@@ -0,0 +1,203 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gemm_ukernel_4x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float vacc00 = w[0];
+ float vacc01 = w[1];
+ float vacc02 = w[2];
+ float vacc03 = w[3];
+ w += 4;
+ float vacc10 = vacc00;
+ float vacc11 = vacc01;
+ float vacc12 = vacc02;
+ float vacc13 = vacc03;
+ float vacc20 = vacc00;
+ float vacc21 = vacc01;
+ float vacc22 = vacc02;
+ float vacc23 = vacc03;
+ float vacc30 = vacc00;
+ float vacc31 = vacc01;
+ float vacc32 = vacc02;
+ float vacc33 = vacc03;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+ const float va1 = *a1++;
+ const float va2 = *a2++;
+ const float va3 = *a3++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc02 += va0 * vb2;
+ vacc03 += va0 * vb3;
+ vacc10 += va1 * vb0;
+ vacc11 += va1 * vb1;
+ vacc12 += va1 * vb2;
+ vacc13 += va1 * vb3;
+ vacc20 += va2 * vb0;
+ vacc21 += va2 * vb1;
+ vacc22 += va2 * vb2;
+ vacc23 += va2 * vb3;
+ vacc30 += va3 * vb0;
+ vacc31 += va3 * vb1;
+ vacc32 += va3 * vb2;
+ vacc33 += va3 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc02 = math_max_f32(vacc02, vmin);
+ vacc03 = math_max_f32(vacc03, vmin);
+ vacc10 = math_max_f32(vacc10, vmin);
+ vacc11 = math_max_f32(vacc11, vmin);
+ vacc12 = math_max_f32(vacc12, vmin);
+ vacc13 = math_max_f32(vacc13, vmin);
+ vacc20 = math_max_f32(vacc20, vmin);
+ vacc21 = math_max_f32(vacc21, vmin);
+ vacc22 = math_max_f32(vacc22, vmin);
+ vacc23 = math_max_f32(vacc23, vmin);
+ vacc30 = math_max_f32(vacc30, vmin);
+ vacc31 = math_max_f32(vacc31, vmin);
+ vacc32 = math_max_f32(vacc32, vmin);
+ vacc33 = math_max_f32(vacc33, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc02 = math_min_f32(vacc02, vmax);
+ vacc03 = math_min_f32(vacc03, vmax);
+ vacc10 = math_min_f32(vacc10, vmax);
+ vacc11 = math_min_f32(vacc11, vmax);
+ vacc12 = math_min_f32(vacc12, vmax);
+ vacc13 = math_min_f32(vacc13, vmax);
+ vacc20 = math_min_f32(vacc20, vmax);
+ vacc21 = math_min_f32(vacc21, vmax);
+ vacc22 = math_min_f32(vacc22, vmax);
+ vacc23 = math_min_f32(vacc23, vmax);
+ vacc30 = math_min_f32(vacc30, vmax);
+ vacc31 = math_min_f32(vacc31, vmax);
+ vacc32 = math_min_f32(vacc32, vmax);
+ vacc33 = math_min_f32(vacc33, vmax);
+
+ if XNN_LIKELY(nc >= 4) {
+ c3[0] = vacc30;
+ c3[1] = vacc31;
+ c3[2] = vacc32;
+ c3[3] = vacc33;
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ c2[0] = vacc20;
+ c2[1] = vacc21;
+ c2[2] = vacc22;
+ c2[3] = vacc23;
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ c1[2] = vacc12;
+ c1[3] = vacc13;
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0[2] = vacc02;
+ c0[3] = vacc03;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const void*) ((uintptr_t) a3 - kc);
+ a2 = (const void*) ((uintptr_t) a2 - kc);
+ a1 = (const void*) ((uintptr_t) a1 - kc);
+ a0 = (const void*) ((uintptr_t) a0 - kc);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c3[0] = vacc30;
+ c3[1] = vacc31;
+ vacc30 = vacc32;
+ c3 += 2;
+ c2[0] = vacc20;
+ c2[1] = vacc21;
+ vacc20 = vacc22;
+ c2 += 2;
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ vacc10 = vacc12;
+ c1 += 2;
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ vacc00 = vacc02;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ c3[0] = vacc30;
+ c2[0] = vacc20;
+ c1[0] = vacc10;
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-cortex-a57.S b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a57.S
new file mode 100644
index 0000000..deebc47
--- /dev/null
+++ b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a57.S
@@ -0,0 +1,468 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x8-aarch64-neonfma-cortex-a57.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a57(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+# Vector register usage
+# A0 v0 v4
+# A1 v1 v5
+# A2 v2 v6
+# A3 v3 v7
+# B v8 v9 v10 v11
+# B v12 v13 v14 v15
+# B v20 v21 v22 v23
+# B v24 v25 v26 v27
+# C v16 v17
+# C v18 v19
+# C v28 v29
+# C v30 v31
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a57
+
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load params values
+ LD1R {v4.4s}, [x8], 4
+ LD1R {v5.4s}, [x8]
+ SUB x8, x8, 4
+
+ # Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ STP d10, d11, [sp, 16]
+ STP d12, d13, [sp, 32]
+ STP d14, d15, [sp, 48]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOV v18.16b, v16.16b
+ MOV v19.16b, v17.16b
+ MOV v28.16b, v16.16b
+ MOV v29.16b, v17.16b
+ MOV v30.16b, v16.16b
+ MOV v31.16b, v17.16b
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDP q12, q13, [x5], 32
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDP q14, q15, [x5], 32
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, loads for 1nd block of 4.
+ FMLA v16.4s, v8.4s, v4.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+ FMLA v16.4s, v10.4s, v4.s[1]
+ LDP q24, q25, [x5], 32
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ LDP q26, q27, [x5], 32
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+ FMLA v16.4s, v12.4s, v4.s[2]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ LDR q1, [x11], 16
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ LDR q2, [x12], 16
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ LDR q3, [x4], 16
+ FMLA v31.4s, v13.4s, v7.s[2]
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUBS x0, x0, 32
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+ B.HS 1b
+
+2:
+ # Epilogue
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDP q12, q13, [x5], 32
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDP q14, q15, [x5], 32
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, noloads
+ FMLA v16.4s, v8.4s, v4.s[0]
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+
+ FMLA v16.4s, v10.4s, v4.s[1]
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ LD1R {v4.4s}, [x8], 4
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ LD1R {v5.4s}, [x8]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUB x8, x8, 4
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+
+3:
+ # Remainder- 4 floats of A (16 bytes)
+ TBZ x0, 4, 4f
+
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+4:
+ # Remainder- 2 floats of A (8 bytes)
+ TBZ x0, 3, 5f
+
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+5:
+ # Remainder- 1 float of A (4 bytes)
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-cortex-a57.S.in b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a57.S.in
new file mode 100644
index 0000000..ace480d
--- /dev/null
+++ b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a57.S.in
@@ -0,0 +1,481 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_cortex_a57(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+# Vector register usage
+# A0 v0 v4
+# A1 v1 v5
+# A2 v2 v6
+# A3 v3 v7
+# B v8 v9 v10 v11
+# B v12 v13 v14 v15
+# B v20 v21 v22 v23
+# B v24 v25 v26 v27
+# C v16 v17
+# C v18 v19
+# C v28 v29
+# C v30 v31
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_cortex_a57
+
+ $if INC:
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+ $else:
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load params values
+ LD1R {v4.4s}, [x8], 4
+ LD1R {v5.4s}, [x8]
+ SUB x8, x8, 4
+
+ # Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ STP d10, d11, [sp, 16]
+ STP d12, d13, [sp, 32]
+ STP d14, d15, [sp, 48]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ LDP q18, q19, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOV v18.16b, v16.16b
+ MOV v19.16b, v17.16b
+ MOV v28.16b, v16.16b
+ MOV v29.16b, v17.16b
+ MOV v30.16b, v16.16b
+ MOV v31.16b, v17.16b
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDP q12, q13, [x5], 32
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDP q14, q15, [x5], 32
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, loads for 1nd block of 4.
+ FMLA v16.4s, v8.4s, v4.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+ FMLA v16.4s, v10.4s, v4.s[1]
+ LDP q24, q25, [x5], 32
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ LDP q26, q27, [x5], 32
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+ FMLA v16.4s, v12.4s, v4.s[2]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ LDR q1, [x11], 16
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ LDR q2, [x12], 16
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ LDR q3, [x4], 16
+ FMLA v31.4s, v13.4s, v7.s[2]
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUBS x0, x0, 32
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+ B.HS 1b
+
+2:
+ # Epilogue
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDP q12, q13, [x5], 32
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDP q14, q15, [x5], 32
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, noloads
+ FMLA v16.4s, v8.4s, v4.s[0]
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+
+ FMLA v16.4s, v10.4s, v4.s[1]
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ LD1R {v4.4s}, [x8], 4
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ LD1R {v5.4s}, [x8]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUB x8, x8, 4
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+
+3:
+ # Remainder- 4 floats of A (16 bytes)
+ TBZ x0, 4, 4f
+
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+4:
+ # Remainder- 2 floats of A (8 bytes)
+ TBZ x0, 3, 5f
+
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+5:
+ # Remainder- 1 float of A (4 bytes)
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-cortex-a75.S b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..3b14672
--- /dev/null
+++ b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,471 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x8-aarch64-neonfma-cortex-a75.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+# Vector register usage
+# A0 v0 v4
+# A1 v1 v5
+# A2 v2 v6
+# A3 v3 v7
+# B v8 v9 v10 v11
+# B v12 v13 v14 v15
+# B v20 v21 v22 v23
+# B v24 v25 v26 v27
+# C v16 v17
+# C v18 v19
+# C v28 v29
+# C v30 v31
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75
+
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ STP d10, d11, [sp, 16]
+ STP d12, d13, [sp, 32]
+ STP d14, d15, [sp, 48]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOV v18.16b, v16.16b
+ MOV v19.16b, v17.16b
+ MOV v28.16b, v16.16b
+ MOV v29.16b, v17.16b
+ MOV v30.16b, v16.16b
+ MOV v31.16b, v17.16b
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q12, q13, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, loads for 1nd block of 4.
+ FMLA v16.4s, v8.4s, v4.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v31.4s, v9.4s, v7.s[0]
+ FMLA v16.4s, v10.4s, v4.s[1]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ LDR q1, [x11], 16
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ LDR q2, [x12], 16
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ LDR q3, [x4], 16
+ FMLA v31.4s, v11.4s, v7.s[1]
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUBS x0, x0, 32
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+ B.HS 1b
+
+2:
+ # Epilogue
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q12, q13, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, noloads
+ FMLA v16.4s, v8.4s, v4.s[0]
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+
+ FMLA v16.4s, v10.4s, v4.s[1]
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+
+3:
+ # Remainder- 4 floats of A (16 bytes)
+ TBZ x0, 4, 4f
+
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+4:
+ # Remainder- 2 floats of A (8 bytes)
+ TBZ x0, 3, 5f
+
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+5:
+ # Remainder- 1 float of A (4 bytes)
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-cortex-a75.S.in b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a75.S.in
new file mode 100644
index 0000000..9760ea5
--- /dev/null
+++ b/src/f32-gemm/4x8-aarch64-neonfma-cortex-a75.S.in
@@ -0,0 +1,484 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+# Vector register usage
+# A0 v0 v4
+# A1 v1 v5
+# A2 v2 v6
+# A3 v3 v7
+# B v8 v9 v10 v11
+# B v12 v13 v14 v15
+# B v20 v21 v22 v23
+# B v24 v25 v26 v27
+# C v16 v17
+# C v18 v19
+# C v28 v29
+# C v30 v31
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_cortex_a75
+
+ $if INC:
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+ $else:
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ STP d10, d11, [sp, 16]
+ STP d12, d13, [sp, 32]
+ STP d14, d15, [sp, 48]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ LDP q18, q19, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOV v18.16b, v16.16b
+ MOV v19.16b, v17.16b
+ MOV v28.16b, v16.16b
+ MOV v29.16b, v17.16b
+ MOV v30.16b, v16.16b
+ MOV v31.16b, v17.16b
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q12, q13, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, loads for 1nd block of 4.
+ FMLA v16.4s, v8.4s, v4.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v31.4s, v9.4s, v7.s[0]
+ FMLA v16.4s, v10.4s, v4.s[1]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ LDR q1, [x11], 16
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ LDR q2, [x12], 16
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ LDR q3, [x4], 16
+ FMLA v31.4s, v11.4s, v7.s[1]
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUBS x0, x0, 32
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+ B.HS 1b
+
+2:
+ # Epilogue
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q12, q13, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, noloads
+ FMLA v16.4s, v8.4s, v4.s[0]
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+
+ FMLA v16.4s, v10.4s, v4.s[1]
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+
+3:
+ # Remainder- 4 floats of A (16 bytes)
+ TBZ x0, 4, 4f
+
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+4:
+ # Remainder- 2 floats of A (8 bytes)
+ TBZ x0, 3, 5f
+
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+5:
+ # Remainder- 1 float of A (4 bytes)
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-ld128.S b/src/f32-gemm/4x8-aarch64-neonfma-ld128.S
new file mode 100644
index 0000000..7e3f099
--- /dev/null
+++ b/src/f32-gemm/4x8-aarch64-neonfma-ld128.S
@@ -0,0 +1,249 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128
+
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOV v18.16b, v16.16b
+ MOV v19.16b, v17.16b
+ MOV v28.16b, v16.16b
+ MOV v29.16b, v17.16b
+ MOV v30.16b, v16.16b
+ MOV v31.16b, v17.16b
+
+ # Is there at least 4 floats (16 bytes)?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q26, q27, [x5], 32
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ SUBS x0, x0, 16
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+ B.HS 1b
+
+ # Remainder- 2 floats of A (8 bytes)
+2:
+ TBZ x0, 3, 3f
+
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+ # Remainder- 1 float of A (4 bytes)
+3:
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in b/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in
new file mode 100644
index 0000000..74e8922
--- /dev/null
+++ b/src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in
@@ -0,0 +1,262 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ld128(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ld128
+
+ $if INC:
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+ $else:
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ LDP q18, q19, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOV v18.16b, v16.16b
+ MOV v19.16b, v17.16b
+ MOV v28.16b, v16.16b
+ MOV v29.16b, v17.16b
+ MOV v30.16b, v16.16b
+ MOV v31.16b, v17.16b
+
+ # Is there at least 4 floats (16 bytes)?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q26, q27, [x5], 32
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ SUBS x0, x0, 16
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+ B.HS 1b
+
+ # Remainder- 2 floats of A (8 bytes)
+2:
+ TBZ x0, 3, 3f
+
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+ # Remainder- 1 float of A (4 bytes)
+3:
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ld128
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-ld64.S b/src/f32-gemm/4x8-aarch64-neonfma-ld64.S
new file mode 100644
index 0000000..2dc1a6c
--- /dev/null
+++ b/src/f32-gemm/4x8-aarch64-neonfma-ld64.S
@@ -0,0 +1,203 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64
+
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOV v18.16b, v16.16b
+ MOV v19.16b, v17.16b
+ MOV v28.16b, v16.16b
+ MOV v29.16b, v17.16b
+ MOV v30.16b, v16.16b
+ MOV v31.16b, v17.16b
+
+ # Is there at least 2 floats (8 bytes)?
+ SUBS x0, x2, 8 // k = kc - 8
+ B.LO 2f
+
+ # Main loop - 2 floats of A (8 bytes)
+
+1:
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ SUBS x0, x0, 8
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ B.HS 1b
+2:
+ # Remainder- 1 floats of A (4 bytes)
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3 , [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in b/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in
new file mode 100644
index 0000000..cb0cfde
--- /dev/null
+++ b/src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in
@@ -0,0 +1,216 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ld64(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ld64
+
+ $if INC:
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+ $else:
+ # Load cn_stride, params pointer
+ LDP x14, x8, [sp]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ LDP q18, q19, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOV v18.16b, v16.16b
+ MOV v19.16b, v17.16b
+ MOV v28.16b, v16.16b
+ MOV v29.16b, v17.16b
+ MOV v30.16b, v16.16b
+ MOV v31.16b, v17.16b
+
+ # Is there at least 2 floats (8 bytes)?
+ SUBS x0, x2, 8 // k = kc - 8
+ B.LO 2f
+
+ # Main loop - 2 floats of A (8 bytes)
+
+1:
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ SUBS x0, x0, 8
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ B.HS 1b
+2:
+ # Remainder- 1 floats of A (4 bytes)
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3 , [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_4x8__aarch64_neonfma_ld64
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/4x8-neon-ld128.c b/src/f32-gemm/4x8-neon-ld128.c
new file mode 100644
index 0000000..3ebede2
--- /dev/null
+++ b/src/f32-gemm/4x8-neon-ld128.c
@@ -0,0 +1,225 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld128.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8__neon_ld128(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ const float32x4_t va0 = vld1q_f32(a0); a0 += 4;
+ const float32x4_t va1 = vld1q_f32(a1); a1 += 4;
+ const float32x4_t va2 = vld1q_f32(a2); a2 += 4;
+ const float32x4_t va3 = vld1q_f32(a3); a3 += 4;
+
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, vget_low_f32(va0), 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, vget_low_f32(va1), 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, vget_low_f32(va2), 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, vget_low_f32(va3), 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, vget_low_f32(va0), 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, vget_low_f32(va1), 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, vget_low_f32(va2), 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, vget_low_f32(va3), 0);
+
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, vget_low_f32(va0), 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, vget_low_f32(va1), 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, vget_low_f32(va2), 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, vget_low_f32(va3), 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, vget_low_f32(va0), 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, vget_low_f32(va1), 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, vget_low_f32(va2), 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, vget_low_f32(va3), 1);
+
+ const float32x4_t vb0123c2 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c2 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c2, vget_high_f32(va0), 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c2, vget_high_f32(va1), 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c2, vget_high_f32(va2), 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c2, vget_high_f32(va3), 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c2, vget_high_f32(va0), 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c2, vget_high_f32(va1), 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c2, vget_high_f32(va2), 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c2, vget_high_f32(va3), 0);
+
+ const float32x4_t vb0123c3 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c3 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c3, vget_high_f32(va0), 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c3, vget_high_f32(va1), 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c3, vget_high_f32(va2), 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c3, vget_high_f32(va3), 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c3, vget_high_f32(va0), 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c3, vget_high_f32(va1), 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c3, vget_high_f32(va2), 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c3, vget_high_f32(va3), 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8-neon-ld64.c b/src/f32-gemm/4x8-neon-ld64.c
new file mode 100644
index 0000000..6c19c67
--- /dev/null
+++ b/src/f32-gemm/4x8-neon-ld64.c
@@ -0,0 +1,195 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8-neonfma-ld128.c b/src/f32-gemm/4x8-neonfma-ld128.c
new file mode 100644
index 0000000..08888d6
--- /dev/null
+++ b/src/f32-gemm/4x8-neonfma-ld128.c
@@ -0,0 +1,285 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld128.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8__neonfma_ld128(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ const float32x4_t va0 = vld1q_f32(a0); a0 += 4;
+ const float32x4_t va1 = vld1q_f32(a1); a1 += 4;
+ const float32x4_t va2 = vld1q_f32(a2); a2 += 4;
+ const float32x4_t va3 = vld1q_f32(a3); a3 += 4;
+
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(vget_low_f32(va0), 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(vget_low_f32(va1), 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(vget_low_f32(va2), 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(vget_low_f32(va3), 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ #endif
+
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(vget_low_f32(va0), 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(vget_low_f32(va1), 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(vget_low_f32(va2), 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(vget_low_f32(va3), 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ #endif
+
+ const float32x4_t vb0123c2 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c2 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c2, va0, 2);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c2, va1, 2);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c2, va2, 2);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c2, va3, 2);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c2, va0, 2);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c2, va1, 2);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c2, va2, 2);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c2, va3, 2);
+ #else
+ const float32x4_t va0c2 = vdupq_lane_f32(vget_high_f32(va0), 0);
+ const float32x4_t va1c2 = vdupq_lane_f32(vget_high_f32(va1), 0);
+ const float32x4_t va2c2 = vdupq_lane_f32(vget_high_f32(va2), 0);
+ const float32x4_t va3c2 = vdupq_lane_f32(vget_high_f32(va3), 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c2, vb0123c2);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c2, vb0123c2);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c2, vb0123c2);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c2, vb4567c2);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c2, vb4567c2);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c2, vb4567c2);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c2, vb4567c2);
+ #endif
+
+ const float32x4_t vb0123c3 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c3 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c3, va0, 3);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c3, va1, 3);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c3, va2, 3);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c3, va3, 3);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c3, va0, 3);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c3, va1, 3);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c3, va2, 3);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c3, va3, 3);
+ #else
+ const float32x4_t va0c3 = vdupq_lane_f32(vget_high_f32(va0), 1);
+ const float32x4_t va1c3 = vdupq_lane_f32(vget_high_f32(va1), 1);
+ const float32x4_t va2c3 = vdupq_lane_f32(vget_high_f32(va2), 1);
+ const float32x4_t va3c3 = vdupq_lane_f32(vget_high_f32(va3), 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c3, vb0123c3);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c3, vb0123c3);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c3, vb0123c3);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c3, vb4567c3);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c3, vb4567c3);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c3, vb4567c3);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c3, vb4567c3);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8-neonfma-ld64.c b/src/f32-gemm/4x8-neonfma-ld64.c
new file mode 100644
index 0000000..c0f9dd5
--- /dev/null
+++ b/src/f32-gemm/4x8-neonfma-ld64.c
@@ -0,0 +1,225 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8-psimd-loadsplat.c b/src/f32-gemm/4x8-psimd-loadsplat.c
new file mode 100644
index 0000000..247744d
--- /dev/null
+++ b/src/f32-gemm/4x8-psimd-loadsplat.c
@@ -0,0 +1,180 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-loadsplat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8-psimd-splat.c b/src/f32-gemm/4x8-psimd-splat.c
new file mode 100644
index 0000000..2c20455
--- /dev/null
+++ b/src/f32-gemm/4x8-psimd-splat.c
@@ -0,0 +1,260 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-splat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ const psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ const psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ const psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+
+ const psimd_f32 va0c0 = psimd_splat0_f32(va0);
+ const psimd_f32 va1c0 = psimd_splat0_f32(va1);
+ const psimd_f32 va2c0 = psimd_splat0_f32(va2);
+ const psimd_f32 va3c0 = psimd_splat0_f32(va3);
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c0, vb4567c0);
+ const psimd_f32 va0c1 = psimd_splat1_f32(va0);
+ const psimd_f32 va1c1 = psimd_splat1_f32(va1);
+ const psimd_f32 va2c1 = psimd_splat1_f32(va2);
+ const psimd_f32 va3c1 = psimd_splat1_f32(va3);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c1, vb4567c1);
+ const psimd_f32 va0c2 = psimd_splat2_f32(va0);
+ const psimd_f32 va1c2 = psimd_splat2_f32(va1);
+ const psimd_f32 va2c2 = psimd_splat2_f32(va2);
+ const psimd_f32 va3c2 = psimd_splat2_f32(va3);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c2, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c2, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c2, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c2, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c2, vb4567c2);
+ const psimd_f32 va0c3 = psimd_splat3_f32(va0);
+ const psimd_f32 va1c3 = psimd_splat3_f32(va1);
+ const psimd_f32 va2c3 = psimd_splat3_f32(va2);
+ const psimd_f32 va3c3 = psimd_splat3_f32(va3);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c3, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c3, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c3, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c3, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c3, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c3, vb4567c3);
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8-sse-dup.c b/src/f32-gemm/4x8-sse-dup.c
new file mode 100644
index 0000000..6b1fcd0
--- /dev/null
+++ b/src/f32-gemm/4x8-sse-dup.c
@@ -0,0 +1,264 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-dup.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8__sse_dup(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w + 0);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ __m128 vacc1x0123 = vacc0x0123;
+ __m128 vacc1x4567 = vacc0x4567;
+ __m128 vacc2x0123 = vacc0x0123;
+ __m128 vacc2x4567 = vacc0x4567;
+ __m128 vacc3x0123 = vacc0x0123;
+ __m128 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+ const __m128 va1 = _mm_loadu_ps(a1);
+ a1 += 4;
+ const __m128 va2 = _mm_loadu_ps(a2);
+ a2 += 4;
+ const __m128 va3 = _mm_loadu_ps(a3);
+ a3 += 4;
+
+
+ const __m128 va0c0000 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va1c0000 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va2c0000 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va3c0000 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 0, 0, 0));
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c0000, vb0123c0));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c0000, vb0123c0));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c0000, vb0123c0));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c0000, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c0000, vb4567c0));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c0000, vb4567c0));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c0000, vb4567c0));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c0000, vb4567c0));
+
+ const __m128 va0c1111 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va1c1111 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va2c1111 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va3c1111 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(1, 1, 1, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c1111, vb0123c1));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c1111, vb0123c1));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c1111, vb0123c1));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c1111, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c1111, vb4567c1));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c1111, vb4567c1));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c1111, vb4567c1));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c1111, vb4567c1));
+
+ const __m128 va0c2222 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va1c2222 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va2c2222 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va3c2222 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(2, 2, 2, 2));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c2222, vb0123c2));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c2222, vb0123c2));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c2222, vb0123c2));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c2222, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c2222, vb4567c2));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c2222, vb4567c2));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c2222, vb4567c2));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c2222, vb4567c2));
+
+ const __m128 va0c3333 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(3, 3, 3, 3));
+ const __m128 va1c3333 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(3, 3, 3, 3));
+ const __m128 va2c3333 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(3, 3, 3, 3));
+ const __m128 va3c3333 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(3, 3, 3, 3));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c3333, vb0123c3));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c3333, vb0123c3));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c3333, vb0123c3));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c3333, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c3333, vb4567c3));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c3333, vb4567c3));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c3333, vb4567c3));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c3333, vb4567c3));
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8-sse-load1.c b/src/f32-gemm/4x8-sse-load1.c
new file mode 100644
index 0000000..4afb941
--- /dev/null
+++ b/src/f32-gemm/4x8-sse-load1.c
@@ -0,0 +1,180 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-load1.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8__sse_load1(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w + 0);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ __m128 vacc1x0123 = vacc0x0123;
+ __m128 vacc1x4567 = vacc0x4567;
+ __m128 vacc2x0123 = vacc0x0123;
+ __m128 vacc2x4567 = vacc0x4567;
+ __m128 vacc3x0123 = vacc0x0123;
+ __m128 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8s4-psimd.c b/src/f32-gemm/4x8s4-psimd.c
new file mode 100644
index 0000000..da2fb9f
--- /dev/null
+++ b/src/f32-gemm/4x8s4-psimd.c
@@ -0,0 +1,260 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-s4.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c0);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c1);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c2);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c3);
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/4x8s4-sse.c b/src/f32-gemm/4x8s4-sse.c
new file mode 100644
index 0000000..9f09cea
--- /dev/null
+++ b/src/f32-gemm/4x8s4-sse.c
@@ -0,0 +1,260 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-shuffle.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_4x8s4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w + 0);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ __m128 vacc1x0123 = vacc0x0123;
+ __m128 vacc1x4567 = vacc0x4567;
+ __m128 vacc2x0123 = vacc0x0123;
+ __m128 vacc2x4567 = vacc0x4567;
+ __m128 vacc3x0123 = vacc0x0123;
+ __m128 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+ __m128 va1 = _mm_loadu_ps(a1);
+ a1 += 4;
+ __m128 va2 = _mm_loadu_ps(a2);
+ a2 += 4;
+ __m128 va3 = _mm_loadu_ps(a3);
+ a3 += 4;
+
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c0));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c0));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c0));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c0));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c0));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c0));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c0));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+ va1 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 3, 2, 1));
+ va2 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 3, 2, 1));
+ va3 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c1));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c1));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c1));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c1));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c1));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c1));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c1));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+ va1 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 3, 2, 1));
+ va2 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 3, 2, 1));
+ va3 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c2));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c2));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c2));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c2));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c2));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c2));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c2));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+ va1 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 3, 2, 1));
+ va2 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 3, 2, 1));
+ va3 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c3));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c3));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c3));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c3));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c3));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c3));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c3));
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/5x8-aarch64-neonfma-cortex-a75.S b/src/f32-gemm/5x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..d6014d9
--- /dev/null
+++ b/src/f32-gemm/5x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,584 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/5x8-aarch64-neonfma-cortex-a75.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_5x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# unused compared to 5x8
+# x4 a5
+# x7 c5
+# A5 v10 v11
+# C v30 v31
+
+# d8-d15 need to be preserved if used.
+# x19-x30 need to be preserved if used. x18 is reserved for OS.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x13 c3
+# x7 c4
+
+# Vector register usage
+# A0 v0 v1
+# A1 v2 v3
+# A2 v4 v5
+# A3 v6 v7
+# A4 v8 v9
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# Clamp v30 v31
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_5x8__aarch64_neonfma_cortex_a75
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -48]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d12, d13, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d14, d15, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x13, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x13, x17, x13, LO // c3 = c2
+
+ # Load params pointer
+ LDR x8, [sp, 56]
+
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x7, x13, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x7, x13, x7, LS // c4 = c3
+
+ # Load clamp values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+ # Load cn_stride
+ LDR x14, [sp, 48]
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 80 FMA
+ LDR q0, [x3], 16
+ LDR q2, [x9], 16
+ LDR q4, [x10], 16
+ LDR q6, [x11], 16
+ LDR q8, [x12], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 80 FMA + 5 LDP A + 8 LDP B
+1:
+ # First group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ LDR q1, [x3], 16 // Load next 5 A
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ LDR q3, [x9], 16
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ LDR q5, [x10], 16
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ LDR q7, [x11], 16
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ LDR q9, [x12], 16
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v1.s[0]
+ FMLA v22.4s, v12.4s, v3.s[0]
+ FMLA v24.4s, v12.4s, v5.s[0]
+ LDR q0, [x3], 16 // Load next 5 A
+ FMLA v26.4s, v12.4s, v7.s[0]
+ FMLA v28.4s, v12.4s, v9.s[0]
+ FMLA v21.4s, v13.4s, v1.s[0]
+ LDR q2, [x9], 16
+ FMLA v23.4s, v13.4s, v3.s[0]
+ FMLA v25.4s, v13.4s, v5.s[0]
+ FMLA v27.4s, v13.4s, v7.s[0]
+ LDR q4, [x10], 16
+ FMLA v29.4s, v13.4s, v9.s[0]
+
+ FMLA v20.4s, v14.4s, v1.s[1]
+ FMLA v22.4s, v14.4s, v3.s[1]
+ LDR q6, [x11], 16
+ FMLA v24.4s, v14.4s, v5.s[1]
+ FMLA v26.4s, v14.4s, v7.s[1]
+ FMLA v28.4s, v14.4s, v9.s[1]
+ LDR q8, [x12], 16
+ FMLA v21.4s, v15.4s, v1.s[1]
+ FMLA v23.4s, v15.4s, v3.s[1]
+ FMLA v25.4s, v15.4s, v5.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v7.s[1]
+ FMLA v29.4s, v15.4s, v9.s[1]
+
+ FMLA v20.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v22.4s, v16.4s, v3.s[2]
+ FMLA v24.4s, v16.4s, v5.s[2]
+ FMLA v26.4s, v16.4s, v7.s[2]
+ FMLA v28.4s, v16.4s, v9.s[2]
+ FMLA v21.4s, v17.4s, v1.s[2]
+ FMLA v23.4s, v17.4s, v3.s[2]
+ FMLA v25.4s, v17.4s, v5.s[2]
+ FMLA v27.4s, v17.4s, v7.s[2]
+ FMLA v29.4s, v17.4s, v9.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v1.s[3]
+ FMLA v22.4s, v18.4s, v3.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v5.s[3]
+ FMLA v26.4s, v18.4s, v7.s[3]
+ FMLA v28.4s, v18.4s, v9.s[3]
+ FMLA v21.4s, v19.4s, v1.s[3]
+ FMLA v23.4s, v19.4s, v3.s[3]
+ FMLA v25.4s, v19.4s, v5.s[3]
+ FMLA v27.4s, v19.4s, v7.s[3]
+ FMLA v29.4s, v19.4s, v9.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 80 FMA + 5 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ LDR q1, [x3], 16 // Load next 5 A
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ LDR q3, [x9], 16
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ LDR q5, [x10], 16
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ LDR q7, [x11], 16
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ LDR q9, [x12], 16
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v1.s[0]
+ FMLA v22.4s, v12.4s, v3.s[0]
+ FMLA v24.4s, v12.4s, v5.s[0]
+ FMLA v26.4s, v12.4s, v7.s[0]
+ FMLA v28.4s, v12.4s, v9.s[0]
+ FMLA v21.4s, v13.4s, v1.s[0]
+ FMLA v23.4s, v13.4s, v3.s[0]
+ FMLA v25.4s, v13.4s, v5.s[0]
+ FMLA v27.4s, v13.4s, v7.s[0]
+ FMLA v29.4s, v13.4s, v9.s[0]
+
+ FMLA v20.4s, v14.4s, v1.s[1]
+ FMLA v22.4s, v14.4s, v3.s[1]
+ FMLA v24.4s, v14.4s, v5.s[1]
+ FMLA v26.4s, v14.4s, v7.s[1]
+ FMLA v28.4s, v14.4s, v9.s[1]
+ FMLA v21.4s, v15.4s, v1.s[1]
+ FMLA v23.4s, v15.4s, v3.s[1]
+ FMLA v25.4s, v15.4s, v5.s[1]
+ FMLA v27.4s, v15.4s, v7.s[1]
+ FMLA v29.4s, v15.4s, v9.s[1]
+
+ FMLA v20.4s, v16.4s, v1.s[2]
+ FMLA v22.4s, v16.4s, v3.s[2]
+ FMLA v24.4s, v16.4s, v5.s[2]
+ FMLA v26.4s, v16.4s, v7.s[2]
+ FMLA v28.4s, v16.4s, v9.s[2]
+ FMLA v21.4s, v17.4s, v1.s[2]
+ FMLA v23.4s, v17.4s, v3.s[2]
+ FMLA v25.4s, v17.4s, v5.s[2]
+ FMLA v27.4s, v17.4s, v7.s[2]
+ FMLA v29.4s, v17.4s, v9.s[2]
+ TST x0, 31
+
+ FMLA v20.4s, v18.4s, v1.s[3]
+ FMLA v22.4s, v18.4s, v3.s[3]
+ FMLA v24.4s, v18.4s, v5.s[3]
+ FMLA v26.4s, v18.4s, v7.s[3]
+ FMLA v28.4s, v18.4s, v9.s[3]
+ FMLA v21.4s, v19.4s, v1.s[3]
+ FMLA v23.4s, v19.4s, v3.s[3]
+ FMLA v25.4s, v19.4s, v5.s[3]
+ FMLA v27.4s, v19.4s, v7.s[3]
+ FMLA v29.4s, v19.4s, v9.s[3]
+ B.NE 4f
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v30.4s
+ FMIN v21.4s, v21.4s, v30.4s
+ FMIN v22.4s, v22.4s, v30.4s
+ FMIN v23.4s, v23.4s, v30.4s
+ FMIN v24.4s, v24.4s, v30.4s
+ FMIN v25.4s, v25.4s, v30.4s
+ FMIN v26.4s, v26.4s, v30.4s
+ FMIN v27.4s, v27.4s, v30.4s
+ FMIN v28.4s, v28.4s, v30.4s
+ FMIN v29.4s, v29.4s, v30.4s
+ FMAX v20.4s, v20.4s, v31.4s
+ FMAX v21.4s, v21.4s, v31.4s
+ FMAX v22.4s, v22.4s, v31.4s
+ FMAX v23.4s, v23.4s, v31.4s
+ FMAX v24.4s, v24.4s, v31.4s
+ FMAX v25.4s, v25.4s, v31.4s
+ FMAX v26.4s, v26.4s, v31.4s
+ FMAX v27.4s, v27.4s, v31.4s
+ FMAX v28.4s, v28.4s, v31.4s
+ FMAX v29.4s, v29.4s, v31.4s
+
+ # Store full 5 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x13]
+ ADD x13, x13, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x7]
+ ADD x7, x7, x14
+ SUB x12, x12, x2 // a4 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 32]
+ LDP d12, d13, [sp, 16]
+ LDP d8, d9, [sp], 48
+ RET
+
+ # Load clamp values
+4:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q2, [x9], 16
+ LDR q4, [x10], 16
+ LDR q6, [x11], 16
+ LDR q8, [x12], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d2, [x9], 8
+ LDR d4, [x10], 8
+ LDR d6, [x11], 8
+ LDR d8, [x12], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s2, [x9], 4
+ LDR s4, [x10], 4
+ LDR s6, [x11], 4
+ LDR s8, [x12], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x13], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x7], 16
+ MOV v28.16b, v29.16b
+8:
+ TBZ x1, 1, 9f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x13], 8
+ DUP d26, v26.d[1]
+ STR d28, [x7], 8
+ DUP d28, v28.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x13]
+ STR s28, [x7]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 32]
+ LDP d12, d13, [sp, 16]
+ LDP d8, d9, [sp], 48
+ RET
+
+END_FUNCTION f32_gemm_ukernel_5x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/5x8-aarch64-neonfma-cortex-a75.S.in b/src/f32-gemm/5x8-aarch64-neonfma-cortex-a75.S.in
new file mode 100644
index 0000000..43785d6
--- /dev/null
+++ b/src/f32-gemm/5x8-aarch64-neonfma-cortex-a75.S.in
@@ -0,0 +1,653 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_5x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# unused compared to 5x8
+# x4 a5
+# x7 c5
+# A5 v10 v11
+# C v30 v31
+
+# d8-d15 need to be preserved if used.
+# x19-x30 need to be preserved if used. x18 is reserved for OS.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x13 c3
+# x7 c4
+
+# Vector register usage
+# A0 v0 v1
+# A1 v2 v3
+# A2 v4 v5
+# A3 v6 v7
+# A4 v8 v9
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# Clamp v30 v31
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_5x8__aarch64_neonfma_cortex_a75
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -48]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d12, d13, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d14, d15, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x13, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x13, x17, x13, LO // c3 = c2
+
+ $if INC:
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 56]
+ $else:
+ # Load params pointer
+ LDR x8, [sp, 56]
+
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x7, x13, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x7, x13, x7, LS // c4 = c3
+
+ # Load clamp values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+ # Load cn_stride
+ LDR x14, [sp, 48]
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 80 FMA
+ LDR q0, [x3], 16
+ LDR q2, [x9], 16
+ LDR q4, [x10], 16
+ LDR q6, [x11], 16
+ LDR q8, [x12], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 80 FMA + 5 LDP A + 8 LDP B
+1:
+ # First group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ LDR q1, [x3], 16 // Load next 5 A
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ LDR q3, [x9], 16
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ LDR q5, [x10], 16
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ LDR q7, [x11], 16
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ LDR q9, [x12], 16
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v1.s[0]
+ FMLA v22.4s, v12.4s, v3.s[0]
+ FMLA v24.4s, v12.4s, v5.s[0]
+ LDR q0, [x3], 16 // Load next 5 A
+ FMLA v26.4s, v12.4s, v7.s[0]
+ FMLA v28.4s, v12.4s, v9.s[0]
+ FMLA v21.4s, v13.4s, v1.s[0]
+ LDR q2, [x9], 16
+ FMLA v23.4s, v13.4s, v3.s[0]
+ FMLA v25.4s, v13.4s, v5.s[0]
+ FMLA v27.4s, v13.4s, v7.s[0]
+ LDR q4, [x10], 16
+ FMLA v29.4s, v13.4s, v9.s[0]
+
+ FMLA v20.4s, v14.4s, v1.s[1]
+ FMLA v22.4s, v14.4s, v3.s[1]
+ LDR q6, [x11], 16
+ FMLA v24.4s, v14.4s, v5.s[1]
+ FMLA v26.4s, v14.4s, v7.s[1]
+ FMLA v28.4s, v14.4s, v9.s[1]
+ LDR q8, [x12], 16
+ FMLA v21.4s, v15.4s, v1.s[1]
+ FMLA v23.4s, v15.4s, v3.s[1]
+ FMLA v25.4s, v15.4s, v5.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v7.s[1]
+ FMLA v29.4s, v15.4s, v9.s[1]
+
+ FMLA v20.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v22.4s, v16.4s, v3.s[2]
+ FMLA v24.4s, v16.4s, v5.s[2]
+ FMLA v26.4s, v16.4s, v7.s[2]
+ FMLA v28.4s, v16.4s, v9.s[2]
+ FMLA v21.4s, v17.4s, v1.s[2]
+ FMLA v23.4s, v17.4s, v3.s[2]
+ FMLA v25.4s, v17.4s, v5.s[2]
+ FMLA v27.4s, v17.4s, v7.s[2]
+ FMLA v29.4s, v17.4s, v9.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v1.s[3]
+ FMLA v22.4s, v18.4s, v3.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v5.s[3]
+ FMLA v26.4s, v18.4s, v7.s[3]
+ FMLA v28.4s, v18.4s, v9.s[3]
+ FMLA v21.4s, v19.4s, v1.s[3]
+ FMLA v23.4s, v19.4s, v3.s[3]
+ FMLA v25.4s, v19.4s, v5.s[3]
+ FMLA v27.4s, v19.4s, v7.s[3]
+ FMLA v29.4s, v19.4s, v9.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 80 FMA + 5 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ LDR q1, [x3], 16 // Load next 5 A
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ LDR q3, [x9], 16
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ LDR q5, [x10], 16
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ LDR q7, [x11], 16
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ LDR q9, [x12], 16
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v1.s[0]
+ FMLA v22.4s, v12.4s, v3.s[0]
+ FMLA v24.4s, v12.4s, v5.s[0]
+ FMLA v26.4s, v12.4s, v7.s[0]
+ FMLA v28.4s, v12.4s, v9.s[0]
+ FMLA v21.4s, v13.4s, v1.s[0]
+ FMLA v23.4s, v13.4s, v3.s[0]
+ FMLA v25.4s, v13.4s, v5.s[0]
+ FMLA v27.4s, v13.4s, v7.s[0]
+ FMLA v29.4s, v13.4s, v9.s[0]
+
+ FMLA v20.4s, v14.4s, v1.s[1]
+ FMLA v22.4s, v14.4s, v3.s[1]
+ FMLA v24.4s, v14.4s, v5.s[1]
+ FMLA v26.4s, v14.4s, v7.s[1]
+ FMLA v28.4s, v14.4s, v9.s[1]
+ FMLA v21.4s, v15.4s, v1.s[1]
+ FMLA v23.4s, v15.4s, v3.s[1]
+ FMLA v25.4s, v15.4s, v5.s[1]
+ FMLA v27.4s, v15.4s, v7.s[1]
+ FMLA v29.4s, v15.4s, v9.s[1]
+
+ FMLA v20.4s, v16.4s, v1.s[2]
+ FMLA v22.4s, v16.4s, v3.s[2]
+ FMLA v24.4s, v16.4s, v5.s[2]
+ FMLA v26.4s, v16.4s, v7.s[2]
+ FMLA v28.4s, v16.4s, v9.s[2]
+ FMLA v21.4s, v17.4s, v1.s[2]
+ FMLA v23.4s, v17.4s, v3.s[2]
+ FMLA v25.4s, v17.4s, v5.s[2]
+ FMLA v27.4s, v17.4s, v7.s[2]
+ FMLA v29.4s, v17.4s, v9.s[2]
+ TST x0, 31
+
+ FMLA v20.4s, v18.4s, v1.s[3]
+ FMLA v22.4s, v18.4s, v3.s[3]
+ FMLA v24.4s, v18.4s, v5.s[3]
+ FMLA v26.4s, v18.4s, v7.s[3]
+ FMLA v28.4s, v18.4s, v9.s[3]
+ FMLA v21.4s, v19.4s, v1.s[3]
+ FMLA v23.4s, v19.4s, v3.s[3]
+ FMLA v25.4s, v19.4s, v5.s[3]
+ FMLA v27.4s, v19.4s, v7.s[3]
+ FMLA v29.4s, v19.4s, v9.s[3]
+ B.NE 4f
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v30.4s
+ FMIN v21.4s, v21.4s, v30.4s
+ FMIN v22.4s, v22.4s, v30.4s
+ FMIN v23.4s, v23.4s, v30.4s
+ FMIN v24.4s, v24.4s, v30.4s
+ FMIN v25.4s, v25.4s, v30.4s
+ FMIN v26.4s, v26.4s, v30.4s
+ FMIN v27.4s, v27.4s, v30.4s
+ FMIN v28.4s, v28.4s, v30.4s
+ FMIN v29.4s, v29.4s, v30.4s
+ FMAX v20.4s, v20.4s, v31.4s
+ FMAX v21.4s, v21.4s, v31.4s
+ FMAX v22.4s, v22.4s, v31.4s
+ FMAX v23.4s, v23.4s, v31.4s
+ FMAX v24.4s, v24.4s, v31.4s
+ FMAX v25.4s, v25.4s, v31.4s
+ FMAX v26.4s, v26.4s, v31.4s
+ FMAX v27.4s, v27.4s, v31.4s
+ FMAX v28.4s, v28.4s, v31.4s
+ FMAX v29.4s, v29.4s, v31.4s
+
+ # Store full 5 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ $if INC:
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x7]
+ ADD x7, x7, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x13]
+ ADD x13, x13, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ $else:
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x13]
+ ADD x13, x13, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x7]
+ ADD x7, x7, x14
+ SUB x12, x12, x2 // a4 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 32]
+ LDP d12, d13, [sp, 16]
+ LDP d8, d9, [sp], 48
+ RET
+
+ # Load clamp values
+4:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q2, [x9], 16
+ LDR q4, [x10], 16
+ LDR q6, [x11], 16
+ LDR q8, [x12], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d2, [x9], 8
+ LDR d4, [x10], 8
+ LDR d6, [x11], 8
+ LDR d8, [x12], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s2, [x9], 4
+ LDR s4, [x10], 4
+ LDR s6, [x11], 4
+ LDR s8, [x12], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ $if INC:
+ STR q28, [x7], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x13], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ $else:
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x13], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x7], 16
+ MOV v28.16b, v29.16b
+8:
+ TBZ x1, 1, 9f
+ $if INC:
+ STR d28, [x7], 8
+ DUP d28, v28.d[1]
+ STR d26, [x13], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ $else:
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x13], 8
+ DUP d26, v26.d[1]
+ STR d28, [x7], 8
+ DUP d28, v28.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ $if INC:
+ STR s28, [x7]
+ STR s26, [x13]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+ $else:
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x13]
+ STR s28, [x7]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 32]
+ LDP d12, d13, [sp, 16]
+ LDP d8, d9, [sp], 48
+ RET
+
+END_FUNCTION f32_gemm${"inc" if INC else ""}_ukernel_5x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/5x8-neon-ld64.c b/src/f32-gemm/5x8-neon-ld64.c
new file mode 100644
index 0000000..4ca0a20
--- /dev/null
+++ b/src/f32-gemm/5x8-neon-ld64.c
@@ -0,0 +1,225 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_5x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 5);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc4x0123 = vacc0x0123;
+ float32x4_t vacc4x4567 = vacc0x4567;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+ const float32x4_t va4 = vld1q_dup_f32(a4); a4 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vmlaq_f32(vacc4x0123, va4, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vmlaq_f32(vacc4x4567, va4, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/5x8-neonfma-ld64.c b/src/f32-gemm/5x8-neonfma-ld64.c
new file mode 100644
index 0000000..515db66
--- /dev/null
+++ b/src/f32-gemm/5x8-neonfma-ld64.c
@@ -0,0 +1,261 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_5x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 5);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc4x0123 = vacc0x0123;
+ float32x4_t vacc4x4567 = vacc0x4567;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ const float32x4_t va4c0 = vdupq_lane_f32(va4, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ const float32x4_t va4c1 = vdupq_lane_f32(va4, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+ const float32x4_t va4 = vld1q_dup_f32(a4); a4 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a57.S b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a57.S
new file mode 100644
index 0000000..82f3b9d
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a57.S
@@ -0,0 +1,661 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-cortex-a57.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a57(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+# A57 kernel based on A75 but with PRFM removed from main loop
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a57
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 96 FMA
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+1:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ LDR q0, [x3], 16 // Load next 6 A
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ LDR q1, [x9], 16
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ LDR q2, [x10], 16
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+ LDR q3, [x11], 16
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ LDR q4, [x12], 16
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ LDR q5, [x4], 16
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 4f
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+8:
+ TBZ x1, 1, 9f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a57.S.in b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a57.S.in
new file mode 100644
index 0000000..307a148
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a57.S.in
@@ -0,0 +1,740 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_cortex_a57(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+# A57 kernel based on A75 but with PRFM removed from main loop
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_cortex_a57
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ $if INC:
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+ $else:
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 96 FMA
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+1:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ LDR q0, [x3], 16 // Load next 6 A
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ LDR q1, [x9], 16
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ LDR q2, [x10], 16
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+ LDR q3, [x11], 16
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ LDR q4, [x12], 16
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ LDR q5, [x4], 16
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 4f
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ $if INC:
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+ $else:
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ $if INC:
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ $else:
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+8:
+ TBZ x1, 1, 9f
+ $if INC:
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ $else:
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ $if INC:
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+ $else:
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a73.S b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a73.S
new file mode 100644
index 0000000..dc608f0
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a73.S
@@ -0,0 +1,662 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-cortex-a73.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a73(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a73
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+ .p2align 3
+0:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 96 FMA
+ # load A0 to A4 but not A5
+ LDP q0, q6, [x3], 32
+ LDP q1, q7, [x9], 32
+ LDP q2, q8, [x10], 32
+ LDP q3, q9, [x11], 32
+ LDP q4, q10, [x12], 32
+ # load first set of B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ .p2align 3
+1:
+ # First group of 4 A. 48 FMA. Loads A5
+
+ LDP q5, q11, [x4], 32
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Second group of 4 A. 48 FMA. Loads A0 - A4
+
+ LDP q16, q17, [x5], 32
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v20.4s, v18.4s, v6.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ LDP q0, q6, [x3], 32
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ LDP q1, q7, [x9], 32
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ LDP q2, q8, [x10], 32
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ LDP q3, q9, [x11], 32
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ LDP q4, q10, [x12], 32
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ SUBS x0, x0, 32
+ FMLA v31.4s, v17.4s, v11.s[2]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 48 FMA. Loads A5
+
+ LDP q5, q11, [x4], 32
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Second group of 4 A. 48 FMA. No A Loads, No last B load
+
+ LDP q16, q17, [x5], 32
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ # Last part of epilogue has loads removed.
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 4f
+
+ .p2align 3
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ NOP
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+ .p2align 3
+4:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+8:
+ TBZ x1, 1, 9f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a73
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a73.S.in b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a73.S.in
new file mode 100644
index 0000000..a2f714e
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a73.S.in
@@ -0,0 +1,741 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_cortex_a73(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_cortex_a73
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ $if INC:
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+ $else:
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+ .p2align 3
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 96 FMA
+ # load A0 to A4 but not A5
+ LDP q0, q6, [x3], 32
+ LDP q1, q7, [x9], 32
+ LDP q2, q8, [x10], 32
+ LDP q3, q9, [x11], 32
+ LDP q4, q10, [x12], 32
+ # load first set of B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ .p2align 3
+1:
+ # First group of 4 A. 48 FMA. Loads A5
+
+ LDP q5, q11, [x4], 32
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Second group of 4 A. 48 FMA. Loads A0 - A4
+
+ LDP q16, q17, [x5], 32
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v20.4s, v18.4s, v6.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ LDP q0, q6, [x3], 32
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ LDP q1, q7, [x9], 32
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ LDP q2, q8, [x10], 32
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ LDP q3, q9, [x11], 32
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ LDP q4, q10, [x12], 32
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ SUBS x0, x0, 32
+ FMLA v31.4s, v17.4s, v11.s[2]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 48 FMA. Loads A5
+
+ LDP q5, q11, [x4], 32
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Second group of 4 A. 48 FMA. No A Loads, No last B load
+
+ LDP q16, q17, [x5], 32
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ # Last part of epilogue has loads removed.
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 4f
+
+ .p2align 3
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ $if INC:
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+ $else:
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ NOP
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+ .p2align 3
+4:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ $if INC:
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ $else:
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+8:
+ TBZ x1, 1, 9f
+ $if INC:
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ $else:
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ $if INC:
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+ $else:
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_cortex_a73
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a75.S b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..f1e277d
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,663 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-cortex-a75.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 96 FMA
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+1:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ LDR q0, [x3], 16 // Load next 6 A
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ LDR q1, [x9], 16
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ LDR q2, [x10], 16
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+ LDR q3, [x11], 16
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ LDR q4, [x12], 16
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ LDR q5, [x4], 16
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 4f
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+8:
+ TBZ x1, 1, 9f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-cortex-a75.S.in b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a75.S.in
new file mode 100644
index 0000000..82cd1a3
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-cortex-a75.S.in
@@ -0,0 +1,742 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_cortex_a75
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ $if INC:
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+ $else:
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 96 FMA
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+1:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ LDR q0, [x3], 16 // Load next 6 A
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ LDR q1, [x9], 16
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ LDR q2, [x10], 16
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+ LDR q3, [x11], 16
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ LDR q4, [x12], 16
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ LDR q5, [x4], 16
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 4f
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ $if INC:
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+ $else:
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ $if INC:
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ $else:
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+8:
+ TBZ x1, 1, 9f
+ $if INC:
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ $else:
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ $if INC:
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+ $else:
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-ld128.S b/src/f32-gemm/6x8-aarch64-neonfma-ld128.S
new file mode 100644
index 0000000..c2234f1
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-ld128.S
@@ -0,0 +1,377 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0
+# A1 v1
+# A2 v2
+# A3 v3
+# A4 v4
+# A5 v5
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+# unused A v8 v9 v10 v11
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 4 floats (16 bytes)?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ LDR q0, [x3], 16
+ LDP q12, q13, [x5], 32
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ SUBS x0, x0, 16
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ B.HS 1b
+
+2:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 4f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 5f
+3:
+ # Clamp
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 6f
+
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Remainder- 2 floats of A (8 bytes)
+ LDR d0, [x3], 8
+ LDP q12, q13, [x5], 32
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ TBZ x0, 2, 3b
+
+5:
+ # Remainder- 1 floats of A (4 bytes)
+ LDR s0, [x3], 4
+ LDP q12, q13, [x5], 32
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+6:
+ TBZ x1, 2, 7f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+
+7:
+ TBZ x1, 1, 8f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+8:
+ TBZ x1, 0, 9f
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+9:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma__ld128
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in b/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in
new file mode 100644
index 0000000..ab70862
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in
@@ -0,0 +1,456 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ld128(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0
+# A1 v1
+# A2 v2
+# A3 v3
+# A4 v4
+# A5 v5
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+# unused A v8 v9 v10 v11
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ld128
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ $if INC:
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+ $else:
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 4 floats (16 bytes)?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ LDR q0, [x3], 16
+ LDP q12, q13, [x5], 32
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ SUBS x0, x0, 16
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ B.HS 1b
+
+2:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 4f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 5f
+3:
+ # Clamp
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 6f
+
+ $if INC:
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+ $else:
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Remainder- 2 floats of A (8 bytes)
+ LDR d0, [x3], 8
+ LDP q12, q13, [x5], 32
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ TBZ x0, 2, 3b
+
+5:
+ # Remainder- 1 floats of A (4 bytes)
+ LDR s0, [x3], 4
+ LDP q12, q13, [x5], 32
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+6:
+ TBZ x1, 2, 7f
+ $if INC:
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ $else:
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+
+7:
+ TBZ x1, 1, 8f
+ $if INC:
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ $else:
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+8:
+ TBZ x1, 0, 9f
+ $if INC:
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+ $else:
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+9:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma__ld128
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-ld64.S b/src/f32-gemm/6x8-aarch64-neonfma-ld64.S
new file mode 100644
index 0000000..cc413a9
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-ld64.S
@@ -0,0 +1,311 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0
+# A1 v1
+# A2 v2
+# A3 v3
+# A4 v4
+# A5 v5
+# B v14 v15 v16 v17
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+# unused A v8 v9 v10 v11
+# unused B v16 v17 v18 v19
+
+BEGIN_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 2 floats (8 bytes) for main loop?
+ SUBS x0, x2, 8 // k = kc - 8
+ B.LO 2f
+
+ # Main loop - 2 floats of A (8 bytes)
+ # 24 FMA + 6 LD64 A + 2 LDP B
+1:
+ LDR d0, [x3], 8
+ LDP q12, q13, [x5], 32
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ SUBS x0, x0, 8
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ B.HS 1b
+
+2:
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 4f
+3:
+ # Clamp
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 5f
+
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Remainder- 1 floats of A (4 bytes)
+ LDR s0, [x3], 4
+ LDP q12, q13, [x5], 32
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+5:
+ TBZ x1, 2, 6f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+
+6:
+ TBZ x1, 1, 7f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+7:
+ TBZ x1, 0, 8f
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+8:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemm_ukernel_6x8__aarch64_neonfma__ld64
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in b/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in
new file mode 100644
index 0000000..a40a4c5
--- /dev/null
+++ b/src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in
@@ -0,0 +1,390 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ld64(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+$if INC:
+ # const float*restrict acc, [sp + 8] -> x15
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+$else:
+ # const union xnn_f32_output_params params[restrict static 1]) [sp + 8] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0
+# A1 v1
+# A2 v2
+# A3 v3
+# A4 v4
+# A5 v5
+# B v14 v15 v16 v17
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+# unused A v8 v9 v10 v11
+# unused B v16 v17 v18 v19
+
+BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma_ld64
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ $if INC:
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+ $else:
+ # Load params pointer
+ LDR x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ $if INC:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+ $else:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v24.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x9]
+ MOV v28.16b, v20.16b
+ PRFM PLDL1KEEP, [x10]
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x11]
+ MOV v30.16b, v20.16b
+ PRFM PLDL1KEEP, [x12]
+ MOV v31.16b, v21.16b
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 2 floats (8 bytes) for main loop?
+ SUBS x0, x2, 8 // k = kc - 8
+ B.LO 2f
+
+ # Main loop - 2 floats of A (8 bytes)
+ # 24 FMA + 6 LD64 A + 2 LDP B
+1:
+ LDR d0, [x3], 8
+ LDP q12, q13, [x5], 32
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ SUBS x0, x0, 8
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ B.HS 1b
+
+2:
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 4f
+3:
+ # Clamp
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 5f
+
+ $if INC:
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+ $else:
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Remainder- 1 floats of A (4 bytes)
+ LDR s0, [x3], 4
+ LDP q12, q13, [x5], 32
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+5:
+ TBZ x1, 2, 6f
+ $if INC:
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ $else:
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+
+6:
+ TBZ x1, 1, 7f
+ $if INC:
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ $else:
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+
+7:
+ TBZ x1, 0, 8f
+ $if INC:
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+ $else:
+ STR s20, [x6]
+ STR s22, [x16]
+ STR s24, [x17]
+ STR s26, [x18]
+ STR s28, [x13]
+ STR s30, [x7]
+8:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_ukernel_6x8__aarch64_neonfma__ld64
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemm/6x8-neon-ld64.c b/src/f32-gemm/6x8-neon-ld64.c
new file mode 100644
index 0000000..b2322d8
--- /dev/null
+++ b/src/f32-gemm/6x8-neon-ld64.c
@@ -0,0 +1,255 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_6x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc4x0123 = vacc0x0123;
+ float32x4_t vacc4x4567 = vacc0x4567;
+ float32x4_t vacc5x0123 = vacc0x0123;
+ float32x4_t vacc5x4567 = vacc0x4567;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+ const float32x2_t va5 = vld1_f32(a5); a5 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123c0, va5, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567c0, va5, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123c1, va5, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567c1, va5, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+ const float32x4_t va4 = vld1q_dup_f32(a4); a4 += 1;
+ const float32x4_t va5 = vld1q_dup_f32(a5); a5 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vmlaq_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = vmlaq_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vmlaq_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = vmlaq_f32(vacc5x4567, va5, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc5x0123 = vminq_f32(vacc5x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+ vacc5x4567 = vminq_f32(vacc5x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc5x0123 = vmaxq_f32(vacc5x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+ vacc5x4567 = vmaxq_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c5, vacc5x0123);
+ vst1q_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c5, vacc5x0123); c5 += 4;
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc5x01 = vget_low_f32(vacc5x0123);
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c5, vacc5x01); c5 += 2;
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc5x01 = vget_high_f32(vacc5x0123);
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c5, vacc5x01, 0);
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/6x8-neonfma-ld64.c b/src/f32-gemm/6x8-neonfma-ld64.c
new file mode 100644
index 0000000..682dd07
--- /dev/null
+++ b/src/f32-gemm/6x8-neonfma-ld64.c
@@ -0,0 +1,297 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_6x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc4x0123 = vacc0x0123;
+ float32x4_t vacc4x4567 = vacc0x4567;
+ float32x4_t vacc5x0123 = vacc0x0123;
+ float32x4_t vacc5x4567 = vacc0x4567;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+ const float32x2_t va5 = vld1_f32(a5); a5 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123c0, va5, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567c0, va5, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ const float32x4_t va4c0 = vdupq_lane_f32(va4, 0);
+ const float32x4_t va5c0 = vdupq_lane_f32(va5, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c0, vb0123c0);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c0, vb4567c0);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123c1, va5, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567c1, va5, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ const float32x4_t va4c1 = vdupq_lane_f32(va4, 1);
+ const float32x4_t va5c1 = vdupq_lane_f32(va5, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c1, vb0123c1);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c1, vb4567c1);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+ const float32x4_t va4 = vld1q_dup_f32(a4); a4 += 1;
+ const float32x4_t va5 = vld1q_dup_f32(a5); a5 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc5x0123 = vminq_f32(vacc5x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+ vacc5x4567 = vminq_f32(vacc5x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc5x0123 = vmaxq_f32(vacc5x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+ vacc5x4567 = vmaxq_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c5, vacc5x0123);
+ vst1q_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c5, vacc5x0123); c5 += 4;
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc5x01 = vget_low_f32(vacc5x0123);
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c5, vacc5x01); c5 += 2;
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc5x01 = vget_high_f32(vacc5x0123);
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c5, vacc5x01, 0);
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/6x8-psimd-loadsplat.c b/src/f32-gemm/6x8-psimd-loadsplat.c
new file mode 100644
index 0000000..feba9f8
--- /dev/null
+++ b/src/f32-gemm/6x8-psimd-loadsplat.c
@@ -0,0 +1,234 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-loadsplat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_6x8__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ psimd_f32 vacc4x0123 = vacc0x0123;
+ psimd_f32 vacc4x4567 = vacc0x4567;
+ psimd_f32 vacc5x0123 = vacc0x0123;
+ psimd_f32 vacc5x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+ const psimd_f32 va4 = psimd_load_splat_f32(a4);
+ a4 += 1;
+ const psimd_f32 va5 = psimd_load_splat_f32(a5);
+ a5 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
+ vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+ vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
+ vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
+ vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+ vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
+ vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c5 += 4;
+ c4 += 4;
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c5, vacc5x0123);
+ psimd_store2_f32(c4, vacc4x0123);
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
+ vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c5 += 2;
+ c4 += 2;
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c5, vacc5x0123);
+ psimd_store1_f32(c4, vacc4x0123);
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/6x8-psimd-splat.c b/src/f32-gemm/6x8-psimd-splat.c
new file mode 100644
index 0000000..2aac13e
--- /dev/null
+++ b/src/f32-gemm/6x8-psimd-splat.c
@@ -0,0 +1,342 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-splat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_6x8__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ psimd_f32 vacc4x0123 = vacc0x0123;
+ psimd_f32 vacc4x4567 = vacc0x4567;
+ psimd_f32 vacc5x0123 = vacc0x0123;
+ psimd_f32 vacc5x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ const psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ const psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ const psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+ const psimd_f32 va4 = psimd_load_f32(a4);
+ a4 += 4;
+ const psimd_f32 va5 = psimd_load_f32(a5);
+ a5 += 4;
+
+ const psimd_f32 va0c0 = psimd_splat0_f32(va0);
+ const psimd_f32 va1c0 = psimd_splat0_f32(va1);
+ const psimd_f32 va2c0 = psimd_splat0_f32(va2);
+ const psimd_f32 va3c0 = psimd_splat0_f32(va3);
+ const psimd_f32 va4c0 = psimd_splat0_f32(va4);
+ const psimd_f32 va5c0 = psimd_splat0_f32(va5);
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c0, vb0123c0);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c0, vb4567c0);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c0, vb4567c0);
+ const psimd_f32 va0c1 = psimd_splat1_f32(va0);
+ const psimd_f32 va1c1 = psimd_splat1_f32(va1);
+ const psimd_f32 va2c1 = psimd_splat1_f32(va2);
+ const psimd_f32 va3c1 = psimd_splat1_f32(va3);
+ const psimd_f32 va4c1 = psimd_splat1_f32(va4);
+ const psimd_f32 va5c1 = psimd_splat1_f32(va5);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c1, vb0123c1);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c1, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c1, vb4567c1);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c1, vb4567c1);
+ const psimd_f32 va0c2 = psimd_splat2_f32(va0);
+ const psimd_f32 va1c2 = psimd_splat2_f32(va1);
+ const psimd_f32 va2c2 = psimd_splat2_f32(va2);
+ const psimd_f32 va3c2 = psimd_splat2_f32(va3);
+ const psimd_f32 va4c2 = psimd_splat2_f32(va4);
+ const psimd_f32 va5c2 = psimd_splat2_f32(va5);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c2, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c2, vb0123c2);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c2, vb0123c2);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c2, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c2, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c2, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c2, vb4567c2);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c2, vb4567c2);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c2, vb4567c2);
+ const psimd_f32 va0c3 = psimd_splat3_f32(va0);
+ const psimd_f32 va1c3 = psimd_splat3_f32(va1);
+ const psimd_f32 va2c3 = psimd_splat3_f32(va2);
+ const psimd_f32 va3c3 = psimd_splat3_f32(va3);
+ const psimd_f32 va4c3 = psimd_splat3_f32(va4);
+ const psimd_f32 va5c3 = psimd_splat3_f32(va5);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c3, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c3, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c3, vb0123c3);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c3, vb0123c3);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c3, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c3, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c3, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c3, vb4567c3);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c3, vb4567c3);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c3, vb4567c3);
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+ const psimd_f32 va4 = psimd_load_splat_f32(a4);
+ a4 += 1;
+ const psimd_f32 va5 = psimd_load_splat_f32(a5);
+ a5 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
+ vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+ vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
+ vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
+ vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+ vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
+ vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c5 += 4;
+ c4 += 4;
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c5, vacc5x0123);
+ psimd_store2_f32(c4, vacc4x0123);
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
+ vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c5 += 2;
+ c4 += 2;
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c5, vacc5x0123);
+ psimd_store1_f32(c4, vacc4x0123);
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/6x8s4-psimd.c b/src/f32-gemm/6x8s4-psimd.c
new file mode 100644
index 0000000..e8010ce
--- /dev/null
+++ b/src/f32-gemm/6x8s4-psimd.c
@@ -0,0 +1,340 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-s4.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm_ukernel_6x8s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ psimd_f32 vacc4x0123 = vacc0x0123;
+ psimd_f32 vacc4x4567 = vacc0x4567;
+ psimd_f32 vacc5x0123 = vacc0x0123;
+ psimd_f32 vacc5x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+ psimd_f32 va4 = psimd_load_f32(a4);
+ a4 += 4;
+ psimd_f32 va5 = psimd_load_f32(a5);
+ a5 += 4;
+
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c0);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c0);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c0);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c0);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c0);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+ va4 = __builtin_shufflevector(va4, va4, 1, 2, 3, 0);
+ va5 = __builtin_shufflevector(va5, va5, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c1);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c1);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c1);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c1);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c1);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+ va4 = __builtin_shufflevector(va4, va4, 1, 2, 3, 0);
+ va5 = __builtin_shufflevector(va5, va5, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c2);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c2);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c2);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c2);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c2);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+ va4 = __builtin_shufflevector(va4, va4, 1, 2, 3, 0);
+ va5 = __builtin_shufflevector(va5, va5, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c3);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c3);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c3);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c3);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c3);
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+ const psimd_f32 va4 = psimd_load_splat_f32(a4);
+ a4 += 1;
+ const psimd_f32 va5 = psimd_load_splat_f32(a5);
+ a5 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
+ vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+ vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
+ vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
+ vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+ vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
+ vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c5 += 4;
+ c4 += 4;
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c5, vacc5x0123);
+ psimd_store2_f32(c4, vacc4x0123);
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
+ vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c5 += 2;
+ c4 += 2;
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c5, vacc5x0123);
+ psimd_store1_f32(c4, vacc4x0123);
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/MRx2-neon-ld64.c.in b/src/f32-gemm/MRx2-neon-ld64.c.in
new file mode 100644
index 0000000..0b94074
--- /dev/null
+++ b/src/f32-gemm/MRx2-neon-ld64.c.in
@@ -0,0 +1,131 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR == 2
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(0, MR):
+ float32x2_t vacc${M}x01 = vld1_f32(w); w += 2;
+ $else:
+ float32x2_t vacc0x01 = vld1_f32(w); w += 2;
+ $for M in range(1, MR):
+ float32x2_t vacc${M}x01 = vacc0x01;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ $for M in range(MR):
+ const float32x2_t va${M} = vld1_f32(a${M}); a${M} += 2;
+
+ $for L in range(2):
+ const float32x2_t vb01c${L} = vld1_f32(w); w += 2;
+
+ $if FMA:
+ #if defined(__aarch64__)
+ $for M in range(MR):
+ vacc${M}x01 = vfma_lane_f32(vacc${M}x01, vb01c${L}, va${M}, ${L});
+ #else
+ $for M in range(MR):
+ const float32x2_t va${M}c${L} = vdup_lane_f32(va${M}, ${L});
+ $for M in range(MR):
+ vacc${M}x01 = vfma_f32(vacc${M}x01, va${M}c${L}, vb01c${L});
+ #endif
+ $else:
+ $for M in range(MR):
+ vacc${M}x01 = vmla_lane_f32(vacc${M}x01, vb01c${L}, va${M}, ${L});
+ }
+ if XNN_UNLIKELY(k != 0) {
+ $for M in range(MR):
+ const float32x2_t va${M} = vld1_dup_f32(a${M}); a${M} += 1;
+
+ const float32x2_t vb01 = vld1_f32(w); w += 2;
+
+ $for M in range(MR):
+ $if FMA:
+ vacc${M}x01 = vfma_f32(vacc${M}x01, va${M}, vb01);
+ $else:
+ vacc${M}x01 = vmla_f32(vacc${M}x01, va${M}, vb01);
+ }
+
+ const float32x2_t vmax = vld1_dup_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x01 = vmin_f32(vacc${M}x01, vmax);
+
+ const float32x2_t vmin = vld1_dup_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x01 = vmax_f32(vacc${M}x01, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in range(MR):
+ vst1_f32(c${M}, vacc${M}x01);
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in range(MR):
+ a${M} = (const float*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+ } else {
+ assert(nc == 1);
+ $for M in range(MR):
+ vst1_lane_f32(c${M}, vacc${M}x01, 0);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/neon-ld128.c.in b/src/f32-gemm/neon-ld128.c.in
new file mode 100644
index 0000000..b6cad3c
--- /dev/null
+++ b/src/f32-gemm/neon-ld128.c.in
@@ -0,0 +1,170 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_ld128(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ float32x4_t vacc${M}x${ABC[N:N+4]} = vld1q_f32(acc); acc += 4;
+ $else:
+ $for N in range(0, NR, 4):
+ float32x4_t vacc0x${ABC[N:N+4]} = vld1q_f32(w); w += 4;
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ float32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ $for M in range(MR):
+ const float32x4_t va${M} = vld1q_f32(a${M}); a${M} += 4;
+
+ $for L in range(4):
+ $VGET_PART_F32 = "vget_low_f32" if L < 2 else "vget_high_f32"
+
+ $for N in range(0, NR, 4):
+ const float32x4_t vb${ABC[N:N+4]}c${L} = vld1q_f32(w); w += 4;
+
+ $if FMA:
+ #if defined(__aarch64__)
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vfmaq_laneq_f32(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${L}, va${M}, ${L});
+ #else
+ $for M in range(MR):
+ const float32x4_t va${M}c${L} = vdupq_lane_f32(${VGET_PART_F32}(va${M}), ${L % 2});
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vfmaq_f32(vacc${M}x${ABC[N:N+4]}, va${M}c${L}, vb${ABC[N:N+4]}c${L});
+ #endif
+ $else:
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vmlaq_lane_f32(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${L}, ${VGET_PART_F32}(va${M}), ${L % 2});
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ $for M in range(MR):
+ const float32x4_t va${M} = vld1q_dup_f32(a${M}); a${M} += 1;
+
+ $for N in range(0, NR, 4):
+ const float32x4_t vb${ABC[N:N+4]} = vld1q_f32(w); w += 4;
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ $if FMA:
+ vacc${M}x${ABC[N:N+4]} = vfmaq_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+ $else:
+ vacc${M}x${ABC[N:N+4]} = vmlaq_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vminq_f32(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vmaxq_f32(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ vst1q_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ vst1q_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in reversed(range(MR)):
+ a${M} = (const float*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for N in range(0, 1 << LOG2N, 4):
+ $for M in reversed(range(MR)):
+ vst1q_f32(c${M}, vacc${M}x${ABC[N:N+4]}); c${M} += 4;
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ vst1_f32(c${M}, vacc${M}x${ABC[0:2]}); c${M} += 2;
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:2]} = vget_high_f32(vacc${M}x${ABC[0:4]});
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ vst1_lane_f32(c${M}, vacc${M}x${ABC[0:2]}, 0);
+ }
+ $if LOG2N == 2:
+ $for M in reversed(range(MR)):
+ float32x2_t vacc${M}x${ABC[0:2]} = vget_low_f32(vacc${M}x${ABC[0:4]});
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
\ No newline at end of file
diff --git a/src/f32-gemm/neon-ld64.c.in b/src/f32-gemm/neon-ld64.c.in
new file mode 100644
index 0000000..5a89619
--- /dev/null
+++ b/src/f32-gemm/neon-ld64.c.in
@@ -0,0 +1,164 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ float32x4_t vacc${M}x${ABC[N:N+4]} = vld1q_f32(acc); acc += 4;
+ $else:
+ $for N in range(0, NR, 4):
+ float32x4_t vacc0x${ABC[N:N+4]} = vld1q_f32(w); w += 4;
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ float32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ $for M in range(MR):
+ const float32x2_t va${M} = vld1_f32(a${M}); a${M} += 2;
+
+ $for L in range(2):
+ $for N in range(0, NR, 4):
+ const float32x4_t vb${ABC[N:N+4]}c${L} = vld1q_f32(w); w += 4;
+
+ $if FMA:
+ #if defined(__aarch64__)
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vfmaq_lane_f32(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${L}, va${M}, ${L});
+ #else
+ $for M in range(MR):
+ const float32x4_t va${M}c${L} = vdupq_lane_f32(va${M}, ${L});
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vfmaq_f32(vacc${M}x${ABC[N:N+4]}, va${M}c${L}, vb${ABC[N:N+4]}c${L});
+ #endif
+ $else:
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vmlaq_lane_f32(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${L}, va${M}, ${L});
+ }
+ if XNN_UNLIKELY(k != 0) {
+ $for M in range(MR):
+ const float32x4_t va${M} = vld1q_dup_f32(a${M}); a${M} += 1;
+
+ $for N in range(0, NR, 4):
+ const float32x4_t vb${ABC[N:N+4]} = vld1q_f32(w); w += 4;
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ $if FMA:
+ vacc${M}x${ABC[N:N+4]} = vfmaq_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+ $else:
+ vacc${M}x${ABC[N:N+4]} = vmlaq_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vminq_f32(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = vmaxq_f32(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ vst1q_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ vst1q_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in reversed(range(MR)):
+ a${M} = (const float*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for N in range(0, 1 << LOG2N, 4):
+ $for M in reversed(range(MR)):
+ vst1q_f32(c${M}, vacc${M}x${ABC[N:N+4]}); c${M} += 4;
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ vst1_f32(c${M}, vacc${M}x${ABC[0:2]}); c${M} += 2;
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:2]} = vget_high_f32(vacc${M}x${ABC[0:4]});
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ vst1_lane_f32(c${M}, vacc${M}x${ABC[0:2]}, 0);
+ }
+ $if LOG2N == 2:
+ $for M in reversed(range(MR)):
+ float32x2_t vacc${M}x${ABC[0:2]} = vget_low_f32(vacc${M}x${ABC[0:4]});
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/psimd-loadsplat.c.in b/src/f32-gemm/psimd-loadsplat.c.in
new file mode 100644
index 0000000..6ac9db2
--- /dev/null
+++ b/src/f32-gemm/psimd-loadsplat.c.in
@@ -0,0 +1,147 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${ABC[N:N+4]} = psimd_load_f32(acc + ${M*NR+N});
+ acc += ${MR*NR};
+ $else:
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc0x${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+ w += ${NR};
+
+ size_t k = kc;
+ do {
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_splat_f32(a${M});
+ a${M} += 1;
+
+ const psimd_f32 vb${ABC[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ w += ${NR};
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_min_f32(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_max_f32(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in reversed(range(MR)):
+ a${M} = (const float*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ psimd_store2_f32(c${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:4]} = psimd_concat_hi_f32(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ psimd_store1_f32(c${M}, vacc${M}x${ABC[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/psimd-s4.c.in b/src/f32-gemm/psimd-s4.c.in
new file mode 100644
index 0000000..e922ebb
--- /dev/null
+++ b/src/f32-gemm/psimd-s4.c.in
@@ -0,0 +1,170 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${ABC[N:N+4]} = psimd_load_f32(acc + ${M*NR+N});
+ acc += ${MR*NR};
+ $else:
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc0x${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+ w += ${NR};
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ $for M in range(MR):
+ psimd_f32 va${M} = psimd_load_f32(a${M});
+ a${M} += 4;
+
+ $for L in range(4):
+
+ $for N in range(0, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]}c${L} = psimd_load_f32(w + ${L * NR + N});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]}c${L});
+
+ $if L + 1 != 4:
+ $for M in range(MR):
+ va${M} = __builtin_shufflevector(va${M}, va${M}, 1, 2, 3, 0);
+
+ w += ${4 * NR};
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_splat_f32(a${M});
+ a${M} += 1;
+
+ const psimd_f32 vb${ABC[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ w += ${NR};
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_min_f32(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_max_f32(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in reversed(range(MR)):
+ a${M} = (const float*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ psimd_store2_f32(c${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:4]} = psimd_concat_hi_f32(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ psimd_store1_f32(c${M}, vacc${M}x${ABC[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/psimd-splat.c.in b/src/f32-gemm/psimd-splat.c.in
new file mode 100644
index 0000000..2653b27
--- /dev/null
+++ b/src/f32-gemm/psimd-splat.c.in
@@ -0,0 +1,168 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${ABC[N:N+4]} = psimd_load_f32(acc + ${M*NR+N});
+ acc += ${MR*NR};
+ $else:
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc0x${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+ w += ${NR};
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_f32(a${M});
+ a${M} += 4;
+
+ $for L in range(4):
+ $for M in range(MR):
+ const psimd_f32 va${M}c${L} = psimd_splat${L}_f32(va${M});
+
+ $for N in range(0, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]}c${L} = psimd_load_f32(w + ${L * NR + N});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}c${L}, vb${ABC[N:N+4]}c${L});
+
+ w += ${4 * NR};
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_splat_f32(a${M});
+ a${M} += 1;
+
+ const psimd_f32 vb${ABC[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ w += ${NR};
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_min_f32(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_max_f32(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in reversed(range(MR)):
+ a${M} = (const float*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ psimd_store2_f32(c${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:4]} = psimd_concat_hi_f32(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ psimd_store1_f32(c${M}, vacc${M}x${ABC[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/scalar.c.in b/src/f32-gemm/scalar.c.in
new file mode 100644
index 0000000..622f106
--- /dev/null
+++ b/src/f32-gemm/scalar.c.in
@@ -0,0 +1,123 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(MR):
+ $for N in range(NR):
+ float vacc${M}${N} = acc[${M*NR+N}];
+ acc += ${MR*NR};
+ $else:
+ $for N in range(NR):
+ float vacc0${N} = w[${N}];
+ w += ${NR};
+ $for M in range(1, MR):
+ $for N in range(NR):
+ float vacc${M}${N} = vacc0${N};
+
+ size_t k = kc;
+ do {
+ $for M in range(MR):
+ const float va${M} = *a${M}++;
+
+ $for N in range(NR):
+ const float vb${N} = w[${N}];
+ w += ${NR};
+
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}${N} += va${M} * vb${N};
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmin = params->scalar.min;
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}${N} = math_max_f32(vacc${M}${N}, vmin);
+
+ const float vmax = params->scalar.max;
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}${N} = math_min_f32(vacc${M}${N}, vmax);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ $for N in range(NR):
+ c${M}[${N}] = vacc${M}${N};
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in reversed(range(MR)):
+ a${M} = (const void*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length() - 1)):
+ if (nc & ${1 << LOG2N}) {
+ $for M in reversed(range(MR)):
+ $for N in range(1 << LOG2N):
+ c${M}[${N}] = vacc${M}${N};
+ $if LOG2N != 0:
+ $for N in range(1 << (LOG2N - 1)):
+ vacc${M}${N} = vacc${M}${N + (1 << LOG2N)};
+ c${M} += ${1 << LOG2N};
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/sse-dup.c.in b/src/f32-gemm/sse-dup.c.in
new file mode 100644
index 0000000..8e26ac1
--- /dev/null
+++ b/src/f32-gemm/sse-dup.c.in
@@ -0,0 +1,170 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__sse_dup(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${ABC[N:N+4]} = _mm_load_ps(acc + ${M*NR+N});
+ acc += ${MR*NR};
+ $else:
+ $for N in range(0, NR, 4):
+ __m128 vacc0x${ABC[N:N+4]} = _mm_load_ps(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+ w += ${NR};
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ $for M in range(MR):
+ const __m128 va${M} = _mm_loadu_ps(a${M});
+ a${M} += 4;
+
+ $for L in range(4):
+ $LLLL = str(L) * 4
+
+ $for M in range(MR):
+ const __m128 va${M}c${LLLL} = _mm_shuffle_ps(va${M}, va${M}, _MM_SHUFFLE(${L}, ${L}, ${L}, ${L}));
+
+ $for N in range(0, NR, 4):
+ const __m128 vb${ABC[N:N+4]}c${L} = _mm_load_ps(w + ${L * NR + N});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_add_ps(vacc${M}x${ABC[N:N+4]}, _mm_mul_ps(va${M}c${LLLL}, vb${ABC[N:N+4]}c${L}));
+
+ w += ${4 * NR};
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ $for M in range(MR):
+ const __m128 va${M} = _mm_load1_ps(a${M});
+ a${M} += 1;
+
+ const __m128 vb${ABC[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ const __m128 vb${ABC[N:N+4]} = _mm_load_ps(w + ${N});
+ w += ${NR};
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_add_ps(vacc${M}x${ABC[N:N+4]}, _mm_mul_ps(va${M}, vb${ABC[N:N+4]}));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_min_ps(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_max_ps(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in reversed(range(MR)):
+ a${M} = (const float*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ _mm_storel_pi((__m64*) c${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:4]} = _mm_movehl_ps(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ _mm_store_ss(c${M}, vacc${M}x${ABC[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/sse-load1.c.in b/src/f32-gemm/sse-load1.c.in
new file mode 100644
index 0000000..d8223ed
--- /dev/null
+++ b/src/f32-gemm/sse-load1.c.in
@@ -0,0 +1,147 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__sse_load1(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${ABC[N:N+4]} = _mm_load_ps(acc + ${M*NR+N});
+ acc += ${MR*NR};
+ $else:
+ $for N in range(0, NR, 4):
+ __m128 vacc0x${ABC[N:N+4]} = _mm_load_ps(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+ w += ${NR};
+
+ size_t k = kc;
+ do {
+ $for M in range(MR):
+ const __m128 va${M} = _mm_load1_ps(a${M});
+ a${M} += 1;
+
+ const __m128 vb${ABC[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ const __m128 vb${ABC[N:N+4]} = _mm_load_ps(w + ${N});
+ w += ${NR};
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_add_ps(vacc${M}x${ABC[N:N+4]}, _mm_mul_ps(va${M}, vb${ABC[N:N+4]}));
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_min_ps(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_max_ps(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in reversed(range(MR)):
+ a${M} = (const float*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ _mm_storel_pi((__m64*) c${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:4]} = _mm_movehl_ps(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ _mm_store_ss(c${M}, vacc${M}x${ABC[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemm/sse-shuffle.c.in b/src/f32-gemm/sse-shuffle.c.in
new file mode 100644
index 0000000..e882cd7
--- /dev/null
+++ b/src/f32-gemm/sse-shuffle.c.in
@@ -0,0 +1,170 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}s4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ $if INC:
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ $if INC:
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ $for M in range(1, MR):
+ const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ a${M} = a${M-1};
+ c${M} = c${M-1};
+ }
+
+ do {
+ $if INC:
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${ABC[N:N+4]} = _mm_load_ps(acc + ${M*NR+N});
+ acc += ${MR*NR};
+ $else:
+ $for N in range(0, NR, 4):
+ __m128 vacc0x${ABC[N:N+4]} = _mm_load_ps(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+ w += ${NR};
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ $for M in range(MR):
+ __m128 va${M} = _mm_loadu_ps(a${M});
+ a${M} += 4;
+
+ $for L in range(4):
+
+ $for N in range(0, NR, 4):
+ const __m128 vb${ABC[N:N+4]}c${L} = _mm_load_ps(w + ${L * NR + N});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_add_ps(vacc${M}x${ABC[N:N+4]}, _mm_mul_ps(va${M}, vb${ABC[N:N+4]}c${L}));
+
+ $if L + 1 != 4:
+ $for M in range(MR):
+ va${M} = _mm_shuffle_ps(va${M}, va${M}, _MM_SHUFFLE(0, 3, 2, 1));
+
+ w += ${4 * NR};
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ $for M in range(MR):
+ const __m128 va${M} = _mm_load1_ps(a${M});
+ a${M} += 1;
+
+ const __m128 vb${ABC[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ const __m128 vb${ABC[N:N+4]} = _mm_load_ps(w + ${N});
+ w += ${NR};
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_add_ps(vacc${M}x${ABC[N:N+4]}, _mm_mul_ps(va${M}, vb${ABC[N:N+4]}));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_min_ps(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = _mm_max_ps(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ $for M in reversed(range(MR)):
+ a${M} = (const float*) ((uintptr_t) a${M} - kc);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ _mm_storel_pi((__m64*) c${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:4]} = _mm_movehl_ps(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ _mm_store_ss(c${M}, vacc${M}x${ABC[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemminc/1x12-aarch64-neonfma-cortex-a53.S
new file mode 100644
index 0000000..873ead4
--- /dev/null
+++ b/src/f32-gemminc/1x12-aarch64-neonfma-cortex-a53.S
@@ -0,0 +1,353 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/1x12-aarch64-neonfma-cortex-a53.S.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, (x4) - unused
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+
+# C pointers
+# x6 c0
+
+# Vector register usage and GPR shadows
+# a0 v0 first set of A
+# a0 v1 second set of A
+# B v2 v3 v4 x7 x10 x16 first set of B
+# B v5 v6 v7 x17 x18 x9
+# B v23 v24 v25 x7 x10 x16 second set of B (same x as first set)
+# B v17 v18 v19 x17 x18 x9
+# C v20 v21 v22
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53
+
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+0:
+ # Load initial accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x15], 48
+
+ PRFM PLDL1KEEP, [x5]
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x5, 256]
+ PRFM PLDL1KEEP, [x5, 320]
+
+ # Is there at least 4 floats (16 bytes) for prologue + epilogue?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 3f
+
+ # Prologue - loads for first group of 6 fma
+
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+
+ LDR d2, [x5] // vb0x0123
+ LDR x7, [x5, 8]
+
+ LDR d3, [x5, 16] // vb0x4567
+ LDR x10, [x5, 24]
+
+ LDR d4, [x5, 32] // vb0x89AB
+ LDR x16, [x5, 40]
+
+ LDR d5, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+
+ LDR d6, [x5, 64] // vb1x4567
+ LDR x18, [x5, 72]
+
+ LDR d7, [x5, 80] // vb1x89AB
+ LDR x9, [x5, 88]
+ INS v2.d[1], x7
+ ADD x5, x5, 96
+
+ # Is there at least 4 floats (16 bytes) for main loop?
+ SUBS x0, x0, 16
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ # First group of 6 fma.
+ # A is loaded for 2nd group into v1
+
+ # BLOCK 0
+ LDR d1, [x3], 8 // a0
+ INS v3.d[1], x10
+ FMLA v20.4s, v2.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 192]
+
+ # BLOCK 1
+ INS v4.d[1], x16
+ FMLA v21.4s, v3.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+
+ # BLOCK 2
+ LDR d23, [x5] // vb0x0123
+ INS v5.d[1], x17
+ LDR x7, [x5, 8]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d24, [x5, 16] // vb0x4567
+ INS v6.d[1], x18
+ LDR x10, [x5, 24]
+
+ # BLOCK 4
+ LDR d25, [x5, 32] // vb0x89AB
+ INS v7.d[1], x9
+ FMLA v20.4s, v5.4s, v0.s[1]
+ LDR x16, [x5, 40]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v6.4s, v0.s[1]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ LDR x18, [x5, 72]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v23.d[1], x7 // v23 was loaded in block 2
+ LDR x9, [x5, 88]
+
+ # Second group of 6 fma.
+ # A is loaded for 1st group into v0
+
+ # BLOCK 0
+ LDR d0, [x3], 8 // a0
+ INS v24.d[1], x10
+ FMLA v20.4s, v23.4s, v1.s[0]
+
+ # BLOCK 1
+ INS v25.d[1], x16
+ FMLA v21.4s, v24.4s, v1.s[0]
+
+ # BLOCK 2
+ LDR d2, [x5, 96] // vb0x0123
+ INS v17.d[1], x17
+ LDR x7, [x5, 104]
+ FMLA v22.4s, v25.4s, v1.s[0]
+
+ # BLOCK 3
+ LDR d3, [x5, 112] // vb0x4567
+ INS v18.d[1], x18
+ LDR x10, [x5, 120]
+
+ # BLOCK 4
+ LDR d4, [x5, 128] // vb0x89AB
+ INS v19.d[1], x9
+ FMLA v20.4s, v17.4s, v1.s[1]
+ LDR x16, [x5, 136]
+
+ # BLOCK 5
+ LDR d5, [x5, 144] // vb1x0123
+ LDR x17, [x5, 152]
+ FMLA v21.4s, v18.4s, v1.s[1]
+
+ # BLOCK 6
+ LDR d6, [x5, 160] // vb1x4567
+ LDR x18, [x5, 168]
+ SUBS x0, x0, 16
+ FMLA v22.4s, v19.4s, v1.s[1]
+
+ # BLOCK 7
+ LDR d7, [x5, 176] // vb1x89AB
+ INS v2.d[1], x7
+ LDR x9, [x5, 184]
+ ADD x5, x5, 192
+ B.HS 1b
+
+ # Epilogue
+ # First block same as main loop. Second block has no loads.
+2:
+ # BLOCK 0
+ LDR d1, [x3], 8 // a0
+ INS v3.d[1], x10
+ FMLA v20.4s, v2.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 192]
+
+ # BLOCK 1
+ INS v4.d[1], x16
+ FMLA v21.4s, v3.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+
+ # BLOCK 2
+ LDR d23, [x5] // vb0x0123
+ INS v5.d[1], x17
+ LDR x7, [x5, 8]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d24, [x5, 16] // vb0x4567
+ INS v6.d[1], x18
+ LDR x10, [x5, 24]
+
+ # BLOCK 4
+ LDR d25, [x5, 32] // vb0x89AB
+ INS v7.d[1], x9
+ FMLA v20.4s, v5.4s, v0.s[1]
+ LDR x16, [x5, 40]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v6.4s, v0.s[1]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ LDR x18, [x5, 72]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v23.d[1], x7 // v23 was loaded in block 2
+ LDR x9, [x5, 88]
+ ADD x5, x5, 96
+
+ # Second group of 6 fma. 8 blocks of 4 cycles.
+ # Epilogue version does no loads
+
+ # BLOCK 0
+ INS v24.d[1], x10
+ FMLA v20.4s, v23.4s, v1.s[0]
+
+ # BLOCK 1
+ INS v25.d[1], x16
+ FMLA v21.4s, v24.4s, v1.s[0]
+
+ # BLOCK 2
+ INS v17.d[1], x17
+ FMLA v22.4s, v25.4s, v1.s[0]
+
+ # BLOCK 3
+ INS v18.d[1], x18
+
+ # BLOCK 4
+ INS v19.d[1], x9
+ FMLA v20.4s, v17.4s, v1.s[1]
+
+ # BLOCK 5
+ FMLA v21.4s, v18.4s, v1.s[1]
+
+ # BLOCK 6
+ FMLA v22.4s, v19.4s, v1.s[1]
+
+ # BLOCK 7
+3:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 5f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 6f
+
+4:
+ # Clamp
+ FMIN v20.4s, v20.4s, v30.4s
+ FMIN v21.4s, v21.4s, v30.4s
+ FMIN v22.4s, v22.4s, v30.4s
+ FMAX v20.4s, v20.4s, v31.4s
+ FMAX v21.4s, v21.4s, v31.4s
+ FMAX v22.4s, v22.4s, v31.4s
+
+ # Store full 1 x 12
+ CMP x1, 12
+ B.LO 7f
+
+ ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14
+
+ SUB x3, x3, x2 // a0 -= kc
+
+ SUBS x1, x1, 12
+ B.HI 0b
+
+ RET
+
+5:
+ # Remainder - 2 floats of A (8 bytes)
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+ LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
+ LD1 {v5.16b, v6.16b, v7.16b}, [x5], 48
+
+ # First block of 3 B
+ FMLA v20.4s, v2.4s, v0.s[0]
+ FMLA v21.4s, v3.4s, v0.s[0]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # Second block of 3 B
+ FMLA v20.4s, v5.4s, v0.s[1]
+ FMLA v21.4s, v6.4s, v0.s[1]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ TBZ x0, 2, 4b
+6:
+ # Remainder - 1 float of A (4 bytes)
+ LDR s0, [x3], 4 // a0
+ LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
+
+ FMLA v20.4s, v2.4s, v0.s[0]
+ FMLA v21.4s, v3.4s, v0.s[0]
+ FMLA v22.4s, v4.4s, v0.s[0]
+ B 4b
+
+7:
+ # Store odd channels
+ TBZ x1, 3, 8f
+ STP q20, q21, [x6]
+ ADD x6, x6, 32
+ MOV v20.16b, v22.16b
+
+8:
+ TBZ x1, 2, 9f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+9:
+ TBZ x1, 1, 10f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s20, [x6]
+11:
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/1x4-scalar.c b/src/f32-gemminc/1x4-scalar.c
new file mode 100644
index 0000000..16862f5
--- /dev/null
+++ b/src/f32-gemminc/1x4-scalar.c
@@ -0,0 +1,103 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gemminc_ukernel_1x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ float vacc00 = acc[0];
+ float vacc01 = acc[1];
+ float vacc02 = acc[2];
+ float vacc03 = acc[3];
+ acc += 4;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc02 += va0 * vb2;
+ vacc03 += va0 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc02 = math_max_f32(vacc02, vmin);
+ vacc03 = math_max_f32(vacc03, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc02 = math_min_f32(vacc02, vmax);
+ vacc03 = math_min_f32(vacc03, vmax);
+
+ if XNN_LIKELY(nc >= 4) {
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0[2] = vacc02;
+ c0[3] = vacc03;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const void*) ((uintptr_t) a0 - kc);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ vacc00 = vacc02;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x8-aarch64-neonfma-cortex-a57.S b/src/f32-gemminc/1x8-aarch64-neonfma-cortex-a57.S
new file mode 100644
index 0000000..22d5420
--- /dev/null
+++ b/src/f32-gemminc/1x8-aarch64-neonfma-cortex-a57.S
@@ -0,0 +1,222 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/1x8-aarch64-neonfma-cortex-a57.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_1x8__aarch64_neonfma_cortex_a57(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, (x4) - unused
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+
+# C pointers
+# x6 c0
+
+# Clamp v4 v5
+
+# A57 based on A75 but with PRFM removed from main loop
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x8__aarch64_neonfma_cortex_a57
+
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+0:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+
+ MOVI v18.4s, 0 // second set of C for pipelining FMLA
+ MOVI v19.4s, 0
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDP q20, q21, [x5], 32
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ LDR q0, [x3], 16
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. FMA for second 4, loads for 1st block of 4.
+ FMLA v16.4s, v20.4s, v1.s[0]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v21.4s, v1.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ SUBS x0, x0, 32
+ LDP q26, q27, [x5], 32
+ B.HS 1b
+
+2:
+ # Epilogue
+
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. no loads
+ FMLA v16.4s, v20.4s, v1.s[0]
+ FMLA v17.4s, v21.4s, v1.s[0]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+
+3:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBNZ x0, 4, 5f
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 8f
+
+4:
+ FADD v16.4s, v16.4s, v18.4s
+ FADD v17.4s, v17.4s, v19.4s
+
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+
+ # Store full 1 x 8
+ CMP x1, 8
+ B.LO 9f
+
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+5:
+ # Remainder- 4 floats of A (16 bytes)
+ LDP q20, q21, [x5], 32
+ LDR q0, [x3], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+
+ TBZ x0, 3, 7f
+6:
+ # Remainder- 2 floats of A (8 bytes)
+ LDP q20, q21, [x5], 32
+ LDR d0, [x3], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+7:
+ TBZ x0, 2, 4b
+8:
+ # Remainder- 1 float of A (4 bytes)
+ LDP q20, q21, [x5], 32
+ LDR s0, [x3], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ B 4b
+
+ # Store odd channels
+9:
+ TBZ x1, 2, 10f
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+10:
+ TBZ x1, 1, 11f
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+11:
+ TBZ x1, 0, 12f
+ STR s16, [x6]
+12:
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_1x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/1x8-aarch64-neonfma-cortex-a75.S b/src/f32-gemminc/1x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..47fbb84
--- /dev/null
+++ b/src/f32-gemminc/1x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,226 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_1x8__aarch64_neonfma_cortex_a75(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, (x4) - unused
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+
+# C pointers
+# x6 c0
+
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_1x8__aarch64_neonfma_cortex_a75
+
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+0:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+
+ MOVI v18.4s, 0 // second set of C for pipelining FMLA
+ MOVI v19.4s, 0
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDP q20, q21, [x5], 32
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ LDR q0, [x3], 16
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. FMA for second 4, loads for 1st block of 4.
+ FMLA v16.4s, v20.4s, v1.s[0]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v21.4s, v1.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ SUBS x0, x0, 32
+ LDP q26, q27, [x5], 32
+ B.HS 1b
+
+2:
+ # Epilogue
+
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x3], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. no loads
+ FMLA v16.4s, v20.4s, v1.s[0]
+ FMLA v17.4s, v21.4s, v1.s[0]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+
+3:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBNZ x0, 4, 5f
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 8f
+
+4:
+ FADD v16.4s, v16.4s, v18.4s
+ FADD v17.4s, v17.4s, v19.4s
+
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+
+ # Store full 1 x 8
+ CMP x1, 8
+ B.LO 9f
+
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+5:
+ # Remainder- 4 floats of A (16 bytes)
+ LDP q20, q21, [x5], 32
+ LDR q0, [x3], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+
+ TBZ x0, 3, 7f
+6:
+ # Remainder- 2 floats of A (8 bytes)
+ LDP q20, q21, [x5], 32
+ LDR d0, [x3], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+7:
+ TBZ x0, 2, 4b
+8:
+ # Remainder- 1 float of A (4 bytes)
+ LDP q20, q21, [x5], 32
+ LDR s0, [x3], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ B 4b
+
+ # Store odd channels
+9:
+ TBZ x1, 2, 10f
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+10:
+ TBZ x1, 1, 11f
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+11:
+ TBZ x1, 0, 12f
+ STR s16, [x6]
+12:
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_1x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/1x8-neon-ld64.c b/src/f32-gemminc/1x8-neon-ld64.c
new file mode 100644
index 0000000..a5d04dd
--- /dev/null
+++ b/src/f32-gemminc/1x8-neon-ld64.c
@@ -0,0 +1,107 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_1x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x8-neonfma-ld64.c b/src/f32-gemminc/1x8-neonfma-ld64.c
new file mode 100644
index 0000000..d67a419
--- /dev/null
+++ b/src/f32-gemminc/1x8-neonfma-ld64.c
@@ -0,0 +1,119 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_1x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x8-psimd-loadsplat.c b/src/f32-gemminc/1x8-psimd-loadsplat.c
new file mode 100644
index 0000000..27870c5
--- /dev/null
+++ b/src/f32-gemminc/1x8-psimd-loadsplat.c
@@ -0,0 +1,101 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-loadsplat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_1x8__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(acc + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(acc + 4);
+ acc += 8;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x8-psimd-splat.c b/src/f32-gemminc/1x8-psimd-splat.c
new file mode 100644
index 0000000..23a16f9
--- /dev/null
+++ b/src/f32-gemminc/1x8-psimd-splat.c
@@ -0,0 +1,139 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-splat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_1x8__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(acc + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(acc + 4);
+ acc += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+
+ const psimd_f32 va0c0 = psimd_splat0_f32(va0);
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c0, vb4567c0);
+ const psimd_f32 va0c1 = psimd_splat1_f32(va0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c1, vb4567c1);
+ const psimd_f32 va0c2 = psimd_splat2_f32(va0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c2, vb4567c2);
+ const psimd_f32 va0c3 = psimd_splat3_f32(va0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c3, vb4567c3);
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x8-sse-dup.c b/src/f32-gemminc/1x8-sse-dup.c
new file mode 100644
index 0000000..5175c11
--- /dev/null
+++ b/src/f32-gemminc/1x8-sse-dup.c
@@ -0,0 +1,143 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-dup.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_1x8__sse_dup(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(acc + 0);
+ __m128 vacc0x4567 = _mm_load_ps(acc + 4);
+ acc += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+
+
+ const __m128 va0c0000 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 0, 0, 0));
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c0000, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c0000, vb4567c0));
+
+ const __m128 va0c1111 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(1, 1, 1, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c1111, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c1111, vb4567c1));
+
+ const __m128 va0c2222 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(2, 2, 2, 2));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c2222, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c2222, vb4567c2));
+
+ const __m128 va0c3333 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(3, 3, 3, 3));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c3333, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c3333, vb4567c3));
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x8-sse-load1.c b/src/f32-gemminc/1x8-sse-load1.c
new file mode 100644
index 0000000..0575946
--- /dev/null
+++ b/src/f32-gemminc/1x8-sse-load1.c
@@ -0,0 +1,101 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-load1.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_1x8__sse_load1(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(acc + 0);
+ __m128 vacc0x4567 = _mm_load_ps(acc + 4);
+ acc += 8;
+
+ size_t k = kc;
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x8-sse.c b/src/f32-gemminc/1x8-sse.c
new file mode 100644
index 0000000..0bde2f9
--- /dev/null
+++ b/src/f32-gemminc/1x8-sse.c
@@ -0,0 +1,95 @@
+/*
+ * Auto-generated file. Do not edit!
+ * Template: src/f32-gemm/sse.c.in
+ * Generator: tools/xngen
+ */
+
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_1x8__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float* restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(acc + 0);
+ __m128 vacc0x4567 = _mm_load_ps(acc + 4);
+ acc += 8;
+
+ size_t k = kc;
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x8s4-psimd.c b/src/f32-gemminc/1x8s4-psimd.c
new file mode 100644
index 0000000..851643f
--- /dev/null
+++ b/src/f32-gemminc/1x8s4-psimd.c
@@ -0,0 +1,142 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-s4.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_1x8s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(acc + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(acc + 4);
+ acc += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c0);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c1);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c2);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c3);
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/1x8s4-sse.c b/src/f32-gemminc/1x8s4-sse.c
new file mode 100644
index 0000000..80bc536
--- /dev/null
+++ b/src/f32-gemminc/1x8s4-sse.c
@@ -0,0 +1,142 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-shuffle.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_1x8s4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(acc + 0);
+ __m128 vacc0x4567 = _mm_load_ps(acc + 4);
+ acc += 8;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c0));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c1));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c2));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c3));
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/2x4-scalar.c b/src/f32-gemminc/2x4-scalar.c
new file mode 100644
index 0000000..6506865
--- /dev/null
+++ b/src/f32-gemminc/2x4-scalar.c
@@ -0,0 +1,137 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gemminc_ukernel_2x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+
+ do {
+ float vacc00 = acc[0];
+ float vacc01 = acc[1];
+ float vacc02 = acc[2];
+ float vacc03 = acc[3];
+ float vacc10 = acc[4];
+ float vacc11 = acc[5];
+ float vacc12 = acc[6];
+ float vacc13 = acc[7];
+ acc += 8;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+ const float va1 = *a1++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc02 += va0 * vb2;
+ vacc03 += va0 * vb3;
+ vacc10 += va1 * vb0;
+ vacc11 += va1 * vb1;
+ vacc12 += va1 * vb2;
+ vacc13 += va1 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc02 = math_max_f32(vacc02, vmin);
+ vacc03 = math_max_f32(vacc03, vmin);
+ vacc10 = math_max_f32(vacc10, vmin);
+ vacc11 = math_max_f32(vacc11, vmin);
+ vacc12 = math_max_f32(vacc12, vmin);
+ vacc13 = math_max_f32(vacc13, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc02 = math_min_f32(vacc02, vmax);
+ vacc03 = math_min_f32(vacc03, vmax);
+ vacc10 = math_min_f32(vacc10, vmax);
+ vacc11 = math_min_f32(vacc11, vmax);
+ vacc12 = math_min_f32(vacc12, vmax);
+ vacc13 = math_min_f32(vacc13, vmax);
+
+ if XNN_LIKELY(nc >= 4) {
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ c1[2] = vacc12;
+ c1[3] = vacc13;
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0[2] = vacc02;
+ c0[3] = vacc03;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a1 = (const void*) ((uintptr_t) a1 - kc);
+ a0 = (const void*) ((uintptr_t) a0 - kc);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ vacc10 = vacc12;
+ c1 += 2;
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ vacc00 = vacc02;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ c1[0] = vacc10;
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x12-aarch64-neonfma-cortex-a53.S b/src/f32-gemminc/4x12-aarch64-neonfma-cortex-a53.S
new file mode 100644
index 0000000..1c4a964
--- /dev/null
+++ b/src/f32-gemminc/4x12-aarch64-neonfma-cortex-a53.S
@@ -0,0 +1,591 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x12-aarch64-neonfma-cortex-a53.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+# Vector register usage and GPR shadows
+# a0 v0 first set of A
+# a1 v0[1] x13
+# a2 v1
+# a3 v1[1] x8
+# a0 v2 second set of A
+# a1 v2[1] x13
+# a2 v3
+# a3 v3[1] x8
+# B v6 v7 v8 x20 x21 x16 first set of B
+# B v9 v10 v11 x17 x18 x19
+# B v14 v15 v16 x20 x21 x16 second set of B (same x as first set)
+# B v17 v18 v19 x17 x18 x19
+# C v20 v21 v22
+# C v23 v24 v25
+# C v26 v27 v28
+# C v29 v30 v31
+# Clamp v4 v5
+# v12 to v13 unused.
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53
+
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Save x19-21 on stack
+ STR x21, [sp, -80]!
+ STP x19, x20, [sp, 16]
+
+ # Save d8-d11,d14,d15 on stack
+ STP d8, d9, [sp, 32]
+ STP d10, d11, [sp, 48]
+ STP d14, d15, [sp, 64]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x15], 48
+ LD1 {v23.16b, v24.16b, v25.16b}, [x15], 48
+ LD1 {v26.16b, v27.16b, v28.16b}, [x15], 48
+ LD1 {v29.16b, v30.16b, v31.16b}, [x15], 48
+
+ PRFM PLDL1KEEP, [x5]
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x5, 256]
+ PRFM PLDL1KEEP, [x5, 320]
+
+ # Is there at least 4 floats (16 bytes)?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 3f
+
+ SUBS x0, x0, 16
+
+ # Prologue - loads for first group of 24 FMA
+
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+ LDR x13, [x11], 8 // a1
+ LDR d1, [x12], 8 // a2
+ LDR x8, [x4], 8 // a3
+
+ LDR d6, [x5] // vb0x0123
+ LDR x20, [x5, 8]
+
+ LDR d7, [x5, 16] // vb0x4567
+ LDR x21, [x5, 24]
+
+ LDR d8, [x5, 32] // vb0x89AB
+ LDR x16, [x5, 40]
+
+ LDR d9, [x5, 48] // vb1x0123
+ INS v0.d[1], x13
+ LDR x17, [x5, 56]
+
+ LDR d10, [x5, 64] // vb1x4567
+ INS v1.d[1], x8
+ LDR x18, [x5, 72]
+
+ LDR d11, [x5, 80] // vb1x89AB
+ LDR x19, [x5, 88]
+ INS v6.d[1], x20
+ ADD x5, x5, 96
+
+ # Is there at least 4 floats (16 bytes) for main loop?
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ # First group of 24 fma. 8 blocks of 4 cycles. LDR + 3 FMA
+ # A is loaded for 2nd group into v2/v3
+ # INS is 4 blocks (16 cycles) after load
+
+ # BLOCK 0
+ LDR d2, [x3], 8 // a0
+ INS v7.d[1], x21
+ FMLA v20.4s, v6.4s, v0.s[0]
+ LDR x13, [x11], 8 // a1
+ FMLA v23.4s, v6.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v26.4s, v6.4s, v1.s[0]
+
+ # BLOCK 1
+ LDR d3, [x12], 8 // a2
+ INS v8.d[1], x16
+ FMLA v29.4s, v6.4s, v1.s[2]
+ LDR x8, [x4], 8 // a3
+ FMLA v21.4s, v7.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v24.4s, v7.4s, v0.s[2]
+
+ # BLOCK 2
+ LDR d14, [x5] // vb0x0123
+ INS v9.d[1], x17
+ FMLA v27.4s, v7.4s, v1.s[0]
+ LDR x20, [x5, 8]
+ FMLA v30.4s, v7.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v8.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d15, [x5, 16] // vb0x4567
+ INS v10.d[1], x18
+ FMLA v25.4s, v8.4s, v0.s[2]
+ LDR x21, [x5, 24]
+ FMLA v28.4s, v8.4s, v1.s[0]
+ FMLA v31.4s, v8.4s, v1.s[2]
+
+ # BLOCK 4
+ LDR d16, [x5, 32] // vb0x89AB
+ INS v11.d[1], x19
+ FMLA v20.4s, v9.4s, v0.s[1]
+ LDR x16, [x5, 40]
+ FMLA v23.4s, v9.4s, v0.s[3]
+ FMLA v26.4s, v9.4s, v1.s[1]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ INS v2.d[1], x13 // a1 was loaded in block 0
+ FMLA v29.4s, v9.4s, v1.s[3]
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v10.4s, v0.s[1]
+ FMLA v24.4s, v10.4s, v0.s[3]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ INS v3.d[1], x8 // a3 was loaded in block 1
+ FMLA v27.4s, v10.4s, v1.s[1]
+ LDR x18, [x5, 72]
+ FMLA v30.4s, v10.4s, v1.s[3]
+ FMLA v22.4s, v11.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v14.d[1], x20 // v14 was loaded in block 2
+ FMLA v25.4s, v11.4s, v0.s[3]
+ LDR x19, [x5, 88]
+ FMLA v28.4s, v11.4s, v1.s[1]
+ FMLA v31.4s, v11.4s, v1.s[3]
+
+ # Second group of 24 fma. 8 blocks of 4 cycles. LDR + 3 FMA
+ # A is loaded for 1st group into v0/v1
+
+ # BLOCK 0
+ LDR d0, [x3], 8 // a0
+ INS v15.d[1], x21
+ FMLA v20.4s, v14.4s, v2.s[0]
+ LDR x13, [x11], 8 // a1
+ FMLA v23.4s, v14.4s, v2.s[2]
+ FMLA v26.4s, v14.4s, v3.s[0]
+
+ # BLOCK 1
+ LDR d1, [x12], 8 // a2
+ INS v16.d[1], x16
+ FMLA v29.4s, v14.4s, v3.s[2]
+ LDR x8, [x4], 8 // a3
+ FMLA v21.4s, v15.4s, v2.s[0]
+ FMLA v24.4s, v15.4s, v2.s[2]
+
+ # BLOCK 2
+ LDR d6, [x5, 96] // vb0x0123
+ INS v17.d[1], x17
+ FMLA v27.4s, v15.4s, v3.s[0]
+ LDR x20, [x5, 104]
+ FMLA v30.4s, v15.4s, v3.s[2]
+ FMLA v22.4s, v16.4s, v2.s[0]
+
+ # BLOCK 3
+ LDR d7, [x5, 112] // vb0x4567
+ INS v18.d[1], x18
+ FMLA v25.4s, v16.4s, v2.s[2]
+ LDR x21, [x5, 120]
+ FMLA v28.4s, v16.4s, v3.s[0]
+ FMLA v31.4s, v16.4s, v3.s[2]
+
+ # BLOCK 4
+ LDR d8, [x5, 128] // vb0x89AB
+ INS v19.d[1], x19
+ FMLA v20.4s, v17.4s, v2.s[1]
+ LDR x16, [x5, 136]
+ FMLA v23.4s, v17.4s, v2.s[3]
+ FMLA v26.4s, v17.4s, v3.s[1]
+
+ # BLOCK 5
+ LDR d9, [x5, 144] // vb1x0123
+ INS v0.d[1], x13 // a1
+ FMLA v29.4s, v17.4s, v3.s[3]
+ LDR x17, [x5, 152]
+ FMLA v21.4s, v18.4s, v2.s[1]
+ FMLA v24.4s, v18.4s, v2.s[3]
+
+ # BLOCK 6
+ LDR d10, [x5, 160] // vb1x4567
+ INS v1.d[1], x8 // a3
+ FMLA v27.4s, v18.4s, v3.s[1]
+ LDR x18, [x5, 168]
+ FMLA v30.4s, v18.4s, v3.s[3]
+ SUBS x0, x0, 16
+ FMLA v22.4s, v19.4s, v2.s[1]
+
+ # BLOCK 7
+ LDR d11, [x5, 176] // vb1x89AB
+ INS v6.d[1], x20
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDR x19, [x5, 184]
+ FMLA v28.4s, v19.4s, v3.s[1]
+ ADD x5, x5, 192
+ FMLA v31.4s, v19.4s, v3.s[3]
+ B.HS 1b
+
+ # Epilogue
+ # First block same as main loop. Second block has no loads.
+2:
+ # BLOCK 0
+ LDR d2, [x3], 8 // a0
+ INS v7.d[1], x21
+ FMLA v20.4s, v6.4s, v0.s[0]
+ LDR x13, [x11], 8 // a1
+ FMLA v23.4s, v6.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v26.4s, v6.4s, v1.s[0]
+
+ # BLOCK 1
+ LDR d3, [x12], 8 // a2
+ INS v8.d[1], x16
+ FMLA v29.4s, v6.4s, v1.s[2]
+ LDR x8, [x4], 8 // a3
+ FMLA v21.4s, v7.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v24.4s, v7.4s, v0.s[2]
+
+ # BLOCK 2
+ LDR d14, [x5] // vb0x0123
+ INS v9.d[1], x17
+ FMLA v27.4s, v7.4s, v1.s[0]
+ LDR x20, [x5, 8]
+ FMLA v30.4s, v7.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v8.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d15, [x5, 16] // vb0x4567
+ INS v10.d[1], x18
+ FMLA v25.4s, v8.4s, v0.s[2]
+ LDR x21, [x5, 24]
+ FMLA v28.4s, v8.4s, v1.s[0]
+ FMLA v31.4s, v8.4s, v1.s[2]
+
+ # BLOCK 4
+ LDR d16, [x5, 32] // vb0x89AB
+ INS v11.d[1], x19
+ FMLA v20.4s, v9.4s, v0.s[1]
+ LDR x16, [x5, 40]
+ FMLA v23.4s, v9.4s, v0.s[3]
+ FMLA v26.4s, v9.4s, v1.s[1]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ INS v2.d[1], x13 // a1 was loaded in block 0
+ FMLA v29.4s, v9.4s, v1.s[3]
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v10.4s, v0.s[1]
+ FMLA v24.4s, v10.4s, v0.s[3]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x4567
+ INS v3.d[1], x8 // a3 was loaded in block 1
+ FMLA v27.4s, v10.4s, v1.s[1]
+ LDR x18, [x5, 72]
+ FMLA v30.4s, v10.4s, v1.s[3]
+ FMLA v22.4s, v11.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v14.d[1], x20 // v14 was loaded in block 2
+ FMLA v25.4s, v11.4s, v0.s[3]
+ LDR x19, [x5, 88]
+ FMLA v28.4s, v11.4s, v1.s[1]
+ ADD x5, x5, 96
+ FMLA v31.4s, v11.4s, v1.s[3]
+
+ # Second group of 24 fma. 8 blocks of 4 cycles.
+ # Epilogue version does no loads
+
+ # BLOCK 0
+ INS v15.d[1], x21
+ FMLA v20.4s, v14.4s, v2.s[0]
+ FMLA v23.4s, v14.4s, v2.s[2]
+ FMLA v26.4s, v14.4s, v3.s[0]
+
+ # BLOCK 1
+ INS v16.d[1], x16
+ FMLA v29.4s, v14.4s, v3.s[2]
+ FMLA v21.4s, v15.4s, v2.s[0]
+ FMLA v24.4s, v15.4s, v2.s[2]
+
+ # BLOCK 2
+ INS v17.d[1], x17
+ FMLA v27.4s, v15.4s, v3.s[0]
+ FMLA v30.4s, v15.4s, v3.s[2]
+ FMLA v22.4s, v16.4s, v2.s[0]
+
+ # BLOCK 3
+ INS v18.d[1], x18
+ FMLA v25.4s, v16.4s, v2.s[2]
+ FMLA v28.4s, v16.4s, v3.s[0]
+ FMLA v31.4s, v16.4s, v3.s[2]
+
+ # BLOCK 4
+ INS v19.d[1], x19
+ FMLA v20.4s, v17.4s, v2.s[1]
+ FMLA v23.4s, v17.4s, v2.s[3]
+ FMLA v26.4s, v17.4s, v3.s[1]
+
+ # BLOCK 5
+ FMLA v29.4s, v17.4s, v3.s[3]
+ FMLA v21.4s, v18.4s, v2.s[1]
+ FMLA v24.4s, v18.4s, v2.s[3]
+
+ # BLOCK 6
+ FMLA v27.4s, v18.4s, v3.s[1]
+ FMLA v30.4s, v18.4s, v3.s[3]
+ FMLA v22.4s, v19.4s, v2.s[1]
+
+ # BLOCK 7
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v28.4s, v19.4s, v3.s[1]
+ FMLA v31.4s, v19.4s, v3.s[3]
+
+3:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 5f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 6f
+
+4:
+ # Clamp
+ FMIN v20.4s, v20.4s, v4.4s
+ FMIN v21.4s, v21.4s, v4.4s
+ FMIN v22.4s, v22.4s, v4.4s
+ FMIN v23.4s, v23.4s, v4.4s
+ FMIN v24.4s, v24.4s, v4.4s
+ FMIN v25.4s, v25.4s, v4.4s
+ FMIN v26.4s, v26.4s, v4.4s
+ FMIN v27.4s, v27.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v20.4s, v20.4s, v5.4s
+ FMAX v21.4s, v21.4s, v5.4s
+ FMAX v22.4s, v22.4s, v5.4s
+ FMAX v23.4s, v23.4s, v5.4s
+ FMAX v24.4s, v24.4s, v5.4s
+ FMAX v25.4s, v25.4s, v5.4s
+ FMAX v26.4s, v26.4s, v5.4s
+ FMAX v27.4s, v27.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 12
+ CMP x1, 12
+ B.LO 7f
+
+ ST1 {v29.16b, v30.16b, v31.16b}, [x7], x14
+ ST1 {v26.16b, v27.16b, v28.16b}, [x10], x14
+ ST1 {v23.16b, v24.16b, v25.16b}, [x9], x14
+ ST1 {v20.16b, v21.16b, v22.16b}, [x6], x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 12
+ B.HI 0b
+
+ # Restore d8-d11,d14,d15 from stack
+ LDP d14, d15, [sp, 64]
+ LDP d10, d11, [sp, 48]
+ LDP d8, d9, [sp, 32]
+
+ # Restore x19-21 from stack
+ LDP x19, x20, [sp, 16]
+ LDR x21, [sp], 80
+ RET
+
+5:
+ # Remainder - 2 floats of A (8 bytes)
+ # Read first block of 4 A.
+ LDR d0, [x3], 8 // a0
+ LDR d1, [x11], 8 // a1
+ LDR d2, [x12], 8 // a2
+ LDR d3, [x4], 8 // a3
+ LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48
+ LD1 {v9.16b, v10.16b, v11.16b}, [x5], 48
+
+ # First block of 3 B
+ FMLA v20.4s, v6.4s, v0.s[0]
+ FMLA v23.4s, v6.4s, v1.s[0]
+ FMLA v26.4s, v6.4s, v2.s[0]
+ FMLA v29.4s, v6.4s, v3.s[0]
+ FMLA v21.4s, v7.4s, v0.s[0]
+ FMLA v24.4s, v7.4s, v1.s[0]
+ FMLA v27.4s, v7.4s, v2.s[0]
+ FMLA v30.4s, v7.4s, v3.s[0]
+ FMLA v22.4s, v8.4s, v0.s[0]
+ FMLA v25.4s, v8.4s, v1.s[0]
+ FMLA v28.4s, v8.4s, v2.s[0]
+ FMLA v31.4s, v8.4s, v3.s[0]
+
+ # Second block of 3 B
+ FMLA v20.4s, v9.4s, v0.s[1]
+ FMLA v23.4s, v9.4s, v1.s[1]
+ FMLA v26.4s, v9.4s, v2.s[1]
+ FMLA v29.4s, v9.4s, v3.s[1]
+ FMLA v21.4s, v10.4s, v0.s[1]
+ FMLA v24.4s, v10.4s, v1.s[1]
+ FMLA v27.4s, v10.4s, v2.s[1]
+ FMLA v30.4s, v10.4s, v3.s[1]
+ FMLA v22.4s, v11.4s, v0.s[1]
+ FMLA v25.4s, v11.4s, v1.s[1]
+ FMLA v28.4s, v11.4s, v2.s[1]
+ FMLA v31.4s, v11.4s, v3.s[1]
+
+ TBZ x0, 2, 4b
+6:
+ # Remainder - 1 float of A (4 bytes)
+ LDR s0, [x3], 4 // a0
+ LDR s1, [x11], 4 // a1
+ LDR s2, [x12], 4 // a2
+ LDR s3, [x4], 4 // a3
+ LD1 {v6.16b, v7.16b, v8.16b}, [x5], 48
+
+ FMLA v20.4s, v6.4s, v0.s[0]
+ FMLA v23.4s, v6.4s, v1.s[0]
+ FMLA v26.4s, v6.4s, v2.s[0]
+ FMLA v29.4s, v6.4s, v3.s[0]
+ FMLA v21.4s, v7.4s, v0.s[0]
+ FMLA v24.4s, v7.4s, v1.s[0]
+ FMLA v27.4s, v7.4s, v2.s[0]
+ FMLA v30.4s, v7.4s, v3.s[0]
+ FMLA v22.4s, v8.4s, v0.s[0]
+ FMLA v25.4s, v8.4s, v1.s[0]
+ FMLA v28.4s, v8.4s, v2.s[0]
+ FMLA v31.4s, v8.4s, v3.s[0]
+ B 4b
+
+7:
+ # Store odd channels
+ TBZ x1, 3, 8f
+ STP q29, q30, [x7]
+ ADD x7, x7, 32
+ MOV v29.16b, v31.16b
+ STP q26, q27, [x10]
+ ADD x10, x10, 32
+ MOV v26.16b, v28.16b
+ STP q23, q24, [x9]
+ ADD x9, x9, 32
+ MOV v23.16b, v25.16b
+ STP q20, q21, [x6]
+ ADD x6, x6, 32
+ MOV v20.16b, v22.16b
+
+8:
+ TBZ x1, 2, 9f
+ STR q29, [x7], 16
+ MOV v29.16b, v30.16b
+ STR q26, [x10], 16
+ MOV v26.16b, v27.16b
+ STR q23, [x9], 16
+ MOV v23.16b, v24.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+9:
+ TBZ x1, 1, 10f
+ STR d29, [x7], 8
+ DUP d29, v29.d[1]
+ STR d26, [x10], 8
+ DUP d26, v26.d[1]
+ STR d23, [x9], 8
+ DUP d23, v23.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s29, [x7]
+ STR s26, [x10]
+ STR s23, [x9]
+ STR s20, [x6]
+11:
+ # Restore d8-d11,d14,d15 from stack
+ LDP d14, d15, [sp, 64]
+ LDP d10, d11, [sp, 48]
+ LDP d8, d9, [sp, 32]
+
+ # Restore x19-21 from stack
+ LDP x19, x20, [sp, 16]
+ LDR x21, [sp], 80
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/4x12-neon-ld64.c b/src/f32-gemminc/4x12-neon-ld64.c
new file mode 100644
index 0000000..0773dba
--- /dev/null
+++ b/src/f32-gemminc/4x12-neon-ld64.c
@@ -0,0 +1,243 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x12__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x89AB = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x89AB = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x89AB = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x89AB = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc0x89AB = vmlaq_lane_f32(vacc0x89AB, vb89ABc0, va0, 0);
+ vacc1x89AB = vmlaq_lane_f32(vacc1x89AB, vb89ABc0, va1, 0);
+ vacc2x89AB = vmlaq_lane_f32(vacc2x89AB, vb89ABc0, va2, 0);
+ vacc3x89AB = vmlaq_lane_f32(vacc3x89AB, vb89ABc0, va3, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc0x89AB = vmlaq_lane_f32(vacc0x89AB, vb89ABc1, va0, 1);
+ vacc1x89AB = vmlaq_lane_f32(vacc1x89AB, vb89ABc1, va1, 1);
+ vacc2x89AB = vmlaq_lane_f32(vacc2x89AB, vb89ABc1, va2, 1);
+ vacc3x89AB = vmlaq_lane_f32(vacc3x89AB, vb89ABc1, va3, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89AB = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ vacc0x89AB = vmlaq_f32(vacc0x89AB, va0, vb89AB);
+ vacc1x89AB = vmlaq_f32(vacc1x89AB, va1, vb89AB);
+ vacc2x89AB = vmlaq_f32(vacc2x89AB, va2, vb89AB);
+ vacc3x89AB = vmlaq_f32(vacc3x89AB, va3, vb89AB);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc0x89AB = vminq_f32(vacc0x89AB, vmax);
+ vacc1x89AB = vminq_f32(vacc1x89AB, vmax);
+ vacc2x89AB = vminq_f32(vacc2x89AB, vmax);
+ vacc3x89AB = vminq_f32(vacc3x89AB, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc0x89AB = vmaxq_f32(vacc0x89AB, vmin);
+ vacc1x89AB = vmaxq_f32(vacc1x89AB, vmin);
+ vacc2x89AB = vmaxq_f32(vacc2x89AB, vmin);
+ vacc3x89AB = vmaxq_f32(vacc3x89AB, vmin);
+
+ if XNN_LIKELY(nc >= 12) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ vst1q_f32(c3 + 8, vacc3x89AB);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ vst1q_f32(c2 + 8, vacc2x89AB);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ vst1q_f32(c1 + 8, vacc1x89AB);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ vst1q_f32(c0 + 8, vacc0x89AB);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 12;
+
+ } else {
+ if (nc & 8) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+ vst1q_f32(c3, vacc3x4567); c3 += 4;
+ vst1q_f32(c2, vacc2x4567); c2 += 4;
+ vst1q_f32(c1, vacc1x4567); c1 += 4;
+ vst1q_f32(c0, vacc0x4567); c0 += 4;
+
+ vacc3x0123 = vacc3x89AB;
+ vacc2x0123 = vacc2x89AB;
+ vacc1x0123 = vacc1x89AB;
+ vacc0x0123 = vacc0x89AB;
+ }
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x12-neonfma-ld64.c b/src/f32-gemminc/4x12-neonfma-ld64.c
new file mode 100644
index 0000000..7300418
--- /dev/null
+++ b/src/f32-gemminc/4x12-neonfma-ld64.c
@@ -0,0 +1,281 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x12__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x89AB = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x89AB = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x89AB = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x89AB = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc0x89AB = vfmaq_lane_f32(vacc0x89AB, vb89ABc0, va0, 0);
+ vacc1x89AB = vfmaq_lane_f32(vacc1x89AB, vb89ABc0, va1, 0);
+ vacc2x89AB = vfmaq_lane_f32(vacc2x89AB, vb89ABc0, va2, 0);
+ vacc3x89AB = vfmaq_lane_f32(vacc3x89AB, vb89ABc0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc0x89AB = vfmaq_f32(vacc0x89AB, va0c0, vb89ABc0);
+ vacc1x89AB = vfmaq_f32(vacc1x89AB, va1c0, vb89ABc0);
+ vacc2x89AB = vfmaq_f32(vacc2x89AB, va2c0, vb89ABc0);
+ vacc3x89AB = vfmaq_f32(vacc3x89AB, va3c0, vb89ABc0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc0x89AB = vfmaq_lane_f32(vacc0x89AB, vb89ABc1, va0, 1);
+ vacc1x89AB = vfmaq_lane_f32(vacc1x89AB, vb89ABc1, va1, 1);
+ vacc2x89AB = vfmaq_lane_f32(vacc2x89AB, vb89ABc1, va2, 1);
+ vacc3x89AB = vfmaq_lane_f32(vacc3x89AB, vb89ABc1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc0x89AB = vfmaq_f32(vacc0x89AB, va0c1, vb89ABc1);
+ vacc1x89AB = vfmaq_f32(vacc1x89AB, va1c1, vb89ABc1);
+ vacc2x89AB = vfmaq_f32(vacc2x89AB, va2c1, vb89ABc1);
+ vacc3x89AB = vfmaq_f32(vacc3x89AB, va3c1, vb89ABc1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89AB = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ vacc0x89AB = vfmaq_f32(vacc0x89AB, va0, vb89AB);
+ vacc1x89AB = vfmaq_f32(vacc1x89AB, va1, vb89AB);
+ vacc2x89AB = vfmaq_f32(vacc2x89AB, va2, vb89AB);
+ vacc3x89AB = vfmaq_f32(vacc3x89AB, va3, vb89AB);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc0x89AB = vminq_f32(vacc0x89AB, vmax);
+ vacc1x89AB = vminq_f32(vacc1x89AB, vmax);
+ vacc2x89AB = vminq_f32(vacc2x89AB, vmax);
+ vacc3x89AB = vminq_f32(vacc3x89AB, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc0x89AB = vmaxq_f32(vacc0x89AB, vmin);
+ vacc1x89AB = vmaxq_f32(vacc1x89AB, vmin);
+ vacc2x89AB = vmaxq_f32(vacc2x89AB, vmin);
+ vacc3x89AB = vmaxq_f32(vacc3x89AB, vmin);
+
+ if XNN_LIKELY(nc >= 12) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ vst1q_f32(c3 + 8, vacc3x89AB);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ vst1q_f32(c2 + 8, vacc2x89AB);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ vst1q_f32(c1 + 8, vacc1x89AB);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ vst1q_f32(c0 + 8, vacc0x89AB);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 12;
+
+ } else {
+ if (nc & 8) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+ vst1q_f32(c3, vacc3x4567); c3 += 4;
+ vst1q_f32(c2, vacc2x4567); c2 += 4;
+ vst1q_f32(c1, vacc1x4567); c1 += 4;
+ vst1q_f32(c0, vacc0x4567); c0 += 4;
+
+ vacc3x0123 = vacc3x89AB;
+ vacc2x0123 = vacc2x89AB;
+ vacc1x0123 = vacc1x89AB;
+ vacc0x0123 = vacc0x89AB;
+ }
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x4-scalar.c b/src/f32-gemminc/4x4-scalar.c
new file mode 100644
index 0000000..ae48344
--- /dev/null
+++ b/src/f32-gemminc/4x4-scalar.c
@@ -0,0 +1,205 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_gemminc_ukernel_4x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float vacc00 = acc[0];
+ float vacc01 = acc[1];
+ float vacc02 = acc[2];
+ float vacc03 = acc[3];
+ float vacc10 = acc[4];
+ float vacc11 = acc[5];
+ float vacc12 = acc[6];
+ float vacc13 = acc[7];
+ float vacc20 = acc[8];
+ float vacc21 = acc[9];
+ float vacc22 = acc[10];
+ float vacc23 = acc[11];
+ float vacc30 = acc[12];
+ float vacc31 = acc[13];
+ float vacc32 = acc[14];
+ float vacc33 = acc[15];
+ acc += 16;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+ const float va1 = *a1++;
+ const float va2 = *a2++;
+ const float va3 = *a3++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc02 += va0 * vb2;
+ vacc03 += va0 * vb3;
+ vacc10 += va1 * vb0;
+ vacc11 += va1 * vb1;
+ vacc12 += va1 * vb2;
+ vacc13 += va1 * vb3;
+ vacc20 += va2 * vb0;
+ vacc21 += va2 * vb1;
+ vacc22 += va2 * vb2;
+ vacc23 += va2 * vb3;
+ vacc30 += va3 * vb0;
+ vacc31 += va3 * vb1;
+ vacc32 += va3 * vb2;
+ vacc33 += va3 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc02 = math_max_f32(vacc02, vmin);
+ vacc03 = math_max_f32(vacc03, vmin);
+ vacc10 = math_max_f32(vacc10, vmin);
+ vacc11 = math_max_f32(vacc11, vmin);
+ vacc12 = math_max_f32(vacc12, vmin);
+ vacc13 = math_max_f32(vacc13, vmin);
+ vacc20 = math_max_f32(vacc20, vmin);
+ vacc21 = math_max_f32(vacc21, vmin);
+ vacc22 = math_max_f32(vacc22, vmin);
+ vacc23 = math_max_f32(vacc23, vmin);
+ vacc30 = math_max_f32(vacc30, vmin);
+ vacc31 = math_max_f32(vacc31, vmin);
+ vacc32 = math_max_f32(vacc32, vmin);
+ vacc33 = math_max_f32(vacc33, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc02 = math_min_f32(vacc02, vmax);
+ vacc03 = math_min_f32(vacc03, vmax);
+ vacc10 = math_min_f32(vacc10, vmax);
+ vacc11 = math_min_f32(vacc11, vmax);
+ vacc12 = math_min_f32(vacc12, vmax);
+ vacc13 = math_min_f32(vacc13, vmax);
+ vacc20 = math_min_f32(vacc20, vmax);
+ vacc21 = math_min_f32(vacc21, vmax);
+ vacc22 = math_min_f32(vacc22, vmax);
+ vacc23 = math_min_f32(vacc23, vmax);
+ vacc30 = math_min_f32(vacc30, vmax);
+ vacc31 = math_min_f32(vacc31, vmax);
+ vacc32 = math_min_f32(vacc32, vmax);
+ vacc33 = math_min_f32(vacc33, vmax);
+
+ if XNN_LIKELY(nc >= 4) {
+ c3[0] = vacc30;
+ c3[1] = vacc31;
+ c3[2] = vacc32;
+ c3[3] = vacc33;
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ c2[0] = vacc20;
+ c2[1] = vacc21;
+ c2[2] = vacc22;
+ c2[3] = vacc23;
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ c1[2] = vacc12;
+ c1[3] = vacc13;
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0[2] = vacc02;
+ c0[3] = vacc03;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const void*) ((uintptr_t) a3 - kc);
+ a2 = (const void*) ((uintptr_t) a2 - kc);
+ a1 = (const void*) ((uintptr_t) a1 - kc);
+ a0 = (const void*) ((uintptr_t) a0 - kc);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c3[0] = vacc30;
+ c3[1] = vacc31;
+ vacc30 = vacc32;
+ c3 += 2;
+ c2[0] = vacc20;
+ c2[1] = vacc21;
+ vacc20 = vacc22;
+ c2 += 2;
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ vacc10 = vacc12;
+ c1 += 2;
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ vacc00 = vacc02;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ c3[0] = vacc30;
+ c2[0] = vacc20;
+ c1[0] = vacc10;
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8-aarch64-neonfma-cortex-a57.S b/src/f32-gemminc/4x8-aarch64-neonfma-cortex-a57.S
new file mode 100644
index 0000000..45022cc
--- /dev/null
+++ b/src/f32-gemminc/4x8-aarch64-neonfma-cortex-a57.S
@@ -0,0 +1,476 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x8-aarch64-neonfma-cortex-a57.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a57(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+# Vector register usage
+# A0 v0 v4
+# A1 v1 v5
+# A2 v2 v6
+# A3 v3 v7
+# B v8 v9 v10 v11
+# B v12 v13 v14 v15
+# B v20 v21 v22 v23
+# B v24 v25 v26 v27
+# C v16 v17
+# C v18 v19
+# C v28 v29
+# C v30 v31
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a57
+
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+
+ # Load params values
+ LD1R {v4.4s}, [x8], 4
+ LD1R {v5.4s}, [x8]
+ SUB x8, x8, 4
+
+ # Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ STP d10, d11, [sp, 16]
+ STP d12, d13, [sp, 32]
+ STP d14, d15, [sp, 48]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ LDP q18, q19, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDP q12, q13, [x5], 32
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDP q14, q15, [x5], 32
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, loads for 1nd block of 4.
+ FMLA v16.4s, v8.4s, v4.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+ FMLA v16.4s, v10.4s, v4.s[1]
+ LDP q24, q25, [x5], 32
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ LDP q26, q27, [x5], 32
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+ FMLA v16.4s, v12.4s, v4.s[2]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ LDR q1, [x11], 16
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ LDR q2, [x12], 16
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ LDR q3, [x4], 16
+ FMLA v31.4s, v13.4s, v7.s[2]
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUBS x0, x0, 32
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+ B.HS 1b
+
+2:
+ # Epilogue
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDP q12, q13, [x5], 32
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDP q14, q15, [x5], 32
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, noloads
+ FMLA v16.4s, v8.4s, v4.s[0]
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+
+ FMLA v16.4s, v10.4s, v4.s[1]
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ LD1R {v4.4s}, [x8], 4
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ LD1R {v5.4s}, [x8]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUB x8, x8, 4
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+
+3:
+ # Remainder- 4 floats of A (16 bytes)
+ TBZ x0, 4, 4f
+
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+4:
+ # Remainder- 2 floats of A (8 bytes)
+ TBZ x0, 3, 5f
+
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+5:
+ # Remainder- 1 float of A (4 bytes)
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/4x8-aarch64-neonfma-cortex-a75.S b/src/f32-gemminc/4x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..2ac6fa8
--- /dev/null
+++ b/src/f32-gemminc/4x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,471 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x8-aarch64-neonfma-cortex-a75.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+# Vector register usage
+# A0 v0 v4
+# A1 v1 v5
+# A2 v2 v6
+# A3 v3 v7
+# B v8 v9 v10 v11
+# B v12 v13 v14 v15
+# B v20 v21 v22 v23
+# B v24 v25 v26 v27
+# C v16 v17
+# C v18 v19
+# C v28 v29
+# C v30 v31
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a75
+
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ STP d10, d11, [sp, 16]
+ STP d12, d13, [sp, 32]
+ STP d14, d15, [sp, 48]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ LDP q18, q19, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 3f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+1:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q12, q13, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, loads for 1nd block of 4.
+ FMLA v16.4s, v8.4s, v4.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v31.4s, v9.4s, v7.s[0]
+ FMLA v16.4s, v10.4s, v4.s[1]
+ LDR q0, [x3], 16
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ LDR q1, [x11], 16
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ LDR q2, [x12], 16
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ LDR q3, [x4], 16
+ FMLA v31.4s, v11.4s, v7.s[1]
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUBS x0, x0, 32
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+ B.HS 1b
+
+2:
+ # Epilogue
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q12, q13, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDR q4, [x3], 16
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ LDR q5, [x11], 16
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDR q6, [x12], 16
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ LDR q7, [x4], 16
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, noloads
+ FMLA v16.4s, v8.4s, v4.s[0]
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+
+ FMLA v16.4s, v10.4s, v4.s[1]
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+
+3:
+ # Remainder- 4 floats of A (16 bytes)
+ TBZ x0, 4, 4f
+
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+4:
+ # Remainder- 2 floats of A (8 bytes)
+ TBZ x0, 3, 5f
+
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+5:
+ # Remainder- 1 float of A (4 bytes)
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/4x8-aarch64-neonfma-ld128.S b/src/f32-gemminc/4x8-aarch64-neonfma-ld128.S
new file mode 100644
index 0000000..6c72770
--- /dev/null
+++ b/src/f32-gemminc/4x8-aarch64-neonfma-ld128.S
@@ -0,0 +1,249 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x8-aarch64-neonfma-ld128.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128
+
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ LDP q18, q19, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+
+ # Is there at least 4 floats (16 bytes)?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ LDR q0, [x3], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x11], 16
+ LDR q2, [x12], 16
+ LDR q3, [x4], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q26, q27, [x5], 32
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ SUBS x0, x0, 16
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+ B.HS 1b
+
+ # Remainder- 2 floats of A (8 bytes)
+2:
+ TBZ x0, 3, 3f
+
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+ # Remainder- 1 float of A (4 bytes)
+3:
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3, [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/4x8-aarch64-neonfma-ld64.S b/src/f32-gemminc/4x8-aarch64-neonfma-ld64.S
new file mode 100644
index 0000000..0a4f6d9
--- /dev/null
+++ b/src/f32-gemminc/4x8-aarch64-neonfma-ld64.S
@@ -0,0 +1,203 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/4x8-aarch64-neonfma-ld64.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x11 a1
+# x12 a2
+# x4 a3 / a_stride
+
+# C pointers
+# x6 c0
+# x9 c1
+# x10 c2
+# x7 c3 / cm_stride
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64
+
+ # Load cn_stride, acc
+ LDP x14, x15, [sp]
+ # Load params pointer
+ LDR x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Clamp A and C pointers
+ ADD x11, x3, x4 // a1 = a0 + a_stride
+ ADD x9, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x11, x3, x11, LO // a1 = a0
+ CSEL x9, x6, x9, LO // c1 = c0
+
+ ADD x12, x11, x4 // a2 = a1 + a_stride
+ ADD x10, x9, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x12, x11, x12, LS // a2 = a1
+ CSEL x10, x9, x10, LS // c2 = c1
+
+ ADD x4, x12, x4 // a3 = a2 + a_stride
+ ADD x7, x10, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x4, x12, x4, LO // a3 = a2
+ CSEL x7, x10, x7, LO // c3 = c2
+
+0:
+ # Load initial accumulators
+ LDP q16, q17, [x15], 32
+ LDP q18, q19, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+
+ # Is there at least 2 floats (8 bytes)?
+ SUBS x0, x2, 8 // k = kc - 8
+ B.LO 2f
+
+ # Main loop - 2 floats of A (8 bytes)
+
+1:
+ LDR d0, [x3], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x11], 8
+ LDR d2, [x12], 8
+ LDR d3, [x4], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ SUBS x0, x0, 8
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ B.HS 1b
+2:
+ # Remainder- 1 floats of A (4 bytes)
+ TBZ x0, 2, 6f
+
+ LDR s0, [x3], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x11], 4
+ LDR s2, [x12], 4
+ LDR s3 , [x4], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+6:
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ STP q28, q29, [x10]
+ ADD x10, x10, x14
+ STP q18, q19, [x9]
+ ADD x9, x9, x14
+ STP q16, q17, [x6]
+ ADD x6, x6, x14
+
+ SUB x3, x3, x2 // a0 -= kc
+ SUB x11, x11, x2 // a1 -= kc
+ SUB x12, x12, x2 // a2 -= kc
+ SUB x4, x4, x2 // a3 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ RET
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x10], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x9], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x10], 8
+ DUP d28, v28.d[1]
+ STR d18, [x9], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x10]
+ STR s18, [x9]
+ STR s16, [x6]
+10:
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/4x8-neon-ld128.c b/src/f32-gemminc/4x8-neon-ld128.c
new file mode 100644
index 0000000..07d7562
--- /dev/null
+++ b/src/f32-gemminc/4x8-neon-ld128.c
@@ -0,0 +1,227 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld128.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8__neon_ld128(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ const float32x4_t va0 = vld1q_f32(a0); a0 += 4;
+ const float32x4_t va1 = vld1q_f32(a1); a1 += 4;
+ const float32x4_t va2 = vld1q_f32(a2); a2 += 4;
+ const float32x4_t va3 = vld1q_f32(a3); a3 += 4;
+
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, vget_low_f32(va0), 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, vget_low_f32(va1), 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, vget_low_f32(va2), 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, vget_low_f32(va3), 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, vget_low_f32(va0), 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, vget_low_f32(va1), 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, vget_low_f32(va2), 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, vget_low_f32(va3), 0);
+
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, vget_low_f32(va0), 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, vget_low_f32(va1), 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, vget_low_f32(va2), 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, vget_low_f32(va3), 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, vget_low_f32(va0), 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, vget_low_f32(va1), 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, vget_low_f32(va2), 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, vget_low_f32(va3), 1);
+
+ const float32x4_t vb0123c2 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c2 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c2, vget_high_f32(va0), 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c2, vget_high_f32(va1), 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c2, vget_high_f32(va2), 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c2, vget_high_f32(va3), 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c2, vget_high_f32(va0), 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c2, vget_high_f32(va1), 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c2, vget_high_f32(va2), 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c2, vget_high_f32(va3), 0);
+
+ const float32x4_t vb0123c3 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c3 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c3, vget_high_f32(va0), 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c3, vget_high_f32(va1), 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c3, vget_high_f32(va2), 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c3, vget_high_f32(va3), 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c3, vget_high_f32(va0), 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c3, vget_high_f32(va1), 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c3, vget_high_f32(va2), 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c3, vget_high_f32(va3), 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8-neon-ld64.c b/src/f32-gemminc/4x8-neon-ld64.c
new file mode 100644
index 0000000..f7677c3
--- /dev/null
+++ b/src/f32-gemminc/4x8-neon-ld64.c
@@ -0,0 +1,197 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8-neonfma-ld128.c b/src/f32-gemminc/4x8-neonfma-ld128.c
new file mode 100644
index 0000000..49c074c
--- /dev/null
+++ b/src/f32-gemminc/4x8-neonfma-ld128.c
@@ -0,0 +1,287 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld128.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8__neonfma_ld128(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ const float32x4_t va0 = vld1q_f32(a0); a0 += 4;
+ const float32x4_t va1 = vld1q_f32(a1); a1 += 4;
+ const float32x4_t va2 = vld1q_f32(a2); a2 += 4;
+ const float32x4_t va3 = vld1q_f32(a3); a3 += 4;
+
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(vget_low_f32(va0), 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(vget_low_f32(va1), 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(vget_low_f32(va2), 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(vget_low_f32(va3), 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ #endif
+
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(vget_low_f32(va0), 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(vget_low_f32(va1), 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(vget_low_f32(va2), 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(vget_low_f32(va3), 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ #endif
+
+ const float32x4_t vb0123c2 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c2 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c2, va0, 2);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c2, va1, 2);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c2, va2, 2);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c2, va3, 2);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c2, va0, 2);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c2, va1, 2);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c2, va2, 2);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c2, va3, 2);
+ #else
+ const float32x4_t va0c2 = vdupq_lane_f32(vget_high_f32(va0), 0);
+ const float32x4_t va1c2 = vdupq_lane_f32(vget_high_f32(va1), 0);
+ const float32x4_t va2c2 = vdupq_lane_f32(vget_high_f32(va2), 0);
+ const float32x4_t va3c2 = vdupq_lane_f32(vget_high_f32(va3), 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c2, vb0123c2);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c2, vb0123c2);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c2, vb0123c2);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c2, vb4567c2);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c2, vb4567c2);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c2, vb4567c2);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c2, vb4567c2);
+ #endif
+
+ const float32x4_t vb0123c3 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c3 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c3, va0, 3);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c3, va1, 3);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c3, va2, 3);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c3, va3, 3);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c3, va0, 3);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c3, va1, 3);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c3, va2, 3);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c3, va3, 3);
+ #else
+ const float32x4_t va0c3 = vdupq_lane_f32(vget_high_f32(va0), 1);
+ const float32x4_t va1c3 = vdupq_lane_f32(vget_high_f32(va1), 1);
+ const float32x4_t va2c3 = vdupq_lane_f32(vget_high_f32(va2), 1);
+ const float32x4_t va3c3 = vdupq_lane_f32(vget_high_f32(va3), 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c3, vb0123c3);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c3, vb0123c3);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c3, vb0123c3);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c3, vb4567c3);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c3, vb4567c3);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c3, vb4567c3);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c3, vb4567c3);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8-neonfma-ld64.c b/src/f32-gemminc/4x8-neonfma-ld64.c
new file mode 100644
index 0000000..f0eefcf
--- /dev/null
+++ b/src/f32-gemminc/4x8-neonfma-ld64.c
@@ -0,0 +1,227 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8-psimd-loadsplat.c b/src/f32-gemminc/4x8-psimd-loadsplat.c
new file mode 100644
index 0000000..9e678ad
--- /dev/null
+++ b/src/f32-gemminc/4x8-psimd-loadsplat.c
@@ -0,0 +1,182 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-loadsplat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(acc + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(acc + 4);
+ psimd_f32 vacc1x0123 = psimd_load_f32(acc + 8);
+ psimd_f32 vacc1x4567 = psimd_load_f32(acc + 12);
+ psimd_f32 vacc2x0123 = psimd_load_f32(acc + 16);
+ psimd_f32 vacc2x4567 = psimd_load_f32(acc + 20);
+ psimd_f32 vacc3x0123 = psimd_load_f32(acc + 24);
+ psimd_f32 vacc3x4567 = psimd_load_f32(acc + 28);
+ acc += 32;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8-psimd-splat.c b/src/f32-gemminc/4x8-psimd-splat.c
new file mode 100644
index 0000000..48127c9
--- /dev/null
+++ b/src/f32-gemminc/4x8-psimd-splat.c
@@ -0,0 +1,262 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-splat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(acc + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(acc + 4);
+ psimd_f32 vacc1x0123 = psimd_load_f32(acc + 8);
+ psimd_f32 vacc1x4567 = psimd_load_f32(acc + 12);
+ psimd_f32 vacc2x0123 = psimd_load_f32(acc + 16);
+ psimd_f32 vacc2x4567 = psimd_load_f32(acc + 20);
+ psimd_f32 vacc3x0123 = psimd_load_f32(acc + 24);
+ psimd_f32 vacc3x4567 = psimd_load_f32(acc + 28);
+ acc += 32;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ const psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ const psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ const psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+
+ const psimd_f32 va0c0 = psimd_splat0_f32(va0);
+ const psimd_f32 va1c0 = psimd_splat0_f32(va1);
+ const psimd_f32 va2c0 = psimd_splat0_f32(va2);
+ const psimd_f32 va3c0 = psimd_splat0_f32(va3);
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c0, vb4567c0);
+ const psimd_f32 va0c1 = psimd_splat1_f32(va0);
+ const psimd_f32 va1c1 = psimd_splat1_f32(va1);
+ const psimd_f32 va2c1 = psimd_splat1_f32(va2);
+ const psimd_f32 va3c1 = psimd_splat1_f32(va3);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c1, vb4567c1);
+ const psimd_f32 va0c2 = psimd_splat2_f32(va0);
+ const psimd_f32 va1c2 = psimd_splat2_f32(va1);
+ const psimd_f32 va2c2 = psimd_splat2_f32(va2);
+ const psimd_f32 va3c2 = psimd_splat2_f32(va3);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c2, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c2, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c2, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c2, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c2, vb4567c2);
+ const psimd_f32 va0c3 = psimd_splat3_f32(va0);
+ const psimd_f32 va1c3 = psimd_splat3_f32(va1);
+ const psimd_f32 va2c3 = psimd_splat3_f32(va2);
+ const psimd_f32 va3c3 = psimd_splat3_f32(va3);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c3, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c3, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c3, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c3, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c3, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c3, vb4567c3);
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8-sse-dup.c b/src/f32-gemminc/4x8-sse-dup.c
new file mode 100644
index 0000000..9f7b942
--- /dev/null
+++ b/src/f32-gemminc/4x8-sse-dup.c
@@ -0,0 +1,266 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-dup.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8__sse_dup(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(acc + 0);
+ __m128 vacc0x4567 = _mm_load_ps(acc + 4);
+ __m128 vacc1x0123 = _mm_load_ps(acc + 8);
+ __m128 vacc1x4567 = _mm_load_ps(acc + 12);
+ __m128 vacc2x0123 = _mm_load_ps(acc + 16);
+ __m128 vacc2x4567 = _mm_load_ps(acc + 20);
+ __m128 vacc3x0123 = _mm_load_ps(acc + 24);
+ __m128 vacc3x4567 = _mm_load_ps(acc + 28);
+ acc += 32;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+ const __m128 va1 = _mm_loadu_ps(a1);
+ a1 += 4;
+ const __m128 va2 = _mm_loadu_ps(a2);
+ a2 += 4;
+ const __m128 va3 = _mm_loadu_ps(a3);
+ a3 += 4;
+
+
+ const __m128 va0c0000 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va1c0000 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va2c0000 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va3c0000 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 0, 0, 0));
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c0000, vb0123c0));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c0000, vb0123c0));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c0000, vb0123c0));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c0000, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c0000, vb4567c0));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c0000, vb4567c0));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c0000, vb4567c0));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c0000, vb4567c0));
+
+ const __m128 va0c1111 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va1c1111 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va2c1111 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va3c1111 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(1, 1, 1, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c1111, vb0123c1));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c1111, vb0123c1));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c1111, vb0123c1));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c1111, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c1111, vb4567c1));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c1111, vb4567c1));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c1111, vb4567c1));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c1111, vb4567c1));
+
+ const __m128 va0c2222 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va1c2222 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va2c2222 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va3c2222 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(2, 2, 2, 2));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c2222, vb0123c2));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c2222, vb0123c2));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c2222, vb0123c2));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c2222, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c2222, vb4567c2));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c2222, vb4567c2));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c2222, vb4567c2));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c2222, vb4567c2));
+
+ const __m128 va0c3333 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(3, 3, 3, 3));
+ const __m128 va1c3333 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(3, 3, 3, 3));
+ const __m128 va2c3333 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(3, 3, 3, 3));
+ const __m128 va3c3333 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(3, 3, 3, 3));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c3333, vb0123c3));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c3333, vb0123c3));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c3333, vb0123c3));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c3333, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c3333, vb4567c3));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c3333, vb4567c3));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c3333, vb4567c3));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c3333, vb4567c3));
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8-sse-load1.c b/src/f32-gemminc/4x8-sse-load1.c
new file mode 100644
index 0000000..8f1e8b7
--- /dev/null
+++ b/src/f32-gemminc/4x8-sse-load1.c
@@ -0,0 +1,182 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-load1.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8__sse_load1(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(acc + 0);
+ __m128 vacc0x4567 = _mm_load_ps(acc + 4);
+ __m128 vacc1x0123 = _mm_load_ps(acc + 8);
+ __m128 vacc1x4567 = _mm_load_ps(acc + 12);
+ __m128 vacc2x0123 = _mm_load_ps(acc + 16);
+ __m128 vacc2x4567 = _mm_load_ps(acc + 20);
+ __m128 vacc3x0123 = _mm_load_ps(acc + 24);
+ __m128 vacc3x4567 = _mm_load_ps(acc + 28);
+ acc += 32;
+
+ size_t k = kc;
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8-sse.c b/src/f32-gemminc/4x8-sse.c
new file mode 100644
index 0000000..b16e2fe
--- /dev/null
+++ b/src/f32-gemminc/4x8-sse.c
@@ -0,0 +1,176 @@
+/*
+ * Auto-generated file. Do not edit!
+ * Template: src/f32-gemm/sse.c.in
+ * Generator: tools/xngen
+ */
+
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float* restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(acc + 0);
+ __m128 vacc0x4567 = _mm_load_ps(acc + 4);
+ __m128 vacc1x0123 = _mm_load_ps(acc + 8);
+ __m128 vacc1x4567 = _mm_load_ps(acc + 12);
+ __m128 vacc2x0123 = _mm_load_ps(acc + 16);
+ __m128 vacc2x4567 = _mm_load_ps(acc + 20);
+ __m128 vacc3x0123 = _mm_load_ps(acc + 24);
+ __m128 vacc3x4567 = _mm_load_ps(acc + 28);
+ acc += 32;
+
+ size_t k = kc;
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c3, vacc3x0123);
+
+ vacc0x0123 = vacc0x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc3x0123 = vacc3x4567;
+
+ c0 += 4;
+ c1 += 4;
+ c2 += 4;
+ c3 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+
+ c0 += 2;
+ c1 += 2;
+ c2 += 2;
+ c3 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c3, vacc3x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8s4-psimd.c b/src/f32-gemminc/4x8s4-psimd.c
new file mode 100644
index 0000000..ae9c36f
--- /dev/null
+++ b/src/f32-gemminc/4x8s4-psimd.c
@@ -0,0 +1,262 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-s4.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(acc + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(acc + 4);
+ psimd_f32 vacc1x0123 = psimd_load_f32(acc + 8);
+ psimd_f32 vacc1x4567 = psimd_load_f32(acc + 12);
+ psimd_f32 vacc2x0123 = psimd_load_f32(acc + 16);
+ psimd_f32 vacc2x4567 = psimd_load_f32(acc + 20);
+ psimd_f32 vacc3x0123 = psimd_load_f32(acc + 24);
+ psimd_f32 vacc3x4567 = psimd_load_f32(acc + 28);
+ acc += 32;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c0);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c1);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c2);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c3);
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/4x8s4-sse.c b/src/f32-gemminc/4x8s4-sse.c
new file mode 100644
index 0000000..b82b752
--- /dev/null
+++ b/src/f32-gemminc/4x8s4-sse.c
@@ -0,0 +1,262 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/sse-shuffle.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_4x8s4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(acc + 0);
+ __m128 vacc0x4567 = _mm_load_ps(acc + 4);
+ __m128 vacc1x0123 = _mm_load_ps(acc + 8);
+ __m128 vacc1x4567 = _mm_load_ps(acc + 12);
+ __m128 vacc2x0123 = _mm_load_ps(acc + 16);
+ __m128 vacc2x4567 = _mm_load_ps(acc + 20);
+ __m128 vacc3x0123 = _mm_load_ps(acc + 24);
+ __m128 vacc3x4567 = _mm_load_ps(acc + 28);
+ acc += 32;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+ __m128 va1 = _mm_loadu_ps(a1);
+ a1 += 4;
+ __m128 va2 = _mm_loadu_ps(a2);
+ a2 += 4;
+ __m128 va3 = _mm_loadu_ps(a3);
+ a3 += 4;
+
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c0));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c0));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c0));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c0));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c0));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c0));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c0));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+ va1 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 3, 2, 1));
+ va2 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 3, 2, 1));
+ va3 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c1));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c1));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c1));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c1));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c1));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c1));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c1));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+ va1 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 3, 2, 1));
+ va2 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 3, 2, 1));
+ va3 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c2));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c2));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c2));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c2));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c2));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c2));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c2));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+ va1 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 3, 2, 1));
+ va2 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 3, 2, 1));
+ va3 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c3));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c3));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c3));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c3));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c3));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c3));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c3));
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/5x8-aarch64-neonfma-cortex-a75.S b/src/f32-gemminc/5x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..f9964d7
--- /dev/null
+++ b/src/f32-gemminc/5x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,581 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/5x8-aarch64-neonfma-cortex-a75.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_5x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# unused compared to 5x8
+# x4 a5
+# x7 c5
+# A5 v10 v11
+# C v30 v31
+
+# d8-d15 need to be preserved if used.
+# x19-x30 need to be preserved if used. x18 is reserved for OS.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x13 c3
+# x7 c4
+
+# Vector register usage
+# A0 v0 v1
+# A1 v2 v3
+# A2 v4 v5
+# A3 v6 v7
+# A4 v8 v9
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# Clamp v30 v31
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_5x8__aarch64_neonfma_cortex_a75
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -48]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d12, d13, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d14, d15, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x13, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x13, x17, x13, LO // c3 = c2
+
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 56]
+
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x7, x13, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x7, x13, x7, LS // c4 = c3
+
+ # Load clamp values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+ # Load cn_stride
+ LDR x14, [sp, 48]
+
+0:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 80 FMA
+ LDR q0, [x3], 16
+ LDR q2, [x9], 16
+ LDR q4, [x10], 16
+ LDR q6, [x11], 16
+ LDR q8, [x12], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 80 FMA + 5 LDP A + 8 LDP B
+1:
+ # First group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ LDR q1, [x3], 16 // Load next 5 A
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ LDR q3, [x9], 16
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ LDR q5, [x10], 16
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ LDR q7, [x11], 16
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ LDR q9, [x12], 16
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v1.s[0]
+ FMLA v22.4s, v12.4s, v3.s[0]
+ FMLA v24.4s, v12.4s, v5.s[0]
+ LDR q0, [x3], 16 // Load next 5 A
+ FMLA v26.4s, v12.4s, v7.s[0]
+ FMLA v28.4s, v12.4s, v9.s[0]
+ FMLA v21.4s, v13.4s, v1.s[0]
+ LDR q2, [x9], 16
+ FMLA v23.4s, v13.4s, v3.s[0]
+ FMLA v25.4s, v13.4s, v5.s[0]
+ FMLA v27.4s, v13.4s, v7.s[0]
+ LDR q4, [x10], 16
+ FMLA v29.4s, v13.4s, v9.s[0]
+
+ FMLA v20.4s, v14.4s, v1.s[1]
+ FMLA v22.4s, v14.4s, v3.s[1]
+ LDR q6, [x11], 16
+ FMLA v24.4s, v14.4s, v5.s[1]
+ FMLA v26.4s, v14.4s, v7.s[1]
+ FMLA v28.4s, v14.4s, v9.s[1]
+ LDR q8, [x12], 16
+ FMLA v21.4s, v15.4s, v1.s[1]
+ FMLA v23.4s, v15.4s, v3.s[1]
+ FMLA v25.4s, v15.4s, v5.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v7.s[1]
+ FMLA v29.4s, v15.4s, v9.s[1]
+
+ FMLA v20.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v22.4s, v16.4s, v3.s[2]
+ FMLA v24.4s, v16.4s, v5.s[2]
+ FMLA v26.4s, v16.4s, v7.s[2]
+ FMLA v28.4s, v16.4s, v9.s[2]
+ FMLA v21.4s, v17.4s, v1.s[2]
+ FMLA v23.4s, v17.4s, v3.s[2]
+ FMLA v25.4s, v17.4s, v5.s[2]
+ FMLA v27.4s, v17.4s, v7.s[2]
+ FMLA v29.4s, v17.4s, v9.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v1.s[3]
+ FMLA v22.4s, v18.4s, v3.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v5.s[3]
+ FMLA v26.4s, v18.4s, v7.s[3]
+ FMLA v28.4s, v18.4s, v9.s[3]
+ FMLA v21.4s, v19.4s, v1.s[3]
+ FMLA v23.4s, v19.4s, v3.s[3]
+ FMLA v25.4s, v19.4s, v5.s[3]
+ FMLA v27.4s, v19.4s, v7.s[3]
+ FMLA v29.4s, v19.4s, v9.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 80 FMA + 5 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ LDR q1, [x3], 16 // Load next 5 A
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ LDR q3, [x9], 16
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ LDR q5, [x10], 16
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ LDR q7, [x11], 16
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ LDR q9, [x12], 16
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v1.s[0]
+ FMLA v22.4s, v12.4s, v3.s[0]
+ FMLA v24.4s, v12.4s, v5.s[0]
+ FMLA v26.4s, v12.4s, v7.s[0]
+ FMLA v28.4s, v12.4s, v9.s[0]
+ FMLA v21.4s, v13.4s, v1.s[0]
+ FMLA v23.4s, v13.4s, v3.s[0]
+ FMLA v25.4s, v13.4s, v5.s[0]
+ FMLA v27.4s, v13.4s, v7.s[0]
+ FMLA v29.4s, v13.4s, v9.s[0]
+
+ FMLA v20.4s, v14.4s, v1.s[1]
+ FMLA v22.4s, v14.4s, v3.s[1]
+ FMLA v24.4s, v14.4s, v5.s[1]
+ FMLA v26.4s, v14.4s, v7.s[1]
+ FMLA v28.4s, v14.4s, v9.s[1]
+ FMLA v21.4s, v15.4s, v1.s[1]
+ FMLA v23.4s, v15.4s, v3.s[1]
+ FMLA v25.4s, v15.4s, v5.s[1]
+ FMLA v27.4s, v15.4s, v7.s[1]
+ FMLA v29.4s, v15.4s, v9.s[1]
+
+ FMLA v20.4s, v16.4s, v1.s[2]
+ FMLA v22.4s, v16.4s, v3.s[2]
+ FMLA v24.4s, v16.4s, v5.s[2]
+ FMLA v26.4s, v16.4s, v7.s[2]
+ FMLA v28.4s, v16.4s, v9.s[2]
+ FMLA v21.4s, v17.4s, v1.s[2]
+ FMLA v23.4s, v17.4s, v3.s[2]
+ FMLA v25.4s, v17.4s, v5.s[2]
+ FMLA v27.4s, v17.4s, v7.s[2]
+ FMLA v29.4s, v17.4s, v9.s[2]
+ TST x0, 31
+
+ FMLA v20.4s, v18.4s, v1.s[3]
+ FMLA v22.4s, v18.4s, v3.s[3]
+ FMLA v24.4s, v18.4s, v5.s[3]
+ FMLA v26.4s, v18.4s, v7.s[3]
+ FMLA v28.4s, v18.4s, v9.s[3]
+ FMLA v21.4s, v19.4s, v1.s[3]
+ FMLA v23.4s, v19.4s, v3.s[3]
+ FMLA v25.4s, v19.4s, v5.s[3]
+ FMLA v27.4s, v19.4s, v7.s[3]
+ FMLA v29.4s, v19.4s, v9.s[3]
+ B.NE 4f
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v30.4s
+ FMIN v21.4s, v21.4s, v30.4s
+ FMIN v22.4s, v22.4s, v30.4s
+ FMIN v23.4s, v23.4s, v30.4s
+ FMIN v24.4s, v24.4s, v30.4s
+ FMIN v25.4s, v25.4s, v30.4s
+ FMIN v26.4s, v26.4s, v30.4s
+ FMIN v27.4s, v27.4s, v30.4s
+ FMIN v28.4s, v28.4s, v30.4s
+ FMIN v29.4s, v29.4s, v30.4s
+ FMAX v20.4s, v20.4s, v31.4s
+ FMAX v21.4s, v21.4s, v31.4s
+ FMAX v22.4s, v22.4s, v31.4s
+ FMAX v23.4s, v23.4s, v31.4s
+ FMAX v24.4s, v24.4s, v31.4s
+ FMAX v25.4s, v25.4s, v31.4s
+ FMAX v26.4s, v26.4s, v31.4s
+ FMAX v27.4s, v27.4s, v31.4s
+ FMAX v28.4s, v28.4s, v31.4s
+ FMAX v29.4s, v29.4s, v31.4s
+
+ # Store full 5 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x7]
+ ADD x7, x7, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x13]
+ ADD x13, x13, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 32]
+ LDP d12, d13, [sp, 16]
+ LDP d8, d9, [sp], 48
+ RET
+
+ # Load clamp values
+4:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q2, [x9], 16
+ LDR q4, [x10], 16
+ LDR q6, [x11], 16
+ LDR q8, [x12], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d2, [x9], 8
+ LDR d4, [x10], 8
+ LDR d6, [x11], 8
+ LDR d8, [x12], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s2, [x9], 4
+ LDR s4, [x10], 4
+ LDR s6, [x11], 4
+ LDR s8, [x12], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q28, [x7], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x13], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+8:
+ TBZ x1, 1, 9f
+ STR d28, [x7], 8
+ DUP d28, v28.d[1]
+ STR d26, [x13], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s28, [x7]
+ STR s26, [x13]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 32]
+ LDP d12, d13, [sp, 16]
+ LDP d8, d9, [sp], 48
+ RET
+
+END_FUNCTION f32_gemminc_ukernel_5x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/5x8-neon-ld64.c b/src/f32-gemminc/5x8-neon-ld64.c
new file mode 100644
index 0000000..f2fe644
--- /dev/null
+++ b/src/f32-gemminc/5x8-neon-ld64.c
@@ -0,0 +1,227 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_5x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 5);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc4x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc4x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+ const float32x4_t va4 = vld1q_dup_f32(a4); a4 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vmlaq_f32(vacc4x0123, va4, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vmlaq_f32(vacc4x4567, va4, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/5x8-neonfma-ld64.c b/src/f32-gemminc/5x8-neonfma-ld64.c
new file mode 100644
index 0000000..6bbada0
--- /dev/null
+++ b/src/f32-gemminc/5x8-neonfma-ld64.c
@@ -0,0 +1,263 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_5x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 5);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc4x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc4x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ const float32x4_t va4c0 = vdupq_lane_f32(va4, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ const float32x4_t va4c1 = vdupq_lane_f32(va4, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+ const float32x4_t va4 = vld1q_dup_f32(a4); a4 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/6x8-aarch64-neonfma-cortex-a57.S b/src/f32-gemminc/6x8-aarch64-neonfma-cortex-a57.S
new file mode 100644
index 0000000..dd0c830
--- /dev/null
+++ b/src/f32-gemminc/6x8-aarch64-neonfma-cortex-a57.S
@@ -0,0 +1,657 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-cortex-a57.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a57(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+# A57 kernel based on A75 but with PRFM removed from main loop
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a57
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 96 FMA
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+1:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ LDR q0, [x3], 16 // Load next 6 A
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ LDR q1, [x9], 16
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ LDR q2, [x10], 16
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+ LDR q3, [x11], 16
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ LDR q4, [x12], 16
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ LDR q5, [x4], 16
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 4f
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/6x8-aarch64-neonfma-cortex-a73.S b/src/f32-gemminc/6x8-aarch64-neonfma-cortex-a73.S
new file mode 100644
index 0000000..9cf2201
--- /dev/null
+++ b/src/f32-gemminc/6x8-aarch64-neonfma-cortex-a73.S
@@ -0,0 +1,658 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-cortex-a73.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a73(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a73
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+ .p2align 3
+0:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 96 FMA
+ # load A0 to A4 but not A5
+ LDP q0, q6, [x3], 32
+ LDP q1, q7, [x9], 32
+ LDP q2, q8, [x10], 32
+ LDP q3, q9, [x11], 32
+ LDP q4, q10, [x12], 32
+ # load first set of B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ .p2align 3
+1:
+ # First group of 4 A. 48 FMA. Loads A5
+
+ LDP q5, q11, [x4], 32
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Second group of 4 A. 48 FMA. Loads A0 - A4
+
+ LDP q16, q17, [x5], 32
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v20.4s, v18.4s, v6.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ LDP q0, q6, [x3], 32
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ LDP q1, q7, [x9], 32
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ LDP q2, q8, [x10], 32
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ LDP q3, q9, [x11], 32
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ LDP q4, q10, [x12], 32
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ SUBS x0, x0, 32
+ FMLA v31.4s, v17.4s, v11.s[2]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 48 FMA. Loads A5
+
+ LDP q5, q11, [x4], 32
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Second group of 4 A. 48 FMA. No A Loads, No last B load
+
+ LDP q16, q17, [x5], 32
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ # Last part of epilogue has loads removed.
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 4f
+
+ .p2align 3
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ NOP
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+ .p2align 3
+4:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a73
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/6x8-aarch64-neonfma-cortex-a75.S b/src/f32-gemminc/6x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..ba9dba5
--- /dev/null
+++ b/src/f32-gemminc/6x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,659 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-cortex-a75.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a75
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # Prologue - loads for main loop of 96 FMA
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 2f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+1:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ LDR q0, [x3], 16 // Load next 6 A
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ LDR q1, [x9], 16
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ LDR q2, [x10], 16
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+ LDR q3, [x11], 16
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ LDR q4, [x12], 16
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ LDR q5, [x4], 16
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 1b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+2:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x3], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x9], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x10], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x11], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x12], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x4], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 4f
+
+ # Clamp
+3:
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 7f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 5f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x3], 16
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+5:
+ TBZ x0, 3, 6f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x3], 8
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+6:
+ TBZ x0, 2, 3b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x3], 4
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+7:
+ TBZ x1, 2, 8f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+8:
+ TBZ x1, 1, 9f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+9:
+ TBZ x1, 0, 10f
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+10:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/6x8-aarch64-neonfma-ld128.S b/src/f32-gemminc/6x8-aarch64-neonfma-ld128.S
new file mode 100644
index 0000000..0ad2073
--- /dev/null
+++ b/src/f32-gemminc/6x8-aarch64-neonfma-ld128.S
@@ -0,0 +1,373 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-ld128.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0
+# A1 v1
+# A2 v2
+# A3 v3
+# A4 v4
+# A5 v5
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+# unused A v8 v9 v10 v11
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 4 floats (16 bytes)?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 2f
+
+ # Main loop - 4 floats of A (16 bytes)
+1:
+ LDR q0, [x3], 16
+ LDP q12, q13, [x5], 32
+ LDR q1, [x9], 16
+ LDR q2, [x10], 16
+ LDR q3, [x11], 16
+ LDR q4, [x12], 16
+ LDR q5, [x4], 16
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ SUBS x0, x0, 16
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ B.HS 1b
+
+2:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 4f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 5f
+3:
+ # Clamp
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 6f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Remainder- 2 floats of A (8 bytes)
+ LDR d0, [x3], 8
+ LDP q12, q13, [x5], 32
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ TBZ x0, 2, 3b
+
+5:
+ # Remainder- 1 floats of A (4 bytes)
+ LDR s0, [x3], 4
+ LDP q12, q13, [x5], 32
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+6:
+ TBZ x1, 2, 7f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+7:
+ TBZ x1, 1, 8f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+8:
+ TBZ x1, 0, 9f
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+9:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma__ld128
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/6x8-aarch64-neonfma-ld64.S b/src/f32-gemminc/6x8-aarch64-neonfma-ld64.S
new file mode 100644
index 0000000..83f6e71
--- /dev/null
+++ b/src/f32-gemminc/6x8-aarch64-neonfma-ld64.S
@@ -0,0 +1,307 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/6x8-aarch64-neonfma-ld64.S.in
+// Generator: tools/xngen
+//
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld64(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# const uint8_t*restrict a, x3
+# size_t a_stride, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x14
+# const float*restrict acc, [sp + 8] -> x15
+# const union xnn_f32_output_params params[restrict static 1]) [sp + 16] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x3 a0
+# x9 a1
+# x10 a2
+# x11 a3
+# x12 a4
+# x4 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0
+# A1 v1
+# A2 v2
+# A3 v3
+# A4 v4
+# A5 v5
+# B v14 v15 v16 v17
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+# unused A v8 v9 v10 v11
+# unused B v16 v17 v18 v19
+
+BEGIN_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld64
+
+ # Clamp A and C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x9, x3, x4 // a1 = a0 + a_stride
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x9, x3, x9, LO // a1 = a0
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x10, x9, x4 // a2 = a1 + a_stride
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x10, x9, x10, LS // a2 = a1
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x11, x10, x4 // a3 = a2 + a_stride
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x11, x10, x11, LO // a3 = a2
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x12, x11, x4 // a4 = a3 + a_stride
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x12, x11, x12, LS // a4 = a3
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Load acc, params pointer
+ LDP x15, x8, [sp, 72]
+
+ ADD x4, x12, x4 // a5 = a4 + a_stride
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x4, x12, x4, LO // a5 = a4
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Load cn_stride
+ LDR x14, [sp, 64]
+
+0:
+ # Load initial accumulators
+ LDP q20, q21, [x15], 32
+ LDP q22, q23, [x15], 32
+ LDP q24, q25, [x15], 32
+ LDP q26, q27, [x15], 32
+ LDP q28, q29, [x15], 32
+ LDP q30, q31, [x15], 32
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x3] // Prefetch A
+ PRFM PLDL1KEEP, [x9]
+ PRFM PLDL1KEEP, [x10]
+ PRFM PLDL1KEEP, [x11]
+ PRFM PLDL1KEEP, [x12]
+ PRFM PLDL1KEEP, [x4]
+
+ # Is there at least 2 floats (8 bytes) for main loop?
+ SUBS x0, x2, 8 // k = kc - 8
+ B.LO 2f
+
+ # Main loop - 2 floats of A (8 bytes)
+ # 24 FMA + 6 LD64 A + 2 LDP B
+1:
+ LDR d0, [x3], 8
+ LDP q12, q13, [x5], 32
+ LDR d1, [x9], 8
+ LDR d2, [x10], 8
+ LDR d3, [x11], 8
+ LDR d4, [x12], 8
+ LDR d5, [x4], 8
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ SUBS x0, x0, 8
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ B.HS 1b
+
+2:
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 4f
+3:
+ # Clamp
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 5f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x14
+ SUB x3, x3, x2 // a0 -= kc
+ STP q28, q29, [x13]
+ ADD x13, x13, x14
+ SUB x9, x9, x2 // a1 -= kc
+ STP q26, q27, [x18]
+ ADD x18, x18, x14
+ SUB x10, x10, x2 // a2 -= kc
+ STP q24, q25, [x17]
+ ADD x17, x17, x14
+ SUB x11, x11, x2 // a3 -= kc
+ STP q22, q23, [x16]
+ ADD x16, x16, x14
+ SUB x12, x12, x2 // a4 -= kc
+ STP q20, q21, [x6]
+ ADD x6, x6, x14
+ SUB x4, x4, x2 // a5 -= kc
+
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+4:
+ # Remainder- 1 floats of A (4 bytes)
+ LDR s0, [x3], 4
+ LDP q12, q13, [x5], 32
+ LDR s1, [x9], 4
+ LDR s2, [x10], 4
+ LDR s3, [x11], 4
+ LDR s4, [x12], 4
+ LDR s5, [x4], 4
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 3b
+
+ # Store odd width
+5:
+ TBZ x1, 2, 6f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+6:
+ TBZ x1, 1, 7f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+7:
+ TBZ x1, 0, 8f
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+8:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma__ld64
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-gemminc/6x8-neon-ld64.c b/src/f32-gemminc/6x8-neon-ld64.c
new file mode 100644
index 0000000..f657dc4
--- /dev/null
+++ b/src/f32-gemminc/6x8-neon-ld64.c
@@ -0,0 +1,257 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_6x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc4x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc4x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc5x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc5x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+ const float32x2_t va5 = vld1_f32(a5); a5 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123c0, va5, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567c0, va5, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123c1, va5, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567c1, va5, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+ const float32x4_t va4 = vld1q_dup_f32(a4); a4 += 1;
+ const float32x4_t va5 = vld1q_dup_f32(a5); a5 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vmlaq_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = vmlaq_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vmlaq_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = vmlaq_f32(vacc5x4567, va5, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc5x0123 = vminq_f32(vacc5x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+ vacc5x4567 = vminq_f32(vacc5x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc5x0123 = vmaxq_f32(vacc5x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+ vacc5x4567 = vmaxq_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c5, vacc5x0123);
+ vst1q_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c5, vacc5x0123); c5 += 4;
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc5x01 = vget_low_f32(vacc5x0123);
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c5, vacc5x01); c5 += 2;
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc5x01 = vget_high_f32(vacc5x0123);
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c5, vacc5x01, 0);
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/6x8-neonfma-ld64.c b/src/f32-gemminc/6x8-neonfma-ld64.c
new file mode 100644
index 0000000..22519ef
--- /dev/null
+++ b/src/f32-gemminc/6x8-neonfma-ld64.c
@@ -0,0 +1,299 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_6x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* restrict a,
+ size_t a_stride,
+ const float* restrict w,
+ float* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc1x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc2x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc3x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc4x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc4x4567 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc5x0123 = vld1q_f32(acc); acc += 4;
+ float32x4_t vacc5x4567 = vld1q_f32(acc); acc += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+ const float32x2_t va5 = vld1_f32(a5); a5 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123c0, va5, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567c0, va5, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ const float32x4_t va4c0 = vdupq_lane_f32(va4, 0);
+ const float32x4_t va5c0 = vdupq_lane_f32(va5, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c0, vb0123c0);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c0, vb4567c0);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123c1, va5, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567c1, va5, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ const float32x4_t va4c1 = vdupq_lane_f32(va4, 1);
+ const float32x4_t va5c1 = vdupq_lane_f32(va5, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c1, vb0123c1);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c1, vb4567c1);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+ const float32x4_t va4 = vld1q_dup_f32(a4); a4 += 1;
+ const float32x4_t va5 = vld1q_dup_f32(a5); a5 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5, vb4567);
+ }
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc5x0123 = vminq_f32(vacc5x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+ vacc5x4567 = vminq_f32(vacc5x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc5x0123 = vmaxq_f32(vacc5x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+ vacc5x4567 = vmaxq_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c5, vacc5x0123);
+ vst1q_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c5, vacc5x0123); c5 += 4;
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc5x01 = vget_low_f32(vacc5x0123);
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c5, vacc5x01); c5 += 2;
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc5x01 = vget_high_f32(vacc5x0123);
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c5, vacc5x01, 0);
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/6x8-psimd-loadsplat.c b/src/f32-gemminc/6x8-psimd-loadsplat.c
new file mode 100644
index 0000000..68cd252
--- /dev/null
+++ b/src/f32-gemminc/6x8-psimd-loadsplat.c
@@ -0,0 +1,236 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-loadsplat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_6x8__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(acc + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(acc + 4);
+ psimd_f32 vacc1x0123 = psimd_load_f32(acc + 8);
+ psimd_f32 vacc1x4567 = psimd_load_f32(acc + 12);
+ psimd_f32 vacc2x0123 = psimd_load_f32(acc + 16);
+ psimd_f32 vacc2x4567 = psimd_load_f32(acc + 20);
+ psimd_f32 vacc3x0123 = psimd_load_f32(acc + 24);
+ psimd_f32 vacc3x4567 = psimd_load_f32(acc + 28);
+ psimd_f32 vacc4x0123 = psimd_load_f32(acc + 32);
+ psimd_f32 vacc4x4567 = psimd_load_f32(acc + 36);
+ psimd_f32 vacc5x0123 = psimd_load_f32(acc + 40);
+ psimd_f32 vacc5x4567 = psimd_load_f32(acc + 44);
+ acc += 48;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+ const psimd_f32 va4 = psimd_load_splat_f32(a4);
+ a4 += 1;
+ const psimd_f32 va5 = psimd_load_splat_f32(a5);
+ a5 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
+ vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+ vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
+ vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
+ vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+ vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
+ vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c5 += 4;
+ c4 += 4;
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c5, vacc5x0123);
+ psimd_store2_f32(c4, vacc4x0123);
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
+ vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c5 += 2;
+ c4 += 2;
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c5, vacc5x0123);
+ psimd_store1_f32(c4, vacc4x0123);
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/6x8-psimd-splat.c b/src/f32-gemminc/6x8-psimd-splat.c
new file mode 100644
index 0000000..e2230cf
--- /dev/null
+++ b/src/f32-gemminc/6x8-psimd-splat.c
@@ -0,0 +1,344 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-splat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_6x8__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(acc + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(acc + 4);
+ psimd_f32 vacc1x0123 = psimd_load_f32(acc + 8);
+ psimd_f32 vacc1x4567 = psimd_load_f32(acc + 12);
+ psimd_f32 vacc2x0123 = psimd_load_f32(acc + 16);
+ psimd_f32 vacc2x4567 = psimd_load_f32(acc + 20);
+ psimd_f32 vacc3x0123 = psimd_load_f32(acc + 24);
+ psimd_f32 vacc3x4567 = psimd_load_f32(acc + 28);
+ psimd_f32 vacc4x0123 = psimd_load_f32(acc + 32);
+ psimd_f32 vacc4x4567 = psimd_load_f32(acc + 36);
+ psimd_f32 vacc5x0123 = psimd_load_f32(acc + 40);
+ psimd_f32 vacc5x4567 = psimd_load_f32(acc + 44);
+ acc += 48;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ const psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ const psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ const psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+ const psimd_f32 va4 = psimd_load_f32(a4);
+ a4 += 4;
+ const psimd_f32 va5 = psimd_load_f32(a5);
+ a5 += 4;
+
+ const psimd_f32 va0c0 = psimd_splat0_f32(va0);
+ const psimd_f32 va1c0 = psimd_splat0_f32(va1);
+ const psimd_f32 va2c0 = psimd_splat0_f32(va2);
+ const psimd_f32 va3c0 = psimd_splat0_f32(va3);
+ const psimd_f32 va4c0 = psimd_splat0_f32(va4);
+ const psimd_f32 va5c0 = psimd_splat0_f32(va5);
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c0, vb0123c0);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c0, vb4567c0);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c0, vb4567c0);
+ const psimd_f32 va0c1 = psimd_splat1_f32(va0);
+ const psimd_f32 va1c1 = psimd_splat1_f32(va1);
+ const psimd_f32 va2c1 = psimd_splat1_f32(va2);
+ const psimd_f32 va3c1 = psimd_splat1_f32(va3);
+ const psimd_f32 va4c1 = psimd_splat1_f32(va4);
+ const psimd_f32 va5c1 = psimd_splat1_f32(va5);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c1, vb0123c1);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c1, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c1, vb4567c1);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c1, vb4567c1);
+ const psimd_f32 va0c2 = psimd_splat2_f32(va0);
+ const psimd_f32 va1c2 = psimd_splat2_f32(va1);
+ const psimd_f32 va2c2 = psimd_splat2_f32(va2);
+ const psimd_f32 va3c2 = psimd_splat2_f32(va3);
+ const psimd_f32 va4c2 = psimd_splat2_f32(va4);
+ const psimd_f32 va5c2 = psimd_splat2_f32(va5);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c2, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c2, vb0123c2);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c2, vb0123c2);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c2, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c2, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c2, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c2, vb4567c2);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c2, vb4567c2);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c2, vb4567c2);
+ const psimd_f32 va0c3 = psimd_splat3_f32(va0);
+ const psimd_f32 va1c3 = psimd_splat3_f32(va1);
+ const psimd_f32 va2c3 = psimd_splat3_f32(va2);
+ const psimd_f32 va3c3 = psimd_splat3_f32(va3);
+ const psimd_f32 va4c3 = psimd_splat3_f32(va4);
+ const psimd_f32 va5c3 = psimd_splat3_f32(va5);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c3, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c3, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c3, vb0123c3);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c3, vb0123c3);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c3, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c3, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c3, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c3, vb4567c3);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c3, vb4567c3);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c3, vb4567c3);
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+ const psimd_f32 va4 = psimd_load_splat_f32(a4);
+ a4 += 1;
+ const psimd_f32 va5 = psimd_load_splat_f32(a5);
+ a5 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
+ vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+ vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
+ vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
+ vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+ vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
+ vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c5 += 4;
+ c4 += 4;
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c5, vacc5x0123);
+ psimd_store2_f32(c4, vacc4x0123);
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
+ vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c5 += 2;
+ c4 += 2;
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c5, vacc5x0123);
+ psimd_store1_f32(c4, vacc4x0123);
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-gemminc/6x8s4-psimd.c b/src/f32-gemminc/6x8s4-psimd.c
new file mode 100644
index 0000000..78bac5d
--- /dev/null
+++ b/src/f32-gemminc/6x8s4-psimd.c
@@ -0,0 +1,342 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-gemm/psimd-s4.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_f32_gemminc_ukernel_6x8s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ size_t a_stride,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float*restrict acc,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+ assert(acc != NULL);
+
+ const float* a0 = a;
+ float* c0 = c;
+ const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const float* a5 = (const float*) ((uintptr_t) a4 + a_stride);
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(acc + 0);
+ psimd_f32 vacc0x4567 = psimd_load_f32(acc + 4);
+ psimd_f32 vacc1x0123 = psimd_load_f32(acc + 8);
+ psimd_f32 vacc1x4567 = psimd_load_f32(acc + 12);
+ psimd_f32 vacc2x0123 = psimd_load_f32(acc + 16);
+ psimd_f32 vacc2x4567 = psimd_load_f32(acc + 20);
+ psimd_f32 vacc3x0123 = psimd_load_f32(acc + 24);
+ psimd_f32 vacc3x4567 = psimd_load_f32(acc + 28);
+ psimd_f32 vacc4x0123 = psimd_load_f32(acc + 32);
+ psimd_f32 vacc4x4567 = psimd_load_f32(acc + 36);
+ psimd_f32 vacc5x0123 = psimd_load_f32(acc + 40);
+ psimd_f32 vacc5x4567 = psimd_load_f32(acc + 44);
+ acc += 48;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+ psimd_f32 va4 = psimd_load_f32(a4);
+ a4 += 4;
+ psimd_f32 va5 = psimd_load_f32(a5);
+ a5 += 4;
+
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c0);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c0);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c0);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c0);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c0);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+ va4 = __builtin_shufflevector(va4, va4, 1, 2, 3, 0);
+ va5 = __builtin_shufflevector(va5, va5, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c1);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c1);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c1);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c1);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c1);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+ va4 = __builtin_shufflevector(va4, va4, 1, 2, 3, 0);
+ va5 = __builtin_shufflevector(va5, va5, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c2);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c2);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c2);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c2);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c2);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+ va4 = __builtin_shufflevector(va4, va4, 1, 2, 3, 0);
+ va5 = __builtin_shufflevector(va5, va5, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c3);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c3);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c3);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c3);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c3);
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+ const psimd_f32 va4 = psimd_load_splat_f32(a4);
+ a4 += 1;
+ const psimd_f32 va5 = psimd_load_splat_f32(a5);
+ a5 += 1;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
+ vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+ vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
+ vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
+ vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+ vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
+ vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a5 = (const float*) ((uintptr_t) a5 - kc);
+ a4 = (const float*) ((uintptr_t) a4 - kc);
+ a3 = (const float*) ((uintptr_t) a3 - kc);
+ a2 = (const float*) ((uintptr_t) a2 - kc);
+ a1 = (const float*) ((uintptr_t) a1 - kc);
+ a0 = (const float*) ((uintptr_t) a0 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c5 += 4;
+ c4 += 4;
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c5, vacc5x0123);
+ psimd_store2_f32(c4, vacc4x0123);
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
+ vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c5 += 2;
+ c4 += 2;
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c5, vacc5x0123);
+ psimd_store1_f32(c4, vacc4x0123);
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-hswish/neon.c b/src/f32-hswish/neon.c
new file mode 100644
index 0000000..d29a206
--- /dev/null
+++ b/src/f32-hswish/neon.c
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/hswish.h>
+
+
+void xnn_f32_hswish_ukernel__neon(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_hswish_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float32x4_t vsixth = vld1q_dup_f32(¶ms->scalar.sixth);
+ const float32x4_t vhalf = vld1q_dup_f32(¶ms->scalar.half);
+ const float32x4_t vone = vld1q_dup_f32(¶ms->scalar.one);
+ const float32x4_t vzero = vdupq_n_f32(0.0f);
+
+ for (; n >= 16; n -= 16) {
+ const float32x4_t vx = vld1q_f32(x); x += 4;
+
+ const float32x4_t vt = vminq_f32(vmaxq_f32(vmlaq_f32(vhalf, vx, vsixth), vzero), vone);
+ const float32x4_t vy = vmulq_f32(vt, vx);
+
+ vst1q_f32(y, vy); y += 4;
+ }
+ if (n != 0) {
+ const float32x4_t vx = vld1q_f32(x); x += 4;
+
+ const float32x4_t vt = vminq_f32(vmaxq_f32(vmlaq_f32(vhalf, vx, vsixth), vzero), vone);
+ const float32x4_t vy = vmulq_f32(vt, vx);
+
+ float32x2_t vy_lo = vget_low_f32(vy);
+ if (n & 8) {
+ vst1_f32(y, vy_lo); y += 2;
+ vy_lo = vget_high_f32(vy);
+ }
+ if (n & 4) {
+ vst1_lane_f32(y, vy_lo, 0);
+ }
+ }
+}
diff --git a/src/f32-hswish/neonfma.c b/src/f32-hswish/neonfma.c
new file mode 100644
index 0000000..828545a
--- /dev/null
+++ b/src/f32-hswish/neonfma.c
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/hswish.h>
+
+
+void xnn_f32_hswish_ukernel__neonfma(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_hswish_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float32x4_t vsixth = vld1q_dup_f32(¶ms->scalar.sixth);
+ const float32x4_t vhalf = vld1q_dup_f32(¶ms->scalar.half);
+ const float32x4_t vone = vld1q_dup_f32(¶ms->scalar.one);
+ const float32x4_t vzero = vdupq_n_f32(0.0f);
+
+ for (; n >= 16; n -= 16) {
+ const float32x4_t vx = vld1q_f32(x); x += 4;
+
+ const float32x4_t vt = vminq_f32(vmaxq_f32(vfmaq_f32(vhalf, vx, vsixth), vzero), vone);
+ const float32x4_t vy = vmulq_f32(vt, vx);
+
+ vst1q_f32(y, vy); y += 4;
+ }
+ if (n != 0) {
+ const float32x4_t vx = vld1q_f32(x); x += 4;
+
+ const float32x4_t vt = vminq_f32(vmaxq_f32(vfmaq_f32(vhalf, vx, vsixth), vzero), vone);
+ const float32x4_t vy = vmulq_f32(vt, vx);
+
+ float32x2_t vy_lo = vget_low_f32(vy);
+ if (n & 8) {
+ vst1_f32(y, vy_lo); y += 2;
+ vy_lo = vget_high_f32(vy);
+ }
+ if (n & 4) {
+ vst1_lane_f32(y, vy_lo, 0);
+ }
+ }
+}
diff --git a/src/f32-hswish/psimd.c b/src/f32-hswish/psimd.c
new file mode 100644
index 0000000..43758bf
--- /dev/null
+++ b/src/f32-hswish/psimd.c
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/hswish.h>
+
+
+void xnn_f32_hswish_ukernel__psimd(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_hswish_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const psimd_f32 vsixth = psimd_load_splat_f32(¶ms->scalar.sixth);
+ const psimd_f32 vhalf = psimd_load_splat_f32(¶ms->scalar.half);
+ const psimd_f32 vone = psimd_load_splat_f32(¶ms->scalar.one);
+ const psimd_f32 vzero = psimd_splat_f32(0.0f);
+
+ for (; n >= 16; n -= 16) {
+ const psimd_f32 vx = psimd_load_f32(x);
+ x += 4;
+
+ const psimd_f32 vt = psimd_min_f32(psimd_max_f32(psimd_add_f32(psimd_mul_f32(vx, vsixth), vhalf), vzero), vone);
+ const psimd_f32 vy = psimd_mul_f32(vt, vx);
+
+ psimd_store_f32(y, vy);
+ y += 4;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ const psimd_f32 vx = psimd_load_f32(x);
+ x += 4;
+
+ const psimd_f32 vt = psimd_min_f32(psimd_max_f32(psimd_add_f32(psimd_mul_f32(vx, vsixth), vhalf), vzero), vone);
+ psimd_f32 vy = psimd_mul_f32(vt, vx);
+
+ if (n & 8) {
+ psimd_store2_f32(y, vy);
+ vy = psimd_concat_hi_f32(vy, vy);
+ y += 2;
+ }
+ if (n & 4) {
+ psimd_store1_f32(y, vy);
+ }
+ }
+}
diff --git a/src/f32-hswish/scalar.c b/src/f32-hswish/scalar.c
new file mode 100644
index 0000000..b4d251a
--- /dev/null
+++ b/src/f32-hswish/scalar.c
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/hswish.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_hswish_ukernel__scalar(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_hswish_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float vsixth = params->scalar.sixth;
+ const float vhalf = params->scalar.half;
+ const float vone = params->scalar.one;
+ assert(vhalf == 0.5f);
+ assert(vone == 1.0f);
+
+ do {
+ const float vx = *x++;
+
+ const float vt = math_min_f32(math_max_f32(vx * vsixth + vhalf, 0.0f), vone);
+ const float vy = vt * vx;
+
+ *y++ = vy;
+ n -= 4;
+ } while (n != 0);
+}
diff --git a/src/f32-hswish/sse.c b/src/f32-hswish/sse.c
new file mode 100644
index 0000000..d0f0e5f
--- /dev/null
+++ b/src/f32-hswish/sse.c
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/hswish.h>
+
+
+void xnn_f32_hswish_ukernel__sse(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_hswish_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const __m128 vsixth = _mm_load_ps(params->sse.sixth);
+ const __m128 vhalf = _mm_load_ps(params->sse.half);
+ const __m128 vone = _mm_load_ps(params->sse.one);
+ const __m128 vzero = _mm_setzero_ps();
+
+ for (; n >= 16; n -= 16) {
+ const __m128 vx = _mm_loadu_ps(x);
+ x += 4;
+
+ const __m128 vt = _mm_min_ps(_mm_max_ps(_mm_add_ps(_mm_mul_ps(vx, vsixth), vhalf), vzero), vone);
+ const __m128 vy = _mm_mul_ps(vt, vx);
+
+ _mm_storeu_ps(y, vy);
+ y += 4;
+ }
+ if (n != 0) {
+ const __m128 vx = _mm_loadu_ps(x);
+ x += 4;
+
+ const __m128 vt = _mm_min_ps(_mm_max_ps(_mm_add_ps(_mm_mul_ps(vx, vsixth), vhalf), vzero), vone);
+ __m128 vy = _mm_mul_ps(vt, vx);
+
+ if (n & 8) {
+ _mm_storel_pi((__m64*) y, vy);
+ vy = _mm_movehl_ps(vy, vy);
+ y += 2;
+ }
+ if (n & 4) {
+ _mm_store_ss(y, vy);
+ }
+ }
+}
diff --git a/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S
new file mode 100644
index 0000000..4761e86
--- /dev/null
+++ b/src/f32-igemm/1x12-aarch64-neonfma-cortex-a53.S
@@ -0,0 +1,373 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# size_t ks, x3 / x9
+# const float**restrict a, x4
+# const float*restrict w, x5
+# float*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x10
+# size_t a_offset, [sp + 8] -> x11
+# const float* zero, [sp + 16] -> x12
+# const xnn_f32_output_params params [sp + 24] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x20 a0
+
+# C pointers
+# x6 c0
+
+# Vector register usage and GPR shadows
+# a0 v0 first set of A
+# a0 v1 second set of A
+# B v2 v3 v4 x14 x15 x16 first set of B
+# B v5 v6 v7 x17 x18 x7
+# B v23 v24 v25 x14 x15 x16 second set of B (same x as first set)
+# B v17 v18 v19 x17 x18 x7
+# C v20 v21 v22
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp]
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+ # Save x20,x21 on stack
+ STP x20, x21, [sp, -16]!
+0:
+ # Load initial bias from w into accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x5], 48
+
+ PRFM PLDL1KEEP, [x5]
+ PRFM PLDL1KEEP, [x5, 64]
+ PRFM PLDL1KEEP, [x5, 128]
+ PRFM PLDL1KEEP, [x5, 192]
+ PRFM PLDL1KEEP, [x5, 256]
+ PRFM PLDL1KEEP, [x5, 320]
+
+ MOV x9, x3 // p = ks
+
+1:
+ # Load next A pointer
+ LDR x20, [x4], 8
+
+ CMP x20, x12 // if a0 == zero
+ ADD x20, x20, x11 // a0 += a_offset
+ CSEL x20, x12, x20, EQ // a0 = zero, else += a0 + a_offset
+
+ # Is there at least 4 floats (16 bytes) for prologue + epilogue?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 4f
+
+ # Prologue - loads for first group of 6 fma
+
+ # Read first block of 4 A.
+ LDR d0, [x20], 8 // a0
+
+ LDR d2, [x5] // vb0x0123
+ LDR x14, [x5, 8]
+
+ LDR d3, [x5, 16] // vb0x25567
+ LDR x15, [x5, 24]
+
+ LDR d4, [x5, 32] // vb0x89AB
+ LDR x16, [x5, 40]
+
+ LDR d5, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+
+ LDR d6, [x5, 64] // vb1x25567
+ LDR x18, [x5, 72]
+
+ LDR d7, [x5, 80] // vb1x89AB
+ LDR x7, [x5, 88]
+ INS v2.d[1], x14
+ ADD x5, x5, 96
+
+ # Is there at least 4 floats (16 bytes) for main loop?
+ SUBS x0, x0, 16 // 4 floats for main loop
+ B.LO 3f
+
+ # Main loop - 4 floats of A (16 bytes)
+2:
+ # First group of 6 fma.
+ # A is loaded for 2nd group into v1
+
+ # BLOCK 0
+ LDR d1, [x20], 8 // a0
+ INS v3.d[1], x15
+ FMLA v20.4s, v2.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 192]
+
+ # BLOCK 1
+ INS v4.d[1], x16
+ FMLA v21.4s, v3.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+
+ # BLOCK 2
+ LDR d23, [x5] // vb0x0123
+ INS v5.d[1], x17
+ LDR x14, [x5, 8]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d24, [x5, 16] // vb0x25567
+ INS v6.d[1], x18
+ LDR x15, [x5, 24]
+
+ # BLOCK 4
+ LDR d25, [x5, 32] // vb0x89AB
+ INS v7.d[1], x7
+ FMLA v20.4s, v5.4s, v0.s[1]
+ LDR x16, [x5, 40]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v6.4s, v0.s[1]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x25567
+ LDR x18, [x5, 72]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v23.d[1], x14 // v23 was loaded in block 2
+ LDR x7, [x5, 88]
+
+ # Second group of 6 fma.
+ # A is loaded for 1st group into v0
+
+ # BLOCK 0
+ LDR d0, [x20], 8 // a0
+ INS v24.d[1], x15
+ FMLA v20.4s, v23.4s, v1.s[0]
+
+ # BLOCK 1
+ INS v25.d[1], x16
+ FMLA v21.4s, v24.4s, v1.s[0]
+
+ # BLOCK 2
+ LDR d2, [x5, 96] // vb0x0123
+ INS v17.d[1], x17
+ LDR x14, [x5, 104]
+ FMLA v22.4s, v25.4s, v1.s[0]
+
+ # BLOCK 3
+ LDR d3, [x5, 112] // vb0x25567
+ INS v18.d[1], x18
+ LDR x15, [x5, 120]
+
+ # BLOCK 4
+ LDR d4, [x5, 128] // vb0x89AB
+ INS v19.d[1], x7
+ FMLA v20.4s, v17.4s, v1.s[1]
+ LDR x16, [x5, 136]
+
+ # BLOCK 5
+ LDR d5, [x5, 144] // vb1x0123
+ LDR x17, [x5, 152]
+ FMLA v21.4s, v18.4s, v1.s[1]
+
+ # BLOCK 6
+ LDR d6, [x5, 160] // vb1x25567
+ LDR x18, [x5, 168]
+ SUBS x0, x0, 16
+ FMLA v22.4s, v19.4s, v1.s[1]
+
+ # BLOCK 7
+ LDR d7, [x5, 176] // vb1x89AB
+ INS v2.d[1], x14
+ LDR x7, [x5, 184]
+ ADD x5, x5, 192
+ B.HS 2b
+
+ # Epilogue
+ # First block same as main loop. Second block has no loads.
+3:
+ # BLOCK 0
+ LDR d1, [x20], 8 // a0
+ INS v3.d[1], x15
+ FMLA v20.4s, v2.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 192]
+
+ # BLOCK 1
+ INS v4.d[1], x16
+ FMLA v21.4s, v3.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+
+ # BLOCK 2
+ LDR d23, [x5] // vb0x0123
+ INS v5.d[1], x17
+ LDR x14, [x5, 8]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d24, [x5, 16] // vb0x25567
+ INS v6.d[1], x18
+ LDR x15, [x5, 24]
+
+ # BLOCK 4
+ LDR d25, [x5, 32] // vb0x89AB
+ INS v7.d[1], x7
+ FMLA v20.4s, v5.4s, v0.s[1]
+ LDR x16, [x5, 40]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v6.4s, v0.s[1]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x25567
+ LDR x18, [x5, 72]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v23.d[1], x14 // v23 was loaded in block 2
+ LDR x7, [x5, 88]
+ ADD x5, x5, 96
+
+ # Second group of 6 fma. 8 blocks of 4 cycles.
+ # Epilogue version does no loads
+
+ # BLOCK 0
+ INS v24.d[1], x15
+ FMLA v20.4s, v23.4s, v1.s[0]
+
+ # BLOCK 1
+ INS v25.d[1], x16
+ FMLA v21.4s, v24.4s, v1.s[0]
+
+ # BLOCK 2
+ INS v17.d[1], x17
+ FMLA v22.4s, v25.4s, v1.s[0]
+
+ # BLOCK 3
+ INS v18.d[1], x18
+
+ # BLOCK 4
+ INS v19.d[1], x7
+ FMLA v20.4s, v17.4s, v1.s[1]
+
+ # BLOCK 5
+ FMLA v21.4s, v18.4s, v1.s[1]
+
+ # BLOCK 6
+ FMLA v22.4s, v19.4s, v1.s[1]
+
+ # BLOCK 7
+
+4:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 7f
+
+5:
+ # ks loop
+ SUBS x9, x9, 8 // ks -= MR * sizeof(void*)
+ B.NE 1b
+
+ # Clamp
+ FMIN v20.4s, v20.4s, v30.4s
+ FMIN v21.4s, v21.4s, v30.4s
+ FMIN v22.4s, v22.4s, v30.4s
+ FMAX v20.4s, v20.4s, v31.4s
+ FMAX v21.4s, v21.4s, v31.4s
+ FMAX v22.4s, v22.4s, v31.4s
+
+ # Store full 1 x 12
+ CMP x1, 12
+ B.LO 8f
+
+ ST1 {v20.16b, v21.16b, v22.16b}, [x6], x10
+
+ SUB x4, x4, x3 // a -= ks
+
+ # nc loop
+ SUBS x1, x1, 12
+ B.HI 0b
+
+ # Restore x20,x21 from stack
+ LDP x20, x21, [sp], 16
+ RET
+
+6:
+ # Remainder - 2 floats of A (8 bytes)
+ # Read first block of 4 A.
+ LDR d0, [x20], 8 // a0
+ LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
+ LD1 {v5.16b, v6.16b, v7.16b}, [x5], 48
+
+ # First block of 3 B
+ FMLA v20.4s, v2.4s, v0.s[0]
+ FMLA v21.4s, v3.4s, v0.s[0]
+ FMLA v22.4s, v4.4s, v0.s[0]
+
+ # Second block of 3 B
+ FMLA v20.4s, v5.4s, v0.s[1]
+ FMLA v21.4s, v6.4s, v0.s[1]
+ FMLA v22.4s, v7.4s, v0.s[1]
+
+ TBZ x0, 2, 5b
+7:
+ # Remainder - 1 float of A (4 bytes)
+ LDR s0, [x20], 4 // a0
+ LD1 {v2.16b, v3.16b, v4.16b}, [x5], 48
+
+ FMLA v20.4s, v2.4s, v0.s[0]
+ FMLA v21.4s, v3.4s, v0.s[0]
+ FMLA v22.4s, v4.4s, v0.s[0]
+ B 5b
+
+8:
+ # Store odd channels
+ TBZ x1, 3, 9f
+ STP q20, q21, [x6]
+ ADD x6, x6, 32
+ MOV v20.16b, v22.16b
+
+9:
+ TBZ x1, 2, 10f
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+10:
+ TBZ x1, 1, 11f
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+11:
+ TBZ x1, 0, 12f
+ STR s20, [x6]
+12:
+ # Restore x20,x21 from stack
+ LDP x20, x21, [sp], 16
+ RET
+
+END_FUNCTION xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-igemm/1x4-scalar.c b/src/f32-igemm/1x4-scalar.c
new file mode 100644
index 0000000..daf8757
--- /dev/null
+++ b/src/f32-igemm/1x4-scalar.c
@@ -0,0 +1,111 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_igemm_ukernel_1x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+
+ do {
+ float vacc00 = w[0];
+ float vacc01 = w[1];
+ float vacc02 = w[2];
+ float vacc03 = w[3];
+ w += 4;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc02 += va0 * vb2;
+ vacc03 += va0 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc02 = math_max_f32(vacc02, vmin);
+ vacc03 = math_max_f32(vacc03, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc02 = math_min_f32(vacc02, vmax);
+ vacc03 = math_min_f32(vacc03, vmax);
+
+ if XNN_LIKELY(nc >= 4) {
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0[2] = vacc02;
+ c0[3] = vacc03;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ vacc00 = vacc02;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/1x8-aarch64-neonfma-cortex-a57.S b/src/f32-igemm/1x8-aarch64-neonfma-cortex-a57.S
new file mode 100644
index 0000000..a589fbf
--- /dev/null
+++ b/src/f32-igemm/1x8-aarch64-neonfma-cortex-a57.S
@@ -0,0 +1,239 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# size_t ks, x3 / x9
+# const float**restrict a, x4
+# const float*restrict w, x5
+# float*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x10
+# size_t a_offset, [sp + 8] -> x11
+# const float* zero, [sp + 16] -> x12
+# const xnn_f32_output_params params [sp + 24] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x20 a0
+
+# C pointers
+# x6 c0
+
+# A57 based on A75 but with PRFM removed from main loop
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp]
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+ # Save x20,x21 on stack
+ STP x20, x21, [sp, -16]!
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOVI v18.4s, 0 // second set of C for pipelining FMLA
+ MOVI v19.4s, 0
+
+ MOV x9, x3 // p = ks
+
+1:
+ # Load next A pointer
+ LDR x20, [x4], 8
+
+ CMP x20, x12 // if a0 == zero
+ ADD x20, x20, x11 // a0 += a_offset
+ CSEL x20, x12, x20, EQ // a0 = zero, else += a0 + a_offset
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32 // k = kc
+ B.LO 4f
+
+ # 16 prologue
+ # Read first block of A and B.
+ LDP q20, q21, [x5], 32
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ LDR q0, [x20], 16
+
+ # Is there at least 8. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 3f
+
+ # Main loop - 8 floats of A (32 bytes)
+2:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x20], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. FMA for second 4, loads for 1st block of 4.
+ FMLA v16.4s, v20.4s, v1.s[0]
+ LDR q0, [x20], 16
+ FMLA v17.4s, v21.4s, v1.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ SUBS x0, x0, 32
+ LDP q26, q27, [x5], 32
+ B.HS 2b
+
+3:
+ # Epilogue
+
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x20], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. no loads
+ FMLA v16.4s, v20.4s, v1.s[0]
+ FMLA v17.4s, v21.4s, v1.s[0]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+
+4:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBNZ x0, 4, 6f
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 7f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 9f
+
+5:
+ # ks loop
+ SUBS x9, x9, 8 // ks -= MR * sizeof(void*)
+ B.NE 1b
+
+ FADD v16.4s, v16.4s, v18.4s
+ FADD v17.4s, v17.4s, v19.4s
+
+ # Clamp
+ FMIN v16.4s, v16.4s, v30.4s
+ FMIN v17.4s, v17.4s, v30.4s
+ FMAX v16.4s, v16.4s, v31.4s
+ FMAX v17.4s, v17.4s, v31.4s
+
+ # Store full 1 x 8
+ CMP x1, 8
+ B.LO 10f
+
+ STP q16, q17, [x6]
+ ADD x6, x6, x10
+
+ SUB x4, x4, x3 // a -= ks
+
+ # nc loop
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore x20,x21 from stack
+ LDP x20, x21, [sp], 16
+ RET
+
+6:
+ # Remainder- 4 floats of A (16 bytes)
+ LDP q20, q21, [x5], 32
+ LDR q0, [x20], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+
+ TBZ x0, 3, 8f
+7:
+ # Remainder- 2 floats of A (8 bytes)
+ LDP q20, q21, [x5], 32
+ LDR d0, [x20], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+8:
+ TBZ x0, 2, 5b
+9:
+ # Remainder- 1 float of A (4 bytes)
+ LDP q20, q21, [x5], 32
+ LDR s0, [x20], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ B 5b
+
+10:
+ # Store odd channels
+ TBZ x1, 2, 11f
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+11:
+ TBZ x1, 1, 12f
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+12:
+ TBZ x1, 0, 13f
+ STR s16, [x6], 4
+13:
+ # Restore x20,x21 from stack
+ LDP x20, x21, [sp], 16
+ RET
+
+END_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-igemm/1x8-aarch64-neonfma-cortex-a75.S b/src/f32-igemm/1x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..ec33ee3
--- /dev/null
+++ b/src/f32-igemm/1x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,243 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75(
+# size_t mr, (x0) - unused. mr = 1
+# size_t nc, x1
+# size_t kc, x2 / x0
+# size_t ks, x3 / x9
+# const float**restrict a, x4
+# const float*restrict w, x5
+# float*restrict c, x6
+# size_t cm_stride, (x7) - unused
+# size_t cn_stride, [sp] -> x10
+# size_t a_offset, [sp + 8] -> x11
+# const float* zero, [sp + 16] -> x12
+# const xnn_f32_output_params params [sp + 24] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x20 a0
+
+# C pointers
+# x6 c0
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp]
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+ # Save x20,x21 on stack
+ STP x20, x21, [sp, -16]!
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOVI v18.4s, 0 // second set of C for pipelining FMLA
+ MOVI v19.4s, 0
+
+ MOV x9, x3 // p = ks
+
+1:
+ # Load next A pointer
+ LDR x20, [x4], 8
+
+ CMP x20, x12 // if a0 == zero
+ ADD x20, x20, x11 // a0 += a_offset
+ CSEL x20, x12, x20, EQ // a0 = zero, else += a0 + a_offset
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32 // k = kc
+ B.LO 4f
+
+ # 16 prologue
+ # Read first block of A and B.
+ LDP q20, q21, [x5], 32
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ LDR q0, [x20], 16
+
+ # Is there at least 8. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 3f
+
+ # Main loop - 8 floats of A (32 bytes)
+2:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x20], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. FMA for second 4, loads for 1st block of 4.
+ FMLA v16.4s, v20.4s, v1.s[0]
+ LDR q0, [x20], 16
+ FMLA v17.4s, v21.4s, v1.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ SUBS x0, x0, 32
+ LDP q26, q27, [x5], 32
+ B.HS 2b
+
+3:
+ # Epilogue
+
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDR q1, [x20], 16
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ LDP q22, q23, [x5], 32
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ LDP q24, q25, [x5], 32
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v19.4s, v27.4s, v0.s[3]
+ LDP q26, q27, [x5], 32
+
+ # Second block of 4. no loads
+ FMLA v16.4s, v20.4s, v1.s[0]
+ FMLA v17.4s, v21.4s, v1.s[0]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v16.4s, v24.4s, v1.s[2]
+ FMLA v17.4s, v25.4s, v1.s[2]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+
+4:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBNZ x0, 4, 6f
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 7f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 9f
+
+5:
+ # ks loop
+ SUBS x9, x9, 8 // ks -= MR * sizeof(void*)
+ B.NE 1b
+
+ FADD v16.4s, v16.4s, v18.4s
+ FADD v17.4s, v17.4s, v19.4s
+
+ # Clamp
+ FMIN v16.4s, v16.4s, v30.4s
+ FMIN v17.4s, v17.4s, v30.4s
+ FMAX v16.4s, v16.4s, v31.4s
+ FMAX v17.4s, v17.4s, v31.4s
+
+ # Store full 1 x 8
+ CMP x1, 8
+ B.LO 10f
+
+ STP q16, q17, [x6]
+ ADD x6, x6, x10
+
+ SUB x4, x4, x3 // a -= ks
+
+ # nc loop
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore x20,x21 from stack
+ LDP x20, x21, [sp], 16
+ RET
+
+6:
+ # Remainder- 4 floats of A (16 bytes)
+ LDP q20, q21, [x5], 32
+ LDR q0, [x20], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v26.4s, v0.s[3]
+ FMLA v19.4s, v27.4s, v0.s[3]
+
+ TBZ x0, 3, 8f
+7:
+ # Remainder- 2 floats of A (8 bytes)
+ LDP q20, q21, [x5], 32
+ LDR d0, [x20], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v22.4s, v0.s[1]
+ FMLA v19.4s, v23.4s, v0.s[1]
+8:
+ TBZ x0, 2, 5b
+9:
+ # Remainder- 1 float of A (4 bytes)
+ LDP q20, q21, [x5], 32
+ LDR s0, [x20], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ B 5b
+
+10:
+ # Store odd channels
+ TBZ x1, 2, 11f
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+11:
+ TBZ x1, 1, 12f
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+12:
+ TBZ x1, 0, 13f
+ STR s16, [x6], 4
+13:
+ # Restore x20,x21 from stack
+ LDP x20, x21, [sp], 16
+ RET
+
+END_FUNCTION xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-igemm/1x8-neon-ld64.c b/src/f32-igemm/1x8-neon-ld64.c
new file mode 100644
index 0000000..5ada84d
--- /dev/null
+++ b/src/f32-igemm/1x8-neon-ld64.c
@@ -0,0 +1,115 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_1x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0);
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ }
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/1x8-psimd-loadsplat.c b/src/f32-igemm/1x8-psimd-loadsplat.c
new file mode 100644
index 0000000..e07f53f
--- /dev/null
+++ b/src/f32-igemm/1x8-psimd-loadsplat.c
@@ -0,0 +1,108 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/psimd-loadsplat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_1x8__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/1x8-psimd-splat.c b/src/f32-igemm/1x8-psimd-splat.c
new file mode 100644
index 0000000..bd9d727
--- /dev/null
+++ b/src/f32-igemm/1x8-psimd-splat.c
@@ -0,0 +1,146 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/psimd-splat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_1x8__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+
+ const psimd_f32 va0c0 = psimd_splat0_f32(va0);
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c0, vb4567c0);
+ const psimd_f32 va0c1 = psimd_splat1_f32(va0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c1, vb4567c1);
+ const psimd_f32 va0c2 = psimd_splat2_f32(va0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c2, vb4567c2);
+ const psimd_f32 va0c3 = psimd_splat3_f32(va0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c3, vb4567c3);
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/1x8-sse-dup.c b/src/f32-igemm/1x8-sse-dup.c
new file mode 100644
index 0000000..93b0bdd
--- /dev/null
+++ b/src/f32-igemm/1x8-sse-dup.c
@@ -0,0 +1,150 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/sse-dup.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_1x8__sse_dup(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+
+
+ const __m128 va0c0000 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 0, 0, 0));
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c0000, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c0000, vb4567c0));
+
+ const __m128 va0c1111 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(1, 1, 1, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c1111, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c1111, vb4567c1));
+
+ const __m128 va0c2222 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(2, 2, 2, 2));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c2222, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c2222, vb4567c2));
+
+ const __m128 va0c3333 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(3, 3, 3, 3));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c3333, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c3333, vb4567c3));
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/1x8-sse-load1.c b/src/f32-igemm/1x8-sse-load1.c
new file mode 100644
index 0000000..f410378
--- /dev/null
+++ b/src/f32-igemm/1x8-sse-load1.c
@@ -0,0 +1,108 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/sse-load1.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_1x8__sse_load1(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ do {
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/1x8s4-psimd.c b/src/f32-igemm/1x8s4-psimd.c
new file mode 100644
index 0000000..6e5007e
--- /dev/null
+++ b/src/f32-igemm/1x8s4-psimd.c
@@ -0,0 +1,149 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/psimd-s4.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_1x8s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c0);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c1);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c2);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c3);
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/1x8s4-sse.c b/src/f32-igemm/1x8s4-sse.c
new file mode 100644
index 0000000..a42f014
--- /dev/null
+++ b/src/f32-igemm/1x8s4-sse.c
@@ -0,0 +1,149 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/sse-shuffle.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_1x8s4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (1 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c0));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c1));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c2));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c3));
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc0x0123 = vacc0x4567;
+
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/2x4-scalar.c b/src/f32-igemm/2x4-scalar.c
new file mode 100644
index 0000000..9f94be8
--- /dev/null
+++ b/src/f32-igemm/2x4-scalar.c
@@ -0,0 +1,146 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_igemm_ukernel_2x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (2 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ c1 = c0;
+ }
+
+ do {
+ float vacc00 = w[0];
+ float vacc01 = w[1];
+ float vacc02 = w[2];
+ float vacc03 = w[3];
+ float vacc10 = vacc00;
+ float vacc11 = vacc01;
+ float vacc12 = vacc02;
+ float vacc13 = vacc03;
+ w += 4;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ a += 2;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+ const float va1 = *a1++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc02 += va0 * vb2;
+ vacc03 += va0 * vb3;
+ vacc10 += va1 * vb0;
+ vacc11 += va1 * vb1;
+ vacc12 += va1 * vb2;
+ vacc13 += va1 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= 2 * sizeof(void*);
+ } while (p != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc02 = math_max_f32(vacc02, vmin);
+ vacc03 = math_max_f32(vacc03, vmin);
+ vacc10 = math_max_f32(vacc10, vmin);
+ vacc11 = math_max_f32(vacc11, vmin);
+ vacc12 = math_max_f32(vacc12, vmin);
+ vacc13 = math_max_f32(vacc13, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc02 = math_min_f32(vacc02, vmax);
+ vacc03 = math_min_f32(vacc03, vmax);
+ vacc10 = math_min_f32(vacc10, vmax);
+ vacc11 = math_min_f32(vacc11, vmax);
+ vacc12 = math_min_f32(vacc12, vmax);
+ vacc13 = math_min_f32(vacc13, vmax);
+
+ if XNN_LIKELY(nc >= 4) {
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ c1[2] = vacc12;
+ c1[3] = vacc13;
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0[2] = vacc02;
+ c0[3] = vacc03;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ vacc10 = vacc12;
+ c1 += 2;
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ vacc00 = vacc02;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ c1[0] = vacc10;
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S b/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S
new file mode 100644
index 0000000..16927f8
--- /dev/null
+++ b/src/f32-igemm/4x12-aarch64-neonfma-cortex-a53.S
@@ -0,0 +1,616 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# size_t ks, x3 / x9
+# const float**restrict a, x4
+# const float*restrict w, x5
+# float*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x10
+# size_t a_offset, [sp + 8] -> x11
+# const float* zero, [sp + 16] -> x12
+# const xnn_f32_output_params params [sp + 24] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x22 a0
+# x23 a1
+# x24 a2
+# x25 a3
+
+# C pointers
+# x6 c0
+# x20 c1
+# x21 c2
+# x7 c3 / cm_stride
+
+# Vector register usage and GPR shadows
+# a0 v0 first set of A
+# a1 v0[1] x13
+# a2 v1
+# a3 v1[1] x8
+# a0 v2 second set of A
+# a1 v2[1] x13
+# a2 v3
+# a3 v3[1] x8
+# B v4 v5 v6 x14 x15 x16 first set of B
+# B v7 v8 v9 x17 x18 x19
+# B v14 v15 v16 x14 x15 x16 second set of B (same x as first set)
+# B v17 v18 v19 x17 x18 x19
+# C v20 v21 v22
+# C v23 v24 v25
+# C v26 v27 v28
+# C v29 v30 v31
+# Clamp v10 v11
+# v12 to v13 unused.
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp]
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 16]
+
+ # Save x19-x25 on stack
+ STR x19, [sp, -112]!
+ STP x20, x21, [sp, 16]
+ STP x22, x23, [sp, 32]
+ STP x24, x25, [sp, 48]
+
+ # Save d8-d11,d14,d15 on stack
+ STP d8, d9, [sp, 64]
+ STP d10, d11, [sp, 80]
+ STP d14, d15, [sp, 96]
+
+ # Load clamping_params values
+ LD2R {v10.4s, v11.4s}, [x8]
+
+ # Clamp C pointers
+ ADD x20, x6, x7 // c1 = c0 + cm_stride
+
+ CMP x0, 2 // if mr < 2
+ CSEL x20, x6, x20, LO // c1 = c0
+
+ ADD x21, x20, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+
+ CSEL x21, x20, x21, LS // c2 = c1
+
+ ADD x7, x21, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x7, x21, x7, LO // c3 = c2
+
+0:
+ # Load initial bias from w into accumulators
+ LD1 {v20.16b, v21.16b, v22.16b}, [x5], 48
+ MOV v23.16b, v20.16b
+ PRFM PLDL1KEEP, [x5]
+ MOV v24.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v25.16b, v22.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v26.16b, v20.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 256]
+ MOV v28.16b, v22.16b
+ PRFM PLDL1KEEP, [x5, 320]
+ MOV v29.16b, v20.16b
+ MOV v30.16b, v21.16b
+ MOV v31.16b, v22.16b
+
+ MOV x9, x3 // p = ks
+
+1:
+ # Load next 4 A pointers
+ LDP x22, x23, [x4], 16
+ LDP x24, x25, [x4], 16
+
+ CMP x22, x12 // if a0 == zero
+ ADD x22, x22, x11 // a0 += a_offset
+ CSEL x22, x12, x22, EQ // a0 = zero, else += a0 + a_offset
+ CMP x23, x12 // if a1 == zero
+ ADD x23, x23, x11 // a1 += a_offset
+ CSEL x23, x12, x23, EQ // a1 = zero, else += a1 + a_offset
+ CMP x24, x12 // if a2 == zero
+ ADD x24, x24, x11 // a2 += a_offset
+ CSEL x24, x12, x24, EQ // a2 = zero, else += a2 + a_offset
+ CMP x25, x12 // if a3 == zero
+ ADD x25, x25, x11 // a3 += a_offset
+ CSEL x25, x12, x25, EQ // a3 = zero, else += a3 + a_offset
+
+ # Is there at least 4 floats (16 bytes) for prologue + epilogue?
+ SUBS x0, x2, 16 // k = kc - 16
+ B.LO 4f
+
+ SUBS x0, x0, 16 // 4 floats for main loop
+
+ # Prologue - loads for first group of 24 FMA
+
+ # Read first block of 4 A.
+ LDR d0, [x22], 8 // a0
+ LDR x13, [x23], 8 // a1
+ LDR d1, [x24], 8 // a2
+ LDR x8, [x25], 8 // a3
+
+ LDR d4, [x5] // vb0x0123
+ LDR x14, [x5, 8]
+
+ LDR d5, [x5, 16] // vb0x25567
+ LDR x15, [x5, 24]
+
+ LDR d6, [x5, 32] // vb0x89AB
+ LDR x16, [x5, 40]
+
+ LDR d7, [x5, 48] // vb1x0123
+ INS v0.d[1], x13
+ LDR x17, [x5, 56]
+
+ LDR d8, [x5, 64] // vb1x25567
+ INS v1.d[1], x8
+ LDR x18, [x5, 72]
+
+ LDR d9, [x5, 80] // vb1x89AB
+ LDR x19, [x5, 88]
+ INS v4.d[1], x14
+ ADD x5, x5, 96
+
+ # Is there at least 4 floats (16 bytes) for main loop?
+ B.LO 3f
+
+ # Main loop - 4 floats of A (16 bytes)
+2:
+ # First group of 24 fma. 8 blocks of 4 cycles. LDR + 3 FMA
+ # A is loaded for 2nd group into v2/v3
+ # INS is 4 blocks (16 cycles) after load
+
+ # BLOCK 0
+ LDR d2, [x22], 8 // a0
+ INS v5.d[1], x15
+ FMLA v20.4s, v4.4s, v0.s[0]
+ LDR x13, [x23], 8 // a1
+ FMLA v23.4s, v4.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v26.4s, v4.4s, v1.s[0]
+
+ # BLOCK 1
+ LDR d3, [x24], 8 // a2
+ INS v6.d[1], x16
+ FMLA v29.4s, v4.4s, v1.s[2]
+ LDR x8, [x25], 8 // a3
+ FMLA v21.4s, v5.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v24.4s, v5.4s, v0.s[2]
+
+ # BLOCK 2
+ LDR d14, [x5] // vb0x0123
+ INS v7.d[1], x17
+ FMLA v27.4s, v5.4s, v1.s[0]
+ LDR x14, [x5, 8]
+ FMLA v30.4s, v5.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v6.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d15, [x5, 16] // vb0x25567
+ INS v8.d[1], x18
+ FMLA v25.4s, v6.4s, v0.s[2]
+ LDR x15, [x5, 24]
+ FMLA v28.4s, v6.4s, v1.s[0]
+ FMLA v31.4s, v6.4s, v1.s[2]
+
+ # BLOCK 4
+ LDR d16, [x5, 32] // vb0x89AB
+ INS v9.d[1], x19
+ FMLA v20.4s, v7.4s, v0.s[1]
+ LDR x16, [x5, 40]
+ FMLA v23.4s, v7.4s, v0.s[3]
+ FMLA v26.4s, v7.4s, v1.s[1]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ INS v2.d[1], x13 // a1 was loaded in block 0
+ FMLA v29.4s, v7.4s, v1.s[3]
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v8.4s, v0.s[1]
+ FMLA v24.4s, v8.4s, v0.s[3]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x25567
+ INS v3.d[1], x8 // a3 was loaded in block 1
+ FMLA v27.4s, v8.4s, v1.s[1]
+ LDR x18, [x5, 72]
+ FMLA v30.4s, v8.4s, v1.s[3]
+ FMLA v22.4s, v9.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v14.d[1], x14 // v14 was loaded in block 2
+ FMLA v25.4s, v9.4s, v0.s[3]
+ LDR x19, [x5, 88]
+ FMLA v28.4s, v9.4s, v1.s[1]
+ FMLA v31.4s, v9.4s, v1.s[3]
+
+ # Second group of 24 fma. 8 blocks of 4 cycles. LDR + 3 FMA
+ # A is loaded for 1st group into v0/v1
+
+ # BLOCK 0
+ LDR d0, [x22], 8 // a0
+ INS v15.d[1], x15
+ FMLA v20.4s, v14.4s, v2.s[0]
+ LDR x13, [x23], 8 // a1
+ FMLA v23.4s, v14.4s, v2.s[2]
+ FMLA v26.4s, v14.4s, v3.s[0]
+
+ # BLOCK 1
+ LDR d1, [x24], 8 // a2
+ INS v16.d[1], x16
+ FMLA v29.4s, v14.4s, v3.s[2]
+ LDR x8, [x25], 8 // a3
+ FMLA v21.4s, v15.4s, v2.s[0]
+ FMLA v24.4s, v15.4s, v2.s[2]
+
+ # BLOCK 2
+ LDR d4, [x5, 96] // vb0x0123
+ INS v17.d[1], x17
+ FMLA v27.4s, v15.4s, v3.s[0]
+ LDR x14, [x5, 104]
+ FMLA v30.4s, v15.4s, v3.s[2]
+ FMLA v22.4s, v16.4s, v2.s[0]
+
+ # BLOCK 3
+ LDR d5, [x5, 112] // vb0x25567
+ INS v18.d[1], x18
+ FMLA v25.4s, v16.4s, v2.s[2]
+ LDR x15, [x5, 120]
+ FMLA v28.4s, v16.4s, v3.s[0]
+ FMLA v31.4s, v16.4s, v3.s[2]
+
+ # BLOCK 4
+ LDR d6, [x5, 128] // vb0x89AB
+ INS v19.d[1], x19
+ FMLA v20.4s, v17.4s, v2.s[1]
+ LDR x16, [x5, 136]
+ FMLA v23.4s, v17.4s, v2.s[3]
+ FMLA v26.4s, v17.4s, v3.s[1]
+
+ # BLOCK 5
+ LDR d7, [x5, 144] // vb1x0123
+ INS v0.d[1], x13 // a1
+ FMLA v29.4s, v17.4s, v3.s[3]
+ LDR x17, [x5, 152]
+ FMLA v21.4s, v18.4s, v2.s[1]
+ FMLA v24.4s, v18.4s, v2.s[3]
+
+ # BLOCK 6
+ LDR d8, [x5, 160] // vb1x25567
+ INS v1.d[1], x8 // a3
+ FMLA v27.4s, v18.4s, v3.s[1]
+ LDR x18, [x5, 168]
+ FMLA v30.4s, v18.4s, v3.s[3]
+ SUBS x0, x0, 16
+ FMLA v22.4s, v19.4s, v2.s[1]
+
+ # BLOCK 7
+ LDR d9, [x5, 176] // vb1x89AB
+ INS v4.d[1], x14
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDR x19, [x5, 184]
+ FMLA v28.4s, v19.4s, v3.s[1]
+ ADD x5, x5, 192
+ FMLA v31.4s, v19.4s, v3.s[3]
+ B.HS 2b
+
+ # Epilogue
+ # First block same as main loop. Second block has no loads.
+3:
+ # BLOCK 0
+ LDR d2, [x22], 8 // a0
+ INS v5.d[1], x15
+ FMLA v20.4s, v4.4s, v0.s[0]
+ LDR x13, [x23], 8 // a1
+ FMLA v23.4s, v4.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v26.4s, v4.4s, v1.s[0]
+
+ # BLOCK 1
+ LDR d3, [x24], 8 // a2
+ INS v6.d[1], x16
+ FMLA v29.4s, v4.4s, v1.s[2]
+ LDR x8, [x25], 8 // a3
+ FMLA v21.4s, v5.4s, v0.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v24.4s, v5.4s, v0.s[2]
+
+ # BLOCK 2
+ LDR d14, [x5] // vb0x0123
+ INS v7.d[1], x17
+ FMLA v27.4s, v5.4s, v1.s[0]
+ LDR x14, [x5, 8]
+ FMLA v30.4s, v5.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v22.4s, v6.4s, v0.s[0]
+
+ # BLOCK 3
+ LDR d15, [x5, 16] // vb0x25567
+ INS v8.d[1], x18
+ FMLA v25.4s, v6.4s, v0.s[2]
+ LDR x15, [x5, 24]
+ FMLA v28.4s, v6.4s, v1.s[0]
+ FMLA v31.4s, v6.4s, v1.s[2]
+
+ # BLOCK 4
+ LDR d16, [x5, 32] // vb0x89AB
+ INS v9.d[1], x19
+ FMLA v20.4s, v7.4s, v0.s[1]
+ LDR x16, [x5, 40]
+ FMLA v23.4s, v7.4s, v0.s[3]
+ FMLA v26.4s, v7.4s, v1.s[1]
+
+ # BLOCK 5
+ LDR d17, [x5, 48] // vb1x0123
+ INS v2.d[1], x13 // a1 was loaded in block 0
+ FMLA v29.4s, v7.4s, v1.s[3]
+ LDR x17, [x5, 56]
+ FMLA v21.4s, v8.4s, v0.s[1]
+ FMLA v24.4s, v8.4s, v0.s[3]
+
+ # BLOCK 6
+ LDR d18, [x5, 64] // vb1x25567
+ INS v3.d[1], x8 // a3 was loaded in block 1
+ FMLA v27.4s, v8.4s, v1.s[1]
+ LDR x18, [x5, 72]
+ FMLA v30.4s, v8.4s, v1.s[3]
+ FMLA v22.4s, v9.4s, v0.s[1]
+
+ # BLOCK 7
+ LDR d19, [x5, 80] // vb1x89AB
+ INS v14.d[1], x14 // v14 was loaded in block 2
+ FMLA v25.4s, v9.4s, v0.s[3]
+ LDR x19, [x5, 88]
+ FMLA v28.4s, v9.4s, v1.s[1]
+ ADD x5, x5, 96
+ FMLA v31.4s, v9.4s, v1.s[3]
+
+ # Second group of 24 fma. 8 blocks of 4 cycles.
+ # Epilogue version does no loads
+
+ # BLOCK 0
+ INS v15.d[1], x15
+ FMLA v20.4s, v14.4s, v2.s[0]
+ FMLA v23.4s, v14.4s, v2.s[2]
+ FMLA v26.4s, v14.4s, v3.s[0]
+
+ # BLOCK 1
+ INS v16.d[1], x16
+ FMLA v29.4s, v14.4s, v3.s[2]
+ FMLA v21.4s, v15.4s, v2.s[0]
+ FMLA v24.4s, v15.4s, v2.s[2]
+
+ # BLOCK 2
+ INS v17.d[1], x17
+ FMLA v27.4s, v15.4s, v3.s[0]
+ FMLA v30.4s, v15.4s, v3.s[2]
+ FMLA v22.4s, v16.4s, v2.s[0]
+
+ # BLOCK 3
+ INS v18.d[1], x18
+ FMLA v25.4s, v16.4s, v2.s[2]
+ FMLA v28.4s, v16.4s, v3.s[0]
+ FMLA v31.4s, v16.4s, v3.s[2]
+
+ # BLOCK 4
+ INS v19.d[1], x19
+ FMLA v20.4s, v17.4s, v2.s[1]
+ FMLA v23.4s, v17.4s, v2.s[3]
+ FMLA v26.4s, v17.4s, v3.s[1]
+
+ # BLOCK 5
+ FMLA v29.4s, v17.4s, v3.s[3]
+ FMLA v21.4s, v18.4s, v2.s[1]
+ FMLA v24.4s, v18.4s, v2.s[3]
+
+ # BLOCK 6
+ FMLA v27.4s, v18.4s, v3.s[1]
+ FMLA v30.4s, v18.4s, v3.s[3]
+ FMLA v22.4s, v19.4s, v2.s[1]
+
+ # BLOCK 7
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v28.4s, v19.4s, v3.s[1]
+ FMLA v31.4s, v19.4s, v3.s[3]
+
+4:
+ # Is there a remainder?- 2 floats of A (8 bytes)
+ TBNZ x0, 3, 6f
+ # Is there a remainder?- 1 floats of A (4 bytes)
+ TBNZ x0, 2, 7f
+
+5:
+ # ks loop
+ SUBS x9, x9, 32 // ks -= MR * sizeof(void*)
+ B.NE 1b
+
+ # Clamp
+ FMIN v20.4s, v20.4s, v10.4s
+ FMIN v21.4s, v21.4s, v10.4s
+ FMIN v22.4s, v22.4s, v10.4s
+ FMIN v23.4s, v23.4s, v10.4s
+ FMIN v24.4s, v24.4s, v10.4s
+ FMIN v25.4s, v25.4s, v10.4s
+ FMIN v26.4s, v26.4s, v10.4s
+ FMIN v27.4s, v27.4s, v10.4s
+ FMIN v28.4s, v28.4s, v10.4s
+ FMIN v29.4s, v29.4s, v10.4s
+ FMIN v30.4s, v30.4s, v10.4s
+ FMIN v31.4s, v31.4s, v10.4s
+ FMAX v20.4s, v20.4s, v11.4s
+ FMAX v21.4s, v21.4s, v11.4s
+ FMAX v22.4s, v22.4s, v11.4s
+ FMAX v23.4s, v23.4s, v11.4s
+ FMAX v24.4s, v24.4s, v11.4s
+ FMAX v25.4s, v25.4s, v11.4s
+ FMAX v26.4s, v26.4s, v11.4s
+ FMAX v27.4s, v27.4s, v11.4s
+ FMAX v28.4s, v28.4s, v11.4s
+ FMAX v29.4s, v29.4s, v11.4s
+ FMAX v30.4s, v30.4s, v11.4s
+ FMAX v31.4s, v31.4s, v11.4s
+
+ # Store full 4 x 12
+ CMP x1, 12
+ B.LO 8f
+
+ ST1 {v29.16b, v30.16b, v31.16b}, [x7], x10
+ ST1 {v26.16b, v27.16b, v28.16b}, [x21], x10
+ ST1 {v23.16b, v24.16b, v25.16b}, [x20], x10
+ ST1 {v20.16b, v21.16b, v22.16b}, [x6], x10
+
+ SUB x4, x4, x3 // a -= ks
+
+ # nc loop
+ SUBS x1, x1, 12
+ B.HI 0b
+
+ # Restore d8-d11,d14,d15 from stack
+ LDP d14, d15, [sp, 96]
+ LDP d10, d11, [sp, 80]
+ LDP d8, d9, [sp, 64]
+
+ # Restore x19-x25 from stack
+ LDP x24, x25, [sp, 48]
+ LDP x22, x23, [sp, 32]
+ LDP x20, x21, [sp, 16]
+ LDR x19, [sp], 112
+ RET
+
+6:
+ # Remainder - 2 floats of A (8 bytes)
+ # Read first block of 4 A.
+ LDR d0, [x22], 8 // a0
+ LDR d1, [x23], 8 // a1
+ LDR d2, [x24], 8 // a2
+ LDR d3, [x25], 8 // a3
+ LD1 {v4.16b, v5.16b, v6.16b}, [x5], 48
+ LD1 {v7.16b, v8.16b, v9.16b}, [x5], 48
+
+ # First block of 3 B
+ FMLA v20.4s, v4.4s, v0.s[0]
+ FMLA v23.4s, v4.4s, v1.s[0]
+ FMLA v26.4s, v4.4s, v2.s[0]
+ FMLA v29.4s, v4.4s, v3.s[0]
+ FMLA v21.4s, v5.4s, v0.s[0]
+ FMLA v24.4s, v5.4s, v1.s[0]
+ FMLA v27.4s, v5.4s, v2.s[0]
+ FMLA v30.4s, v5.4s, v3.s[0]
+ FMLA v22.4s, v6.4s, v0.s[0]
+ FMLA v25.4s, v6.4s, v1.s[0]
+ FMLA v28.4s, v6.4s, v2.s[0]
+ FMLA v31.4s, v6.4s, v3.s[0]
+
+ # Second block of 3 B
+ FMLA v20.4s, v7.4s, v0.s[1]
+ FMLA v23.4s, v7.4s, v1.s[1]
+ FMLA v26.4s, v7.4s, v2.s[1]
+ FMLA v29.4s, v7.4s, v3.s[1]
+ FMLA v21.4s, v8.4s, v0.s[1]
+ FMLA v24.4s, v8.4s, v1.s[1]
+ FMLA v27.4s, v8.4s, v2.s[1]
+ FMLA v30.4s, v8.4s, v3.s[1]
+ FMLA v22.4s, v9.4s, v0.s[1]
+ FMLA v25.4s, v9.4s, v1.s[1]
+ FMLA v28.4s, v9.4s, v2.s[1]
+ FMLA v31.4s, v9.4s, v3.s[1]
+
+ TBZ x0, 2, 5b
+7:
+ # Remainder - 1 float of A (4 bytes)
+ LDR s0, [x22], 4 // a0
+ LDR s1, [x23], 4 // a1
+ LDR s2, [x24], 4 // a2
+ LDR s3, [x25], 4 // a3
+ LD1 {v4.16b, v5.16b, v6.16b}, [x5], 48
+
+ FMLA v20.4s, v4.4s, v0.s[0]
+ FMLA v23.4s, v4.4s, v1.s[0]
+ FMLA v26.4s, v4.4s, v2.s[0]
+ FMLA v29.4s, v4.4s, v3.s[0]
+ FMLA v21.4s, v5.4s, v0.s[0]
+ FMLA v24.4s, v5.4s, v1.s[0]
+ FMLA v27.4s, v5.4s, v2.s[0]
+ FMLA v30.4s, v5.4s, v3.s[0]
+ FMLA v22.4s, v6.4s, v0.s[0]
+ FMLA v25.4s, v6.4s, v1.s[0]
+ FMLA v28.4s, v6.4s, v2.s[0]
+ FMLA v31.4s, v6.4s, v3.s[0]
+ B 5b
+
+8:
+ # Store odd channels
+ TBZ x1, 3, 9f
+ STP q29, q30, [x7], 32
+ MOV v29.16b, v31.16b
+ STP q26, q27, [x21], 32
+ MOV v26.16b, v28.16b
+ STP q23, q24, [x20], 32
+ MOV v23.16b, v25.16b
+ STP q20, q21, [x6], 32
+ MOV v20.16b, v22.16b
+
+9:
+ TBZ x1, 2, 10f
+ STR q29, [x7], 16
+ MOV v29.16b, v30.16b
+ STR q26, [x21], 16
+ MOV v26.16b, v27.16b
+ STR q23, [x20], 16
+ MOV v23.16b, v24.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+
+10:
+ TBZ x1, 1, 11f
+ STR d29, [x7], 8
+ DUP d29, v29.d[1]
+ STR d26, [x21], 8
+ DUP d26, v26.d[1]
+ STR d23, [x20], 8
+ DUP d23, v23.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+11:
+ TBZ x1, 0, 12f
+ STR s29, [x7]
+ STR s26, [x21]
+ STR s23, [x20]
+ STR s20, [x6]
+12:
+ # Restore d8-d11,d14,d15 from stack
+ LDP d14, d15, [sp, 96]
+ LDP d10, d11, [sp, 80]
+ LDP d8, d9, [sp, 64]
+
+ # Restore x19-x25 from stack
+ LDP x24, x25, [sp, 48]
+ LDP x22, x23, [sp, 32]
+ LDP x20, x21, [sp, 16]
+ LDR x19, [sp], 112
+ RET
+
+END_FUNCTION xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-igemm/4x12-neon-ld64.c b/src/f32-igemm/4x12-neon-ld64.c
new file mode 100644
index 0000000..fbe3fa4
--- /dev/null
+++ b/src/f32-igemm/4x12-neon-ld64.c
@@ -0,0 +1,254 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x12__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x89AB = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc1x89AB = vacc0x89AB;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc2x89AB = vacc0x89AB;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc3x89AB = vacc0x89AB;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc0x89AB = vmlaq_lane_f32(vacc0x89AB, vb89ABc0, va0, 0);
+ vacc1x89AB = vmlaq_lane_f32(vacc1x89AB, vb89ABc0, va1, 0);
+ vacc2x89AB = vmlaq_lane_f32(vacc2x89AB, vb89ABc0, va2, 0);
+ vacc3x89AB = vmlaq_lane_f32(vacc3x89AB, vb89ABc0, va3, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc0x89AB = vmlaq_lane_f32(vacc0x89AB, vb89ABc1, va0, 1);
+ vacc1x89AB = vmlaq_lane_f32(vacc1x89AB, vb89ABc1, va1, 1);
+ vacc2x89AB = vmlaq_lane_f32(vacc2x89AB, vb89ABc1, va2, 1);
+ vacc3x89AB = vmlaq_lane_f32(vacc3x89AB, vb89ABc1, va3, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0);
+ const float32x4_t va1 = vld1q_dup_f32(a1);
+ const float32x4_t va2 = vld1q_dup_f32(a2);
+ const float32x4_t va3 = vld1q_dup_f32(a3);
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89AB = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ vacc0x89AB = vmlaq_f32(vacc0x89AB, va0, vb89AB);
+ vacc1x89AB = vmlaq_f32(vacc1x89AB, va1, vb89AB);
+ vacc2x89AB = vmlaq_f32(vacc2x89AB, va2, vb89AB);
+ vacc3x89AB = vmlaq_f32(vacc3x89AB, va3, vb89AB);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc0x89AB = vminq_f32(vacc0x89AB, vmax);
+ vacc1x89AB = vminq_f32(vacc1x89AB, vmax);
+ vacc2x89AB = vminq_f32(vacc2x89AB, vmax);
+ vacc3x89AB = vminq_f32(vacc3x89AB, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc0x89AB = vmaxq_f32(vacc0x89AB, vmin);
+ vacc1x89AB = vmaxq_f32(vacc1x89AB, vmin);
+ vacc2x89AB = vmaxq_f32(vacc2x89AB, vmin);
+ vacc3x89AB = vmaxq_f32(vacc3x89AB, vmin);
+
+ if XNN_LIKELY(nc >= 12) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ vst1q_f32(c3 + 8, vacc3x89AB);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ vst1q_f32(c2 + 8, vacc2x89AB);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ vst1q_f32(c1 + 8, vacc1x89AB);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ vst1q_f32(c0 + 8, vacc0x89AB);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 12;
+ } else {
+ if (nc & 8) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+ vst1q_f32(c3, vacc3x4567); c3 += 4;
+ vst1q_f32(c2, vacc2x4567); c2 += 4;
+ vst1q_f32(c1, vacc1x4567); c1 += 4;
+ vst1q_f32(c0, vacc0x4567); c0 += 4;
+
+ vacc3x0123 = vacc3x89AB;
+ vacc2x0123 = vacc2x89AB;
+ vacc1x0123 = vacc1x89AB;
+ vacc0x0123 = vacc0x89AB;
+ }
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x12-neonfma-ld64.c b/src/f32-igemm/4x12-neonfma-ld64.c
new file mode 100644
index 0000000..9f83c12
--- /dev/null
+++ b/src/f32-igemm/4x12-neonfma-ld64.c
@@ -0,0 +1,292 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x12__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x89AB = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc1x89AB = vacc0x89AB;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc2x89AB = vacc0x89AB;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc3x89AB = vacc0x89AB;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc0x89AB = vfmaq_lane_f32(vacc0x89AB, vb89ABc0, va0, 0);
+ vacc1x89AB = vfmaq_lane_f32(vacc1x89AB, vb89ABc0, va1, 0);
+ vacc2x89AB = vfmaq_lane_f32(vacc2x89AB, vb89ABc0, va2, 0);
+ vacc3x89AB = vfmaq_lane_f32(vacc3x89AB, vb89ABc0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc0x89AB = vfmaq_f32(vacc0x89AB, va0c0, vb89ABc0);
+ vacc1x89AB = vfmaq_f32(vacc1x89AB, va1c0, vb89ABc0);
+ vacc2x89AB = vfmaq_f32(vacc2x89AB, va2c0, vb89ABc0);
+ vacc3x89AB = vfmaq_f32(vacc3x89AB, va3c0, vb89ABc0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89ABc1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc0x89AB = vfmaq_lane_f32(vacc0x89AB, vb89ABc1, va0, 1);
+ vacc1x89AB = vfmaq_lane_f32(vacc1x89AB, vb89ABc1, va1, 1);
+ vacc2x89AB = vfmaq_lane_f32(vacc2x89AB, vb89ABc1, va2, 1);
+ vacc3x89AB = vfmaq_lane_f32(vacc3x89AB, vb89ABc1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc0x89AB = vfmaq_f32(vacc0x89AB, va0c1, vb89ABc1);
+ vacc1x89AB = vfmaq_f32(vacc1x89AB, va1c1, vb89ABc1);
+ vacc2x89AB = vfmaq_f32(vacc2x89AB, va2c1, vb89ABc1);
+ vacc3x89AB = vfmaq_f32(vacc3x89AB, va3c1, vb89ABc1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0);
+ const float32x4_t va1 = vld1q_dup_f32(a1);
+ const float32x4_t va2 = vld1q_dup_f32(a2);
+ const float32x4_t va3 = vld1q_dup_f32(a3);
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+ const float32x4_t vb89AB = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ vacc0x89AB = vfmaq_f32(vacc0x89AB, va0, vb89AB);
+ vacc1x89AB = vfmaq_f32(vacc1x89AB, va1, vb89AB);
+ vacc2x89AB = vfmaq_f32(vacc2x89AB, va2, vb89AB);
+ vacc3x89AB = vfmaq_f32(vacc3x89AB, va3, vb89AB);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc0x89AB = vminq_f32(vacc0x89AB, vmax);
+ vacc1x89AB = vminq_f32(vacc1x89AB, vmax);
+ vacc2x89AB = vminq_f32(vacc2x89AB, vmax);
+ vacc3x89AB = vminq_f32(vacc3x89AB, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc0x89AB = vmaxq_f32(vacc0x89AB, vmin);
+ vacc1x89AB = vmaxq_f32(vacc1x89AB, vmin);
+ vacc2x89AB = vmaxq_f32(vacc2x89AB, vmin);
+ vacc3x89AB = vmaxq_f32(vacc3x89AB, vmin);
+
+ if XNN_LIKELY(nc >= 12) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ vst1q_f32(c3 + 8, vacc3x89AB);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ vst1q_f32(c2 + 8, vacc2x89AB);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ vst1q_f32(c1 + 8, vacc1x89AB);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ vst1q_f32(c0 + 8, vacc0x89AB);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 12;
+ } else {
+ if (nc & 8) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+ vst1q_f32(c3, vacc3x4567); c3 += 4;
+ vst1q_f32(c2, vacc2x4567); c2 += 4;
+ vst1q_f32(c1, vacc1x4567); c1 += 4;
+ vst1q_f32(c0, vacc0x4567); c0 += 4;
+
+ vacc3x0123 = vacc3x89AB;
+ vacc2x0123 = vacc2x89AB;
+ vacc1x0123 = vacc1x89AB;
+ vacc0x0123 = vacc0x89AB;
+ }
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x2-neon-ld64.c b/src/f32-igemm/4x2-neon-ld64.c
new file mode 100644
index 0000000..48c5463
--- /dev/null
+++ b/src/f32-igemm/4x2-neon-ld64.c
@@ -0,0 +1,150 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/MRx2-neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x2__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x2_t vacc0x01 = vld1_f32(w); w += 2;
+ float32x2_t vacc1x01 = vacc0x01;
+ float32x2_t vacc2x01 = vacc0x01;
+ float32x2_t vacc3x01 = vacc0x01;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x2_t vb01c0 = vld1_f32(w); w += 2;
+
+ vacc0x01 = vmla_lane_f32(vacc0x01, vb01c0, va0, 0);
+ vacc1x01 = vmla_lane_f32(vacc1x01, vb01c0, va1, 0);
+ vacc2x01 = vmla_lane_f32(vacc2x01, vb01c0, va2, 0);
+ vacc3x01 = vmla_lane_f32(vacc3x01, vb01c0, va3, 0);
+ const float32x2_t vb01c1 = vld1_f32(w); w += 2;
+
+ vacc0x01 = vmla_lane_f32(vacc0x01, vb01c1, va0, 1);
+ vacc1x01 = vmla_lane_f32(vacc1x01, vb01c1, va1, 1);
+ vacc2x01 = vmla_lane_f32(vacc2x01, vb01c1, va2, 1);
+ vacc3x01 = vmla_lane_f32(vacc3x01, vb01c1, va3, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x2_t va0 = vld1_dup_f32(a0);
+ const float32x2_t va1 = vld1_dup_f32(a1);
+ const float32x2_t va2 = vld1_dup_f32(a2);
+ const float32x2_t va3 = vld1_dup_f32(a3);
+
+ const float32x2_t vb01 = vld1_f32(w); w += 2;
+
+ vacc0x01 = vmla_f32(vacc0x01, va0, vb01);
+ vacc1x01 = vmla_f32(vacc1x01, va1, vb01);
+ vacc2x01 = vmla_f32(vacc2x01, va2, vb01);
+ vacc3x01 = vmla_f32(vacc3x01, va3, vb01);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x2_t vmax = vld1_dup_f32(¶ms->scalar.max);
+ vacc0x01 = vmin_f32(vacc0x01, vmax);
+ vacc1x01 = vmin_f32(vacc1x01, vmax);
+ vacc2x01 = vmin_f32(vacc2x01, vmax);
+ vacc3x01 = vmin_f32(vacc3x01, vmax);
+
+ const float32x2_t vmin = vld1_dup_f32(¶ms->scalar.min);
+ vacc0x01 = vmax_f32(vacc0x01, vmin);
+ vacc1x01 = vmax_f32(vacc1x01, vmin);
+ vacc2x01 = vmax_f32(vacc2x01, vmin);
+ vacc3x01 = vmax_f32(vacc3x01, vmin);
+
+ if XNN_LIKELY(nc >= 2) {
+ vst1_f32(c3, vacc3x01);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1_f32(c2, vacc2x01);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1_f32(c1, vacc1x01);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1_f32(c0, vacc0x01);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 2;
+ } else {
+ assert(nc == 1);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x2-neonfma-ld64.c b/src/f32-igemm/4x2-neonfma-ld64.c
new file mode 100644
index 0000000..f97cc5f
--- /dev/null
+++ b/src/f32-igemm/4x2-neonfma-ld64.c
@@ -0,0 +1,172 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/MRx2-neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x2__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x2_t vacc0x01 = vld1_f32(w); w += 2;
+ float32x2_t vacc1x01 = vacc0x01;
+ float32x2_t vacc2x01 = vacc0x01;
+ float32x2_t vacc3x01 = vacc0x01;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x2_t vb01c0 = vld1_f32(w); w += 2;
+
+ #if defined(__aarch64__)
+ vacc0x01 = vfma_lane_f32(vacc0x01, vb01c0, va0, 0);
+ vacc1x01 = vfma_lane_f32(vacc1x01, vb01c0, va1, 0);
+ vacc2x01 = vfma_lane_f32(vacc2x01, vb01c0, va2, 0);
+ vacc3x01 = vfma_lane_f32(vacc3x01, vb01c0, va3, 0);
+ #else
+ const float32x2_t va0c0 = vdup_lane_f32(va0, 0);
+ const float32x2_t va1c0 = vdup_lane_f32(va1, 0);
+ const float32x2_t va2c0 = vdup_lane_f32(va2, 0);
+ const float32x2_t va3c0 = vdup_lane_f32(va3, 0);
+ vacc0x01 = vfma_f32(vacc0x01, va0c0, vb01c0);
+ vacc1x01 = vfma_f32(vacc1x01, va1c0, vb01c0);
+ vacc2x01 = vfma_f32(vacc2x01, va2c0, vb01c0);
+ vacc3x01 = vfma_f32(vacc3x01, va3c0, vb01c0);
+ #endif
+ const float32x2_t vb01c1 = vld1_f32(w); w += 2;
+
+ #if defined(__aarch64__)
+ vacc0x01 = vfma_lane_f32(vacc0x01, vb01c1, va0, 1);
+ vacc1x01 = vfma_lane_f32(vacc1x01, vb01c1, va1, 1);
+ vacc2x01 = vfma_lane_f32(vacc2x01, vb01c1, va2, 1);
+ vacc3x01 = vfma_lane_f32(vacc3x01, vb01c1, va3, 1);
+ #else
+ const float32x2_t va0c1 = vdup_lane_f32(va0, 1);
+ const float32x2_t va1c1 = vdup_lane_f32(va1, 1);
+ const float32x2_t va2c1 = vdup_lane_f32(va2, 1);
+ const float32x2_t va3c1 = vdup_lane_f32(va3, 1);
+ vacc0x01 = vfma_f32(vacc0x01, va0c1, vb01c1);
+ vacc1x01 = vfma_f32(vacc1x01, va1c1, vb01c1);
+ vacc2x01 = vfma_f32(vacc2x01, va2c1, vb01c1);
+ vacc3x01 = vfma_f32(vacc3x01, va3c1, vb01c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x2_t va0 = vld1_dup_f32(a0);
+ const float32x2_t va1 = vld1_dup_f32(a1);
+ const float32x2_t va2 = vld1_dup_f32(a2);
+ const float32x2_t va3 = vld1_dup_f32(a3);
+
+ const float32x2_t vb01 = vld1_f32(w); w += 2;
+
+ vacc0x01 = vfma_f32(vacc0x01, va0, vb01);
+ vacc1x01 = vfma_f32(vacc1x01, va1, vb01);
+ vacc2x01 = vfma_f32(vacc2x01, va2, vb01);
+ vacc3x01 = vfma_f32(vacc3x01, va3, vb01);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x2_t vmax = vld1_dup_f32(¶ms->scalar.max);
+ vacc0x01 = vmin_f32(vacc0x01, vmax);
+ vacc1x01 = vmin_f32(vacc1x01, vmax);
+ vacc2x01 = vmin_f32(vacc2x01, vmax);
+ vacc3x01 = vmin_f32(vacc3x01, vmax);
+
+ const float32x2_t vmin = vld1_dup_f32(¶ms->scalar.min);
+ vacc0x01 = vmax_f32(vacc0x01, vmin);
+ vacc1x01 = vmax_f32(vacc1x01, vmin);
+ vacc2x01 = vmax_f32(vacc2x01, vmin);
+ vacc3x01 = vmax_f32(vacc3x01, vmin);
+
+ if XNN_LIKELY(nc >= 2) {
+ vst1_f32(c3, vacc3x01);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1_f32(c2, vacc2x01);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1_f32(c1, vacc1x01);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1_f32(c0, vacc0x01);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 2;
+ } else {
+ assert(nc == 1);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x2-scalar.c b/src/f32-igemm/4x2-scalar.c
new file mode 100644
index 0000000..76ef02a
--- /dev/null
+++ b/src/f32-igemm/4x2-scalar.c
@@ -0,0 +1,156 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_igemm_ukernel_4x2__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float vacc00 = w[0];
+ float vacc01 = w[1];
+ float vacc10 = vacc00;
+ float vacc11 = vacc01;
+ float vacc20 = vacc00;
+ float vacc21 = vacc01;
+ float vacc30 = vacc00;
+ float vacc31 = vacc01;
+ w += 2;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+ const float va1 = *a1++;
+ const float va2 = *a2++;
+ const float va3 = *a3++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ w += 2;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc10 += va1 * vb0;
+ vacc11 += va1 * vb1;
+ vacc20 += va2 * vb0;
+ vacc21 += va2 * vb1;
+ vacc30 += va3 * vb0;
+ vacc31 += va3 * vb1;
+
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc10 = math_max_f32(vacc10, vmin);
+ vacc11 = math_max_f32(vacc11, vmin);
+ vacc20 = math_max_f32(vacc20, vmin);
+ vacc21 = math_max_f32(vacc21, vmin);
+ vacc30 = math_max_f32(vacc30, vmin);
+ vacc31 = math_max_f32(vacc31, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc10 = math_min_f32(vacc10, vmax);
+ vacc11 = math_min_f32(vacc11, vmax);
+ vacc20 = math_min_f32(vacc20, vmax);
+ vacc21 = math_min_f32(vacc21, vmax);
+ vacc30 = math_min_f32(vacc30, vmax);
+ vacc31 = math_min_f32(vacc31, vmax);
+
+ if XNN_LIKELY(nc >= 2) {
+ c3[0] = vacc30;
+ c3[1] = vacc31;
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ c2[0] = vacc20;
+ c2[1] = vacc21;
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 2;
+ } else {
+ if (nc & 1) {
+ c3[0] = vacc30;
+ c2[0] = vacc20;
+ c1[0] = vacc10;
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x2c4-psimd.c b/src/f32-igemm/4x2c4-psimd.c
new file mode 100644
index 0000000..626e31c
--- /dev/null
+++ b/src/f32-igemm/4x2c4-psimd.c
@@ -0,0 +1,173 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/MRx2c4-psimd.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x2c4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0c4 = psimd_load1_f32(w);
+ psimd_f32 vacc0x1c4 = psimd_load1_f32(w + 1);
+ psimd_f32 vacc1x0c4 = vacc0x0c4;
+ psimd_f32 vacc1x1c4 = vacc0x1c4;
+ psimd_f32 vacc2x0c4 = vacc0x0c4;
+ psimd_f32 vacc2x1c4 = vacc0x1c4;
+ psimd_f32 vacc3x0c4 = vacc0x0c4;
+ psimd_f32 vacc3x1c4 = vacc0x1c4;
+ w += 2;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ const psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ const psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ const psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+
+ const psimd_f32 vb0 = psimd_load_f32(w);
+ const psimd_f32 vb1 = psimd_load_f32(w + 4);
+ w += 8;
+
+ vacc0x0c4 = psimd_qfma_f32(vacc0x0c4, va0, vb0);
+ vacc0x1c4 = psimd_qfma_f32(vacc0x1c4, va0, vb1);
+ vacc1x0c4 = psimd_qfma_f32(vacc1x0c4, va1, vb0);
+ vacc1x1c4 = psimd_qfma_f32(vacc1x1c4, va1, vb1);
+ vacc2x0c4 = psimd_qfma_f32(vacc2x0c4, va2, vb0);
+ vacc2x1c4 = psimd_qfma_f32(vacc2x1c4, va2, vb1);
+ vacc3x0c4 = psimd_qfma_f32(vacc3x0c4, va3, vb0);
+ vacc3x1c4 = psimd_qfma_f32(vacc3x1c4, va3, vb1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ const psimd_f32 va1 = psimd_load_f32(a1);
+ const psimd_f32 va2 = psimd_load_f32(a2);
+ const psimd_f32 va3 = psimd_load_f32(a3);
+
+ const psimd_f32 vb0 = psimd_load_f32(w);
+ const psimd_f32 vb1 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 vzero = psimd_splat_f32(0.0f);
+ const psimd_s32 vmask0 = vb0 != vzero;
+ const psimd_s32 vmask1 = vb1 != vzero;
+
+ vacc0x0c4 = psimd_qfma_f32(vacc0x0c4, psimd_andmask_f32(vmask0, va0), vb0);
+ vacc0x1c4 = psimd_qfma_f32(vacc0x1c4, psimd_andmask_f32(vmask1, va0), vb1);
+ vacc1x0c4 = psimd_qfma_f32(vacc1x0c4, psimd_andmask_f32(vmask0, va1), vb0);
+ vacc1x1c4 = psimd_qfma_f32(vacc1x1c4, psimd_andmask_f32(vmask1, va1), vb1);
+ vacc2x0c4 = psimd_qfma_f32(vacc2x0c4, psimd_andmask_f32(vmask0, va2), vb0);
+ vacc2x1c4 = psimd_qfma_f32(vacc2x1c4, psimd_andmask_f32(vmask1, va2), vb1);
+ vacc3x0c4 = psimd_qfma_f32(vacc3x0c4, psimd_andmask_f32(vmask0, va3), vb0);
+ vacc3x1c4 = psimd_qfma_f32(vacc3x1c4, psimd_andmask_f32(vmask1, va3), vb1);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vacc0x01c2 = psimd_add_f32(psimd_interleave_lo_f32(vacc0x0c4, vacc0x1c4), psimd_interleave_hi_f32(vacc0x0c4, vacc0x1c4));
+ const psimd_f32 vacc1x01c2 = psimd_add_f32(psimd_interleave_lo_f32(vacc1x0c4, vacc1x1c4), psimd_interleave_hi_f32(vacc1x0c4, vacc1x1c4));
+ const psimd_f32 vacc2x01c2 = psimd_add_f32(psimd_interleave_lo_f32(vacc2x0c4, vacc2x1c4), psimd_interleave_hi_f32(vacc2x0c4, vacc2x1c4));
+ const psimd_f32 vacc3x01c2 = psimd_add_f32(psimd_interleave_lo_f32(vacc3x0c4, vacc3x1c4), psimd_interleave_hi_f32(vacc3x0c4, vacc3x1c4));
+
+ psimd_f32 vacc01x01 = psimd_add_f32(psimd_concat_lo_f32(vacc0x01c2, vacc1x01c2), psimd_concat_hi_f32(vacc0x01c2, vacc1x01c2));
+ psimd_f32 vacc23x01 = psimd_add_f32(psimd_concat_lo_f32(vacc2x01c2, vacc3x01c2), psimd_concat_hi_f32(vacc2x01c2, vacc3x01c2));
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc01x01 = psimd_min_f32(vacc01x01, vmax);
+ vacc23x01 = psimd_min_f32(vacc23x01, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc01x01 = psimd_max_f32(vacc01x01, vmin);
+ vacc23x01 = psimd_max_f32(vacc23x01, vmin);
+
+ if XNN_LIKELY(nc >= 2) {
+ psimd_store2_f32(c3, psimd_concat_hi_f32(vacc23x01, vacc23x01));
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store2_f32(c2, vacc23x01);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store2_f32(c1, psimd_concat_hi_f32(vacc01x01, vacc01x01));
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store2_f32(c0, vacc01x01);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 2;
+ } else {
+ assert(nc == 1);
+ psimd_store1_f32(c3, psimd_concat_hi_f32(vacc23x01, vacc23x01));
+ psimd_store1_f32(c2, vacc23x01);
+ psimd_store1_f32(c1, psimd_concat_hi_f32(vacc01x01, vacc01x01));
+ psimd_store1_f32(c0, vacc01x01);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x2c4-sse.c b/src/f32-igemm/4x2c4-sse.c
new file mode 100644
index 0000000..178cec6
--- /dev/null
+++ b/src/f32-igemm/4x2c4-sse.c
@@ -0,0 +1,172 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/MRx2c4-sse.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x2c4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0c4 = _mm_load_ss(w);
+ __m128 vacc0x1c4 = _mm_load_ss(w + 1);
+ __m128 vacc1x0c4 = vacc0x0c4;
+ __m128 vacc1x1c4 = vacc0x1c4;
+ __m128 vacc2x0c4 = vacc0x0c4;
+ __m128 vacc2x1c4 = vacc0x1c4;
+ __m128 vacc3x0c4 = vacc0x0c4;
+ __m128 vacc3x1c4 = vacc0x1c4;
+ w += 2;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+ const __m128 va1 = _mm_loadu_ps(a1);
+ a1 += 4;
+ const __m128 va2 = _mm_loadu_ps(a2);
+ a2 += 4;
+ const __m128 va3 = _mm_loadu_ps(a3);
+ a3 += 4;
+
+ const __m128 vb0 = _mm_loadu_ps(w);
+ const __m128 vb1 = _mm_loadu_ps(w + 4);
+ w += 8;
+
+ vacc0x0c4 = _mm_add_ps(vacc0x0c4, _mm_mul_ps(va0, vb0));
+ vacc0x1c4 = _mm_add_ps(vacc0x1c4, _mm_mul_ps(va0, vb1));
+ vacc1x0c4 = _mm_add_ps(vacc1x0c4, _mm_mul_ps(va1, vb0));
+ vacc1x1c4 = _mm_add_ps(vacc1x1c4, _mm_mul_ps(va1, vb1));
+ vacc2x0c4 = _mm_add_ps(vacc2x0c4, _mm_mul_ps(va2, vb0));
+ vacc2x1c4 = _mm_add_ps(vacc2x1c4, _mm_mul_ps(va2, vb1));
+ vacc3x0c4 = _mm_add_ps(vacc3x0c4, _mm_mul_ps(va3, vb0));
+ vacc3x1c4 = _mm_add_ps(vacc3x1c4, _mm_mul_ps(va3, vb1));
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const __m128 va0 = _mm_loadu_ps(a0);
+ const __m128 va1 = _mm_loadu_ps(a1);
+ const __m128 va2 = _mm_loadu_ps(a2);
+ const __m128 va3 = _mm_loadu_ps(a3);
+
+ const __m128 vb0 = _mm_loadu_ps(w);
+ const __m128 vb1 = _mm_loadu_ps(w + 4);
+ w += 8;
+
+ const __m128 vmask0 = _mm_cmpeq_ps(_mm_setzero_ps(), vb0);
+ const __m128 vmask1 = _mm_cmpeq_ps(_mm_setzero_ps(), vb1);
+
+ vacc0x0c4 = _mm_add_ps(vacc0x0c4, _mm_mul_ps(_mm_andnot_ps(vmask0, va0), vb0));
+ vacc0x1c4 = _mm_add_ps(vacc0x1c4, _mm_mul_ps(_mm_andnot_ps(vmask1, va0), vb1));
+ vacc1x0c4 = _mm_add_ps(vacc1x0c4, _mm_mul_ps(_mm_andnot_ps(vmask0, va1), vb0));
+ vacc1x1c4 = _mm_add_ps(vacc1x1c4, _mm_mul_ps(_mm_andnot_ps(vmask1, va1), vb1));
+ vacc2x0c4 = _mm_add_ps(vacc2x0c4, _mm_mul_ps(_mm_andnot_ps(vmask0, va2), vb0));
+ vacc2x1c4 = _mm_add_ps(vacc2x1c4, _mm_mul_ps(_mm_andnot_ps(vmask1, va2), vb1));
+ vacc3x0c4 = _mm_add_ps(vacc3x0c4, _mm_mul_ps(_mm_andnot_ps(vmask0, va3), vb0));
+ vacc3x1c4 = _mm_add_ps(vacc3x1c4, _mm_mul_ps(_mm_andnot_ps(vmask1, va3), vb1));
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vacc0x01c2 = _mm_add_ps(_mm_unpacklo_ps(vacc0x0c4, vacc0x1c4), _mm_unpackhi_ps(vacc0x0c4, vacc0x1c4));
+ const __m128 vacc1x01c2 = _mm_add_ps(_mm_unpacklo_ps(vacc1x0c4, vacc1x1c4), _mm_unpackhi_ps(vacc1x0c4, vacc1x1c4));
+ const __m128 vacc2x01c2 = _mm_add_ps(_mm_unpacklo_ps(vacc2x0c4, vacc2x1c4), _mm_unpackhi_ps(vacc2x0c4, vacc2x1c4));
+ const __m128 vacc3x01c2 = _mm_add_ps(_mm_unpacklo_ps(vacc3x0c4, vacc3x1c4), _mm_unpackhi_ps(vacc3x0c4, vacc3x1c4));
+
+ __m128 vacc01x01 = _mm_add_ps(_mm_movelh_ps(vacc0x01c2, vacc1x01c2), _mm_movehl_ps(vacc1x01c2, vacc0x01c2));
+ __m128 vacc23x01 = _mm_add_ps(_mm_movelh_ps(vacc2x01c2, vacc3x01c2), _mm_movehl_ps(vacc3x01c2, vacc2x01c2));
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc01x01 = _mm_min_ps(vacc01x01, vmax);
+ vacc23x01 = _mm_min_ps(vacc23x01, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc01x01 = _mm_max_ps(vacc01x01, vmin);
+ vacc23x01 = _mm_max_ps(vacc23x01, vmin);
+
+ if XNN_LIKELY(nc >= 2) {
+ _mm_storeh_pi((__m64*) c3, vacc23x01);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storel_pi((__m64*) c2, vacc23x01);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeh_pi((__m64*) c1, vacc01x01);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storel_pi((__m64*) c0, vacc01x01);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 2;
+ } else {
+ assert(nc == 1);
+ _mm_store_ss(c3, _mm_movehl_ps(vacc23x01, vacc23x01));
+ _mm_store_ss(c2, vacc23x01);
+ _mm_store_ss(c1, _mm_movehl_ps(vacc01x01, vacc01x01));
+ _mm_store_ss(c0, vacc01x01);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x4-neon-ld64.c b/src/f32-igemm/4x4-neon-ld64.c
new file mode 100644
index 0000000..70288e2
--- /dev/null
+++ b/src/f32-igemm/4x4-neon-ld64.c
@@ -0,0 +1,166 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x4__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc3x0123 = vacc0x0123;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0);
+ const float32x4_t va1 = vld1q_dup_f32(a1);
+ const float32x4_t va2 = vld1q_dup_f32(a2);
+ const float32x4_t va3 = vld1q_dup_f32(a3);
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+
+ if XNN_LIKELY(nc >= 4) {
+ vst1q_f32(c3, vacc3x0123);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 4;
+ } else {
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x4-neonfma-ld64.c b/src/f32-igemm/4x4-neonfma-ld64.c
new file mode 100644
index 0000000..863f65e
--- /dev/null
+++ b/src/f32-igemm/4x4-neonfma-ld64.c
@@ -0,0 +1,188 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x4__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc3x0123 = vacc0x0123;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0);
+ const float32x4_t va1 = vld1q_dup_f32(a1);
+ const float32x4_t va2 = vld1q_dup_f32(a2);
+ const float32x4_t va3 = vld1q_dup_f32(a3);
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+
+ if XNN_LIKELY(nc >= 4) {
+ vst1q_f32(c3, vacc3x0123);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 4;
+ } else {
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x4-scalar.c b/src/f32-igemm/4x4-scalar.c
new file mode 100644
index 0000000..997d862
--- /dev/null
+++ b/src/f32-igemm/4x4-scalar.c
@@ -0,0 +1,216 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_igemm_ukernel_4x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float vacc00 = w[0];
+ float vacc01 = w[1];
+ float vacc02 = w[2];
+ float vacc03 = w[3];
+ float vacc10 = vacc00;
+ float vacc11 = vacc01;
+ float vacc12 = vacc02;
+ float vacc13 = vacc03;
+ float vacc20 = vacc00;
+ float vacc21 = vacc01;
+ float vacc22 = vacc02;
+ float vacc23 = vacc03;
+ float vacc30 = vacc00;
+ float vacc31 = vacc01;
+ float vacc32 = vacc02;
+ float vacc33 = vacc03;
+ w += 4;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ do {
+ const float va0 = *a0++;
+ const float va1 = *a1++;
+ const float va2 = *a2++;
+ const float va3 = *a3++;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc00 += va0 * vb0;
+ vacc01 += va0 * vb1;
+ vacc02 += va0 * vb2;
+ vacc03 += va0 * vb3;
+ vacc10 += va1 * vb0;
+ vacc11 += va1 * vb1;
+ vacc12 += va1 * vb2;
+ vacc13 += va1 * vb3;
+ vacc20 += va2 * vb0;
+ vacc21 += va2 * vb1;
+ vacc22 += va2 * vb2;
+ vacc23 += va2 * vb3;
+ vacc30 += va3 * vb0;
+ vacc31 += va3 * vb1;
+ vacc32 += va3 * vb2;
+ vacc33 += va3 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float vmin = params->scalar.min;
+ vacc00 = math_max_f32(vacc00, vmin);
+ vacc01 = math_max_f32(vacc01, vmin);
+ vacc02 = math_max_f32(vacc02, vmin);
+ vacc03 = math_max_f32(vacc03, vmin);
+ vacc10 = math_max_f32(vacc10, vmin);
+ vacc11 = math_max_f32(vacc11, vmin);
+ vacc12 = math_max_f32(vacc12, vmin);
+ vacc13 = math_max_f32(vacc13, vmin);
+ vacc20 = math_max_f32(vacc20, vmin);
+ vacc21 = math_max_f32(vacc21, vmin);
+ vacc22 = math_max_f32(vacc22, vmin);
+ vacc23 = math_max_f32(vacc23, vmin);
+ vacc30 = math_max_f32(vacc30, vmin);
+ vacc31 = math_max_f32(vacc31, vmin);
+ vacc32 = math_max_f32(vacc32, vmin);
+ vacc33 = math_max_f32(vacc33, vmin);
+
+ const float vmax = params->scalar.max;
+ vacc00 = math_min_f32(vacc00, vmax);
+ vacc01 = math_min_f32(vacc01, vmax);
+ vacc02 = math_min_f32(vacc02, vmax);
+ vacc03 = math_min_f32(vacc03, vmax);
+ vacc10 = math_min_f32(vacc10, vmax);
+ vacc11 = math_min_f32(vacc11, vmax);
+ vacc12 = math_min_f32(vacc12, vmax);
+ vacc13 = math_min_f32(vacc13, vmax);
+ vacc20 = math_min_f32(vacc20, vmax);
+ vacc21 = math_min_f32(vacc21, vmax);
+ vacc22 = math_min_f32(vacc22, vmax);
+ vacc23 = math_min_f32(vacc23, vmax);
+ vacc30 = math_min_f32(vacc30, vmax);
+ vacc31 = math_min_f32(vacc31, vmax);
+ vacc32 = math_min_f32(vacc32, vmax);
+ vacc33 = math_min_f32(vacc33, vmax);
+
+ if XNN_LIKELY(nc >= 4) {
+ c3[0] = vacc30;
+ c3[1] = vacc31;
+ c3[2] = vacc32;
+ c3[3] = vacc33;
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ c2[0] = vacc20;
+ c2[1] = vacc21;
+ c2[2] = vacc22;
+ c2[3] = vacc23;
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ c1[2] = vacc12;
+ c1[3] = vacc13;
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ c0[2] = vacc02;
+ c0[3] = vacc03;
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c3[0] = vacc30;
+ c3[1] = vacc31;
+ vacc30 = vacc32;
+ c3 += 2;
+ c2[0] = vacc20;
+ c2[1] = vacc21;
+ vacc20 = vacc22;
+ c2 += 2;
+ c1[0] = vacc10;
+ c1[1] = vacc11;
+ vacc10 = vacc12;
+ c1 += 2;
+ c0[0] = vacc00;
+ c0[1] = vacc01;
+ vacc00 = vacc02;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ c3[0] = vacc30;
+ c2[0] = vacc20;
+ c1[0] = vacc10;
+ c0[0] = vacc00;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8-aarch64-neonfma-cortex-a75.S b/src/f32-igemm/4x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..bd57f6b
--- /dev/null
+++ b/src/f32-igemm/4x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,496 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# size_t ks, x3 / x9
+# const float**restrict a, x4
+# const float*restrict w, x5
+# float*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x10
+# size_t a_offset, [sp + 8] -> x11
+# const float* zero, [sp + 16] -> x12
+# const xnn_f32_output_params params [sp + 24] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x20 a0
+# x13 a1
+# x14 a2
+# x15 a3
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x7 c3 / cm_stride
+
+# Vector register usage
+# A0 v0 v4
+# A1 v1 v5
+# A2 v2 v6
+# A3 v3 v7
+# B v8 v9 v10 v11
+# B v12 v13 v14 v15
+# B v20 v21 v22 v23
+# B v24 v25 v26 v27
+# C v16 v17
+# C v18 v19
+# C v28 v29
+# C v30 v31
+# Clamp v4 v5
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp]
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 16]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ # Save x20 on stack
+ STR x20, [sp, -80]!
+
+ # Save d8-d15 on stack
+ STP d8, d9, [sp, 16]
+ STP d10, d11, [sp, 32]
+ STP d12, d13, [sp, 48]
+ STP d14, d15, [sp, 64]
+
+ # Clamp C pointers
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ ADD x7, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x7, x17, x7, LO // c3 = c2
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q16, q17, [x5], 32
+ MOV v18.16b, v16.16b
+ MOV v19.16b, v17.16b
+ MOV v28.16b, v16.16b
+ MOV v29.16b, v17.16b
+ MOV v30.16b, v16.16b
+ MOV v31.16b, v17.16b
+
+ MOV x9, x3 // p = ks
+
+1:
+ # Load next 4 A pointers
+ LDP x20, x13, [x4], 16
+ LDP x14, x15, [x4], 16
+
+ CMP x20, x12 // if a0 == zero
+ ADD x20, x20, x11 // a0 += a_offset
+ CSEL x20, x12, x20, EQ // a0 = zero, else += a0 + a_offset
+ CMP x13, x12 // if a1 == zero
+ ADD x13, x13, x11 // a1 += a_offset
+ CSEL x13, x12, x13, EQ // a1 = zero, else += a1 + a_offset
+ CMP x14, x12 // if a2 == zero
+ ADD x14, x14, x11 // a2 += a_offset
+ CSEL x14, x12, x14, EQ // a2 = zero, else += a2 + a_offset
+ CMP x15, x12 // if a3 == zero
+ ADD x15, x15, x11 // a3 += a_offset
+ CSEL x15, x12, x15, EQ // a3 = zero, else += a3 + a_offset
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 4f
+
+ # 16 prologue
+ # Read first block of 4 A and B.
+ LDR q0, [x20], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x13], 16
+ LDR q2, [x14], 16
+ LDR q3, [x15], 16
+ LDP q22, q23, [x5], 32
+ LDP q24, q25, [x5], 32
+ LDP q26, q27, [x5], 32
+
+ # Is there at least 32. yes do main loop
+ SUBS x0, x0, 32
+ B.LO 3f
+
+ # Main loop - 8 floats of A
+2:
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q12, q13, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDR q4, [x20], 16
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ LDR q5, [x13], 16
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDR q6, [x14], 16
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ LDR q7, [x15], 16
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ PRFM PLDL1KEEP, [x5, 128]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 192]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ PRFM PLDL1KEEP, [x5, 320]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, loads for 1nd block of 4.
+ FMLA v16.4s, v8.4s, v4.s[0]
+ LDP q20, q21, [x5], 32
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v31.4s, v9.4s, v7.s[0]
+ FMLA v16.4s, v10.4s, v4.s[1]
+ LDR q0, [x20], 16
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ LDR q1, [x13], 16
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ LDR q2, [x14], 16
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ LDR q3, [x15], 16
+ FMLA v31.4s, v11.4s, v7.s[1]
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ SUBS x0, x0, 32
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+
+ B.HS 2b
+
+3:
+ # Epilogue
+ # First block of 4. FMA for first 4, loads for 2nd block of 4.
+ FMLA v16.4s, v20.4s, v0.s[0]
+ LDP q8, q9, [x5], 32
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ LDP q10, q11, [x5], 32
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ LDP q12, q13, [x5], 32
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ LDP q14, q15, [x5], 32
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ LDR q4, [x20], 16
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ LDR q5, [x13], 16
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ LDR q6, [x14], 16
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ LDR q7, [x15], 16
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+ # Second block of 4. FMA for second 4, noloads
+ FMLA v16.4s, v8.4s, v4.s[0]
+ FMLA v17.4s, v9.4s, v4.s[0]
+ FMLA v18.4s, v8.4s, v5.s[0]
+ FMLA v19.4s, v9.4s, v5.s[0]
+ FMLA v28.4s, v8.4s, v6.s[0]
+ FMLA v29.4s, v9.4s, v6.s[0]
+ FMLA v30.4s, v8.4s, v7.s[0]
+ FMLA v31.4s, v9.4s, v7.s[0]
+ FMLA v16.4s, v10.4s, v4.s[1]
+ FMLA v17.4s, v11.4s, v4.s[1]
+ FMLA v18.4s, v10.4s, v5.s[1]
+ FMLA v19.4s, v11.4s, v5.s[1]
+ FMLA v28.4s, v10.4s, v6.s[1]
+ FMLA v29.4s, v11.4s, v6.s[1]
+ FMLA v30.4s, v10.4s, v7.s[1]
+ FMLA v31.4s, v11.4s, v7.s[1]
+ FMLA v16.4s, v12.4s, v4.s[2]
+ FMLA v17.4s, v13.4s, v4.s[2]
+ FMLA v18.4s, v12.4s, v5.s[2]
+ FMLA v19.4s, v13.4s, v5.s[2]
+ FMLA v28.4s, v12.4s, v6.s[2]
+ FMLA v29.4s, v13.4s, v6.s[2]
+ FMLA v30.4s, v12.4s, v7.s[2]
+ FMLA v31.4s, v13.4s, v7.s[2]
+
+ FMLA v16.4s, v14.4s, v4.s[3]
+ FMLA v17.4s, v15.4s, v4.s[3]
+ FMLA v18.4s, v14.4s, v5.s[3]
+ FMLA v19.4s, v15.4s, v5.s[3]
+
+ # Load clamping_params values
+ LD2R {v4.4s, v5.4s}, [x8]
+
+ FMLA v28.4s, v14.4s, v6.s[3]
+ FMLA v29.4s, v15.4s, v6.s[3]
+ FMLA v30.4s, v14.4s, v7.s[3]
+ FMLA v31.4s, v15.4s, v7.s[3]
+
+4:
+ # Remainder- 4 floats of A
+ TBZ x0, 4, 5f
+
+ LDR q0, [x20], 16
+ LDP q20, q21, [x5], 32
+ LDR q1, [x13], 16
+ LDR q2, [x14], 16
+ LDR q3, [x15], 16
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ LDP q24, q25, [x5], 32
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ LDP q26, q27, [x5], 32
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+ FMLA v16.4s, v24.4s, v0.s[2]
+ FMLA v17.4s, v25.4s, v0.s[2]
+ FMLA v18.4s, v24.4s, v1.s[2]
+ FMLA v19.4s, v25.4s, v1.s[2]
+ FMLA v28.4s, v24.4s, v2.s[2]
+ FMLA v29.4s, v25.4s, v2.s[2]
+ FMLA v30.4s, v24.4s, v3.s[2]
+ FMLA v31.4s, v25.4s, v3.s[2]
+ FMLA v16.4s, v26.4s, v0.s[3]
+ FMLA v17.4s, v27.4s, v0.s[3]
+ FMLA v18.4s, v26.4s, v1.s[3]
+ FMLA v19.4s, v27.4s, v1.s[3]
+ FMLA v28.4s, v26.4s, v2.s[3]
+ FMLA v29.4s, v27.4s, v2.s[3]
+ FMLA v30.4s, v26.4s, v3.s[3]
+ FMLA v31.4s, v27.4s, v3.s[3]
+
+5:
+ # Remainder- 2 floats of A
+ TBZ x0, 3, 6f
+
+ LDR d0, [x20], 8
+ LDP q20, q21, [x5], 32
+ LDR d1, [x13], 8
+ LDR d2, [x14], 8
+ LDR d3, [x15], 8
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ LDP q22, q23, [x5], 32
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+ FMLA v16.4s, v22.4s, v0.s[1]
+ FMLA v17.4s, v23.4s, v0.s[1]
+ FMLA v18.4s, v22.4s, v1.s[1]
+ FMLA v19.4s, v23.4s, v1.s[1]
+ FMLA v28.4s, v22.4s, v2.s[1]
+ FMLA v29.4s, v23.4s, v2.s[1]
+ FMLA v30.4s, v22.4s, v3.s[1]
+ FMLA v31.4s, v23.4s, v3.s[1]
+
+6:
+ # Remainder- 1 float of A
+ TBZ x0, 2, 7f
+
+ LDR s0, [x20], 4
+ LDP q20, q21, [x5], 32
+ LDR s1, [x13], 4
+ LDR s2, [x14], 4
+ LDR s3, [x15], 4
+ FMLA v16.4s, v20.4s, v0.s[0]
+ FMLA v17.4s, v21.4s, v0.s[0]
+ FMLA v18.4s, v20.4s, v1.s[0]
+ FMLA v19.4s, v21.4s, v1.s[0]
+ FMLA v28.4s, v20.4s, v2.s[0]
+ FMLA v29.4s, v21.4s, v2.s[0]
+ FMLA v30.4s, v20.4s, v3.s[0]
+ FMLA v31.4s, v21.4s, v3.s[0]
+
+7:
+ # ks loop
+ SUBS x9, x9, 32 // ks -= MR * sizeof(void*)
+ B.NE 1b
+
+ # Clamp
+ FMIN v16.4s, v16.4s, v4.4s
+ FMIN v17.4s, v17.4s, v4.4s
+ FMIN v18.4s, v18.4s, v4.4s
+ FMIN v19.4s, v19.4s, v4.4s
+ FMIN v28.4s, v28.4s, v4.4s
+ FMIN v29.4s, v29.4s, v4.4s
+ FMIN v30.4s, v30.4s, v4.4s
+ FMIN v31.4s, v31.4s, v4.4s
+ FMAX v16.4s, v16.4s, v5.4s
+ FMAX v17.4s, v17.4s, v5.4s
+ FMAX v18.4s, v18.4s, v5.4s
+ FMAX v19.4s, v19.4s, v5.4s
+ FMAX v28.4s, v28.4s, v5.4s
+ FMAX v29.4s, v29.4s, v5.4s
+ FMAX v30.4s, v30.4s, v5.4s
+ FMAX v31.4s, v31.4s, v5.4s
+
+ # Store full 4 x 8
+ CMP x1, 8
+ B.LO 8f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x10
+ STP q28, q29, [x17]
+ ADD x17, x17, x10
+ STP q18, q19, [x16]
+ ADD x16, x16, x10
+ STP q16, q17, [x6]
+ ADD x6, x6, x10
+
+ SUB x4, x4, x3 // a -= ks
+
+ # nc loop
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 64]
+ LDP d12, d13, [sp, 48]
+ LDP d10, d11, [sp, 32]
+ LDP d8, d9, [sp, 16]
+
+ # Restore x20 from stack
+ LDR x20, [sp], 80
+ RET
+
+ # Store odd width
+8:
+ TBZ x1, 2, 9f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x17], 16
+ MOV v28.16b, v29.16b
+ STR q18, [x16], 16
+ MOV v18.16b, v19.16b
+ STR q16, [x6], 16
+ MOV v16.16b, v17.16b
+
+9:
+ TBZ x1, 1, 10f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x17], 8
+ DUP d28, v28.d[1]
+ STR d18, [x16], 8
+ DUP d18, v18.d[1]
+ STR d16, [x6], 8
+ DUP d16, v16.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s30, [x7]
+ STR s28, [x17]
+ STR s18, [x16]
+ STR s16, [x6]
+11:
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 64]
+ LDP d12, d13, [sp, 48]
+ LDP d10, d11, [sp, 32]
+ LDP d8, d9, [sp, 16]
+
+ # Restore x20 from stack
+ LDR x20, [sp], 80
+ RET
+
+END_FUNCTION xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-igemm/4x8-neon-ld128.c b/src/f32-igemm/4x8-neon-ld128.c
new file mode 100644
index 0000000..4e1f0cd
--- /dev/null
+++ b/src/f32-igemm/4x8-neon-ld128.c
@@ -0,0 +1,239 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld128.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8__neon_ld128(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ const float32x4_t va0 = vld1q_f32(a0); a0 += 4;
+ const float32x4_t va1 = vld1q_f32(a1); a1 += 4;
+ const float32x4_t va2 = vld1q_f32(a2); a2 += 4;
+ const float32x4_t va3 = vld1q_f32(a3); a3 += 4;
+
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, vget_low_f32(va0), 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, vget_low_f32(va1), 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, vget_low_f32(va2), 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, vget_low_f32(va3), 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, vget_low_f32(va0), 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, vget_low_f32(va1), 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, vget_low_f32(va2), 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, vget_low_f32(va3), 0);
+
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, vget_low_f32(va0), 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, vget_low_f32(va1), 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, vget_low_f32(va2), 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, vget_low_f32(va3), 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, vget_low_f32(va0), 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, vget_low_f32(va1), 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, vget_low_f32(va2), 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, vget_low_f32(va3), 1);
+
+ const float32x4_t vb0123c2 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c2 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c2, vget_high_f32(va0), 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c2, vget_high_f32(va1), 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c2, vget_high_f32(va2), 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c2, vget_high_f32(va3), 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c2, vget_high_f32(va0), 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c2, vget_high_f32(va1), 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c2, vget_high_f32(va2), 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c2, vget_high_f32(va3), 0);
+
+ const float32x4_t vb0123c3 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c3 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c3, vget_high_f32(va0), 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c3, vget_high_f32(va1), 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c3, vget_high_f32(va2), 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c3, vget_high_f32(va3), 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c3, vget_high_f32(va0), 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c3, vget_high_f32(va1), 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c3, vget_high_f32(va2), 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c3, vget_high_f32(va3), 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8-neon-ld64.c b/src/f32-igemm/4x8-neon-ld64.c
new file mode 100644
index 0000000..591a076
--- /dev/null
+++ b/src/f32-igemm/4x8-neon-ld64.c
@@ -0,0 +1,208 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0);
+ const float32x4_t va1 = vld1q_dup_f32(a1);
+ const float32x4_t va2 = vld1q_dup_f32(a2);
+ const float32x4_t va3 = vld1q_dup_f32(a3);
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8-neonfma-ld128.c b/src/f32-igemm/4x8-neonfma-ld128.c
new file mode 100644
index 0000000..19d6ee5
--- /dev/null
+++ b/src/f32-igemm/4x8-neonfma-ld128.c
@@ -0,0 +1,299 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld128.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8__neonfma_ld128(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ const float32x4_t va0 = vld1q_f32(a0); a0 += 4;
+ const float32x4_t va1 = vld1q_f32(a1); a1 += 4;
+ const float32x4_t va2 = vld1q_f32(a2); a2 += 4;
+ const float32x4_t va3 = vld1q_f32(a3); a3 += 4;
+
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(vget_low_f32(va0), 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(vget_low_f32(va1), 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(vget_low_f32(va2), 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(vget_low_f32(va3), 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ #endif
+
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(vget_low_f32(va0), 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(vget_low_f32(va1), 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(vget_low_f32(va2), 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(vget_low_f32(va3), 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ #endif
+
+ const float32x4_t vb0123c2 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c2 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c2, va0, 2);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c2, va1, 2);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c2, va2, 2);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c2, va3, 2);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c2, va0, 2);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c2, va1, 2);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c2, va2, 2);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c2, va3, 2);
+ #else
+ const float32x4_t va0c2 = vdupq_lane_f32(vget_high_f32(va0), 0);
+ const float32x4_t va1c2 = vdupq_lane_f32(vget_high_f32(va1), 0);
+ const float32x4_t va2c2 = vdupq_lane_f32(vget_high_f32(va2), 0);
+ const float32x4_t va3c2 = vdupq_lane_f32(vget_high_f32(va3), 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c2, vb0123c2);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c2, vb0123c2);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c2, vb0123c2);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c2, vb4567c2);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c2, vb4567c2);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c2, vb4567c2);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c2, vb4567c2);
+ #endif
+
+ const float32x4_t vb0123c3 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c3 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123c3, va0, 3);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123c3, va1, 3);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123c3, va2, 3);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123c3, va3, 3);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567c3, va0, 3);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567c3, va1, 3);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567c3, va2, 3);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567c3, va3, 3);
+ #else
+ const float32x4_t va0c3 = vdupq_lane_f32(vget_high_f32(va0), 1);
+ const float32x4_t va1c3 = vdupq_lane_f32(vget_high_f32(va1), 1);
+ const float32x4_t va2c3 = vdupq_lane_f32(vget_high_f32(va2), 1);
+ const float32x4_t va3c3 = vdupq_lane_f32(vget_high_f32(va3), 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c3, vb0123c3);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c3, vb0123c3);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c3, vb0123c3);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c3, vb4567c3);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c3, vb4567c3);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c3, vb4567c3);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c3, vb4567c3);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const float32x4_t va0 = vld1q_dup_f32(a0); a0 += 1;
+ const float32x4_t va1 = vld1q_dup_f32(a1); a1 += 1;
+ const float32x4_t va2 = vld1q_dup_f32(a2); a2 += 1;
+ const float32x4_t va3 = vld1q_dup_f32(a3); a3 += 1;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8-neonfma-ld64.c b/src/f32-igemm/4x8-neonfma-ld64.c
new file mode 100644
index 0000000..d388df2
--- /dev/null
+++ b/src/f32-igemm/4x8-neonfma-ld64.c
@@ -0,0 +1,238 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0);
+ const float32x4_t va1 = vld1q_dup_f32(a1);
+ const float32x4_t va2 = vld1q_dup_f32(a2);
+ const float32x4_t va3 = vld1q_dup_f32(a3);
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8-psimd-loadsplat.c b/src/f32-igemm/4x8-psimd-loadsplat.c
new file mode 100644
index 0000000..fc4d284
--- /dev/null
+++ b/src/f32-igemm/4x8-psimd-loadsplat.c
@@ -0,0 +1,192 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/psimd-loadsplat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8-psimd-splat.c b/src/f32-igemm/4x8-psimd-splat.c
new file mode 100644
index 0000000..e42b33c
--- /dev/null
+++ b/src/f32-igemm/4x8-psimd-splat.c
@@ -0,0 +1,272 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/psimd-splat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ const psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ const psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ const psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+
+ const psimd_f32 va0c0 = psimd_splat0_f32(va0);
+ const psimd_f32 va1c0 = psimd_splat0_f32(va1);
+ const psimd_f32 va2c0 = psimd_splat0_f32(va2);
+ const psimd_f32 va3c0 = psimd_splat0_f32(va3);
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c0, vb4567c0);
+ const psimd_f32 va0c1 = psimd_splat1_f32(va0);
+ const psimd_f32 va1c1 = psimd_splat1_f32(va1);
+ const psimd_f32 va2c1 = psimd_splat1_f32(va2);
+ const psimd_f32 va3c1 = psimd_splat1_f32(va3);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c1, vb4567c1);
+ const psimd_f32 va0c2 = psimd_splat2_f32(va0);
+ const psimd_f32 va1c2 = psimd_splat2_f32(va1);
+ const psimd_f32 va2c2 = psimd_splat2_f32(va2);
+ const psimd_f32 va3c2 = psimd_splat2_f32(va3);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c2, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c2, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c2, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c2, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c2, vb4567c2);
+ const psimd_f32 va0c3 = psimd_splat3_f32(va0);
+ const psimd_f32 va1c3 = psimd_splat3_f32(va1);
+ const psimd_f32 va2c3 = psimd_splat3_f32(va2);
+ const psimd_f32 va3c3 = psimd_splat3_f32(va3);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c3, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c3, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c3, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c3, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c3, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c3, vb4567c3);
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8-sse-dup.c b/src/f32-igemm/4x8-sse-dup.c
new file mode 100644
index 0000000..5d1a17c
--- /dev/null
+++ b/src/f32-igemm/4x8-sse-dup.c
@@ -0,0 +1,276 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/sse-dup.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8__sse_dup(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ __m128 vacc1x0123 = vacc0x0123;
+ __m128 vacc1x4567 = vacc0x4567;
+ __m128 vacc2x0123 = vacc0x0123;
+ __m128 vacc2x4567 = vacc0x4567;
+ __m128 vacc3x0123 = vacc0x0123;
+ __m128 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+ const __m128 va1 = _mm_loadu_ps(a1);
+ a1 += 4;
+ const __m128 va2 = _mm_loadu_ps(a2);
+ a2 += 4;
+ const __m128 va3 = _mm_loadu_ps(a3);
+ a3 += 4;
+
+
+ const __m128 va0c0000 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va1c0000 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va2c0000 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va3c0000 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 0, 0, 0));
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c0000, vb0123c0));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c0000, vb0123c0));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c0000, vb0123c0));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c0000, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c0000, vb4567c0));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c0000, vb4567c0));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c0000, vb4567c0));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c0000, vb4567c0));
+
+ const __m128 va0c1111 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va1c1111 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va2c1111 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va3c1111 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(1, 1, 1, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c1111, vb0123c1));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c1111, vb0123c1));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c1111, vb0123c1));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c1111, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c1111, vb4567c1));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c1111, vb4567c1));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c1111, vb4567c1));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c1111, vb4567c1));
+
+ const __m128 va0c2222 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va1c2222 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va2c2222 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va3c2222 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(2, 2, 2, 2));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c2222, vb0123c2));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c2222, vb0123c2));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c2222, vb0123c2));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c2222, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c2222, vb4567c2));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c2222, vb4567c2));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c2222, vb4567c2));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c2222, vb4567c2));
+
+ const __m128 va0c3333 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(3, 3, 3, 3));
+ const __m128 va1c3333 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(3, 3, 3, 3));
+ const __m128 va2c3333 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(3, 3, 3, 3));
+ const __m128 va3c3333 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(3, 3, 3, 3));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0c3333, vb0123c3));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1c3333, vb0123c3));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2c3333, vb0123c3));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3c3333, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0c3333, vb4567c3));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1c3333, vb4567c3));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2c3333, vb4567c3));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3c3333, vb4567c3));
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8-sse-load1.c b/src/f32-igemm/4x8-sse-load1.c
new file mode 100644
index 0000000..b99e908
--- /dev/null
+++ b/src/f32-igemm/4x8-sse-load1.c
@@ -0,0 +1,192 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/sse-load1.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8__sse_load1(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ __m128 vacc1x0123 = vacc0x0123;
+ __m128 vacc1x4567 = vacc0x4567;
+ __m128 vacc2x0123 = vacc0x0123;
+ __m128 vacc2x4567 = vacc0x4567;
+ __m128 vacc3x0123 = vacc0x0123;
+ __m128 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ do {
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8s4-psimd.c b/src/f32-igemm/4x8s4-psimd.c
new file mode 100644
index 0000000..6f2a4b6
--- /dev/null
+++ b/src/f32-igemm/4x8s4-psimd.c
@@ -0,0 +1,272 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/psimd-s4.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c0);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c1);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c2);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c3);
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/4x8s4-sse.c b/src/f32-igemm/4x8s4-sse.c
new file mode 100644
index 0000000..643bec0
--- /dev/null
+++ b/src/f32-igemm/4x8s4-sse.c
@@ -0,0 +1,272 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/sse-shuffle.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_4x8s4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ __m128 vacc1x0123 = vacc0x0123;
+ __m128 vacc1x4567 = vacc0x4567;
+ __m128 vacc2x0123 = vacc0x0123;
+ __m128 vacc2x4567 = vacc0x4567;
+ __m128 vacc3x0123 = vacc0x0123;
+ __m128 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ __m128 va0 = _mm_loadu_ps(a0);
+ a0 += 4;
+ __m128 va1 = _mm_loadu_ps(a1);
+ a1 += 4;
+ __m128 va2 = _mm_loadu_ps(a2);
+ a2 += 4;
+ __m128 va3 = _mm_loadu_ps(a3);
+ a3 += 4;
+
+
+ const __m128 vb0123c0 = _mm_load_ps(w + 0);
+ const __m128 vb4567c0 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c0));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c0));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c0));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c0));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c0));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c0));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c0));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c0));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+ va1 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 3, 2, 1));
+ va2 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 3, 2, 1));
+ va3 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c1 = _mm_load_ps(w + 8);
+ const __m128 vb4567c1 = _mm_load_ps(w + 12);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c1));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c1));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c1));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c1));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c1));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c1));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c1));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c1));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+ va1 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 3, 2, 1));
+ va2 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 3, 2, 1));
+ va3 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c2 = _mm_load_ps(w + 16);
+ const __m128 vb4567c2 = _mm_load_ps(w + 20);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c2));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c2));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c2));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c2));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c2));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c2));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c2));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c2));
+
+ va0 = _mm_shuffle_ps(va0, va0, _MM_SHUFFLE(0, 3, 2, 1));
+ va1 = _mm_shuffle_ps(va1, va1, _MM_SHUFFLE(0, 3, 2, 1));
+ va2 = _mm_shuffle_ps(va2, va2, _MM_SHUFFLE(0, 3, 2, 1));
+ va3 = _mm_shuffle_ps(va3, va3, _MM_SHUFFLE(0, 3, 2, 1));
+
+ const __m128 vb0123c3 = _mm_load_ps(w + 24);
+ const __m128 vb4567c3 = _mm_load_ps(w + 28);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123c3));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123c3));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123c3));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123c3));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567c3));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567c3));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567c3));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567c3));
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ const __m128 va0 = _mm_load1_ps(a0);
+ a0 += 1;
+ const __m128 va1 = _mm_load1_ps(a1);
+ a1 += 1;
+ const __m128 va2 = _mm_load1_ps(a2);
+ a2 += 1;
+ const __m128 va3 = _mm_load1_ps(a3);
+ a3 += 1;
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0, vb4567));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1, vb0123));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1, vb4567));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2, vb0123));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2, vb4567));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3, vb0123));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3, vb4567));
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/5x8-aarch64-neonfma-cortex-a75.S b/src/f32-igemm/5x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..6835c5e
--- /dev/null
+++ b/src/f32-igemm/5x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,603 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_igemm_ukernel_5x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# size_t ks, x3 / x9
+# const float**restrict a, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x10
+# size_t a_offset, [sp + 8] -> x11
+# const float* zero, [sp + 16] -> x12
+# const xnn_f32_output_params params [sp + 24] -> x8
+
+# 5x8 strips the following out of 5x8
+# x23 a5
+# x7 c5 x13 unused
+# A5 v10 v11
+# C v30 v31
+
+# d8-d15 need to be preserved if used.
+# x19-x30 need to be preserved if used. x18 is reserved for OS.
+
+# A pointers
+# x14 a0
+# x15 a1
+# x20 a2
+# x21 a3
+# x8 a4
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x13 c3
+# x7 c4
+
+# Vector register usage
+# A0 v0 v1
+# A1 v2 v3
+# A2 v4 v5
+# A3 v6 v7
+# A4 v8 v9
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# Clamp v30 v31
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_5x8__aarch64_neonfma_cortex_a75
+
+ # Clamp C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -64]!
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d12, d13, [sp, 16]
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d14, d15, [sp, 32]
+ ADD x13, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x13, x17, x13, LO // c3 = c2
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 80]
+ ADD x7, x13, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x7, x13, x7, LS // c4 = c3
+
+ # Save x20,x21 on stack
+ STP x20, x21, [sp, 48]
+
+ # Load clamp values
+ LD2R {v30.4s, v31.4s}, [x8]
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp, 64]
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v24.16b, v20.16b
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v26.16b, v20.16b
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v28.16b, v20.16b
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+
+ MOV x9, x3 // p = ks
+
+1:
+ # Load next 5 A pointers
+ LDP x14, x15, [x4], 16
+ LDP x20, x21, [x4], 16
+ LDR x8, [x4], 8
+
+ CMP x14, x12 // if a0 == zero
+ ADD x14, x14, x11 // a0 += a_offset
+ CSEL x14, x12, x14, EQ // a0 = zero, else += a0 + a_offset
+ CMP x15, x12 // if a1 == zero
+ ADD x15, x15, x11 // a1 += a_offset
+ CSEL x15, x12, x15, EQ // a1 = zero, else += a1 + a_offset
+ CMP x20, x12 // if a2 == zero
+ ADD x20, x20, x11 // a2 += a_offset
+ CSEL x20, x12, x20, EQ // a2 = zero, else += a2 + a_offset
+ CMP x21, x12 // if a3 == zero
+ ADD x21, x21, x11 // a3 += a_offset
+ CSEL x21, x12, x21, EQ // a3 = zero, else += a3 + a_offset
+ CMP x8, x12 // if a4 == zero
+ ADD x8, x8, x11 // a4 += a_offset
+ CSEL x8, x12, x8, EQ // a4 = zero, else += a4 + a_offset
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 5f
+
+ # Prologue - loads for main loop of 96 FMA
+ LDR q0, [x14], 16
+ LDR q2, [x15], 16
+ LDR q4, [x20], 16
+ LDR q6, [x21], 16
+ LDR q8, [x8], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 3f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 80 FMA + 5 LDP A + 8 LDP B
+2:
+ # First group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ LDR q1, [x14], 16 // Load next 5 A
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ LDR q3, [x15], 16
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ LDR q5, [x20], 16
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ LDR q7, [x21], 16
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ LDR q9, [x8], 16
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v1.s[0]
+ FMLA v22.4s, v12.4s, v3.s[0]
+ FMLA v24.4s, v12.4s, v5.s[0]
+ LDR q0, [x14], 16 // Load next 5 A
+ FMLA v26.4s, v12.4s, v7.s[0]
+ FMLA v28.4s, v12.4s, v9.s[0]
+ FMLA v21.4s, v13.4s, v1.s[0]
+ LDR q2, [x15], 16
+ FMLA v23.4s, v13.4s, v3.s[0]
+ FMLA v25.4s, v13.4s, v5.s[0]
+ FMLA v27.4s, v13.4s, v7.s[0]
+ LDR q4, [x20], 16
+ FMLA v29.4s, v13.4s, v9.s[0]
+
+ FMLA v20.4s, v14.4s, v1.s[1]
+ FMLA v22.4s, v14.4s, v3.s[1]
+ LDR q6, [x21], 16
+ FMLA v24.4s, v14.4s, v5.s[1]
+ FMLA v26.4s, v14.4s, v7.s[1]
+ FMLA v28.4s, v14.4s, v9.s[1]
+ LDR q8, [x8], 16
+ FMLA v21.4s, v15.4s, v1.s[1]
+ FMLA v23.4s, v15.4s, v3.s[1]
+ FMLA v25.4s, v15.4s, v5.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v7.s[1]
+ FMLA v29.4s, v15.4s, v9.s[1]
+
+ FMLA v20.4s, v16.4s, v1.s[2]
+ FMLA v22.4s, v16.4s, v3.s[2]
+ FMLA v24.4s, v16.4s, v5.s[2]
+ FMLA v26.4s, v16.4s, v7.s[2]
+ FMLA v28.4s, v16.4s, v9.s[2]
+ FMLA v21.4s, v17.4s, v1.s[2]
+ FMLA v23.4s, v17.4s, v3.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v25.4s, v17.4s, v5.s[2]
+ FMLA v27.4s, v17.4s, v7.s[2]
+ FMLA v29.4s, v17.4s, v9.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v1.s[3]
+ FMLA v22.4s, v18.4s, v3.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v5.s[3]
+ FMLA v26.4s, v18.4s, v7.s[3]
+ FMLA v28.4s, v18.4s, v9.s[3]
+ FMLA v21.4s, v19.4s, v1.s[3]
+ FMLA v23.4s, v19.4s, v3.s[3]
+ FMLA v25.4s, v19.4s, v5.s[3]
+ FMLA v27.4s, v19.4s, v7.s[3]
+ FMLA v29.4s, v19.4s, v9.s[3]
+ B.HS 2b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 80 FMA + 5 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+3:
+ # First group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ LDR q1, [x14], 16 // Load next 5 A
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ LDR q3, [x15], 16
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ LDR q5, [x20], 16
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ LDR q7, [x21], 16
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ LDR q9, [x8], 16
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 40 FMA.
+ FMLA v20.4s, v12.4s, v1.s[0]
+ FMLA v22.4s, v12.4s, v3.s[0]
+ FMLA v24.4s, v12.4s, v5.s[0]
+ FMLA v26.4s, v12.4s, v7.s[0]
+ FMLA v28.4s, v12.4s, v9.s[0]
+ FMLA v21.4s, v13.4s, v1.s[0]
+ FMLA v23.4s, v13.4s, v3.s[0]
+ FMLA v25.4s, v13.4s, v5.s[0]
+ FMLA v27.4s, v13.4s, v7.s[0]
+ FMLA v29.4s, v13.4s, v9.s[0]
+
+ FMLA v20.4s, v14.4s, v1.s[1]
+ FMLA v22.4s, v14.4s, v3.s[1]
+ FMLA v24.4s, v14.4s, v5.s[1]
+ FMLA v26.4s, v14.4s, v7.s[1]
+ FMLA v28.4s, v14.4s, v9.s[1]
+ FMLA v21.4s, v15.4s, v1.s[1]
+ FMLA v23.4s, v15.4s, v3.s[1]
+ FMLA v25.4s, v15.4s, v5.s[1]
+ FMLA v27.4s, v15.4s, v7.s[1]
+ FMLA v29.4s, v15.4s, v9.s[1]
+
+ FMLA v20.4s, v16.4s, v1.s[2]
+ FMLA v22.4s, v16.4s, v3.s[2]
+ FMLA v24.4s, v16.4s, v5.s[2]
+ FMLA v26.4s, v16.4s, v7.s[2]
+ FMLA v28.4s, v16.4s, v9.s[2]
+ FMLA v21.4s, v17.4s, v1.s[2]
+ FMLA v23.4s, v17.4s, v3.s[2]
+ FMLA v25.4s, v17.4s, v5.s[2]
+ FMLA v27.4s, v17.4s, v7.s[2]
+ FMLA v29.4s, v17.4s, v9.s[2]
+
+ FMLA v20.4s, v18.4s, v1.s[3]
+ FMLA v22.4s, v18.4s, v3.s[3]
+ FMLA v24.4s, v18.4s, v5.s[3]
+ FMLA v26.4s, v18.4s, v7.s[3]
+ FMLA v28.4s, v18.4s, v9.s[3]
+ FMLA v21.4s, v19.4s, v1.s[3]
+ FMLA v23.4s, v19.4s, v3.s[3]
+ FMLA v25.4s, v19.4s, v5.s[3]
+ FMLA v27.4s, v19.4s, v7.s[3]
+ FMLA v29.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ B.NE 5f
+
+4:
+ # ks loop
+ SUBS x9, x9, 40 // ks -= MR * sizeof(void*)
+ B.NE 1b
+
+ # Clamp
+ FMIN v20.4s, v20.4s, v30.4s
+ FMIN v21.4s, v21.4s, v30.4s
+ FMIN v22.4s, v22.4s, v30.4s
+ FMIN v23.4s, v23.4s, v30.4s
+ FMIN v24.4s, v24.4s, v30.4s
+ FMIN v25.4s, v25.4s, v30.4s
+ FMIN v26.4s, v26.4s, v30.4s
+ FMIN v27.4s, v27.4s, v30.4s
+ FMIN v28.4s, v28.4s, v30.4s
+ FMIN v29.4s, v29.4s, v30.4s
+ FMAX v20.4s, v20.4s, v31.4s
+ FMAX v21.4s, v21.4s, v31.4s
+ FMAX v22.4s, v22.4s, v31.4s
+ FMAX v23.4s, v23.4s, v31.4s
+ FMAX v24.4s, v24.4s, v31.4s
+ FMAX v25.4s, v25.4s, v31.4s
+ FMAX v26.4s, v26.4s, v31.4s
+ FMAX v27.4s, v27.4s, v31.4s
+ FMAX v28.4s, v28.4s, v31.4s
+ FMAX v29.4s, v29.4s, v31.4s
+
+ # Store full 5 x 8
+ CMP x1, 8
+ B.LO 8f
+
+ STP q28, q29, [x7]
+ ADD x7, x7, x10
+ STP q26, q27, [x13]
+ ADD x13, x13, x10
+ STP q24, q25, [x17]
+ ADD x17, x17, x10
+ STP q22, q23, [x16]
+ ADD x16, x16, x10
+ STP q20, q21, [x6]
+ ADD x6, x6, x10
+
+ SUB x4, x4, x3 // a -= ks
+
+ # nc loop
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore x20,x21 from stack
+ LDP x20, x21, [sp, 48]
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 32]
+ LDP d12, d13, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+5:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 6f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x14], 16
+ LDR q2, [x15], 16
+ LDR q4, [x20], 16
+ LDR q6, [x21], 16
+ LDR q8, [x8], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v2.s[2]
+ FMLA v24.4s, v16.4s, v4.s[2]
+ FMLA v26.4s, v16.4s, v6.s[2]
+ FMLA v28.4s, v16.4s, v8.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v2.s[2]
+ FMLA v25.4s, v17.4s, v4.s[2]
+ FMLA v27.4s, v17.4s, v6.s[2]
+ FMLA v29.4s, v17.4s, v8.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v2.s[3]
+ FMLA v24.4s, v18.4s, v4.s[3]
+ FMLA v26.4s, v18.4s, v6.s[3]
+ FMLA v28.4s, v18.4s, v8.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v2.s[3]
+ FMLA v25.4s, v19.4s, v4.s[3]
+ FMLA v27.4s, v19.4s, v6.s[3]
+ FMLA v29.4s, v19.4s, v8.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+6:
+ TBZ x0, 3, 7f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x14], 8
+ LDR d2, [x15], 8
+ LDR d4, [x20], 8
+ LDR d6, [x21], 8
+ LDR d8, [x8], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v2.s[1]
+ FMLA v24.4s, v14.4s, v4.s[1]
+ FMLA v26.4s, v14.4s, v6.s[1]
+ FMLA v28.4s, v14.4s, v8.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v2.s[1]
+ FMLA v25.4s, v15.4s, v4.s[1]
+ FMLA v27.4s, v15.4s, v6.s[1]
+ FMLA v29.4s, v15.4s, v8.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+7:
+ TBZ x0, 2, 4b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x14], 4
+ LDR s2, [x15], 4
+ LDR s4, [x20], 4
+ LDR s6, [x21], 4
+ LDR s8, [x8], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v2.s[0]
+ FMLA v24.4s, v12.4s, v4.s[0]
+ FMLA v26.4s, v12.4s, v6.s[0]
+ FMLA v28.4s, v12.4s, v8.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v2.s[0]
+ FMLA v25.4s, v13.4s, v4.s[0]
+ FMLA v27.4s, v13.4s, v6.s[0]
+ FMLA v29.4s, v13.4s, v8.s[0]
+ B 4b
+
+ # Store odd width
+8:
+ TBZ x1, 2, 9f
+ STR q28, [x7], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x13], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+9:
+ TBZ x1, 1, 10f
+ STR d28, [x7], 8
+ DUP d28, v28.d[1]
+ STR d26, [x13], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s28, [x7]
+ STR s26, [x13]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+11:
+ # Restore x20,x21 from stack
+ LDP x20, x21, [sp, 48]
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 32]
+ LDP d12, d13, [sp, 16]
+ LDP d8, d9, [sp], 64
+ RET
+
+END_FUNCTION xnn_f32_igemm_ukernel_5x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-igemm/6x8-aarch64-neonfma-cortex-a57.S b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a57.S
new file mode 100644
index 0000000..cbe2b95
--- /dev/null
+++ b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a57.S
@@ -0,0 +1,683 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a57(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# size_t ks, x3 / x9
+# const float**restrict a, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x10
+# size_t a_offset, [sp + 8] -> x11
+# const float* zero, [sp + 16] -> x12
+# const xnn_f32_output_params params [sp + 24] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x14 a0
+# x15 a1
+# x20 a2
+# x21 a3
+# x22 a4
+# x23 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+# A57 kernel based on A75 but with PRFM removed from main loop
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a57
+
+ # Clamp C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -96]!
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Save x20,x21,x22,x23 on stack
+ STP x20, x21, [sp, 64]
+ STP x22, x23, [sp, 80]
+
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp, 96]
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 112]
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v24.16b, v20.16b
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v26.16b, v20.16b
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v28.16b, v20.16b
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v30.16b, v20.16b
+ MOV v31.16b, v21.16b
+
+ MOV x9, x3 // p = ks
+
+1:
+ # Load next 6 A pointers
+ LDP x14, x15, [x4], 16
+ LDP x20, x21, [x4], 16
+ LDP x22, x23, [x4], 16
+
+ CMP x14, x12 // if a0 == zero
+ ADD x14, x14, x11 // a0 += a_offset
+ CSEL x14, x12, x14, EQ // a0 = zero, else += a0 + a_offset
+ CMP x15, x12 // if a1 == zero
+ ADD x15, x15, x11 // a1 += a_offset
+ CSEL x15, x12, x15, EQ // a1 = zero, else += a1 + a_offset
+ CMP x20, x12 // if a2 == zero
+ ADD x20, x20, x11 // a2 += a_offset
+ CSEL x20, x12, x20, EQ // a2 = zero, else += a2 + a_offset
+ CMP x21, x12 // if a3 == zero
+ ADD x21, x21, x11 // a3 += a_offset
+ CSEL x21, x12, x21, EQ // a3 = zero, else += a3 + a_offset
+ CMP x22, x12 // if a4 == zero
+ ADD x22, x22, x11 // a4 += a_offset
+ CSEL x22, x12, x22, EQ // a4 = zero, else += a4 + a_offset
+ CMP x23, x12 // if a5 == zero
+ ADD x23, x23, x11 // a5 += a_offset
+ CSEL x23, x12, x23, EQ // a5 = zero, else += a5 + a_offset
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 5f
+
+ # Prologue - loads for main loop of 96 FMA
+ LDR q0, [x14], 16
+ LDR q1, [x15], 16
+ LDR q2, [x20], 16
+ LDR q3, [x21], 16
+ LDR q4, [x22], 16
+ LDR q5, [x23], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 3f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+2:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x14], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x15], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x20], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x21], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x22], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x23], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ LDR q0, [x14], 16 // Load next 6 A
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ LDR q1, [x15], 16
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ LDR q2, [x20], 16
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+ LDR q3, [x21], 16
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ LDR q4, [x22], 16
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ LDR q5, [x23], 16
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 2b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+3:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x14], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x15], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x20], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x21], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x22], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x23], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 5f
+
+4:
+ # ks loop
+ SUBS x9, x9, 48 // ks -= MR * sizeof(void*)
+ B.NE 1b
+
+ # Clamp
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 8f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x10
+ STP q28, q29, [x13]
+ ADD x13, x13, x10
+ STP q26, q27, [x18]
+ ADD x18, x18, x10
+ STP q24, q25, [x17]
+ ADD x17, x17, x10
+ STP q22, q23, [x16]
+ ADD x16, x16, x10
+ STP q20, q21, [x6]
+ ADD x6, x6, x10
+
+ SUB x4, x4, x3 // a -= ks
+
+ # nc loop
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore x20,x21,x22,x23 from stack
+ LDP x22, x23, [sp, 80]
+ LDP x20, x21, [sp, 64]
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 96
+ RET
+
+5:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 6f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x14], 16
+ LDR q1, [x15], 16
+ LDR q2, [x20], 16
+ LDR q3, [x21], 16
+ LDR q4, [x22], 16
+ LDR q5, [x23], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+6:
+ TBZ x0, 3, 7f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x14], 8
+ LDR d1, [x15], 8
+ LDR d2, [x20], 8
+ LDR d3, [x21], 8
+ LDR d4, [x22], 8
+ LDR d5, [x23], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+7:
+ TBZ x0, 2, 4b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x14], 4
+ LDR s1, [x15], 4
+ LDR s2, [x20], 4
+ LDR s3, [x21], 4
+ LDR s4, [x22], 4
+ LDR s5, [x23], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 4b
+
+ # Store odd width
+8:
+ TBZ x1, 2, 9f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+9:
+ TBZ x1, 1, 10f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+11:
+ # Restore x20,x21,x22,x23 from stack
+ LDP x22, x23, [sp, 80]
+ LDP x20, x21, [sp, 64]
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 96
+ RET
+
+END_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a57
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-igemm/6x8-aarch64-neonfma-cortex-a73.S b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a73.S
new file mode 100644
index 0000000..aaf8196
--- /dev/null
+++ b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a73.S
@@ -0,0 +1,683 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a73(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# size_t ks, x3 / x9
+# const float**restrict a, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x10
+# size_t a_offset, [sp + 8] -> x11
+# const float* zero, [sp + 16] -> x12
+# const xnn_f32_output_params params [sp + 24] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x14 a0
+# x15 a1
+# x20 a2
+# x21 a3
+# x22 a4
+# x23 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a73
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp]
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 16]
+
+ # Clamp C pointers
+ STP d8, d9, [sp, -96]!
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Save x20,x21,x22,x23 on stack
+ STP x20, x21, [sp, 64]
+ STP x22, x23, [sp, 80]
+
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 112]
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp, 96]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+0:
+ # Load initial bias from w into accumulators
+ LD1 {v20.16b, v21.16b}, [x5], 32
+ MOV v22.16b, v20.16b
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v24.16b, v20.16b
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v26.16b, v20.16b
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v28.16b, v20.16b
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v30.16b, v20.16b
+ MOV v31.16b, v21.16b
+
+ MOV x9, x3 // p = ks
+
+1:
+ # Load next 6 A pointers
+ LDP x14, x15, [x4], 16
+ LDP x20, x21, [x4], 16
+ LDP x22, x23, [x4], 16
+
+ CMP x14, x12 // if a0 == zero
+ ADD x14, x14, x11 // a0 += a_offset
+ CSEL x14, x12, x14, EQ // a0 = zero, else += a0 + a_offset
+ CMP x15, x12 // if a1 == zero
+ ADD x15, x15, x11 // a1 += a_offset
+ CSEL x15, x12, x15, EQ // a1 = zero, else += a1 + a_offset
+ CMP x20, x12 // if a2 == zero
+ ADD x20, x20, x11 // a2 += a_offset
+ CSEL x20, x12, x20, EQ // a2 = zero, else += a2 + a_offset
+ CMP x21, x12 // if a3 == zero
+ ADD x21, x21, x11 // a3 += a_offset
+ CSEL x21, x12, x21, EQ // a3 = zero, else += a3 + a_offset
+ CMP x22, x12 // if a4 == zero
+ ADD x22, x22, x11 // a4 += a_offset
+ CSEL x22, x12, x22, EQ // a4 = zero, else += a4 + a_offset
+ CMP x23, x12 // if a5 == zero
+ ADD x23, x23, x11 // a5 += a_offset
+ CSEL x23, x12, x23, EQ // a5 = zero, else += a5 + a_offset
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 5f
+
+ # Prologue - loads for main loop of 96 FMA
+ # load A0 to A4 but not A5
+ LDP q0, q6, [x14], 32
+ LDP q1, q7, [x15], 32
+ LDP q2, q8, [x20], 32
+ LDP q3, q9, [x21], 32
+ LDP q4, q10, [x22], 32
+ # load first set of B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 3f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+2:
+ # First group of 4 A. 48 FMA. Loads A5
+
+ LDP q5, q11, [x23], 32
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Second group of 4 A. 48 FMA. Loads A0 - A4
+
+ LDP q16, q17, [x5], 32
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v20.4s, v18.4s, v6.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ LDP q0, q6, [x14], 32
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ LDP q1, q7, [x15], 32
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ LDP q2, q8, [x20], 32
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ LDP q3, q9, [x21], 32
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ LDP q4, q10, [x22], 32
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ SUBS x0, x0, 32
+ FMLA v31.4s, v17.4s, v11.s[2]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 2b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+3:
+ # First group of 4 A. 48 FMA. Loads A5
+
+ LDP q5, q11, [x23], 32
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ LDP q16, q17, [x5], 32
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ LDP q12, q13, [x5], 32
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ LDP q14, q15, [x5], 32
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Second group of 4 A. 48 FMA. No A Loads, No last B load
+
+ LDP q16, q17, [x5], 32
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ LDP q18, q19, [x5], 32
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ # Last part of epilogue has loads removed.
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 5f
+
+4:
+ # ks loop
+ SUBS x9, x9, 48 // ks -= MR * sizeof(void*)
+ B.NE 1b
+
+ # Clamp
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 8f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x10
+ STP q28, q29, [x13]
+ ADD x13, x13, x10
+ STP q26, q27, [x18]
+ ADD x18, x18, x10
+ STP q24, q25, [x17]
+ ADD x17, x17, x10
+ STP q22, q23, [x16]
+ ADD x16, x16, x10
+ STP q20, q21, [x6]
+ ADD x6, x6, x10
+
+ SUB x4, x4, x3 // a -= ks
+
+ # nc loop
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore x20,x21,x22,x23 from stack
+ LDP x22, x23, [sp, 80]
+ LDP x20, x21, [sp, 64]
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 96
+ RET
+
+5:
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 6f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x14], 16
+ LDR q1, [x15], 16
+ LDR q2, [x20], 16
+ LDR q3, [x21], 16
+ LDR q4, [x22], 16
+ LDR q5, [x23], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+6:
+ TBZ x0, 3, 7f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x14], 8
+ LDR d1, [x15], 8
+ LDR d2, [x20], 8
+ LDR d3, [x21], 8
+ LDR d4, [x22], 8
+ LDR d5, [x23], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+7:
+ TBZ x0, 2, 4b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x14], 4
+ LDR s1, [x15], 4
+ LDR s2, [x20], 4
+ LDR s3, [x21], 4
+ LDR s4, [x22], 4
+ LDR s5, [x23], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 4b
+
+ # Store odd width
+8:
+ TBZ x1, 2, 9f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+9:
+ TBZ x1, 1, 10f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+11:
+ # Restore x20,x21,x22,x23 from stack
+ LDP x22, x23, [sp, 80]
+ LDP x20, x21, [sp, 64]
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 96
+ RET
+
+END_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a73
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-igemm/6x8-aarch64-neonfma-cortex-a75.S b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a75.S
new file mode 100644
index 0000000..7177b25
--- /dev/null
+++ b/src/f32-igemm/6x8-aarch64-neonfma-cortex-a75.S
@@ -0,0 +1,685 @@
+# Copyright 2019 Google LLC
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+#include <xnnpack/assembly.h>
+
+# void xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a75(
+# size_t mr, x0
+# size_t nc, x1
+# size_t kc, x2 / x0
+# size_t ks, x3 / x9
+# const float**restrict a, x4
+# const void*restrict w, x5
+# uint8_t*restrict c, x6
+# size_t cm_stride, x7
+# size_t cn_stride, [sp] -> x10
+# size_t a_offset, [sp + 8] -> x11
+# const float* zero, [sp + 16] -> x12
+# const xnn_f32_output_params params [sp + 24] -> x8
+
+# d8-d15 need to be preserved if used.
+# x19-30 need to be preserved if used.
+
+# A pointers
+# x14 a0
+# x15 a1
+# x20 a2
+# x21 a3
+# x22 a4
+# x23 a5
+
+# C pointers
+# x6 c0
+# x16 c1
+# x17 c2
+# x18 c3
+# x13 c4
+# x7 c5
+
+# Vector register usage
+# A0 v0 v6
+# A1 v1 v7
+# A2 v2 v8
+# A3 v3 v9
+# A4 v4 v10
+# A5 v5 v11
+# B v12 v13 v14 v15
+# B v16 v17 v18 v19
+# C v20 v21
+# C v22 v23
+# C v24 v25
+# C v26 v27
+# C v28 v29
+# C v30 v31
+# Clamp v6 v7
+
+BEGIN_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a75
+
+ # Clamp C pointers / Save d8-d15 on stack
+ STP d8, d9, [sp, -96]!
+ ADD x16, x6, x7 // c1 = c0 + cm_stride
+ CMP x0, 2 // if mr < 2
+ CSEL x16, x6, x16, LO // c1 = c0
+
+ STP d10, d11, [sp, 16]
+ ADD x17, x16, x7 // c2 = c1 + cm_stride
+ // if mr <= 2
+ CSEL x17, x16, x17, LS // c2 = c1
+
+ STP d12, d13, [sp, 32]
+ ADD x18, x17, x7 // c3 = c2 + cm_stride
+ CMP x0, 4 // if mr < 4
+ CSEL x18, x17, x18, LO // c3 = c2
+
+ STP d14, d15, [sp, 48]
+ ADD x13, x18, x7 // c4 = c3 + cm_stride
+ // if mr <= 5
+ CSEL x13, x18, x13, LS // c4 = c3
+
+ # Save x20,x21,x22,x23 on stack
+ STP x20, x21, [sp, 64]
+ STP x22, x23, [sp, 80]
+
+ ADD x7, x13, x7 // c5 = c4 + cm_stride
+ CMP x0, 6 // if mr < 6
+ CSEL x7, x13, x7, LO // c5 = c4
+
+ # Load cn_stride, a_offset
+ LDP x10, x11, [sp, 96]
+
+ # Load zero, clamping params pointer
+ LDP x12, x8, [sp, 112]
+
+0:
+ # Load initial bias from w into accumulators
+ LDP q20, q21, [x5], 32
+ MOV v22.16b, v20.16b
+ MOV v23.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 0] // Prefetch B
+ MOV v24.16b, v20.16b
+ MOV v25.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 64]
+ MOV v26.16b, v20.16b
+ MOV v27.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 128]
+ MOV v28.16b, v20.16b
+ MOV v29.16b, v21.16b
+ PRFM PLDL1KEEP, [x5, 192]
+ MOV v30.16b, v20.16b
+ MOV v31.16b, v21.16b
+
+ MOV x9, x3 // p = ks
+
+1:
+ # Load next 6 A pointers
+ LDP x14, x15, [x4], 16
+ LDP x20, x21, [x4], 16
+ LDP x22, x23, [x4], 16
+
+ CMP x14, x12 // if a0 == zero
+ ADD x14, x14, x11 // a0 += a_offset
+ CSEL x14, x12, x14, EQ // a0 = zero, else += a0 + a_offset
+ CMP x15, x12 // if a1 == zero
+ ADD x15, x15, x11 // a1 += a_offset
+ CSEL x15, x12, x15, EQ // a1 = zero, else += a1 + a_offset
+ CMP x20, x12 // if a2 == zero
+ ADD x20, x20, x11 // a2 += a_offset
+ CSEL x20, x12, x20, EQ // a2 = zero, else += a2 + a_offset
+ CMP x21, x12 // if a3 == zero
+ ADD x21, x21, x11 // a3 += a_offset
+ CSEL x21, x12, x21, EQ // a3 = zero, else += a3 + a_offset
+ CMP x22, x12 // if a4 == zero
+ ADD x22, x22, x11 // a4 += a_offset
+ CSEL x22, x12, x22, EQ // a4 = zero, else += a4 + a_offset
+ CMP x23, x12 // if a5 == zero
+ ADD x23, x23, x11 // a5 += a_offset
+ CSEL x23, x12, x23, EQ // a5 = zero, else += a5 + a_offset
+
+ # Is there at least 8 floats (32 bytes) for prologue + epilogue?
+ SUBS x0, x2, 32 // k = kc - 32
+ B.LO 5f
+
+ # Prologue - loads for main loop of 96 FMA
+ LDR q0, [x14], 16
+ LDR q1, [x15], 16
+ LDR q2, [x20], 16
+ LDR q3, [x21], 16
+ LDR q4, [x22], 16
+ LDR q5, [x23], 16
+ LDP q12, q13, [x5], 32 // Fetch 3 B (4th deferred)
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+
+ # Is there at least 8 floats (32 bytes) for main loop?
+ SUBS x0, x0, 32
+ B.LO 3f
+
+ # Main loop - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+2:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x14], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x15], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x20], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x21], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x22], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x23], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ LDR q0, [x14], 16 // Load next 6 A
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ LDR q1, [x15], 16
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ LDR q2, [x20], 16
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+ LDR q3, [x21], 16
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ LDR q4, [x22], 16
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ LDR q5, [x23], 16
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ LDP q12, q13, [x5], 32 // Load next 3 B (not last)
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+ LDP q16, q17, [x5], 32
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ SUBS x0, x0, 32
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.HS 2b
+
+ # Epilogue - 8 floats of A (32 bytes)
+ # 96 FMA + 6 LDP A + 8 LDP B
+ # First block same as main loop. Second block has no preloads.
+3:
+ # First group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v0.s[0]
+ LDP q18, q19, [x5], 32 // Load last B
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+
+ FMLA v31.4s, v13.4s, v5.s[0]
+ FMLA v20.4s, v14.4s, v0.s[1]
+ PRFM PLDL1KEEP, [x5, 128] // Prefetch B
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ PRFM PLDL1KEEP, [x5, 256]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ LDR q6, [x14], 16 // Load next 6 A
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+ LDR q7, [x15], 16
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ LDR q8, [x20], 16
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ LDR q9, [x21], 16
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ LDR q10, [x22], 16
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+ LDR q11, [x23], 16
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ LDP q12, q13, [x5], 32 // Load 4 B
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ LDP q14, q15, [x5], 32
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ LDP q16, q17, [x5], 32
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+ LDP q18, q19, [x5], 32
+
+ # Second group of 4 A. 48 FMA.
+ FMLA v20.4s, v12.4s, v6.s[0]
+ FMLA v22.4s, v12.4s, v7.s[0]
+ FMLA v24.4s, v12.4s, v8.s[0]
+ FMLA v26.4s, v12.4s, v9.s[0]
+ FMLA v28.4s, v12.4s, v10.s[0]
+ FMLA v30.4s, v12.4s, v11.s[0]
+ FMLA v21.4s, v13.4s, v6.s[0]
+ FMLA v23.4s, v13.4s, v7.s[0]
+ FMLA v25.4s, v13.4s, v8.s[0]
+ FMLA v27.4s, v13.4s, v9.s[0]
+ FMLA v29.4s, v13.4s, v10.s[0]
+ FMLA v31.4s, v13.4s, v11.s[0]
+
+ FMLA v20.4s, v14.4s, v6.s[1]
+ FMLA v22.4s, v14.4s, v7.s[1]
+ FMLA v24.4s, v14.4s, v8.s[1]
+ FMLA v26.4s, v14.4s, v9.s[1]
+ FMLA v28.4s, v14.4s, v10.s[1]
+ FMLA v30.4s, v14.4s, v11.s[1]
+ FMLA v21.4s, v15.4s, v6.s[1]
+ FMLA v23.4s, v15.4s, v7.s[1]
+ FMLA v25.4s, v15.4s, v8.s[1]
+ FMLA v27.4s, v15.4s, v9.s[1]
+ FMLA v29.4s, v15.4s, v10.s[1]
+ FMLA v31.4s, v15.4s, v11.s[1]
+
+ FMLA v20.4s, v16.4s, v6.s[2]
+ FMLA v22.4s, v16.4s, v7.s[2]
+ FMLA v24.4s, v16.4s, v8.s[2]
+ FMLA v26.4s, v16.4s, v9.s[2]
+ FMLA v28.4s, v16.4s, v10.s[2]
+ FMLA v30.4s, v16.4s, v11.s[2]
+ FMLA v21.4s, v17.4s, v6.s[2]
+ FMLA v23.4s, v17.4s, v7.s[2]
+ FMLA v25.4s, v17.4s, v8.s[2]
+ FMLA v27.4s, v17.4s, v9.s[2]
+ FMLA v29.4s, v17.4s, v10.s[2]
+ FMLA v31.4s, v17.4s, v11.s[2]
+
+ FMLA v20.4s, v18.4s, v6.s[3]
+ FMLA v22.4s, v18.4s, v7.s[3]
+ FMLA v24.4s, v18.4s, v8.s[3]
+ FMLA v26.4s, v18.4s, v9.s[3]
+ FMLA v28.4s, v18.4s, v10.s[3]
+ FMLA v30.4s, v18.4s, v11.s[3]
+ FMLA v21.4s, v19.4s, v6.s[3]
+ FMLA v23.4s, v19.4s, v7.s[3]
+
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ FMLA v25.4s, v19.4s, v8.s[3]
+ FMLA v27.4s, v19.4s, v9.s[3]
+ # Is there a remainder?- 4 floats of A (16 bytes) or less
+ TST x0, 31
+ FMLA v29.4s, v19.4s, v10.s[3]
+ FMLA v31.4s, v19.4s, v11.s[3]
+ B.NE 5f
+
+4:
+ # ks loop
+ SUBS x9, x9, 48 // ks -= MR * sizeof(void*)
+ B.NE 1b
+
+ # Clamp
+ FMIN v20.4s, v20.4s, v6.4s
+ FMIN v21.4s, v21.4s, v6.4s
+ FMIN v22.4s, v22.4s, v6.4s
+ FMIN v23.4s, v23.4s, v6.4s
+ FMIN v24.4s, v24.4s, v6.4s
+ FMIN v25.4s, v25.4s, v6.4s
+ FMIN v26.4s, v26.4s, v6.4s
+ FMIN v27.4s, v27.4s, v6.4s
+ FMIN v28.4s, v28.4s, v6.4s
+ FMIN v29.4s, v29.4s, v6.4s
+ FMIN v30.4s, v30.4s, v6.4s
+ FMIN v31.4s, v31.4s, v6.4s
+ FMAX v20.4s, v20.4s, v7.4s
+ FMAX v21.4s, v21.4s, v7.4s
+ FMAX v22.4s, v22.4s, v7.4s
+ FMAX v23.4s, v23.4s, v7.4s
+ FMAX v24.4s, v24.4s, v7.4s
+ FMAX v25.4s, v25.4s, v7.4s
+ FMAX v26.4s, v26.4s, v7.4s
+ FMAX v27.4s, v27.4s, v7.4s
+ FMAX v28.4s, v28.4s, v7.4s
+ FMAX v29.4s, v29.4s, v7.4s
+ FMAX v30.4s, v30.4s, v7.4s
+ FMAX v31.4s, v31.4s, v7.4s
+
+ # Store full 6 x 8
+ CMP x1, 8
+ B.LO 8f
+
+ STP q30, q31, [x7]
+ ADD x7, x7, x10
+ STP q28, q29, [x13]
+ ADD x13, x13, x10
+ STP q26, q27, [x18]
+ ADD x18, x18, x10
+ STP q24, q25, [x17]
+ ADD x17, x17, x10
+ STP q22, q23, [x16]
+ ADD x16, x16, x10
+ STP q20, q21, [x6]
+ ADD x6, x6, x10
+
+ SUB x4, x4, x3 // a -= ks
+
+ # nc loop
+ SUBS x1, x1, 8
+ B.HI 0b
+
+ # Restore x20,x21,x22,x23 from stack
+ LDP x22, x23, [sp, 80]
+ LDP x20, x21, [sp, 64]
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 96
+ RET
+
+5:
+ # Load clamping_params values
+ LD2R {v6.4s, v7.4s}, [x8]
+
+ # Is there a remainder?- 4 floats of A (16 bytes)
+ TBZ x0, 4, 6f
+
+ # Remainder- 4 floats of A (16 bytes)
+ # Load A
+ LDR q0, [x14], 16
+ LDR q1, [x15], 16
+ LDR q2, [x20], 16
+ LDR q3, [x21], 16
+ LDR q4, [x22], 16
+ LDR q5, [x23], 16
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+ LDP q16, q17, [x5], 32
+ LDP q18, q19, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ FMLA v20.4s, v16.4s, v0.s[2]
+ FMLA v22.4s, v16.4s, v1.s[2]
+ FMLA v24.4s, v16.4s, v2.s[2]
+ FMLA v26.4s, v16.4s, v3.s[2]
+ FMLA v28.4s, v16.4s, v4.s[2]
+ FMLA v30.4s, v16.4s, v5.s[2]
+ FMLA v21.4s, v17.4s, v0.s[2]
+ FMLA v23.4s, v17.4s, v1.s[2]
+ FMLA v25.4s, v17.4s, v2.s[2]
+ FMLA v27.4s, v17.4s, v3.s[2]
+ FMLA v29.4s, v17.4s, v4.s[2]
+ FMLA v31.4s, v17.4s, v5.s[2]
+
+ FMLA v20.4s, v18.4s, v0.s[3]
+ FMLA v22.4s, v18.4s, v1.s[3]
+ FMLA v24.4s, v18.4s, v2.s[3]
+ FMLA v26.4s, v18.4s, v3.s[3]
+ FMLA v28.4s, v18.4s, v4.s[3]
+ FMLA v30.4s, v18.4s, v5.s[3]
+ FMLA v21.4s, v19.4s, v0.s[3]
+ FMLA v23.4s, v19.4s, v1.s[3]
+ FMLA v25.4s, v19.4s, v2.s[3]
+ FMLA v27.4s, v19.4s, v3.s[3]
+ FMLA v29.4s, v19.4s, v4.s[3]
+ FMLA v31.4s, v19.4s, v5.s[3]
+
+ # Is there a remainder?- 2 floats of A (8 bytes)
+6:
+ TBZ x0, 3, 7f
+
+ # Remainder- 2 floats of A (8 bytes)
+ # Load A
+ LDR d0, [x14], 8
+ LDR d1, [x15], 8
+ LDR d2, [x20], 8
+ LDR d3, [x21], 8
+ LDR d4, [x22], 8
+ LDR d5, [x23], 8
+ # Load B
+ LDP q12, q13, [x5], 32
+ LDP q14, q15, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+
+ FMLA v20.4s, v14.4s, v0.s[1]
+ FMLA v22.4s, v14.4s, v1.s[1]
+ FMLA v24.4s, v14.4s, v2.s[1]
+ FMLA v26.4s, v14.4s, v3.s[1]
+ FMLA v28.4s, v14.4s, v4.s[1]
+ FMLA v30.4s, v14.4s, v5.s[1]
+ FMLA v21.4s, v15.4s, v0.s[1]
+ FMLA v23.4s, v15.4s, v1.s[1]
+ FMLA v25.4s, v15.4s, v2.s[1]
+ FMLA v27.4s, v15.4s, v3.s[1]
+ FMLA v29.4s, v15.4s, v4.s[1]
+ FMLA v31.4s, v15.4s, v5.s[1]
+
+ # Is there a remainder?- 1 float of A (4 bytes)
+7:
+ TBZ x0, 2, 4b
+
+ # Remainder- 1 float of A (4 bytes)
+ # Load A
+ LDR s0, [x14], 4
+ LDR s1, [x15], 4
+ LDR s2, [x20], 4
+ LDR s3, [x21], 4
+ LDR s4, [x22], 4
+ LDR s5, [x23], 4
+ # Load B
+ LDP q12, q13, [x5], 32
+
+ FMLA v20.4s, v12.4s, v0.s[0]
+ FMLA v22.4s, v12.4s, v1.s[0]
+ FMLA v24.4s, v12.4s, v2.s[0]
+ FMLA v26.4s, v12.4s, v3.s[0]
+ FMLA v28.4s, v12.4s, v4.s[0]
+ FMLA v30.4s, v12.4s, v5.s[0]
+ FMLA v21.4s, v13.4s, v0.s[0]
+ FMLA v23.4s, v13.4s, v1.s[0]
+ FMLA v25.4s, v13.4s, v2.s[0]
+ FMLA v27.4s, v13.4s, v3.s[0]
+ FMLA v29.4s, v13.4s, v4.s[0]
+ FMLA v31.4s, v13.4s, v5.s[0]
+ B 4b
+
+ # Store odd width
+8:
+ TBZ x1, 2, 9f
+ STR q30, [x7], 16
+ MOV v30.16b, v31.16b
+ STR q28, [x13], 16
+ MOV v28.16b, v29.16b
+ STR q26, [x18], 16
+ MOV v26.16b, v27.16b
+ STR q24, [x17], 16
+ MOV v24.16b, v25.16b
+ STR q22, [x16], 16
+ MOV v22.16b, v23.16b
+ STR q20, [x6], 16
+ MOV v20.16b, v21.16b
+9:
+ TBZ x1, 1, 10f
+ STR d30, [x7], 8
+ DUP d30, v30.d[1]
+ STR d28, [x13], 8
+ DUP d28, v28.d[1]
+ STR d26, [x18], 8
+ DUP d26, v26.d[1]
+ STR d24, [x17], 8
+ DUP d24, v24.d[1]
+ STR d22, [x16], 8
+ DUP d22, v22.d[1]
+ STR d20, [x6], 8
+ DUP d20, v20.d[1]
+
+10:
+ TBZ x1, 0, 11f
+ STR s30, [x7]
+ STR s28, [x13]
+ STR s26, [x18]
+ STR s24, [x17]
+ STR s22, [x16]
+ STR s20, [x6]
+11:
+ # Restore x20,x21,x22,x23 from stack
+ LDP x22, x23, [sp, 80]
+ LDP x20, x21, [sp, 64]
+
+ # Restore d8-d15 from stack
+ LDP d14, d15, [sp, 48]
+ LDP d12, d13, [sp, 32]
+ LDP d10, d11, [sp, 16]
+ LDP d8, d9, [sp], 96
+ RET
+
+END_FUNCTION xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a75
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/f32-igemm/6x8-neon-ld64.c b/src/f32-igemm/6x8-neon-ld64.c
new file mode 100644
index 0000000..f12a96c
--- /dev/null
+++ b/src/f32-igemm/6x8-neon-ld64.c
@@ -0,0 +1,270 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_6x8__neon_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (6 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ c3 = c2;
+ }
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ c4 = c3;
+ }
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ c5 = c4;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc4x0123 = vacc0x0123;
+ float32x4_t vacc4x4567 = vacc0x4567;
+ float32x4_t vacc5x0123 = vacc0x0123;
+ float32x4_t vacc5x4567 = vacc0x4567;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ const float* restrict a4 = a[4];
+ if XNN_UNPREDICTABLE(a4 != zero) {
+ a4 = (const float*) ((uintptr_t) a4 + a_offset);
+ }
+ const float* restrict a5 = a[5];
+ if XNN_UNPREDICTABLE(a5 != zero) {
+ a5 = (const float*) ((uintptr_t) a5 + a_offset);
+ }
+ a += 6;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+ const float32x2_t va5 = vld1_f32(a5); a5 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123c0, va5, 0);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567c0, va5, 0);
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123c1, va5, 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567c1, va5, 1);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0);
+ const float32x4_t va1 = vld1q_dup_f32(a1);
+ const float32x4_t va2 = vld1q_dup_f32(a2);
+ const float32x4_t va3 = vld1q_dup_f32(a3);
+ const float32x4_t va4 = vld1q_dup_f32(a4);
+ const float32x4_t va5 = vld1q_dup_f32(a5);
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vmlaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vmlaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vmlaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vmlaq_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = vmlaq_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = vmlaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vmlaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vmlaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vmlaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vmlaq_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = vmlaq_f32(vacc5x4567, va5, vb4567);
+ }
+ p -= 6 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc5x0123 = vminq_f32(vacc5x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+ vacc5x4567 = vminq_f32(vacc5x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc5x0123 = vmaxq_f32(vacc5x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+ vacc5x4567 = vmaxq_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c5, vacc5x0123);
+ vst1q_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c5, vacc5x0123); c5 += 4;
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc5x01 = vget_low_f32(vacc5x0123);
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c5, vacc5x01); c5 += 2;
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc5x01 = vget_high_f32(vacc5x0123);
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c5, vacc5x01, 0);
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/6x8-neonfma-ld64.c b/src/f32-igemm/6x8-neonfma-ld64.c
new file mode 100644
index 0000000..2dfa93e
--- /dev/null
+++ b/src/f32-igemm/6x8-neonfma-ld64.c
@@ -0,0 +1,312 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/neon-ld64.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_6x8__neonfma_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (6 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ c3 = c2;
+ }
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ c4 = c3;
+ }
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ c5 = c4;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc4x0123 = vacc0x0123;
+ float32x4_t vacc4x4567 = vacc0x4567;
+ float32x4_t vacc5x0123 = vacc0x0123;
+ float32x4_t vacc5x4567 = vacc0x4567;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ const float* restrict a4 = a[4];
+ if XNN_UNPREDICTABLE(a4 != zero) {
+ a4 = (const float*) ((uintptr_t) a4 + a_offset);
+ }
+ const float* restrict a5 = a[5];
+ if XNN_UNPREDICTABLE(a5 != zero) {
+ a5 = (const float*) ((uintptr_t) a5 + a_offset);
+ }
+ a += 6;
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ const float32x2_t va0 = vld1_f32(a0); a0 += 2;
+ const float32x2_t va1 = vld1_f32(a1); a1 += 2;
+ const float32x2_t va2 = vld1_f32(a2); a2 += 2;
+ const float32x2_t va3 = vld1_f32(a3); a3 += 2;
+ const float32x2_t va4 = vld1_f32(a4); a4 += 2;
+ const float32x2_t va5 = vld1_f32(a5); a5 += 2;
+
+ const float32x4_t vb0123c0 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c0 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c0, va0, 0);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c0, va1, 0);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c0, va2, 0);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c0, va3, 0);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c0, va4, 0);
+ vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123c0, va5, 0);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c0, va0, 0);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c0, va1, 0);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c0, va2, 0);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c0, va3, 0);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c0, va4, 0);
+ vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567c0, va5, 0);
+ #else
+ const float32x4_t va0c0 = vdupq_lane_f32(va0, 0);
+ const float32x4_t va1c0 = vdupq_lane_f32(va1, 0);
+ const float32x4_t va2c0 = vdupq_lane_f32(va2, 0);
+ const float32x4_t va3c0 = vdupq_lane_f32(va3, 0);
+ const float32x4_t va4c0 = vdupq_lane_f32(va4, 0);
+ const float32x4_t va5c0 = vdupq_lane_f32(va5, 0);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c0, vb0123c0);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5c0, vb0123c0);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c0, vb4567c0);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5c0, vb4567c0);
+ #endif
+ const float32x4_t vb0123c1 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567c1 = vld1q_f32(w); w += 4;
+
+ #if defined(__aarch64__)
+ vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123c1, va0, 1);
+ vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123c1, va1, 1);
+ vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123c1, va2, 1);
+ vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123c1, va3, 1);
+ vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123c1, va4, 1);
+ vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123c1, va5, 1);
+ vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567c1, va0, 1);
+ vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567c1, va1, 1);
+ vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567c1, va2, 1);
+ vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567c1, va3, 1);
+ vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567c1, va4, 1);
+ vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567c1, va5, 1);
+ #else
+ const float32x4_t va0c1 = vdupq_lane_f32(va0, 1);
+ const float32x4_t va1c1 = vdupq_lane_f32(va1, 1);
+ const float32x4_t va2c1 = vdupq_lane_f32(va2, 1);
+ const float32x4_t va3c1 = vdupq_lane_f32(va3, 1);
+ const float32x4_t va4c1 = vdupq_lane_f32(va4, 1);
+ const float32x4_t va5c1 = vdupq_lane_f32(va5, 1);
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4c1, vb0123c1);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5c1, vb0123c1);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4c1, vb4567c1);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5c1, vb4567c1);
+ #endif
+ }
+ if XNN_UNLIKELY(k != 0) {
+ const float32x4_t va0 = vld1q_dup_f32(a0);
+ const float32x4_t va1 = vld1q_dup_f32(a1);
+ const float32x4_t va2 = vld1q_dup_f32(a2);
+ const float32x4_t va3 = vld1q_dup_f32(a3);
+ const float32x4_t va4 = vld1q_dup_f32(a4);
+ const float32x4_t va5 = vld1q_dup_f32(a5);
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3, vb0123);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4, vb0123);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3, vb4567);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4, vb4567);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5, vb4567);
+ }
+ p -= 6 * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc5x0123 = vminq_f32(vacc5x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+ vacc5x4567 = vminq_f32(vacc5x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc5x0123 = vmaxq_f32(vacc5x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+ vacc5x4567 = vmaxq_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c5, vacc5x0123);
+ vst1q_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c5, vacc5x0123); c5 += 4;
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc5x01 = vget_low_f32(vacc5x0123);
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c5, vacc5x01); c5 += 2;
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc5x01 = vget_high_f32(vacc5x0123);
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c5, vacc5x01, 0);
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/6x8-psimd-loadsplat.c b/src/f32-igemm/6x8-psimd-loadsplat.c
new file mode 100644
index 0000000..3e03d23
--- /dev/null
+++ b/src/f32-igemm/6x8-psimd-loadsplat.c
@@ -0,0 +1,248 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/psimd-loadsplat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_6x8__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (6 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ c3 = c2;
+ }
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ c4 = c3;
+ }
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ c5 = c4;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ psimd_f32 vacc4x0123 = vacc0x0123;
+ psimd_f32 vacc4x4567 = vacc0x4567;
+ psimd_f32 vacc5x0123 = vacc0x0123;
+ psimd_f32 vacc5x4567 = vacc0x4567;
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ const float* restrict a4 = a[4];
+ if XNN_UNPREDICTABLE(a4 != zero) {
+ a4 = (const float*) ((uintptr_t) a4 + a_offset);
+ }
+ const float* restrict a5 = a[5];
+ if XNN_UNPREDICTABLE(a5 != zero) {
+ a5 = (const float*) ((uintptr_t) a5 + a_offset);
+ }
+ a += 6;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+ const psimd_f32 va4 = psimd_load_splat_f32(a4);
+ a4 += 1;
+ const psimd_f32 va5 = psimd_load_splat_f32(a5);
+ a5 += 1;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567);
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= 6 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
+ vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+ vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
+ vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
+ vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+ vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
+ vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c5 += 4;
+ c4 += 4;
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c5, vacc5x0123);
+ psimd_store2_f32(c4, vacc4x0123);
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
+ vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c5 += 2;
+ c4 += 2;
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c5, vacc5x0123);
+ psimd_store1_f32(c4, vacc4x0123);
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/6x8-psimd-splat.c b/src/f32-igemm/6x8-psimd-splat.c
new file mode 100644
index 0000000..8e3d1f1
--- /dev/null
+++ b/src/f32-igemm/6x8-psimd-splat.c
@@ -0,0 +1,356 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/psimd-splat.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_6x8__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (6 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ c3 = c2;
+ }
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ c4 = c3;
+ }
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ c5 = c4;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ psimd_f32 vacc4x0123 = vacc0x0123;
+ psimd_f32 vacc4x4567 = vacc0x4567;
+ psimd_f32 vacc5x0123 = vacc0x0123;
+ psimd_f32 vacc5x4567 = vacc0x4567;
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ const float* restrict a4 = a[4];
+ if XNN_UNPREDICTABLE(a4 != zero) {
+ a4 = (const float*) ((uintptr_t) a4 + a_offset);
+ }
+ const float* restrict a5 = a[5];
+ if XNN_UNPREDICTABLE(a5 != zero) {
+ a5 = (const float*) ((uintptr_t) a5 + a_offset);
+ }
+ a += 6;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ const psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ const psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ const psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+ const psimd_f32 va4 = psimd_load_f32(a4);
+ a4 += 4;
+ const psimd_f32 va5 = psimd_load_f32(a5);
+ a5 += 4;
+
+ const psimd_f32 va0c0 = psimd_splat0_f32(va0);
+ const psimd_f32 va1c0 = psimd_splat0_f32(va1);
+ const psimd_f32 va2c0 = psimd_splat0_f32(va2);
+ const psimd_f32 va3c0 = psimd_splat0_f32(va3);
+ const psimd_f32 va4c0 = psimd_splat0_f32(va4);
+ const psimd_f32 va5c0 = psimd_splat0_f32(va5);
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c0, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c0, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c0, vb0123c0);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c0, vb0123c0);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c0, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c0, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c0, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c0, vb4567c0);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c0, vb4567c0);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c0, vb4567c0);
+ const psimd_f32 va0c1 = psimd_splat1_f32(va0);
+ const psimd_f32 va1c1 = psimd_splat1_f32(va1);
+ const psimd_f32 va2c1 = psimd_splat1_f32(va2);
+ const psimd_f32 va3c1 = psimd_splat1_f32(va3);
+ const psimd_f32 va4c1 = psimd_splat1_f32(va4);
+ const psimd_f32 va5c1 = psimd_splat1_f32(va5);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c1, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c1, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c1, vb0123c1);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c1, vb0123c1);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c1, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c1, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c1, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c1, vb4567c1);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c1, vb4567c1);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c1, vb4567c1);
+ const psimd_f32 va0c2 = psimd_splat2_f32(va0);
+ const psimd_f32 va1c2 = psimd_splat2_f32(va1);
+ const psimd_f32 va2c2 = psimd_splat2_f32(va2);
+ const psimd_f32 va3c2 = psimd_splat2_f32(va3);
+ const psimd_f32 va4c2 = psimd_splat2_f32(va4);
+ const psimd_f32 va5c2 = psimd_splat2_f32(va5);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c2, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c2, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c2, vb0123c2);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c2, vb0123c2);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c2, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c2, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c2, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c2, vb4567c2);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c2, vb4567c2);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c2, vb4567c2);
+ const psimd_f32 va0c3 = psimd_splat3_f32(va0);
+ const psimd_f32 va1c3 = psimd_splat3_f32(va1);
+ const psimd_f32 va2c3 = psimd_splat3_f32(va2);
+ const psimd_f32 va3c3 = psimd_splat3_f32(va3);
+ const psimd_f32 va4c3 = psimd_splat3_f32(va4);
+ const psimd_f32 va5c3 = psimd_splat3_f32(va5);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0c3, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1c3, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2c3, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3c3, vb0123c3);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4c3, vb0123c3);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5c3, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0c3, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1c3, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2c3, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3c3, vb4567c3);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4c3, vb4567c3);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5c3, vb4567c3);
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+ const psimd_f32 va4 = psimd_load_splat_f32(a4);
+ a4 += 1;
+ const psimd_f32 va5 = psimd_load_splat_f32(a5);
+ a5 += 1;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567);
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 6 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
+ vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+ vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
+ vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
+ vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+ vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
+ vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c5 += 4;
+ c4 += 4;
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c5, vacc5x0123);
+ psimd_store2_f32(c4, vacc4x0123);
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
+ vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c5 += 2;
+ c4 += 2;
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c5, vacc5x0123);
+ psimd_store1_f32(c4, vacc4x0123);
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/6x8s4-psimd.c b/src/f32-igemm/6x8s4-psimd.c
new file mode 100644
index 0000000..ce9097c
--- /dev/null
+++ b/src/f32-igemm/6x8s4-psimd.c
@@ -0,0 +1,354 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-igemm/psimd-s4.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_6x8s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 6);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (6 * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ c3 = c2;
+ }
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ c4 = c3;
+ }
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 6) {
+ c5 = c4;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ psimd_f32 vacc4x0123 = vacc0x0123;
+ psimd_f32 vacc4x4567 = vacc0x4567;
+ psimd_f32 vacc5x0123 = vacc0x0123;
+ psimd_f32 vacc5x4567 = vacc0x4567;
+ w += 8;
+
+ size_t p = ks;
+ do {
+ const float* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const float*) ((uintptr_t) a0 + a_offset);
+ }
+ const float* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const float*) ((uintptr_t) a1 + a_offset);
+ }
+ const float* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const float*) ((uintptr_t) a2 + a_offset);
+ }
+ const float* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const float*) ((uintptr_t) a3 + a_offset);
+ }
+ const float* restrict a4 = a[4];
+ if XNN_UNPREDICTABLE(a4 != zero) {
+ a4 = (const float*) ((uintptr_t) a4 + a_offset);
+ }
+ const float* restrict a5 = a[5];
+ if XNN_UNPREDICTABLE(a5 != zero) {
+ a5 = (const float*) ((uintptr_t) a5 + a_offset);
+ }
+ a += 6;
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ psimd_f32 va0 = psimd_load_f32(a0);
+ a0 += 4;
+ psimd_f32 va1 = psimd_load_f32(a1);
+ a1 += 4;
+ psimd_f32 va2 = psimd_load_f32(a2);
+ a2 += 4;
+ psimd_f32 va3 = psimd_load_f32(a3);
+ a3 += 4;
+ psimd_f32 va4 = psimd_load_f32(a4);
+ a4 += 4;
+ psimd_f32 va5 = psimd_load_f32(a5);
+ a5 += 4;
+
+
+ const psimd_f32 vb0123c0 = psimd_load_f32(w + 0);
+ const psimd_f32 vb4567c0 = psimd_load_f32(w + 4);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c0);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c0);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c0);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c0);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c0);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c0);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c0);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c0);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c0);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c0);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c0);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c0);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+ va4 = __builtin_shufflevector(va4, va4, 1, 2, 3, 0);
+ va5 = __builtin_shufflevector(va5, va5, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c1 = psimd_load_f32(w + 8);
+ const psimd_f32 vb4567c1 = psimd_load_f32(w + 12);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c1);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c1);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c1);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c1);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c1);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c1);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c1);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c1);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c1);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c1);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c1);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c1);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+ va4 = __builtin_shufflevector(va4, va4, 1, 2, 3, 0);
+ va5 = __builtin_shufflevector(va5, va5, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c2 = psimd_load_f32(w + 16);
+ const psimd_f32 vb4567c2 = psimd_load_f32(w + 20);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c2);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c2);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c2);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c2);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c2);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c2);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c2);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c2);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c2);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c2);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c2);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c2);
+
+ va0 = __builtin_shufflevector(va0, va0, 1, 2, 3, 0);
+ va1 = __builtin_shufflevector(va1, va1, 1, 2, 3, 0);
+ va2 = __builtin_shufflevector(va2, va2, 1, 2, 3, 0);
+ va3 = __builtin_shufflevector(va3, va3, 1, 2, 3, 0);
+ va4 = __builtin_shufflevector(va4, va4, 1, 2, 3, 0);
+ va5 = __builtin_shufflevector(va5, va5, 1, 2, 3, 0);
+
+ const psimd_f32 vb0123c3 = psimd_load_f32(w + 24);
+ const psimd_f32 vb4567c3 = psimd_load_f32(w + 28);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123c3);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123c3);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123c3);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123c3);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123c3);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123c3);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567c3);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567c3);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567c3);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567c3);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567c3);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567c3);
+
+
+ w += 32;
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0 = psimd_load_splat_f32(a0);
+ a0 += 1;
+ const psimd_f32 va1 = psimd_load_splat_f32(a1);
+ a1 += 1;
+ const psimd_f32 va2 = psimd_load_splat_f32(a2);
+ a2 += 1;
+ const psimd_f32 va3 = psimd_load_splat_f32(a3);
+ a3 += 1;
+ const psimd_f32 va4 = psimd_load_splat_f32(a4);
+ a4 += 1;
+ const psimd_f32 va5 = psimd_load_splat_f32(a5);
+ a5 += 1;
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0, vb4567);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1, vb0123);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1, vb4567);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2, vb0123);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2, vb4567);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3, vb0123);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3, vb4567);
+ vacc4x0123 = psimd_qfma_f32(vacc4x0123, va4, vb0123);
+ vacc4x4567 = psimd_qfma_f32(vacc4x4567, va4, vb4567);
+ vacc5x0123 = psimd_qfma_f32(vacc5x0123, va5, vb0123);
+ vacc5x4567 = psimd_qfma_f32(vacc5x4567, va5, vb4567);
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= 6 * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
+ vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+ vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
+ vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
+ vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+ vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
+ vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c5, vacc5x0123);
+ psimd_store_f32(c4, vacc4x0123);
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c5 += 4;
+ c4 += 4;
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c5, vacc5x0123);
+ psimd_store2_f32(c4, vacc4x0123);
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
+ vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c5 += 2;
+ c4 += 2;
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c5, vacc5x0123);
+ psimd_store1_f32(c4, vacc4x0123);
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/MRx2-neon-ld64.c.in b/src/f32-igemm/MRx2-neon-ld64.c.in
new file mode 100644
index 0000000..e36ed22
--- /dev/null
+++ b/src/f32-igemm/MRx2-neon-ld64.c.in
@@ -0,0 +1,129 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR == 2
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ float32x2_t vacc0x01 = vld1_f32(w); w += 2;
+ $for M in range(1, MR):
+ float32x2_t vacc${M}x01 = vacc0x01;
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ $for M in range(MR):
+ const float32x2_t va${M} = vld1_f32(a${M}); a${M} += 2;
+
+ $for L in range(2):
+ const float32x2_t vb01c${L} = vld1_f32(w); w += 2;
+
+ $if FMA:
+ #if defined(__aarch64__)
+ $for M in range(MR):
+ vacc${M}x01 = vfma_lane_f32(vacc${M}x01, vb01c${L}, va${M}, ${L});
+ #else
+ $for M in range(MR):
+ const float32x2_t va${M}c${L} = vdup_lane_f32(va${M}, ${L});
+ $for M in range(MR):
+ vacc${M}x01 = vfma_f32(vacc${M}x01, va${M}c${L}, vb01c${L});
+ #endif
+ $else:
+ $for M in range(MR):
+ vacc${M}x01 = vmla_lane_f32(vacc${M}x01, vb01c${L}, va${M}, ${L});
+ }
+ if XNN_UNLIKELY(k != 0) {
+ $for M in range(MR):
+ const float32x2_t va${M} = vld1_dup_f32(a${M});
+
+ const float32x2_t vb01 = vld1_f32(w); w += 2;
+
+ $for M in range(MR):
+ $if FMA:
+ vacc${M}x01 = vfma_f32(vacc${M}x01, va${M}, vb01);
+ $else:
+ vacc${M}x01 = vmla_f32(vacc${M}x01, va${M}, vb01);
+ }
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const float32x2_t vmax = vld1_dup_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x01 = vmin_f32(vacc${M}x01, vmax);
+
+ const float32x2_t vmin = vld1_dup_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x01 = vmax_f32(vacc${M}x01, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ vst1_f32(c${M}, vacc${M}x01);
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ assert(nc == 1);
+ $for M in reversed(range(MR)):
+ vst1_lane_f32(c${M}, vacc${M}x01, 0);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/MRx2c4-psimd.c.in b/src/f32-igemm/MRx2c4-psimd.c.in
new file mode 100644
index 0000000..3a037ac
--- /dev/null
+++ b/src/f32-igemm/MRx2c4-psimd.c.in
@@ -0,0 +1,139 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR == 2
+$assert MR % 2 == 0
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}c4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ psimd_f32 vacc0x0c4 = psimd_load1_f32(w);
+ $for N in range(1, NR):
+ psimd_f32 vacc0x${N}c4 = psimd_load1_f32(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(NR):
+ psimd_f32 vacc${M}x${N}c4 = vacc0x${N}c4;
+ w += ${NR};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_f32(a${M});
+ a${M} += 4;
+
+ const psimd_f32 vb0 = psimd_load_f32(w);
+ $for N in range(1, NR):
+ const psimd_f32 vb${N} = psimd_load_f32(w + ${N * 4});
+ w += ${NR * 4};
+
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}x${N}c4 = psimd_qfma_f32(vacc${M}x${N}c4, va${M}, vb${N});
+ }
+ if XNN_UNLIKELY(k != 0) {
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_f32(a${M});
+
+ const psimd_f32 vb0 = psimd_load_f32(w);
+ $for N in range(1, NR):
+ const psimd_f32 vb${N} = psimd_load_f32(w + ${N * 4});
+ w += ${NR * 4};
+
+ const psimd_f32 vzero = psimd_splat_f32(0.0f);
+ $for N in range(NR):
+ const psimd_s32 vmask${N} = vb${N} != vzero;
+
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}x${N}c4 = psimd_qfma_f32(vacc${M}x${N}c4, psimd_andmask_f32(vmask${N}, va${M}), vb${N});
+ }
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ $for M in range(MR):
+ const psimd_f32 vacc${M}x01c2 = psimd_add_f32(psimd_interleave_lo_f32(vacc${M}x0c4, vacc${M}x1c4), psimd_interleave_hi_f32(vacc${M}x0c4, vacc${M}x1c4));
+
+ $for M in range(0, MR, 2):
+ psimd_f32 vacc${M}${M+1}x01 = psimd_add_f32(psimd_concat_lo_f32(vacc${M}x01c2, vacc${M+1}x01c2), psimd_concat_hi_f32(vacc${M}x01c2, vacc${M+1}x01c2));
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ $for M in range(0, MR, 2):
+ vacc${M}${M+1}x01 = psimd_min_f32(vacc${M}${M+1}x01, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ $for M in range(0, MR, 2):
+ vacc${M}${M+1}x01 = psimd_max_f32(vacc${M}${M+1}x01, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(0, MR, 2)):
+ psimd_store2_f32(c${M+1}, psimd_concat_hi_f32(vacc${M}${M+1}x01, vacc${M}${M+1}x01));
+ c${M+1} = (float*) ((uintptr_t) c${M+1} + cn_stride);
+ psimd_store2_f32(c${M}, vacc${M}${M+1}x01);
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ assert(nc == 1);
+ $for M in reversed(range(0, MR, 2)):
+ psimd_store1_f32(c${M+1}, psimd_concat_hi_f32(vacc${M}${M+1}x01, vacc${M}${M+1}x01));
+ psimd_store1_f32(c${M}, vacc${M}${M+1}x01);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/MRx2c4-sse.c.in b/src/f32-igemm/MRx2c4-sse.c.in
new file mode 100644
index 0000000..9851646
--- /dev/null
+++ b/src/f32-igemm/MRx2c4-sse.c.in
@@ -0,0 +1,138 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR == 2
+$assert MR % 2 == 0
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}c4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ __m128 vacc0x0c4 = _mm_load_ss(w);
+ $for N in range(1, NR):
+ __m128 vacc0x${N}c4 = _mm_load_ss(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(NR):
+ __m128 vacc${M}x${N}c4 = vacc0x${N}c4;
+ w += ${NR};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ $for M in range(MR):
+ const __m128 va${M} = _mm_loadu_ps(a${M});
+ a${M} += 4;
+
+ const __m128 vb0 = _mm_loadu_ps(w);
+ $for N in range(1, NR):
+ const __m128 vb${N} = _mm_loadu_ps(w + ${N * 4});
+ w += ${NR * 4};
+
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}x${N}c4 = _mm_add_ps(vacc${M}x${N}c4, _mm_mul_ps(va${M}, vb${N}));
+ }
+ if XNN_UNLIKELY(k != 0) {
+ $for M in range(MR):
+ const __m128 va${M} = _mm_loadu_ps(a${M});
+
+ const __m128 vb0 = _mm_loadu_ps(w);
+ $for N in range(1, NR):
+ const __m128 vb${N} = _mm_loadu_ps(w + ${N * 4});
+ w += ${NR * 4};
+
+ $for N in range(NR):
+ const __m128 vmask${N} = _mm_cmpeq_ps(_mm_setzero_ps(), vb${N});
+
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}x${N}c4 = _mm_add_ps(vacc${M}x${N}c4, _mm_mul_ps(_mm_andnot_ps(vmask${N}, va${M}), vb${N}));
+ }
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ $for M in range(MR):
+ const __m128 vacc${M}x01c2 = _mm_add_ps(_mm_unpacklo_ps(vacc${M}x0c4, vacc${M}x1c4), _mm_unpackhi_ps(vacc${M}x0c4, vacc${M}x1c4));
+
+ $for M in range(0, MR, 2):
+ __m128 vacc${M}${M+1}x01 = _mm_add_ps(_mm_movelh_ps(vacc${M}x01c2, vacc${M+1}x01c2), _mm_movehl_ps(vacc${M+1}x01c2, vacc${M}x01c2));
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ $for M in range(0, MR, 2):
+ vacc${M}${M+1}x01 = _mm_min_ps(vacc${M}${M+1}x01, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ $for M in range(0, MR, 2):
+ vacc${M}${M+1}x01 = _mm_max_ps(vacc${M}${M+1}x01, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(0, MR, 2)):
+ _mm_storeh_pi((__m64*) c${M+1}, vacc${M}${M+1}x01);
+ c${M+1} = (float*) ((uintptr_t) c${M+1} + cn_stride);
+ _mm_storel_pi((__m64*) c${M}, vacc${M}${M+1}x01);
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ assert(nc == 1);
+ $for M in reversed(range(0, MR, 2)):
+ _mm_store_ss(c${M+1}, _mm_movehl_ps(vacc${M}${M+1}x01, vacc${M}${M+1}x01));
+ _mm_store_ss(c${M}, vacc${M}${M+1}x01);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/neon-ld128.c.in b/src/f32-igemm/neon-ld128.c.in
new file mode 100644
index 0000000..9e67349
--- /dev/null
+++ b/src/f32-igemm/neon-ld128.c.in
@@ -0,0 +1,168 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$IDLETTERS = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_ld128(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ $for N in range(0, NR, 4):
+ float32x4_t vacc0x${IDLETTERS[N:N+4]} = vld1q_f32(w); w += 4;
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ float32x4_t vacc${M}x${IDLETTERS[N:N+4]} = vacc0x${IDLETTERS[N:N+4]};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
+ $for M in range(MR):
+ const float32x4_t va${M} = vld1q_f32(a${M}); a${M} += 4;
+
+ $for L in range(4):
+ $VGET_PART_F32 = "vget_low_f32" if L < 2 else "vget_high_f32"
+
+ $for N in range(0, NR, 4):
+ const float32x4_t vb${IDLETTERS[N:N+4]}c${L} = vld1q_f32(w); w += 4;
+
+ $if FMA:
+ #if defined(__aarch64__)
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vfmaq_laneq_f32(vacc${M}x${IDLETTERS[N:N+4]}, vb${IDLETTERS[N:N+4]}c${L}, va${M}, ${L});
+ #else
+ $for M in range(MR):
+ const float32x4_t va${M}c${L} = vdupq_lane_f32(${VGET_PART_F32}(va${M}), ${L % 2});
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vfmaq_f32(vacc${M}x${IDLETTERS[N:N+4]}, va${M}c${L}, vb${IDLETTERS[N:N+4]}c${L});
+ #endif
+ $else:
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vmlaq_lane_f32(vacc${M}x${IDLETTERS[N:N+4]}, vb${IDLETTERS[N:N+4]}c${L}, ${VGET_PART_F32}(va${M}), ${L % 2});
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ $for M in range(MR):
+ const float32x4_t va${M} = vld1q_dup_f32(a${M}); a${M} += 1;
+
+ $for N in range(0, NR, 4):
+ const float32x4_t vb${IDLETTERS[N:N+4]} = vld1q_f32(w); w += 4;
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ $if FMA:
+ vacc${M}x${IDLETTERS[N:N+4]} = vfmaq_f32(vacc${M}x${IDLETTERS[N:N+4]}, va${M}, vb${IDLETTERS[N:N+4]});
+ $else:
+ vacc${M}x${IDLETTERS[N:N+4]} = vmlaq_f32(vacc${M}x${IDLETTERS[N:N+4]}, va${M}, vb${IDLETTERS[N:N+4]});
+
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vminq_f32(vacc${M}x${IDLETTERS[N:N+4]}, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vmaxq_f32(vacc${M}x${IDLETTERS[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ vst1q_f32(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, NR, 4):
+ vst1q_f32(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if LOG2N == 1:
+ $for M in reversed(range(MR)):
+ float32x2_t vacc${M}x${IDLETTERS[0:2]} = vget_low_f32(vacc${M}x${IDLETTERS[0:4]});
+ $if 1 << LOG2N != NR:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for N in range(0, 1 << LOG2N, 4):
+ $for M in reversed(range(MR)):
+ vst1q_f32(c${M}, vacc${M}x${IDLETTERS[N:N+4]}); c${M} += 4;
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = vacc${M}x${IDLETTERS[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ vst1_f32(c${M}, vacc${M}x${IDLETTERS[0:2]}); c${M} += 2;
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${IDLETTERS[0:2]} = vget_high_f32(vacc${M}x${IDLETTERS[0:4]});
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ vst1_lane_f32(c${M}, vacc${M}x${IDLETTERS[0:2]}, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/neon-ld64.c.in b/src/f32-igemm/neon-ld64.c.in
new file mode 100644
index 0000000..5090686
--- /dev/null
+++ b/src/f32-igemm/neon-ld64.c.in
@@ -0,0 +1,161 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$IDLETTERS = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_ld64(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ $for N in range(0, NR, 4):
+ float32x4_t vacc0x${IDLETTERS[N:N+4]} = vld1q_f32(w); w += 4;
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ float32x4_t vacc${M}x${IDLETTERS[N:N+4]} = vacc0x${IDLETTERS[N:N+4]};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ for (; k >= 2 * sizeof(float); k -= 2 * sizeof(float)) {
+ $for M in range(MR):
+ const float32x2_t va${M} = vld1_f32(a${M}); a${M} += 2;
+
+ $for L in range(2):
+ $for N in range(0, NR, 4):
+ const float32x4_t vb${IDLETTERS[N:N+4]}c${L} = vld1q_f32(w); w += 4;
+
+ $if FMA:
+ #if defined(__aarch64__)
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vfmaq_lane_f32(vacc${M}x${IDLETTERS[N:N+4]}, vb${IDLETTERS[N:N+4]}c${L}, va${M}, ${L});
+ #else
+ $for M in range(MR):
+ const float32x4_t va${M}c${L} = vdupq_lane_f32(va${M}, ${L});
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vfmaq_f32(vacc${M}x${IDLETTERS[N:N+4]}, va${M}c${L}, vb${IDLETTERS[N:N+4]}c${L});
+ #endif
+ $else:
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vmlaq_lane_f32(vacc${M}x${IDLETTERS[N:N+4]}, vb${IDLETTERS[N:N+4]}c${L}, va${M}, ${L});
+ }
+ if XNN_UNLIKELY(k != 0) {
+ $for M in range(MR):
+ const float32x4_t va${M} = vld1q_dup_f32(a${M});
+
+ $for N in range(0, NR, 4):
+ const float32x4_t vb${IDLETTERS[N:N+4]} = vld1q_f32(w); w += 4;
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ $if FMA:
+ vacc${M}x${IDLETTERS[N:N+4]} = vfmaq_f32(vacc${M}x${IDLETTERS[N:N+4]}, va${M}, vb${IDLETTERS[N:N+4]});
+ $else:
+ vacc${M}x${IDLETTERS[N:N+4]} = vmlaq_f32(vacc${M}x${IDLETTERS[N:N+4]}, va${M}, vb${IDLETTERS[N:N+4]});
+ }
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vminq_f32(vacc${M}x${IDLETTERS[N:N+4]}, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vmaxq_f32(vacc${M}x${IDLETTERS[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ vst1q_f32(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, NR, 4):
+ vst1q_f32(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if LOG2N == 1:
+ $for M in reversed(range(MR)):
+ float32x2_t vacc${M}x${IDLETTERS[0:2]} = vget_low_f32(vacc${M}x${IDLETTERS[0:4]});
+ $if 1 << LOG2N != NR:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for N in range(0, 1 << LOG2N, 4):
+ $for M in reversed(range(MR)):
+ vst1q_f32(c${M}, vacc${M}x${IDLETTERS[N:N+4]}); c${M} += 4;
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = vacc${M}x${IDLETTERS[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ vst1_f32(c${M}, vacc${M}x${IDLETTERS[0:2]}); c${M} += 2;
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${IDLETTERS[0:2]} = vget_high_f32(vacc${M}x${IDLETTERS[0:4]});
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ vst1_lane_f32(c${M}, vacc${M}x${IDLETTERS[0:2]}, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/psimd-loadsplat.c.in b/src/f32-igemm/psimd-loadsplat.c.in
new file mode 100644
index 0000000..159fe86
--- /dev/null
+++ b/src/f32-igemm/psimd-loadsplat.c.in
@@ -0,0 +1,143 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}__psimd_loadsplat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ psimd_f32 vacc0x${ABC[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ psimd_f32 vacc0x${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+ w += ${NR};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ do {
+ const psimd_f32 vb${ABC[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ w += ${NR};
+
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_splat_f32(a${M});
+ a${M} += 1;
+
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_min_f32(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_max_f32(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ psimd_store2_f32(c${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:4]} = psimd_concat_hi_f32(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ psimd_store1_f32(c${M}, vacc${M}x${ABC[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/psimd-s4.c.in b/src/f32-igemm/psimd-s4.c.in
new file mode 100644
index 0000000..f595803
--- /dev/null
+++ b/src/f32-igemm/psimd-s4.c.in
@@ -0,0 +1,166 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}s4__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ psimd_f32 vacc0x${ABC[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ psimd_f32 vacc0x${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+ w += ${NR};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ $for M in range(MR):
+ psimd_f32 va${M} = psimd_load_f32(a${M});
+ a${M} += 4;
+
+ $for L in range(4):
+
+ $for N in range(0, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]}c${L} = psimd_load_f32(w + ${L * NR + N});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]}c${L});
+
+ $if L + 1 != 4:
+ $for M in range(MR):
+ va${M} = __builtin_shufflevector(va${M}, va${M}, 1, 2, 3, 0);
+
+ w += ${4 * NR};
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 vb${ABC[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ w += ${NR};
+
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_splat_f32(a${M});
+ a${M} += 1;
+
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_min_f32(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_max_f32(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ psimd_store2_f32(c${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:4]} = psimd_concat_hi_f32(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ psimd_store1_f32(c${M}, vacc${M}x${ABC[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/psimd-splat.c.in b/src/f32-igemm/psimd-splat.c.in
new file mode 100644
index 0000000..69f782a
--- /dev/null
+++ b/src/f32-igemm/psimd-splat.c.in
@@ -0,0 +1,164 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}__psimd_splat(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ psimd_f32 vacc0x${ABC[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ psimd_f32 vacc0x${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
+ w += ${NR};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_f32(a${M});
+ a${M} += 4;
+
+ $for L in range(4):
+ $for M in range(MR):
+ const psimd_f32 va${M}c${L} = psimd_splat${L}_f32(va${M});
+
+ $for N in range(0, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]}c${L} = psimd_load_f32(w + ${L * NR + N});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}c${L}, vb${ABC[N:N+4]}c${L});
+
+ w += ${4 * NR};
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_f32 vb${ABC[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ const psimd_f32 vb${ABC[N:N+4]} = psimd_load_f32(w + ${N});
+ w += ${NR};
+
+ $for M in range(MR):
+ const psimd_f32 va${M} = psimd_load_splat_f32(a${M});
+ a${M} += 1;
+
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ vacc${M}x${ABC[N:N+4]} = psimd_qfma_f32(vacc${M}x${ABC[N:N+4]}, va${M}, vb${ABC[N:N+4]});
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_min_f32(vacc${M}x${ABC[N:N+4]}, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = psimd_max_f32(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, NR, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${ABC[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${ABC[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${ABC[N:N+4]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ psimd_store2_f32(c${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${ABC[0:4]} = psimd_concat_hi_f32(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ psimd_store1_f32(c${M}, vacc${M}x${ABC[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/scalar.c.in b/src/f32-igemm/scalar.c.in
new file mode 100644
index 0000000..4954108
--- /dev/null
+++ b/src/f32-igemm/scalar.c.in
@@ -0,0 +1,119 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ $for N in range(NR):
+ float vacc0${N} = w[${N}];
+ $for M in range(1, MR):
+ $for N in range(NR):
+ float vacc${M}${N} = vacc0${N};
+ w += ${NR};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ do {
+ $for M in range(MR):
+ const float va${M} = *a${M}++;
+
+ $for N in range(NR):
+ const float vb${N} = w[${N}];
+ w += ${NR};
+
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}${N} += va${M} * vb${N};
+
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const float vmin = params->scalar.min;
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}${N} = math_max_f32(vacc${M}${N}, vmin);
+
+ const float vmax = params->scalar.max;
+ $for M in range(MR):
+ $for N in range(NR):
+ vacc${M}${N} = math_min_f32(vacc${M}${N}, vmax);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ $for N in range(NR):
+ c${M}[${N}] = vacc${M}${N};
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length() - 1)):
+ if (nc & ${1 << LOG2N}) {
+ $for M in reversed(range(MR)):
+ $for N in range(1 << LOG2N):
+ c${M}[${N}] = vacc${M}${N};
+ $if LOG2N != 0:
+ $for N in range(1 << (LOG2N - 1)):
+ vacc${M}${N} = vacc${M}${N + (1 << LOG2N)};
+ c${M} += ${1 << LOG2N};
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/sse-dup.c.in b/src/f32-igemm/sse-dup.c.in
new file mode 100644
index 0000000..84fa4d3
--- /dev/null
+++ b/src/f32-igemm/sse-dup.c.in
@@ -0,0 +1,166 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$IDLETTERS = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}__sse_dup(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ __m128 vacc0x${IDLETTERS[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ __m128 vacc0x${IDLETTERS[N:N+4]} = _mm_load_ps(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${IDLETTERS[N:N+4]} = vacc0x${IDLETTERS[N:N+4]};
+ w += ${NR};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ $for M in range(MR):
+ const __m128 va${M} = _mm_loadu_ps(a${M});
+ a${M} += 4;
+
+ $for L in range(4):
+ $LLLL = str(L) * 4
+
+ $for M in range(MR):
+ const __m128 va${M}c${LLLL} = _mm_shuffle_ps(va${M}, va${M}, _MM_SHUFFLE(${L}, ${L}, ${L}, ${L}));
+
+ $for N in range(0, NR, 4):
+ const __m128 vb${IDLETTERS[N:N+4]}c${L} = _mm_load_ps(w + ${L * NR + N});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_add_ps(vacc${M}x${IDLETTERS[N:N+4]}, _mm_mul_ps(va${M}c${LLLL}, vb${IDLETTERS[N:N+4]}c${L}));
+
+ w += ${4 * NR};
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 vb${IDLETTERS[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ const __m128 vb${IDLETTERS[N:N+4]} = _mm_load_ps(w + ${N});
+ w += ${NR};
+
+ $for M in range(MR):
+ const __m128 va${M} = _mm_load1_ps(a${M});
+ a${M} += 1;
+
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_add_ps(vacc${M}x${IDLETTERS[N:N+4]}, _mm_mul_ps(va${M}, vb${IDLETTERS[N:N+4]}));
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_min_ps(vacc${M}x${IDLETTERS[N:N+4]}, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_max_ps(vacc${M}x${IDLETTERS[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, NR, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = vacc${M}x${IDLETTERS[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ _mm_storel_pi((__m64*) c${M}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${IDLETTERS[0:4]} = _mm_movehl_ps(vacc${M}x${IDLETTERS[0:4]}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ _mm_store_ss(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/sse-load1.c.in b/src/f32-igemm/sse-load1.c.in
new file mode 100644
index 0000000..de27a31
--- /dev/null
+++ b/src/f32-igemm/sse-load1.c.in
@@ -0,0 +1,143 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$IDLETTERS = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}__sse_load1(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ __m128 vacc0x${IDLETTERS[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ __m128 vacc0x${IDLETTERS[N:N+4]} = _mm_load_ps(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${IDLETTERS[N:N+4]} = vacc0x${IDLETTERS[N:N+4]};
+ w += ${NR};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ do {
+ const __m128 vb${IDLETTERS[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ const __m128 vb${IDLETTERS[N:N+4]} = _mm_load_ps(w + ${N});
+ w += ${NR};
+
+ $for M in range(MR):
+ const __m128 va${M} = _mm_load1_ps(a${M});
+ a${M} += 1;
+
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_add_ps(vacc${M}x${IDLETTERS[N:N+4]}, _mm_mul_ps(va${M}, vb${IDLETTERS[N:N+4]}));
+ k -= sizeof(float);
+ } while (k != 0);
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_min_ps(vacc${M}x${IDLETTERS[N:N+4]}, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_max_ps(vacc${M}x${IDLETTERS[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, NR, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = vacc${M}x${IDLETTERS[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ _mm_storel_pi((__m64*) c${M}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${IDLETTERS[0:4]} = _mm_movehl_ps(vacc${M}x${IDLETTERS[0:4]}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ _mm_store_ss(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-igemm/sse-shuffle.c.in b/src/f32-igemm/sse-shuffle.c.in
new file mode 100644
index 0000000..512ec44
--- /dev/null
+++ b/src/f32-igemm/sse-shuffle.c.in
@@ -0,0 +1,166 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert NR % 4 == 0
+$IDLETTERS = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_f32_igemm_ukernel_${MR}x${NR}s4__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const float**restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+ assert(ks != 0);
+ assert(ks % (${MR} * sizeof(void*)) == 0);
+ assert(a_offset % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ __m128 vacc0x${IDLETTERS[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ __m128 vacc0x${IDLETTERS[N:N+4]} = _mm_load_ps(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${IDLETTERS[N:N+4]} = vacc0x${IDLETTERS[N:N+4]};
+ w += ${NR};
+
+ size_t p = ks;
+ do {
+ $for M in range(MR):
+ const float* restrict a${M} = a[${M}];
+ if XNN_UNPREDICTABLE(a${M} != zero) {
+ a${M} = (const float*) ((uintptr_t) a${M} + a_offset);
+ }
+ a += ${MR};
+
+ size_t k = kc;
+ while (k >= 4 * sizeof(float)) {
+ $for M in range(MR):
+ __m128 va${M} = _mm_loadu_ps(a${M});
+ a${M} += 4;
+
+ $for L in range(4):
+
+ $for N in range(0, NR, 4):
+ const __m128 vb${IDLETTERS[N:N+4]}c${L} = _mm_load_ps(w + ${L * NR + N});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_add_ps(vacc${M}x${IDLETTERS[N:N+4]}, _mm_mul_ps(va${M}, vb${IDLETTERS[N:N+4]}c${L}));
+
+ $if L + 1 != 4:
+ $for M in range(MR):
+ va${M} = _mm_shuffle_ps(va${M}, va${M}, _MM_SHUFFLE(0, 3, 2, 1));
+
+ w += ${4 * NR};
+ k -= 4 * sizeof(float);
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 vb${IDLETTERS[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ const __m128 vb${IDLETTERS[N:N+4]} = _mm_load_ps(w + ${N});
+ w += ${NR};
+
+ $for M in range(MR):
+ const __m128 va${M} = _mm_load1_ps(a${M});
+ a${M} += 1;
+
+ $for M in range(MR):
+ $for N in range(0, NR, 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_add_ps(vacc${M}x${IDLETTERS[N:N+4]}, _mm_mul_ps(va${M}, vb${IDLETTERS[N:N+4]}));
+ k -= sizeof(float);
+ } while (k != 0);
+ }
+ p -= ${MR} * sizeof(void*);
+ } while (p != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_min_ps(vacc${M}x${IDLETTERS[N:N+4]}, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_max_ps(vacc${M}x${IDLETTERS[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, NR, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float**restrict) ((uintptr_t) a - ks);
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = vacc${M}x${IDLETTERS[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ _mm_storel_pi((__m64*) c${M}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${IDLETTERS[0:4]} = _mm_movehl_ps(vacc${M}x${IDLETTERS[0:4]}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ _mm_store_ss(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-maxpool/9p8q-psimd.c b/src/f32-maxpool/9p8q-psimd.c
new file mode 100644
index 0000000..570b3c1
--- /dev/null
+++ b/src/f32-maxpool/9p8q-psimd.c
@@ -0,0 +1,244 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/maxpool.h>
+
+
+void xnn_f32_maxpool_ukernel_9p8q__psimd(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(kc != 0);
+
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.max);
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.min);
+ do {
+ float* o = output;
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks < 4) {
+ i3 = i0;
+ }
+ if (ks <= 4) {
+ i4 = i0;
+ }
+ if (ks < 6) {
+ i5 = i0;
+ }
+ if (ks <= 6) {
+ i6 = i0;
+ }
+ if (ks < 8) {
+ i7 = i0;
+ }
+ if (ks <= 8) {
+ i8 = i0;
+ }
+
+ size_t k = kc;
+ for (; k >= 4; k -= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ i8 += 4;
+
+ const psimd_f32 vmax018 = psimd_max_f32(psimd_max_f32(vi0, vi1), vi8);
+ const psimd_f32 vmax23 = psimd_max_f32(vi2, vi3);
+ const psimd_f32 vmax45 = psimd_max_f32(vi4, vi5);
+ const psimd_f32 vmax67 = psimd_max_f32(vi6, vi7);
+
+ const psimd_f32 vmax2345 = psimd_max_f32(vmax23, vmax45);
+ const psimd_f32 vmax01678 = psimd_max_f32(vmax018, vmax67);
+ const psimd_f32 vmax = psimd_max_f32(vmax2345, vmax01678);
+ const psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ psimd_store_f32(o, vout);
+ o += 4;
+ }
+ if (k != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ i8 += 4;
+
+ const psimd_f32 vmax018 = psimd_max_f32(psimd_max_f32(vi0, vi1), vi8);
+ const psimd_f32 vmax23 = psimd_max_f32(vi2, vi3);
+ const psimd_f32 vmax45 = psimd_max_f32(vi4, vi5);
+ const psimd_f32 vmax67 = psimd_max_f32(vi6, vi7);
+
+ const psimd_f32 vmax2345 = psimd_max_f32(vmax23, vmax45);
+ const psimd_f32 vmax01678 = psimd_max_f32(vmax018, vmax67);
+ const psimd_f32 vmax = psimd_max_f32(vmax2345, vmax01678);
+ psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ psimd_store2_f32(o, vout);
+ vout = psimd_concat_hi_f32(vout, vout);
+ o += 2;
+ }
+ if (k & 1) {
+ psimd_store1_f32(o, vout);
+ o += 1;
+ }
+ }
+ }
+
+ for (ptrdiff_t m = (ptrdiff_t) ks - 9; m > 0; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ if (m < 2) {
+ i1 = i0;
+ }
+ if (m <= 2) {
+ i2 = i0;
+ }
+ if (m < 4) {
+ i3 = i0;
+ }
+ if (m <= 4) {
+ i4 = i0;
+ }
+ if (m < 6) {
+ i5 = i0;
+ }
+ if (m <= 6) {
+ i6 = i0;
+ }
+ if (m < 8) {
+ i7 = i0;
+ }
+
+ o = output;
+ size_t k = kc;
+ for (; k >= 4; k -= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vo = psimd_load_f32(o);
+
+ const psimd_f32 vmax01 = psimd_max_f32(psimd_max_f32(vi0, vi1), vo);
+ const psimd_f32 vmax23 = psimd_max_f32(vi2, vi3);
+ const psimd_f32 vmax45 = psimd_max_f32(vi4, vi5);
+ const psimd_f32 vmax67 = psimd_max_f32(vi6, vi7);
+
+ const psimd_f32 vmax2345 = psimd_max_f32(vmax23, vmax45);
+ const psimd_f32 vmax0167 = psimd_max_f32(vmax01, vmax67);
+ const psimd_f32 vmax = psimd_max_f32(vmax2345, vmax0167);
+ const psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ psimd_store_f32(o, vout);
+ o += 4;
+ }
+ if (k != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vo = psimd_load_f32(o);
+
+ const psimd_f32 vmax01 = psimd_max_f32(psimd_max_f32(vi0, vi1), vo);
+ const psimd_f32 vmax23 = psimd_max_f32(vi2, vi3);
+ const psimd_f32 vmax45 = psimd_max_f32(vi4, vi5);
+ const psimd_f32 vmax67 = psimd_max_f32(vi6, vi7);
+
+ const psimd_f32 vmax2345 = psimd_max_f32(vmax23, vmax45);
+ const psimd_f32 vmax0167 = psimd_max_f32(vmax01, vmax67);
+ const psimd_f32 vmax = psimd_max_f32(vmax2345, vmax0167);
+ psimd_f32 vout = psimd_max_f32(psimd_min_f32(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ psimd_store2_f32(o, vout);
+ vout = psimd_concat_hi_f32(vout, vout);
+ o += 2;
+ }
+ if (k & 1) {
+ psimd_store1_f32(o, vout);
+ o += 1;
+ }
+ }
+ }
+ input = (const float**) ((uintptr_t) input + input_increment);
+ output = (float*) ((uintptr_t) o + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-maxpool/9p8q-scalar.c b/src/f32-maxpool/9p8q-scalar.c
new file mode 100644
index 0000000..e05d025
--- /dev/null
+++ b/src/f32-maxpool/9p8q-scalar.c
@@ -0,0 +1,157 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/maxpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_maxpool_ukernel_9p8q__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(kc != 0);
+
+ const float voutput_min = params->scalar.min;
+ const float voutput_max = params->scalar.max;
+ do {
+ float* o = output;
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks < 4) {
+ i3 = i0;
+ }
+ if (ks <= 4) {
+ i4 = i0;
+ }
+ if (ks < 6) {
+ i5 = i0;
+ }
+ if (ks <= 6) {
+ i6 = i0;
+ }
+ if (ks < 8) {
+ i7 = i0;
+ }
+ if (ks <= 8) {
+ i8 = i0;
+ }
+
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vi8 = *i8++;
+
+ const float vmax01 = math_max_f32(vi0, vi1);
+ const float vmax23 = math_max_f32(vi2, vi3);
+ const float vmax45 = math_max_f32(vi4, vi5);
+ const float vmax67 = math_max_f32(vi6, vi7);
+ const float vmax018 = math_max_f32(vmax01, vi8);
+
+ const float vmax2345 = math_max_f32(vmax23, vmax45);
+ const float vmax01678 = math_max_f32(vmax018, vmax67);
+ float vout = math_max_f32(vmax2345, vmax01678);
+ vout = math_max_f32(vout, voutput_min);
+ vout = math_min_f32(vout, voutput_max);
+
+ *o++ = vout;
+ } while (--k != 0);
+ }
+
+ for (ptrdiff_t m = (ptrdiff_t) ks - 9; m > 0; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ if (m < 2) {
+ i1 = i0;
+ }
+ if (m <= 2) {
+ i2 = i0;
+ }
+ if (m < 4) {
+ i3 = i0;
+ }
+ if (m <= 4) {
+ i4 = i0;
+ }
+ if (m < 6) {
+ i5 = i0;
+ }
+ if (m <= 6) {
+ i6 = i0;
+ }
+ if (m < 8) {
+ i7 = i0;
+ }
+
+ o = output;
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vi8 = *o;
+
+ const float vmax01 = math_max_f32(vi0, vi1);
+ const float vmax23 = math_max_f32(vi2, vi3);
+ const float vmax45 = math_max_f32(vi4, vi5);
+ const float vmax67 = math_max_f32(vi6, vi7);
+ const float vmax018 = math_max_f32(vmax01, vi8);
+
+ const float vmax2345 = math_max_f32(vmax23, vmax45);
+ const float vmax01678 = math_max_f32(vmax018, vmax67);
+ float vout = math_max_f32(vmax2345, vmax01678);
+ vout = math_max_f32(vout, voutput_min);
+ vout = math_min_f32(vout, voutput_max);
+
+ *o++ = vout;
+ } while (--k != 0);
+ }
+ input = (const float**) ((uintptr_t) input + input_increment);
+ output = (float*) ((uintptr_t) o + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-maxpool/9p8q-sse.c b/src/f32-maxpool/9p8q-sse.c
new file mode 100644
index 0000000..dc8c117
--- /dev/null
+++ b/src/f32-maxpool/9p8q-sse.c
@@ -0,0 +1,244 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/maxpool.h>
+
+
+void xnn_f32_maxpool_ukernel_9p8q__sse(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(kc != 0);
+
+ const __m128 voutput_max = _mm_load_ps(params->sse.max);
+ const __m128 voutput_min = _mm_load_ps(params->sse.min);
+ do {
+ float* o = output;
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks < 4) {
+ i3 = i0;
+ }
+ if (ks <= 4) {
+ i4 = i0;
+ }
+ if (ks < 6) {
+ i5 = i0;
+ }
+ if (ks <= 6) {
+ i6 = i0;
+ }
+ if (ks < 8) {
+ i7 = i0;
+ }
+ if (ks <= 8) {
+ i8 = i0;
+ }
+
+ size_t k = kc;
+ for (; k >= 4; k -= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ i8 += 4;
+
+ const __m128 vmax018 = _mm_max_ps(_mm_max_ps(vi0, vi1), vi8);
+ const __m128 vmax23 = _mm_max_ps(vi2, vi3);
+ const __m128 vmax45 = _mm_max_ps(vi4, vi5);
+ const __m128 vmax67 = _mm_max_ps(vi6, vi7);
+
+ const __m128 vmax2345 = _mm_max_ps(vmax23, vmax45);
+ const __m128 vmax01678 = _mm_max_ps(vmax018, vmax67);
+ const __m128 vmax = _mm_max_ps(vmax2345, vmax01678);
+ const __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ _mm_storeu_ps(o, vout);
+ o += 4;
+ }
+ if (k != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ i8 += 4;
+
+ const __m128 vmax018 = _mm_max_ps(_mm_max_ps(vi0, vi1), vi8);
+ const __m128 vmax23 = _mm_max_ps(vi2, vi3);
+ const __m128 vmax45 = _mm_max_ps(vi4, vi5);
+ const __m128 vmax67 = _mm_max_ps(vi6, vi7);
+
+ const __m128 vmax2345 = _mm_max_ps(vmax23, vmax45);
+ const __m128 vmax01678 = _mm_max_ps(vmax018, vmax67);
+ const __m128 vmax = _mm_max_ps(vmax2345, vmax01678);
+ __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ _mm_storel_pi((__m64*) o, vout);
+ o += 2;
+ vout = _mm_movehl_ps(vout, vout);
+ }
+ if (k & 1) {
+ _mm_store_ss(o, vout);
+ o += 1;
+ }
+ }
+ }
+
+ for (ptrdiff_t m = (ptrdiff_t) ks - 9; m > 0; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ if (m < 2) {
+ i1 = i0;
+ }
+ if (m <= 2) {
+ i2 = i0;
+ }
+ if (m < 4) {
+ i3 = i0;
+ }
+ if (m <= 4) {
+ i4 = i0;
+ }
+ if (m < 6) {
+ i5 = i0;
+ }
+ if (m <= 6) {
+ i6 = i0;
+ }
+ if (m < 8) {
+ i7 = i0;
+ }
+
+ o = output;
+ size_t k = kc;
+ for (; k >= 4; k -= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vo = _mm_loadu_ps(o);
+
+ const __m128 vmax01 = _mm_max_ps(_mm_max_ps(vi0, vi1), vo);
+ const __m128 vmax23 = _mm_max_ps(vi2, vi3);
+ const __m128 vmax45 = _mm_max_ps(vi4, vi5);
+ const __m128 vmax67 = _mm_max_ps(vi6, vi7);
+
+ const __m128 vmax2345 = _mm_max_ps(vmax23, vmax45);
+ const __m128 vmax0167 = _mm_max_ps(vmax01, vmax67);
+ const __m128 vmax = _mm_max_ps(vmax2345, vmax0167);
+ const __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ _mm_storeu_ps(o, vout);
+ o += 4;
+ }
+ if (k != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vo = _mm_loadu_ps(o);
+
+ const __m128 vmax01 = _mm_max_ps(_mm_max_ps(vi0, vi1), vo);
+ const __m128 vmax23 = _mm_max_ps(vi2, vi3);
+ const __m128 vmax45 = _mm_max_ps(vi4, vi5);
+ const __m128 vmax67 = _mm_max_ps(vi6, vi7);
+
+ const __m128 vmax2345 = _mm_max_ps(vmax23, vmax45);
+ const __m128 vmax0167 = _mm_max_ps(vmax01, vmax67);
+ const __m128 vmax = _mm_max_ps(vmax2345, vmax0167);
+ __m128 vout = _mm_max_ps(_mm_min_ps(vmax, voutput_max), voutput_min);
+
+ if (k & 2) {
+ _mm_storel_pi((__m64*) o, vout);
+ o += 2;
+ vout = _mm_movehl_ps(vout, vout);
+ }
+ if (k & 1) {
+ _mm_store_ss(o, vout);
+ o += 1;
+ }
+ }
+ }
+ input = (const float**) ((uintptr_t) input + input_increment);
+ output = (float*) ((uintptr_t) o + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-pavgpool/mp9p8q-neon.c b/src/f32-pavgpool/mp9p8q-neon.c
new file mode 100644
index 0000000..70422c5
--- /dev/null
+++ b/src/f32-pavgpool/mp9p8q-neon.c
@@ -0,0 +1,208 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/pavgpool.h>
+
+
+void xnn_f32_pavgpool_ukernel_mp9p8q__neon(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ const float* multiplier,
+ float* buffer,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const float32x4_t voutput_min = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t voutput_max = vld1q_dup_f32(¶ms->scalar.max);
+
+ do {
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vi8 = vld1q_f32(i8); i8 += 4;
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum018 = vaddq_f32(vsum01, vi8);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum01678 = vaddq_f32(vsum018, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum01678);
+
+ vst1q_f32(b, vsum); b += 4;
+ }
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vacc = vld1q_f32(b);
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum01a = vaddq_f32(vsum01, vacc);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum0167a = vaddq_f32(vsum01a, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum0167a);
+
+ vst1q_f32(b, vsum); b += 4;
+ }
+ }
+
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ const float32x4_t vmultiplier = vld1q_dup_f32(multiplier); multiplier += 1;
+
+ size_t k = kc;
+ float* b = buffer;
+ while (k >= 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vacc = vld1q_f32(b); b += 4;
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum01a = vaddq_f32(vsum01, vacc);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum0167a = vaddq_f32(vsum01a, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum0167a);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ vst1q_f32(output, vout); output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const float32x4_t vi0 = vld1q_f32(i0);
+ const float32x4_t vi1 = vld1q_f32(i1);
+ const float32x4_t vi2 = vld1q_f32(i2);
+ const float32x4_t vi3 = vld1q_f32(i3);
+ const float32x4_t vi4 = vld1q_f32(i4);
+ const float32x4_t vi5 = vld1q_f32(i5);
+ const float32x4_t vi6 = vld1q_f32(i6);
+ const float32x4_t vi7 = vld1q_f32(i7);
+ const float32x4_t vacc = vld1q_f32(b);
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum01a = vaddq_f32(vsum01, vacc);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum0167a = vaddq_f32(vsum01a, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum0167a);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ float32x2_t vout_lo = vget_low_f32(vout);
+ if (k & 2) {
+ vst1_f32(output, vout_lo); output += 2;
+ vout_lo = vget_high_f32(vout);
+ }
+ if (k & 1) {
+ vst1_lane_f32(output, vout_lo, 0); output += 1;
+ }
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-pavgpool/mp9p8q-psimd.c b/src/f32-pavgpool/mp9p8q-psimd.c
new file mode 100644
index 0000000..78fd150
--- /dev/null
+++ b/src/f32-pavgpool/mp9p8q-psimd.c
@@ -0,0 +1,239 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/pavgpool.h>
+
+
+void xnn_f32_pavgpool_ukernel_mp9p8q__psimd(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ const float* multiplier,
+ float* buffer,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.min);
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.max);
+
+ do {
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ i8 += 4;
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum018 = psimd_add_f32(vsum01, vi8);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum01678 = psimd_add_f32(vsum018, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum01678);
+
+ psimd_store_f32(b, vsum);
+ b += 4;
+ }
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vacc = psimd_load_f32(b);
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum01a = psimd_add_f32(vsum01, vacc);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum0167a = psimd_add_f32(vsum01a, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum0167a);
+
+ psimd_store_f32(b, vsum);
+ b += 4;
+ }
+ }
+
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ const psimd_f32 vmultiplier = psimd_load_splat_f32(multiplier);
+ multiplier += 1;
+
+ size_t k = kc;
+ float* b = buffer;
+ while (k >= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vacc = psimd_load_f32(b);
+ b += 4;
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum01a = psimd_add_f32(vsum01, vacc);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum0167a = psimd_add_f32(vsum01a, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum0167a);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ psimd_store_f32(output, vout);
+ output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vacc = psimd_load_f32(b);
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum01a = psimd_add_f32(vsum01, vacc);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum0167a = psimd_add_f32(vsum01a, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum0167a);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ if (k & 2) {
+ psimd_store2_f32(output, vout);
+ output += 2;
+ vout = psimd_concat_hi_f32(vout, vout);
+ }
+ if (k & 1) {
+ psimd_store1_f32(output, vout);
+ output += 1;
+ }
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-pavgpool/mp9p8q-scalar.c b/src/f32-pavgpool/mp9p8q-scalar.c
new file mode 100644
index 0000000..b618240
--- /dev/null
+++ b/src/f32-pavgpool/mp9p8q-scalar.c
@@ -0,0 +1,174 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/pavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_pavgpool_ukernel_mp9p8q__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ const float* multiplier,
+ float* buffer,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const float voutput_min = params->scalar.min;
+ const float voutput_max = params->scalar.max;
+
+ do {
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ float* b = buffer;
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vi8 = *i8++;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum67 = vi6 + vi7;
+ const float vsum018 = vsum01 + vi8;
+ const float vsum2345 = vsum23 + vsum45;
+ const float vsum01678 = vsum018 + vsum67;
+ const float vsum = vsum2345 + vsum01678;
+
+ *b++ = vsum;
+ } while (--k != 0);
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* b = buffer;
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vacc = *b;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum67 = vi6 + vi7;
+ const float vsum01a = vsum01 + vacc;
+ const float vsum2345 = vsum23 + vsum45;
+ const float vsum0167a = vsum01a + vsum67;
+ const float vsum = vsum2345 + vsum0167a;
+
+ *b++ = vsum;
+ } while (--k != 0);
+ }
+
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ const float vmultiplier = *multiplier++;
+
+ size_t k = kc;
+ float* b = buffer;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vacc = *b++;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum67 = vi6 + vi7;
+ const float vsum01a = vsum01 + vacc;
+ const float vsum2345 = vsum23 + vsum45;
+ const float vsum0167a = vsum01a + vsum67;
+ const float vsum = vsum2345 + vsum0167a;
+
+ float vout = vsum * vmultiplier;
+ vout = math_max_f32(vout, voutput_min);
+ vout = math_min_f32(vout, voutput_max);
+
+ *output++ = vout;
+ } while (--k != 0);
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-pavgpool/mp9p8q-sse.c b/src/f32-pavgpool/mp9p8q-sse.c
new file mode 100644
index 0000000..89b97ed
--- /dev/null
+++ b/src/f32-pavgpool/mp9p8q-sse.c
@@ -0,0 +1,237 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/pavgpool.h>
+
+
+void xnn_f32_pavgpool_ukernel_mp9p8q__sse(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ const float* multiplier,
+ float* buffer,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const __m128 voutput_min = _mm_load_ps(params->sse.min);
+ const __m128 voutput_max = _mm_load_ps(params->sse.max);
+
+ do {
+ {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+ const float* i8 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ i8 += 4;
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum018 = _mm_add_ps(vsum01, vi8);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum01678 = _mm_add_ps(vsum018, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum01678);
+
+ _mm_store_ps(b, vsum); b += 4;
+ }
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const float* i0 = *input++;
+ const float* i1 = *input++;
+ const float* i2 = *input++;
+ const float* i3 = *input++;
+ const float* i4 = *input++;
+ const float* i5 = *input++;
+ const float* i6 = *input++;
+ const float* i7 = *input++;
+
+ float* b = buffer;
+ for (size_t k = 0; k < kc; k += 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vacc = _mm_load_ps(b);
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum01a = _mm_add_ps(vsum01, vacc);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum0167a = _mm_add_ps(vsum01a, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum0167a);
+
+ _mm_store_ps(b, vsum); b += 4;
+ }
+ }
+
+ {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ const __m128 vmultiplier = _mm_load1_ps(multiplier);
+ multiplier += 1;
+
+ size_t k = kc;
+ float* b = buffer;
+ while (k >= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vacc = _mm_load_ps(b);
+ b += 4;
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum01a = _mm_add_ps(vsum01, vacc);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum0167a = _mm_add_ps(vsum01a, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum0167a);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ _mm_storeu_ps(output, vout);
+ output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vacc = _mm_load_ps(b);
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum01a = _mm_add_ps(vsum01, vacc);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum0167a = _mm_add_ps(vsum01a, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum0167a);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ if (k & 2) {
+ _mm_storel_pi((__m64*) output, vout);
+ vout = _mm_movehl_ps(vout, vout);
+ output += 2;
+ }
+ if (k & 1) {
+ _mm_store_ss(output, vout);
+ output += 1;
+ }
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-pavgpool/up9-neon.c b/src/f32-pavgpool/up9-neon.c
new file mode 100644
index 0000000..2195019
--- /dev/null
+++ b/src/f32-pavgpool/up9-neon.c
@@ -0,0 +1,137 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/pavgpool.h>
+
+
+void xnn_f32_pavgpool_ukernel_up9__neon(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ const float* multiplier,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const float32x4_t voutput_min = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t voutput_max = vld1q_dup_f32(¶ms->scalar.max);
+
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ const float32x4_t vmultiplier = vld1q_dup_f32(multiplier); multiplier += 1;
+
+ size_t k = kc;
+ while (k >= 4) {
+ const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
+ const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
+ const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
+ const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
+ const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
+ const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
+ const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
+ const float32x4_t vi7 = vld1q_f32(i7); i7 += 4;
+ const float32x4_t vi8 = vld1q_f32(i8); i8 += 4;
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum018 = vaddq_f32(vsum01, vi8);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum01678 = vaddq_f32(vsum018, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum01678);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ vst1q_f32(output, vout); output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const float32x4_t vi0 = vld1q_f32(i0);
+ const float32x4_t vi1 = vld1q_f32(i1);
+ const float32x4_t vi2 = vld1q_f32(i2);
+ const float32x4_t vi3 = vld1q_f32(i3);
+ const float32x4_t vi4 = vld1q_f32(i4);
+ const float32x4_t vi5 = vld1q_f32(i5);
+ const float32x4_t vi6 = vld1q_f32(i6);
+ const float32x4_t vi7 = vld1q_f32(i7);
+ const float32x4_t vi8 = vld1q_f32(i8);
+
+ const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
+ const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
+ const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
+ const float32x4_t vsum67 = vaddq_f32(vi6, vi7);
+ const float32x4_t vsum018 = vaddq_f32(vsum01, vi8);
+ const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
+ const float32x4_t vsum01678 = vaddq_f32(vsum018, vsum67);
+ const float32x4_t vsum = vaddq_f32(vsum2345, vsum01678);
+
+ float32x4_t vout = vmulq_f32(vsum, vmultiplier);
+ vout = vmaxq_f32(vout, voutput_min);
+ vout = vminq_f32(vout, voutput_max);
+
+ float32x2_t vout_lo = vget_low_f32(vout);
+ if (k & 2) {
+ vst1_f32(output, vout_lo); output += 2;
+ vout_lo = vget_high_f32(vout);
+ }
+ if (k & 1) {
+ vst1_lane_f32(output, vout_lo, 0); output += 1;
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-pavgpool/up9-psimd.c b/src/f32-pavgpool/up9-psimd.c
new file mode 100644
index 0000000..49637b4
--- /dev/null
+++ b/src/f32-pavgpool/up9-psimd.c
@@ -0,0 +1,149 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/pavgpool.h>
+
+
+void xnn_f32_pavgpool_ukernel_up9__psimd(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ const float* multiplier,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const psimd_f32 voutput_min = psimd_load_splat_f32(¶ms->scalar.min);
+ const psimd_f32 voutput_max = psimd_load_splat_f32(¶ms->scalar.max);
+
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ const psimd_f32 vmultiplier = psimd_load_splat_f32(multiplier);
+ multiplier += 1;
+
+ size_t k = kc;
+ while (k >= 4) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ i0 += 4;
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ i1 += 4;
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ i2 += 4;
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ i3 += 4;
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ i4 += 4;
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ i5 += 4;
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ i6 += 4;
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ i7 += 4;
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+ i8 += 4;
+
+ const psimd_f32 vsum018 = psimd_add_f32(psimd_add_f32(vi0, vi1), vi8);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum01678 = psimd_add_f32(vsum018, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum01678);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ psimd_store_f32(output, vout);
+ output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const psimd_f32 vi0 = psimd_load_f32(i0);
+ const psimd_f32 vi1 = psimd_load_f32(i1);
+ const psimd_f32 vi2 = psimd_load_f32(i2);
+ const psimd_f32 vi3 = psimd_load_f32(i3);
+ const psimd_f32 vi4 = psimd_load_f32(i4);
+ const psimd_f32 vi5 = psimd_load_f32(i5);
+ const psimd_f32 vi6 = psimd_load_f32(i6);
+ const psimd_f32 vi7 = psimd_load_f32(i7);
+ const psimd_f32 vi8 = psimd_load_f32(i8);
+
+ const psimd_f32 vsum01 = psimd_add_f32(vi0, vi1);
+ const psimd_f32 vsum23 = psimd_add_f32(vi2, vi3);
+ const psimd_f32 vsum45 = psimd_add_f32(vi4, vi5);
+ const psimd_f32 vsum67 = psimd_add_f32(vi6, vi7);
+ const psimd_f32 vsum018 = psimd_add_f32(vsum01, vi8);
+ const psimd_f32 vsum2345 = psimd_add_f32(vsum23, vsum45);
+ const psimd_f32 vsum01678 = psimd_add_f32(vsum018, vsum67);
+ const psimd_f32 vsum = psimd_add_f32(vsum2345, vsum01678);
+
+ psimd_f32 vout = psimd_mul_f32(vsum, vmultiplier);
+ vout = psimd_max_f32(vout, voutput_min);
+ vout = psimd_min_f32(vout, voutput_max);
+
+ if (k & 2) {
+ psimd_store2_f32(output, vout);
+ output += 2;
+ vout = psimd_concat_hi_f32(vout, vout);
+ }
+ if (k & 1) {
+ psimd_store1_f32(output, vout);
+ output += 1;
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-pavgpool/up9-scalar.c b/src/f32-pavgpool/up9-scalar.c
new file mode 100644
index 0000000..1778d87
--- /dev/null
+++ b/src/f32-pavgpool/up9-scalar.c
@@ -0,0 +1,101 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/pavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_pavgpool_ukernel_up9__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ const float* multiplier,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const float voutput_min = params->scalar.min;
+ const float voutput_max = params->scalar.max;
+
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ const float vmultiplier = *multiplier++;
+
+ size_t k = kc;
+ do {
+ const float vi0 = *i0++;
+ const float vi1 = *i1++;
+ const float vi2 = *i2++;
+ const float vi3 = *i3++;
+ const float vi4 = *i4++;
+ const float vi5 = *i5++;
+ const float vi6 = *i6++;
+ const float vi7 = *i7++;
+ const float vi8 = *i8++;
+
+ const float vsum01 = vi0 + vi1;
+ const float vsum23 = vi2 + vi3;
+ const float vsum45 = vi4 + vi5;
+ const float vsum67 = vi6 + vi7;
+ const float vsum018 = vsum01 + vi8;
+ const float vsum2345 = vsum23 + vsum45;
+ const float vsum01678 = vsum018 + vsum67;
+ const float vsum = vsum2345 + vsum01678;
+
+ float vout = vsum * vmultiplier;
+ vout = math_max_f32(vout, voutput_min);
+ vout = math_min_f32(vout, voutput_max);
+
+ *output++ = vout;
+ } while (--k != 0);
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-pavgpool/up9-sse.c b/src/f32-pavgpool/up9-sse.c
new file mode 100644
index 0000000..f10a613
--- /dev/null
+++ b/src/f32-pavgpool/up9-sse.c
@@ -0,0 +1,148 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/pavgpool.h>
+
+
+void xnn_f32_pavgpool_ukernel_up9__sse(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** input,
+ const float* zero,
+ const float* multiplier,
+ float* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const __m128 voutput_min = _mm_load_ps(params->sse.min);
+ const __m128 voutput_max = _mm_load_ps(params->sse.max);
+
+ do {
+ const float* i0 = input[0];
+ const float* i1 = input[1];
+ const float* i2 = input[2];
+ const float* i3 = input[3];
+ const float* i4 = input[4];
+ const float* i5 = input[5];
+ const float* i6 = input[6];
+ const float* i7 = input[7];
+ const float* i8 = input[8];
+ input = (const float**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ const __m128 vmultiplier = _mm_load1_ps(multiplier);
+ multiplier += 1;
+
+ size_t k = kc;
+ while (k >= 4) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ i0 += 4;
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ i1 += 4;
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ i2 += 4;
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ i3 += 4;
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ i4 += 4;
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ i5 += 4;
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ i6 += 4;
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ i7 += 4;
+ const __m128 vi8 = _mm_loadu_ps(i8);
+ i8 += 4;
+
+ const __m128 vsum018 = _mm_add_ps(_mm_add_ps(vi0, vi1), vi8);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum01678 = _mm_add_ps(vsum018, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum01678);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ _mm_storeu_ps(output, vout); output += 4;
+
+ k -= 4;
+ }
+ if (k != 0) {
+ const __m128 vi0 = _mm_loadu_ps(i0);
+ const __m128 vi1 = _mm_loadu_ps(i1);
+ const __m128 vi2 = _mm_loadu_ps(i2);
+ const __m128 vi3 = _mm_loadu_ps(i3);
+ const __m128 vi4 = _mm_loadu_ps(i4);
+ const __m128 vi5 = _mm_loadu_ps(i5);
+ const __m128 vi6 = _mm_loadu_ps(i6);
+ const __m128 vi7 = _mm_loadu_ps(i7);
+ const __m128 vi8 = _mm_loadu_ps(i8);
+
+ const __m128 vsum01 = _mm_add_ps(vi0, vi1);
+ const __m128 vsum23 = _mm_add_ps(vi2, vi3);
+ const __m128 vsum45 = _mm_add_ps(vi4, vi5);
+ const __m128 vsum67 = _mm_add_ps(vi6, vi7);
+ const __m128 vsum018 = _mm_add_ps(vsum01, vi8);
+ const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
+ const __m128 vsum01678 = _mm_add_ps(vsum018, vsum67);
+ const __m128 vsum = _mm_add_ps(vsum2345, vsum01678);
+
+ __m128 vout = _mm_mul_ps(vsum, vmultiplier);
+ vout = _mm_max_ps(vout, voutput_min);
+ vout = _mm_min_ps(vout, voutput_max);
+
+ if (k & 2) {
+ _mm_storel_pi((__m64*) output, vout);
+ vout = _mm_movehl_ps(vout, vout);
+ output += 2;
+ }
+ if (k & 1) {
+ _mm_store_ss(output, vout);
+ output += 1;
+ }
+ }
+ output = (float*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/f32-ppmm/2x4-scalar.c b/src/f32-ppmm/2x4-scalar.c
new file mode 100644
index 0000000..583262f
--- /dev/null
+++ b/src/f32-ppmm/2x4-scalar.c
@@ -0,0 +1,131 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_2x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ c1 = c0;
+ }
+
+ do {
+ float vacc0x0 = w[0];
+ float vacc0x1 = w[1];
+ float vacc0x2 = w[2];
+ float vacc0x3 = w[3];
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = vacc0x1;
+ float vacc1x2 = vacc0x2;
+ float vacc1x3 = vacc0x3;
+ w += 4;
+
+ size_t k = kc;
+ do {
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a += 2;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
+ vacc0x1 += va0 * vb1;
+ vacc1x1 += va1 * vb1;
+ vacc0x2 += va0 * vb2;
+ vacc1x2 += va1 * vb2;
+ vacc0x3 += va0 * vb3;
+ vacc1x3 += va1 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmax = params->scalar.max;
+ vacc0x0 = math_min_f32(vacc0x0, vmax);
+ vacc1x0 = math_min_f32(vacc1x0, vmax);
+ vacc0x1 = math_min_f32(vacc0x1, vmax);
+ vacc1x1 = math_min_f32(vacc1x1, vmax);
+ vacc0x2 = math_min_f32(vacc0x2, vmax);
+ vacc1x2 = math_min_f32(vacc1x2, vmax);
+ vacc0x3 = math_min_f32(vacc0x3, vmax);
+ vacc1x3 = math_min_f32(vacc1x3, vmax);
+
+ const float vmin = params->scalar.min;
+ vacc0x0 = math_max_f32(vacc0x0, vmin);
+ vacc1x0 = math_max_f32(vacc1x0, vmin);
+ vacc0x1 = math_max_f32(vacc0x1, vmin);
+ vacc1x1 = math_max_f32(vacc1x1, vmin);
+ vacc0x2 = math_max_f32(vacc0x2, vmin);
+ vacc1x2 = math_max_f32(vacc1x2, vmin);
+ vacc0x3 = math_max_f32(vacc0x3, vmin);
+ vacc1x3 = math_max_f32(vacc1x3, vmin);
+
+ if XNN_LIKELY(nc >= 4) {
+ c1[0] = vacc1x0;
+ c1[1] = vacc1x1;
+ c1[2] = vacc1x2;
+ c1[3] = vacc1x3;
+ c0[0] = vacc0x0;
+ c0[1] = vacc0x1;
+ c0[2] = vacc0x2;
+ c0[3] = vacc0x3;
+
+ a = (const float*) ((uintptr_t) a - kc * 2);
+
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c1[0] = vacc1x0;
+ c1[1] = vacc1x1;
+ c0[0] = vacc0x0;
+ c0[1] = vacc0x1;
+
+ vacc1x0 = vacc1x2;
+ vacc0x0 = vacc0x2;
+
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ *c1 = vacc1x0;
+ *c0 = vacc0x0;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/3x3-scalar.c b/src/f32-ppmm/3x3-scalar.c
new file mode 100644
index 0000000..bc1f53b
--- /dev/null
+++ b/src/f32-ppmm/3x3-scalar.c
@@ -0,0 +1,146 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_3x3__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 3);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+
+ do {
+ float vacc0x0 = w[0];
+ float vacc0x1 = w[1];
+ float vacc0x2 = w[2];
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = vacc0x1;
+ float vacc1x2 = vacc0x2;
+ float vacc2x0 = vacc0x0;
+ float vacc2x1 = vacc0x1;
+ float vacc2x2 = vacc0x2;
+ w += 3;
+
+ size_t k = kc;
+ do {
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ a += 3;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ w += 3;
+
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
+ vacc2x0 += va2 * vb0;
+ vacc0x1 += va0 * vb1;
+ vacc1x1 += va1 * vb1;
+ vacc2x1 += va2 * vb1;
+ vacc0x2 += va0 * vb2;
+ vacc1x2 += va1 * vb2;
+ vacc2x2 += va2 * vb2;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmax = params->scalar.max;
+ vacc0x0 = math_min_f32(vacc0x0, vmax);
+ vacc1x0 = math_min_f32(vacc1x0, vmax);
+ vacc2x0 = math_min_f32(vacc2x0, vmax);
+ vacc0x1 = math_min_f32(vacc0x1, vmax);
+ vacc1x1 = math_min_f32(vacc1x1, vmax);
+ vacc2x1 = math_min_f32(vacc2x1, vmax);
+ vacc0x2 = math_min_f32(vacc0x2, vmax);
+ vacc1x2 = math_min_f32(vacc1x2, vmax);
+ vacc2x2 = math_min_f32(vacc2x2, vmax);
+
+ const float vmin = params->scalar.min;
+ vacc0x0 = math_max_f32(vacc0x0, vmin);
+ vacc1x0 = math_max_f32(vacc1x0, vmin);
+ vacc2x0 = math_max_f32(vacc2x0, vmin);
+ vacc0x1 = math_max_f32(vacc0x1, vmin);
+ vacc1x1 = math_max_f32(vacc1x1, vmin);
+ vacc2x1 = math_max_f32(vacc2x1, vmin);
+ vacc0x2 = math_max_f32(vacc0x2, vmin);
+ vacc1x2 = math_max_f32(vacc1x2, vmin);
+ vacc2x2 = math_max_f32(vacc2x2, vmin);
+
+ if XNN_LIKELY(nc >= 3) {
+ c2[0] = vacc2x0;
+ c2[1] = vacc2x1;
+ c2[2] = vacc2x2;
+ c1[0] = vacc1x0;
+ c1[1] = vacc1x1;
+ c1[2] = vacc1x2;
+ c0[0] = vacc0x0;
+ c0[1] = vacc0x1;
+ c0[2] = vacc0x2;
+
+ a = (const float*) ((uintptr_t) a - kc * 3);
+
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ nc -= 3;
+ } else {
+ if (nc & 2) {
+ c2[0] = vacc2x0;
+ c2[1] = vacc2x1;
+ c1[0] = vacc1x0;
+ c1[1] = vacc1x1;
+ c0[0] = vacc0x0;
+ c0[1] = vacc0x1;
+
+ vacc2x0 = vacc2x2;
+ vacc1x0 = vacc1x2;
+ vacc0x0 = vacc0x2;
+
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ *c2 = vacc2x0;
+ *c1 = vacc1x0;
+ *c0 = vacc0x0;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/4x2-scalar.c b/src/f32-ppmm/4x2-scalar.c
new file mode 100644
index 0000000..d236fd6
--- /dev/null
+++ b/src/f32-ppmm/4x2-scalar.c
@@ -0,0 +1,131 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_4x2__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float vacc0x0 = w[0];
+ float vacc0x1 = w[1];
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = vacc0x1;
+ float vacc2x0 = vacc0x0;
+ float vacc2x1 = vacc0x1;
+ float vacc3x0 = vacc0x0;
+ float vacc3x1 = vacc0x1;
+ w += 2;
+
+ size_t k = kc;
+ do {
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ a += 4;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ w += 2;
+
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
+ vacc2x0 += va2 * vb0;
+ vacc3x0 += va3 * vb0;
+ vacc0x1 += va0 * vb1;
+ vacc1x1 += va1 * vb1;
+ vacc2x1 += va2 * vb1;
+ vacc3x1 += va3 * vb1;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmax = params->scalar.max;
+ vacc0x0 = math_min_f32(vacc0x0, vmax);
+ vacc1x0 = math_min_f32(vacc1x0, vmax);
+ vacc2x0 = math_min_f32(vacc2x0, vmax);
+ vacc3x0 = math_min_f32(vacc3x0, vmax);
+ vacc0x1 = math_min_f32(vacc0x1, vmax);
+ vacc1x1 = math_min_f32(vacc1x1, vmax);
+ vacc2x1 = math_min_f32(vacc2x1, vmax);
+ vacc3x1 = math_min_f32(vacc3x1, vmax);
+
+ const float vmin = params->scalar.min;
+ vacc0x0 = math_max_f32(vacc0x0, vmin);
+ vacc1x0 = math_max_f32(vacc1x0, vmin);
+ vacc2x0 = math_max_f32(vacc2x0, vmin);
+ vacc3x0 = math_max_f32(vacc3x0, vmin);
+ vacc0x1 = math_max_f32(vacc0x1, vmin);
+ vacc1x1 = math_max_f32(vacc1x1, vmin);
+ vacc2x1 = math_max_f32(vacc2x1, vmin);
+ vacc3x1 = math_max_f32(vacc3x1, vmin);
+
+ if XNN_LIKELY(nc >= 2) {
+ c3[0] = vacc3x0;
+ c3[1] = vacc3x1;
+ c2[0] = vacc2x0;
+ c2[1] = vacc2x1;
+ c1[0] = vacc1x0;
+ c1[1] = vacc1x1;
+ c0[0] = vacc0x0;
+ c0[1] = vacc0x1;
+
+ a = (const float*) ((uintptr_t) a - kc * 4);
+
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ nc -= 2;
+ } else {
+ if (nc & 1) {
+ *c3 = vacc3x0;
+ *c2 = vacc2x0;
+ *c1 = vacc1x0;
+ *c0 = vacc0x0;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/4x4-scalar.c b/src/f32-ppmm/4x4-scalar.c
new file mode 100644
index 0000000..f8df14c
--- /dev/null
+++ b/src/f32-ppmm/4x4-scalar.c
@@ -0,0 +1,193 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_4x4__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float vacc0x0 = w[0];
+ float vacc0x1 = w[1];
+ float vacc0x2 = w[2];
+ float vacc0x3 = w[3];
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = vacc0x1;
+ float vacc1x2 = vacc0x2;
+ float vacc1x3 = vacc0x3;
+ float vacc2x0 = vacc0x0;
+ float vacc2x1 = vacc0x1;
+ float vacc2x2 = vacc0x2;
+ float vacc2x3 = vacc0x3;
+ float vacc3x0 = vacc0x0;
+ float vacc3x1 = vacc0x1;
+ float vacc3x2 = vacc0x2;
+ float vacc3x3 = vacc0x3;
+ w += 4;
+
+ size_t k = kc;
+ do {
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ a += 4;
+
+ const float vb0 = w[0];
+ const float vb1 = w[1];
+ const float vb2 = w[2];
+ const float vb3 = w[3];
+ w += 4;
+
+ vacc0x0 += va0 * vb0;
+ vacc1x0 += va1 * vb0;
+ vacc2x0 += va2 * vb0;
+ vacc3x0 += va3 * vb0;
+ vacc0x1 += va0 * vb1;
+ vacc1x1 += va1 * vb1;
+ vacc2x1 += va2 * vb1;
+ vacc3x1 += va3 * vb1;
+ vacc0x2 += va0 * vb2;
+ vacc1x2 += va1 * vb2;
+ vacc2x2 += va2 * vb2;
+ vacc3x2 += va3 * vb2;
+ vacc0x3 += va0 * vb3;
+ vacc1x3 += va1 * vb3;
+ vacc2x3 += va2 * vb3;
+ vacc3x3 += va3 * vb3;
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmax = params->scalar.max;
+ vacc0x0 = math_min_f32(vacc0x0, vmax);
+ vacc1x0 = math_min_f32(vacc1x0, vmax);
+ vacc2x0 = math_min_f32(vacc2x0, vmax);
+ vacc3x0 = math_min_f32(vacc3x0, vmax);
+ vacc0x1 = math_min_f32(vacc0x1, vmax);
+ vacc1x1 = math_min_f32(vacc1x1, vmax);
+ vacc2x1 = math_min_f32(vacc2x1, vmax);
+ vacc3x1 = math_min_f32(vacc3x1, vmax);
+ vacc0x2 = math_min_f32(vacc0x2, vmax);
+ vacc1x2 = math_min_f32(vacc1x2, vmax);
+ vacc2x2 = math_min_f32(vacc2x2, vmax);
+ vacc3x2 = math_min_f32(vacc3x2, vmax);
+ vacc0x3 = math_min_f32(vacc0x3, vmax);
+ vacc1x3 = math_min_f32(vacc1x3, vmax);
+ vacc2x3 = math_min_f32(vacc2x3, vmax);
+ vacc3x3 = math_min_f32(vacc3x3, vmax);
+
+ const float vmin = params->scalar.min;
+ vacc0x0 = math_max_f32(vacc0x0, vmin);
+ vacc1x0 = math_max_f32(vacc1x0, vmin);
+ vacc2x0 = math_max_f32(vacc2x0, vmin);
+ vacc3x0 = math_max_f32(vacc3x0, vmin);
+ vacc0x1 = math_max_f32(vacc0x1, vmin);
+ vacc1x1 = math_max_f32(vacc1x1, vmin);
+ vacc2x1 = math_max_f32(vacc2x1, vmin);
+ vacc3x1 = math_max_f32(vacc3x1, vmin);
+ vacc0x2 = math_max_f32(vacc0x2, vmin);
+ vacc1x2 = math_max_f32(vacc1x2, vmin);
+ vacc2x2 = math_max_f32(vacc2x2, vmin);
+ vacc3x2 = math_max_f32(vacc3x2, vmin);
+ vacc0x3 = math_max_f32(vacc0x3, vmin);
+ vacc1x3 = math_max_f32(vacc1x3, vmin);
+ vacc2x3 = math_max_f32(vacc2x3, vmin);
+ vacc3x3 = math_max_f32(vacc3x3, vmin);
+
+ if XNN_LIKELY(nc >= 4) {
+ c3[0] = vacc3x0;
+ c3[1] = vacc3x1;
+ c3[2] = vacc3x2;
+ c3[3] = vacc3x3;
+ c2[0] = vacc2x0;
+ c2[1] = vacc2x1;
+ c2[2] = vacc2x2;
+ c2[3] = vacc2x3;
+ c1[0] = vacc1x0;
+ c1[1] = vacc1x1;
+ c1[2] = vacc1x2;
+ c1[3] = vacc1x3;
+ c0[0] = vacc0x0;
+ c0[1] = vacc0x1;
+ c0[2] = vacc0x2;
+ c0[3] = vacc0x3;
+
+ a = (const float*) ((uintptr_t) a - kc * 4);
+
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ c3[0] = vacc3x0;
+ c3[1] = vacc3x1;
+ c2[0] = vacc2x0;
+ c2[1] = vacc2x1;
+ c1[0] = vacc1x0;
+ c1[1] = vacc1x1;
+ c0[0] = vacc0x0;
+ c0[1] = vacc0x1;
+
+ vacc3x0 = vacc3x2;
+ vacc2x0 = vacc2x2;
+ vacc1x0 = vacc1x2;
+ vacc0x0 = vacc0x2;
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ *c3 = vacc3x0;
+ *c2 = vacc2x0;
+ *c1 = vacc1x0;
+ *c0 = vacc0x0;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/4x8-neon.c b/src/f32-ppmm/4x8-neon.c
new file mode 100644
index 0000000..ad55a43
--- /dev/null
+++ b/src/f32-ppmm/4x8-neon.c
@@ -0,0 +1,151 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_4x8__neon(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t k = kc;
+ do {
+ const float32x4_t va0123 = vld1q_f32(a); a += 4;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123, vget_low_f32(va0123), 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123, vget_low_f32(va0123), 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123, vget_high_f32(va0123), 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123, vget_high_f32(va0123), 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567, vget_low_f32(va0123), 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567, vget_low_f32(va0123), 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567, vget_high_f32(va0123), 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567, vget_high_f32(va0123), 1);
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float*) ((uintptr_t) a - kc * 4);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/4x8-neonfma.c b/src/f32-ppmm/4x8-neonfma.c
new file mode 100644
index 0000000..04dd136
--- /dev/null
+++ b/src/f32-ppmm/4x8-neonfma.c
@@ -0,0 +1,167 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_4x8__neonfma(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t k = kc;
+ do {
+ const float32x4_t va0123 = vld1q_f32(a); a += 4;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ #ifdef __aarch64__
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123, va0123, 0);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123, va0123, 1);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123, va0123, 2);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123, va0123, 3);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567, va0123, 0);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567, va0123, 1);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567, va0123, 2);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567, va0123, 3);
+ #else
+ const float32x4_t va0000 = vdupq_lane_f32(vget_low_f32(va0123), 0);
+ const float32x4_t va1111 = vdupq_lane_f32(vget_low_f32(va0123), 1);
+ const float32x4_t va2222 = vdupq_lane_f32(vget_high_f32(va0123), 0);
+ const float32x4_t va3333 = vdupq_lane_f32(vget_high_f32(va0123), 1);
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0000, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1111, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2222, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3333, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0000, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1111, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2222, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3333, vb4567);
+ #endif
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float*) ((uintptr_t) a - kc * 4);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/4x8-psimd.c b/src/f32-ppmm/4x8-psimd.c
new file mode 100644
index 0000000..dbfca63
--- /dev/null
+++ b/src/f32-ppmm/4x8-psimd.c
@@ -0,0 +1,166 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/psimd.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_4x8__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ psimd_f32 vacc0x0123 = psimd_load_f32(w);
+ psimd_f32 vacc0x4567 = psimd_load_f32(w + 4);
+ psimd_f32 vacc1x0123 = vacc0x0123;
+ psimd_f32 vacc1x4567 = vacc0x4567;
+ psimd_f32 vacc2x0123 = vacc0x0123;
+ psimd_f32 vacc2x4567 = vacc0x4567;
+ psimd_f32 vacc3x0123 = vacc0x0123;
+ psimd_f32 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ do {
+ const psimd_f32 va0123 = psimd_load_f32(a);
+ a += 4;
+
+ const psimd_f32 vb0123 = psimd_load_f32(w);
+ const psimd_f32 vb4567 = psimd_load_f32(w + 4);
+ w += 8;
+
+ const psimd_f32 va0000 = psimd_splat0_f32(va0123);
+ const psimd_f32 va1111 = psimd_splat1_f32(va0123);
+ const psimd_f32 va2222 = psimd_splat2_f32(va0123);
+ const psimd_f32 va3333 = psimd_splat3_f32(va0123);
+
+ vacc0x0123 = psimd_qfma_f32(vacc0x0123, va0000, vb0123);
+ vacc1x0123 = psimd_qfma_f32(vacc1x0123, va1111, vb0123);
+ vacc2x0123 = psimd_qfma_f32(vacc2x0123, va2222, vb0123);
+ vacc3x0123 = psimd_qfma_f32(vacc3x0123, va3333, vb0123);
+ vacc0x4567 = psimd_qfma_f32(vacc0x4567, va0000, vb4567);
+ vacc1x4567 = psimd_qfma_f32(vacc1x4567, va1111, vb4567);
+ vacc2x4567 = psimd_qfma_f32(vacc2x4567, va2222, vb4567);
+ vacc3x4567 = psimd_qfma_f32(vacc3x4567, va3333, vb4567);
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+ vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
+ vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
+ vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
+ vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
+ vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
+ vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+ vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
+ vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
+ vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
+ vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
+ vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
+ vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c3 + 4, vacc3x4567);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c2 + 4, vacc2x4567);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c1 + 4, vacc1x4567);
+ psimd_store_f32(c0, vacc0x0123);
+ psimd_store_f32(c0 + 4, vacc0x4567);
+
+ a = (const float*) ((uintptr_t) a - kc * 4);
+
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ psimd_store_f32(c3, vacc3x0123);
+ psimd_store_f32(c2, vacc2x0123);
+ psimd_store_f32(c1, vacc1x0123);
+ psimd_store_f32(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ psimd_store2_f32(c3, vacc3x0123);
+ psimd_store2_f32(c2, vacc2x0123);
+ psimd_store2_f32(c1, vacc1x0123);
+ psimd_store2_f32(c0, vacc0x0123);
+
+ vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
+ vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ psimd_store1_f32(c3, vacc3x0123);
+ psimd_store1_f32(c2, vacc2x0123);
+ psimd_store1_f32(c1, vacc1x0123);
+ psimd_store1_f32(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/4x8-sse.c b/src/f32-ppmm/4x8-sse.c
new file mode 100644
index 0000000..31fa8dd
--- /dev/null
+++ b/src/f32-ppmm/4x8-sse.c
@@ -0,0 +1,166 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/sse.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_4x8__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ do {
+ __m128 vacc0x0123 = _mm_load_ps(w);
+ __m128 vacc0x4567 = _mm_load_ps(w + 4);
+ __m128 vacc1x0123 = vacc0x0123;
+ __m128 vacc1x4567 = vacc0x4567;
+ __m128 vacc2x0123 = vacc0x0123;
+ __m128 vacc2x4567 = vacc0x4567;
+ __m128 vacc3x0123 = vacc0x0123;
+ __m128 vacc3x4567 = vacc0x4567;
+ w += 8;
+
+ size_t k = kc;
+ do {
+ const __m128 va0123 = _mm_load_ps(a);
+ a += 4;
+
+ const __m128 vb0123 = _mm_load_ps(w);
+ const __m128 vb4567 = _mm_load_ps(w + 4);
+ w += 8;
+
+ const __m128 va0000 = _mm_shuffle_ps(va0123, va0123, _MM_SHUFFLE(0, 0, 0, 0));
+ const __m128 va1111 = _mm_shuffle_ps(va0123, va0123, _MM_SHUFFLE(1, 1, 1, 1));
+ const __m128 va2222 = _mm_shuffle_ps(va0123, va0123, _MM_SHUFFLE(2, 2, 2, 2));
+ const __m128 va3333 = _mm_shuffle_ps(va0123, va0123, _MM_SHUFFLE(3, 3, 3, 3));
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, _mm_mul_ps(va0000, vb0123));
+ vacc1x0123 = _mm_add_ps(vacc1x0123, _mm_mul_ps(va1111, vb0123));
+ vacc2x0123 = _mm_add_ps(vacc2x0123, _mm_mul_ps(va2222, vb0123));
+ vacc3x0123 = _mm_add_ps(vacc3x0123, _mm_mul_ps(va3333, vb0123));
+ vacc0x4567 = _mm_add_ps(vacc0x4567, _mm_mul_ps(va0000, vb4567));
+ vacc1x4567 = _mm_add_ps(vacc1x4567, _mm_mul_ps(va1111, vb4567));
+ vacc2x4567 = _mm_add_ps(vacc2x4567, _mm_mul_ps(va2222, vb4567));
+ vacc3x4567 = _mm_add_ps(vacc3x4567, _mm_mul_ps(va3333, vb4567));
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+ vacc2x0123 = _mm_min_ps(vacc2x0123, vmax);
+ vacc3x0123 = _mm_min_ps(vacc3x0123, vmax);
+ vacc0x4567 = _mm_min_ps(vacc0x4567, vmax);
+ vacc1x4567 = _mm_min_ps(vacc1x4567, vmax);
+ vacc2x4567 = _mm_min_ps(vacc2x4567, vmax);
+ vacc3x4567 = _mm_min_ps(vacc3x4567, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+ vacc2x0123 = _mm_max_ps(vacc2x0123, vmin);
+ vacc3x0123 = _mm_max_ps(vacc3x0123, vmin);
+ vacc0x4567 = _mm_max_ps(vacc0x4567, vmin);
+ vacc1x4567 = _mm_max_ps(vacc1x4567, vmin);
+ vacc2x4567 = _mm_max_ps(vacc2x4567, vmin);
+ vacc3x4567 = _mm_max_ps(vacc3x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c3 + 4, vacc3x4567);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c2 + 4, vacc2x4567);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c1 + 4, vacc1x4567);
+ _mm_storeu_ps(c0, vacc0x0123);
+ _mm_storeu_ps(c0 + 4, vacc0x4567);
+
+ a = (const float*) ((uintptr_t) a - kc * 4);
+
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ _mm_storeu_ps(c3, vacc3x0123);
+ _mm_storeu_ps(c2, vacc2x0123);
+ _mm_storeu_ps(c1, vacc1x0123);
+ _mm_storeu_ps(c0, vacc0x0123);
+
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+
+ c3 += 4;
+ c2 += 4;
+ c1 += 4;
+ c0 += 4;
+ }
+ if (nc & 2) {
+ _mm_storel_pi((__m64*) c3, vacc3x0123);
+ _mm_storel_pi((__m64*) c2, vacc2x0123);
+ _mm_storel_pi((__m64*) c1, vacc1x0123);
+ _mm_storel_pi((__m64*) c0, vacc0x0123);
+
+ vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
+ vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+
+ c3 += 2;
+ c2 += 2;
+ c1 += 2;
+ c0 += 2;
+ }
+ if (nc & 1) {
+ _mm_store_ss(c3, vacc3x0123);
+ _mm_store_ss(c2, vacc2x0123);
+ _mm_store_ss(c1, vacc1x0123);
+ _mm_store_ss(c0, vacc0x0123);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/8x8-neon.c b/src/f32-ppmm/8x8-neon.c
new file mode 100644
index 0000000..06a8b6d
--- /dev/null
+++ b/src/f32-ppmm/8x8-neon.c
@@ -0,0 +1,236 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_8x8__neon(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 8);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ c3 = c2;
+ }
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ c4 = c3;
+ }
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 6) {
+ c5 = c4;
+ }
+ float* c6 = (float*) ((uintptr_t) c5 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 6) {
+ c6 = c5;
+ }
+ float* c7 = (float*) ((uintptr_t) c6 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 8) {
+ c7 = c6;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc4x0123 = vacc0x0123;
+ float32x4_t vacc4x4567 = vacc0x4567;
+ float32x4_t vacc5x0123 = vacc0x0123;
+ float32x4_t vacc5x4567 = vacc0x4567;
+ float32x4_t vacc6x0123 = vacc0x0123;
+ float32x4_t vacc6x4567 = vacc0x4567;
+ float32x4_t vacc7x0123 = vacc0x0123;
+ float32x4_t vacc7x4567 = vacc0x4567;
+
+ size_t k = kc;
+ do {
+ const float32x4_t va0123 = vld1q_f32(a); a += 4;
+ const float32x4_t va4567 = vld1q_f32(a); a += 4;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123, vget_low_f32(va0123), 0);
+ vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123, vget_low_f32(va0123), 1);
+ vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123, vget_high_f32(va0123), 0);
+ vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123, vget_high_f32(va0123), 1);
+ vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123, vget_low_f32(va4567), 0);
+ vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123, vget_low_f32(va4567), 1);
+ vacc6x0123 = vmlaq_lane_f32(vacc6x0123, vb0123, vget_high_f32(va4567), 0);
+ vacc7x0123 = vmlaq_lane_f32(vacc7x0123, vb0123, vget_high_f32(va4567), 1);
+ vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567, vget_low_f32(va0123), 0);
+ vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567, vget_low_f32(va0123), 1);
+ vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567, vget_high_f32(va0123), 0);
+ vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567, vget_high_f32(va0123), 1);
+ vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567, vget_low_f32(va4567), 0);
+ vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567, vget_low_f32(va4567), 1);
+ vacc6x4567 = vmlaq_lane_f32(vacc6x4567, vb4567, vget_high_f32(va4567), 0);
+ vacc7x4567 = vmlaq_lane_f32(vacc7x4567, vb4567, vget_high_f32(va4567), 1);
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc5x0123 = vminq_f32(vacc5x0123, vmax);
+ vacc6x0123 = vminq_f32(vacc6x0123, vmax);
+ vacc7x0123 = vminq_f32(vacc7x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+ vacc5x4567 = vminq_f32(vacc5x4567, vmax);
+ vacc6x4567 = vminq_f32(vacc6x4567, vmax);
+ vacc7x4567 = vminq_f32(vacc7x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc5x0123 = vmaxq_f32(vacc5x0123, vmin);
+ vacc6x0123 = vmaxq_f32(vacc6x0123, vmin);
+ vacc7x0123 = vmaxq_f32(vacc7x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+ vacc5x4567 = vmaxq_f32(vacc5x4567, vmin);
+ vacc6x4567 = vmaxq_f32(vacc6x4567, vmin);
+ vacc7x4567 = vmaxq_f32(vacc7x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c7, vacc7x0123);
+ vst1q_f32(c7 + 4, vacc7x4567);
+ c7 = (float*) ((uintptr_t) c7 + cn_stride);
+ vst1q_f32(c6, vacc6x0123);
+ vst1q_f32(c6 + 4, vacc6x4567);
+ c6 = (float*) ((uintptr_t) c6 + cn_stride);
+ vst1q_f32(c5, vacc5x0123);
+ vst1q_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float*) ((uintptr_t) a - kc * 8);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c7, vacc7x0123); c7 += 4;
+ vst1q_f32(c6, vacc6x0123); c6 += 4;
+ vst1q_f32(c5, vacc5x0123); c5 += 4;
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc7x0123 = vacc7x4567;
+ vacc6x0123 = vacc6x4567;
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc7x01 = vget_low_f32(vacc7x0123);
+ float32x2_t vacc6x01 = vget_low_f32(vacc6x0123);
+ float32x2_t vacc5x01 = vget_low_f32(vacc5x0123);
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c7, vacc7x01); c7 += 2;
+ vst1_f32(c6, vacc6x01); c6 += 2;
+ vst1_f32(c5, vacc5x01); c5 += 2;
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc7x01 = vget_high_f32(vacc7x0123);
+ vacc6x01 = vget_high_f32(vacc6x0123);
+ vacc5x01 = vget_high_f32(vacc5x0123);
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c7, vacc7x01, 0);
+ vst1_lane_f32(c6, vacc6x01, 0);
+ vst1_lane_f32(c5, vacc5x01, 0);
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/8x8-neonfma.c b/src/f32-ppmm/8x8-neonfma.c
new file mode 100644
index 0000000..0eb67af
--- /dev/null
+++ b/src/f32-ppmm/8x8-neonfma.c
@@ -0,0 +1,264 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-ppmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_8x8__neonfma(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 8);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ c3 = c2;
+ }
+ float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ c4 = c3;
+ }
+ float* c5 = (float*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 6) {
+ c5 = c4;
+ }
+ float* c6 = (float*) ((uintptr_t) c5 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 6) {
+ c6 = c5;
+ }
+ float* c7 = (float*) ((uintptr_t) c6 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 8) {
+ c7 = c6;
+ }
+
+ do {
+ float32x4_t vacc0x0123 = vld1q_f32(w); w += 4;
+ float32x4_t vacc0x4567 = vld1q_f32(w); w += 4;
+ float32x4_t vacc1x0123 = vacc0x0123;
+ float32x4_t vacc1x4567 = vacc0x4567;
+ float32x4_t vacc2x0123 = vacc0x0123;
+ float32x4_t vacc2x4567 = vacc0x4567;
+ float32x4_t vacc3x0123 = vacc0x0123;
+ float32x4_t vacc3x4567 = vacc0x4567;
+ float32x4_t vacc4x0123 = vacc0x0123;
+ float32x4_t vacc4x4567 = vacc0x4567;
+ float32x4_t vacc5x0123 = vacc0x0123;
+ float32x4_t vacc5x4567 = vacc0x4567;
+ float32x4_t vacc6x0123 = vacc0x0123;
+ float32x4_t vacc6x4567 = vacc0x4567;
+ float32x4_t vacc7x0123 = vacc0x0123;
+ float32x4_t vacc7x4567 = vacc0x4567;
+
+ size_t k = kc;
+ do {
+ const float32x4_t va0123 = vld1q_f32(a); a += 4;
+ const float32x4_t va4567 = vld1q_f32(a); a += 4;
+
+ const float32x4_t vb0123 = vld1q_f32(w); w += 4;
+ const float32x4_t vb4567 = vld1q_f32(w); w += 4;
+
+ #ifdef __aarch64__
+ vacc0x0123 = vfmaq_laneq_f32(vacc0x0123, vb0123, va0123, 0);
+ vacc1x0123 = vfmaq_laneq_f32(vacc1x0123, vb0123, va0123, 1);
+ vacc2x0123 = vfmaq_laneq_f32(vacc2x0123, vb0123, va0123, 2);
+ vacc3x0123 = vfmaq_laneq_f32(vacc3x0123, vb0123, va0123, 3);
+ vacc4x0123 = vfmaq_laneq_f32(vacc4x0123, vb0123, va4567, 0);
+ vacc5x0123 = vfmaq_laneq_f32(vacc5x0123, vb0123, va4567, 1);
+ vacc6x0123 = vfmaq_laneq_f32(vacc6x0123, vb0123, va4567, 2);
+ vacc7x0123 = vfmaq_laneq_f32(vacc7x0123, vb0123, va4567, 3);
+ vacc0x4567 = vfmaq_laneq_f32(vacc0x4567, vb4567, va0123, 0);
+ vacc1x4567 = vfmaq_laneq_f32(vacc1x4567, vb4567, va0123, 1);
+ vacc2x4567 = vfmaq_laneq_f32(vacc2x4567, vb4567, va0123, 2);
+ vacc3x4567 = vfmaq_laneq_f32(vacc3x4567, vb4567, va0123, 3);
+ vacc4x4567 = vfmaq_laneq_f32(vacc4x4567, vb4567, va4567, 0);
+ vacc5x4567 = vfmaq_laneq_f32(vacc5x4567, vb4567, va4567, 1);
+ vacc6x4567 = vfmaq_laneq_f32(vacc6x4567, vb4567, va4567, 2);
+ vacc7x4567 = vfmaq_laneq_f32(vacc7x4567, vb4567, va4567, 3);
+ #else
+ const float32x4_t va0000 = vdupq_lane_f32(vget_low_f32(va0123), 0);
+ const float32x4_t va1111 = vdupq_lane_f32(vget_low_f32(va0123), 1);
+ const float32x4_t va2222 = vdupq_lane_f32(vget_high_f32(va0123), 0);
+ const float32x4_t va3333 = vdupq_lane_f32(vget_high_f32(va0123), 1);
+ const float32x4_t va4444 = vdupq_lane_f32(vget_low_f32(va4567), 0);
+ const float32x4_t va5555 = vdupq_lane_f32(vget_low_f32(va4567), 1);
+ const float32x4_t va6666 = vdupq_lane_f32(vget_high_f32(va4567), 0);
+ const float32x4_t va7777 = vdupq_lane_f32(vget_high_f32(va4567), 1);
+
+ vacc0x0123 = vfmaq_f32(vacc0x0123, va0000, vb0123);
+ vacc1x0123 = vfmaq_f32(vacc1x0123, va1111, vb0123);
+ vacc2x0123 = vfmaq_f32(vacc2x0123, va2222, vb0123);
+ vacc3x0123 = vfmaq_f32(vacc3x0123, va3333, vb0123);
+ vacc4x0123 = vfmaq_f32(vacc4x0123, va4444, vb0123);
+ vacc5x0123 = vfmaq_f32(vacc5x0123, va5555, vb0123);
+ vacc6x0123 = vfmaq_f32(vacc6x0123, va6666, vb0123);
+ vacc7x0123 = vfmaq_f32(vacc7x0123, va7777, vb0123);
+ vacc0x4567 = vfmaq_f32(vacc0x4567, va0000, vb4567);
+ vacc1x4567 = vfmaq_f32(vacc1x4567, va1111, vb4567);
+ vacc2x4567 = vfmaq_f32(vacc2x4567, va2222, vb4567);
+ vacc3x4567 = vfmaq_f32(vacc3x4567, va3333, vb4567);
+ vacc4x4567 = vfmaq_f32(vacc4x4567, va4444, vb4567);
+ vacc5x4567 = vfmaq_f32(vacc5x4567, va5555, vb4567);
+ vacc6x4567 = vfmaq_f32(vacc6x4567, va6666, vb4567);
+ vacc7x4567 = vfmaq_f32(vacc7x4567, va7777, vb4567);
+ #endif
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+ vacc2x0123 = vminq_f32(vacc2x0123, vmax);
+ vacc3x0123 = vminq_f32(vacc3x0123, vmax);
+ vacc4x0123 = vminq_f32(vacc4x0123, vmax);
+ vacc5x0123 = vminq_f32(vacc5x0123, vmax);
+ vacc6x0123 = vminq_f32(vacc6x0123, vmax);
+ vacc7x0123 = vminq_f32(vacc7x0123, vmax);
+ vacc0x4567 = vminq_f32(vacc0x4567, vmax);
+ vacc1x4567 = vminq_f32(vacc1x4567, vmax);
+ vacc2x4567 = vminq_f32(vacc2x4567, vmax);
+ vacc3x4567 = vminq_f32(vacc3x4567, vmax);
+ vacc4x4567 = vminq_f32(vacc4x4567, vmax);
+ vacc5x4567 = vminq_f32(vacc5x4567, vmax);
+ vacc6x4567 = vminq_f32(vacc6x4567, vmax);
+ vacc7x4567 = vminq_f32(vacc7x4567, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+ vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
+ vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
+ vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
+ vacc5x0123 = vmaxq_f32(vacc5x0123, vmin);
+ vacc6x0123 = vmaxq_f32(vacc6x0123, vmin);
+ vacc7x0123 = vmaxq_f32(vacc7x0123, vmin);
+ vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
+ vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
+ vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
+ vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
+ vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
+ vacc5x4567 = vmaxq_f32(vacc5x4567, vmin);
+ vacc6x4567 = vmaxq_f32(vacc6x4567, vmin);
+ vacc7x4567 = vmaxq_f32(vacc7x4567, vmin);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1q_f32(c7, vacc7x0123);
+ vst1q_f32(c7 + 4, vacc7x4567);
+ c7 = (float*) ((uintptr_t) c7 + cn_stride);
+ vst1q_f32(c6, vacc6x0123);
+ vst1q_f32(c6 + 4, vacc6x4567);
+ c6 = (float*) ((uintptr_t) c6 + cn_stride);
+ vst1q_f32(c5, vacc5x0123);
+ vst1q_f32(c5 + 4, vacc5x4567);
+ c5 = (float*) ((uintptr_t) c5 + cn_stride);
+ vst1q_f32(c4, vacc4x0123);
+ vst1q_f32(c4 + 4, vacc4x4567);
+ c4 = (float*) ((uintptr_t) c4 + cn_stride);
+ vst1q_f32(c3, vacc3x0123);
+ vst1q_f32(c3 + 4, vacc3x4567);
+ c3 = (float*) ((uintptr_t) c3 + cn_stride);
+ vst1q_f32(c2, vacc2x0123);
+ vst1q_f32(c2 + 4, vacc2x4567);
+ c2 = (float*) ((uintptr_t) c2 + cn_stride);
+ vst1q_f32(c1, vacc1x0123);
+ vst1q_f32(c1 + 4, vacc1x4567);
+ c1 = (float*) ((uintptr_t) c1 + cn_stride);
+ vst1q_f32(c0, vacc0x0123);
+ vst1q_f32(c0 + 4, vacc0x4567);
+ c0 = (float*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const float*) ((uintptr_t) a - kc * 8);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_f32(c7, vacc7x0123); c7 += 4;
+ vst1q_f32(c6, vacc6x0123); c6 += 4;
+ vst1q_f32(c5, vacc5x0123); c5 += 4;
+ vst1q_f32(c4, vacc4x0123); c4 += 4;
+ vst1q_f32(c3, vacc3x0123); c3 += 4;
+ vst1q_f32(c2, vacc2x0123); c2 += 4;
+ vst1q_f32(c1, vacc1x0123); c1 += 4;
+ vst1q_f32(c0, vacc0x0123); c0 += 4;
+
+ vacc7x0123 = vacc7x4567;
+ vacc6x0123 = vacc6x4567;
+ vacc5x0123 = vacc5x4567;
+ vacc4x0123 = vacc4x4567;
+ vacc3x0123 = vacc3x4567;
+ vacc2x0123 = vacc2x4567;
+ vacc1x0123 = vacc1x4567;
+ vacc0x0123 = vacc0x4567;
+ }
+ float32x2_t vacc7x01 = vget_low_f32(vacc7x0123);
+ float32x2_t vacc6x01 = vget_low_f32(vacc6x0123);
+ float32x2_t vacc5x01 = vget_low_f32(vacc5x0123);
+ float32x2_t vacc4x01 = vget_low_f32(vacc4x0123);
+ float32x2_t vacc3x01 = vget_low_f32(vacc3x0123);
+ float32x2_t vacc2x01 = vget_low_f32(vacc2x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ if (nc & 2) {
+ vst1_f32(c7, vacc7x01); c7 += 2;
+ vst1_f32(c6, vacc6x01); c6 += 2;
+ vst1_f32(c5, vacc5x01); c5 += 2;
+ vst1_f32(c4, vacc4x01); c4 += 2;
+ vst1_f32(c3, vacc3x01); c3 += 2;
+ vst1_f32(c2, vacc2x01); c2 += 2;
+ vst1_f32(c1, vacc1x01); c1 += 2;
+ vst1_f32(c0, vacc0x01); c0 += 2;
+
+ vacc7x01 = vget_high_f32(vacc7x0123);
+ vacc6x01 = vget_high_f32(vacc6x0123);
+ vacc5x01 = vget_high_f32(vacc5x0123);
+ vacc4x01 = vget_high_f32(vacc4x0123);
+ vacc3x01 = vget_high_f32(vacc3x0123);
+ vacc2x01 = vget_high_f32(vacc2x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ }
+ if (nc & 1) {
+ vst1_lane_f32(c7, vacc7x01, 0);
+ vst1_lane_f32(c6, vacc6x01, 0);
+ vst1_lane_f32(c5, vacc5x01, 0);
+ vst1_lane_f32(c4, vacc4x01, 0);
+ vst1_lane_f32(c3, vacc3x01, 0);
+ vst1_lane_f32(c2, vacc2x01, 0);
+ vst1_lane_f32(c1, vacc1x01, 0);
+ vst1_lane_f32(c0, vacc0x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/neon.c.in b/src/f32-ppmm/neon.c.in
new file mode 100644
index 0000000..4c66811
--- /dev/null
+++ b/src/f32-ppmm/neon.c.in
@@ -0,0 +1,138 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert MR % 4 == 0
+$assert NR % 4 == 0
+$IDLETTERS = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ $for N in range(0, NR, 4):
+ float32x4_t vacc0x${IDLETTERS[N:N+4]} = vld1q_f32(w); w += 4;
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ float32x4_t vacc${M}x${IDLETTERS[N:N+4]} = vacc0x${IDLETTERS[N:N+4]};
+
+ size_t k = kc;
+ do {
+ $for M in range(0, MR, 4):
+ const float32x4_t va${IDLETTERS[M:M+4]} = vld1q_f32(a); a += 4;
+
+ $for N in range(0, NR, 4):
+ const float32x4_t vb${IDLETTERS[N:N+4]} = vld1q_f32(w); w += 4;
+
+ $if FMA:
+ #ifdef __aarch64__
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vfmaq_laneq_f32(vacc${M}x${IDLETTERS[N:N+4]}, vb${IDLETTERS[N:N+4]}, va${IDLETTERS[M&-4:4+M&-4]}, ${M % 4});
+ #else
+ $for M in range(MR):
+ $VGET_PART_F32 = "vget_low_f32" if M % 4 < 2 else "vget_high_f32"
+ $MMMM = str(M) * 4
+ const float32x4_t va${MMMM} = vdupq_lane_f32(${VGET_PART_F32}(va${IDLETTERS[M&-4:4+M&-4]}), ${M % 2});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ $MMMM = str(M) * 4
+ vacc${M}x${IDLETTERS[N:N+4]} = vfmaq_f32(vacc${M}x${IDLETTERS[N:N+4]}, va${MMMM}, vb${IDLETTERS[N:N+4]});
+ #endif
+ $else:
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ $VGET_PART_F32 = "vget_low_f32" if M % 4 < 2 else "vget_high_f32"
+ vacc${M}x${IDLETTERS[N:N+4]} = vmlaq_lane_f32(vacc${M}x${IDLETTERS[N:N+4]}, vb${IDLETTERS[N:N+4]}, ${VGET_PART_F32}(va${IDLETTERS[M&-4:4+M&-4]}), ${M % 2});
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vminq_f32(vacc${M}x${IDLETTERS[N:N+4]}, vmax);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = vmaxq_f32(vacc${M}x${IDLETTERS[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ vst1q_f32(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, NR, 4):
+ vst1q_f32(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ a = (const float*) ((uintptr_t) a - kc * ${MR});
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ $if LOG2N == 1:
+ $for M in reversed(range(MR)):
+ float32x2_t vacc${M}x01 = vget_low_f32(vacc${M}x0123);
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << LOG2N, 4):
+ vst1q_f32(c${M}, vacc${M}x${IDLETTERS[N:N+4]}); c${M} += 4;
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = vacc${M}x${IDLETTERS[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ vst1_f32(c${M}, vacc${M}x01); c${M} += 2;
+
+ $for M in reversed(range(MR)):
+ vacc${M}x01 = vget_high_f32(vacc${M}x0123);
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ vst1_lane_f32(c${M}, vacc${M}x01, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/psimd.c.in b/src/f32-ppmm/psimd.c.in
new file mode 100644
index 0000000..721ad5f
--- /dev/null
+++ b/src/f32-ppmm/psimd.c.in
@@ -0,0 +1,137 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert MR % 4 == 0
+$assert NR % 4 == 0
+$IDLETTERS = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_${MR}x${NR}__psimd(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ psimd_f32 vacc0x${IDLETTERS[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ psimd_f32 vacc0x${IDLETTERS[N:N+4]} = psimd_load_f32(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ psimd_f32 vacc${M}x${IDLETTERS[N:N+4]} = vacc0x${IDLETTERS[N:N+4]};
+ w += ${NR};
+
+ size_t k = kc;
+ do {
+ const psimd_f32 va${IDLETTERS[0:4]} = psimd_load_f32(a);
+ $for M in range(4, MR, 4):
+ const psimd_f32 va${IDLETTERS[M:M+4]} = psimd_load_f32(a + ${M});
+ a += ${MR};
+
+ const psimd_f32 vb${IDLETTERS[0:4]} = psimd_load_f32(w);
+ $for N in range(4, NR, 4):
+ const psimd_f32 vb${IDLETTERS[N:N+4]} = psimd_load_f32(w + ${N});
+ w += ${NR};
+
+ $for M in range(MR):
+ $MMMM = str(M) * 4
+ const psimd_f32 va${MMMM} = psimd_splat${M % 4}_f32(va${IDLETTERS[M&-4:4+M&-4]});
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ $MMMM = str(M) * 4
+ vacc${M}x${IDLETTERS[N:N+4]} = psimd_qfma_f32(vacc${M}x${IDLETTERS[N:N+4]}, va${MMMM}, vb${IDLETTERS[N:N+4]});
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = psimd_min_f32(vacc${M}x${IDLETTERS[N:N+4]}, vmax);
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = psimd_max_f32(vacc${M}x${IDLETTERS[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, NR, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+
+ a = (const float*) ((uintptr_t) a - kc * ${MR});
+
+ $for M in reversed(range(MR)):
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ psimd_store_f32(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ psimd_store_f32(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = vacc${M}x${IDLETTERS[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ psimd_store2_f32(c${M}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${IDLETTERS[0:4]} = psimd_concat_hi_f32(vacc${M}x${IDLETTERS[0:4]}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ psimd_store1_f32(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/scalar.c.in b/src/f32-ppmm/scalar.c.in
new file mode 100644
index 0000000..9350695
--- /dev/null
+++ b/src/f32-ppmm/scalar.c.in
@@ -0,0 +1,114 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_${MR}x${NR}__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ $for N in range(NR):
+ float vacc0x${N} = w[${N}];
+ $for M in range(1, MR):
+ $for N in range(NR):
+ float vacc${M}x${N} = vacc0x${N};
+ w += ${NR};
+
+ size_t k = kc;
+ do {
+ $for M in range(MR):
+ const float va${M} = a[${M}];
+ a += ${MR};
+
+ $for N in range(NR):
+ const float vb${N} = w[${N}];
+ w += ${NR};
+
+ $for N in range(NR):
+ $for M in range(MR):
+ vacc${M}x${N} += va${M} * vb${N};
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const float vmax = params->scalar.max;
+ $for N in range(NR):
+ $for M in range(MR):
+ vacc${M}x${N} = math_min_f32(vacc${M}x${N}, vmax);
+
+ const float vmin = params->scalar.min;
+ $for N in range(NR):
+ $for M in range(MR):
+ vacc${M}x${N} = math_max_f32(vacc${M}x${N}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ $for N in range(NR):
+ c${M}[${N}] = vacc${M}x${N};
+
+ a = (const float*) ((uintptr_t) a - kc * ${MR});
+
+ $for M in reversed(range(MR)):
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N != 0:
+ $for M in reversed(range(MR)):
+ $for N in range(1 << LOG2N):
+ c${M}[${N}] = vacc${M}x${N};
+
+ $for M in reversed(range(MR)):
+ $for N in range(1 << (LOG2N - 1)):
+ vacc${M}x${N} = vacc${M}x${N + (1 << LOG2N)};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $else:
+ $for M in reversed(range(MR)):
+ *c${M} = vacc${M}x0;
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-ppmm/sse.c.in b/src/f32-ppmm/sse.c.in
new file mode 100644
index 0000000..d14ca00
--- /dev/null
+++ b/src/f32-ppmm/sse.c.in
@@ -0,0 +1,137 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert MR % 4 == 0
+$assert NR % 4 == 0
+$IDLETTERS = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/ppmm.h>
+
+
+void xnn_f32_ppmm_ukernel_${MR}x${NR}__sse(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float*restrict a,
+ const float*restrict w,
+ float*restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= ${MR});
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(float) == 0);
+
+ float* c0 = c;
+ $for M in range(1, MR):
+ float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(mr <= ${M}) {
+ c${M} = c${M-1};
+ }
+ $elif M + 1 == MR:
+ if XNN_UNPREDICTABLE(mr != ${M+1}) {
+ c${M} = c${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(mr < ${M+1}) {
+ c${M} = c${M-1};
+ }
+
+ do {
+ __m128 vacc0x${IDLETTERS[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ __m128 vacc0x${IDLETTERS[N:N+4]} = _mm_load_ps(w + ${N});
+ $for M in range(1, MR):
+ $for N in range(0, NR, 4):
+ __m128 vacc${M}x${IDLETTERS[N:N+4]} = vacc0x${IDLETTERS[N:N+4]};
+ w += ${NR};
+
+ size_t k = kc;
+ do {
+ const __m128 va${IDLETTERS[0:4]} = _mm_load_ps(a);
+ $for M in range(4, MR, 4):
+ const __m128 va${IDLETTERS[M:M+4]} = _mm_load_ps(a + ${M});
+ a += ${MR};
+
+ const __m128 vb${IDLETTERS[0:4]} = _mm_load_ps(w);
+ $for N in range(4, NR, 4):
+ const __m128 vb${IDLETTERS[N:N+4]} = _mm_load_ps(w + ${N});
+ w += ${NR};
+
+ $for M in range(MR):
+ $MMMM = str(M) * 4
+ const __m128 va${MMMM} = _mm_shuffle_ps(va${IDLETTERS[M&-4:4+M&-4]}, va${IDLETTERS[M&-4:4+M&-4]}, _MM_SHUFFLE(${M % 4}, ${M % 4}, ${M % 4}, ${M % 4}));
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ $MMMM = str(M) * 4
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_add_ps(vacc${M}x${IDLETTERS[N:N+4]}, _mm_mul_ps(va${MMMM}, vb${IDLETTERS[N:N+4]}));
+
+ k -= sizeof(float);
+ } while (k != 0);
+
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_min_ps(vacc${M}x${IDLETTERS[N:N+4]}, vmax);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${IDLETTERS[N:N+4]} = _mm_max_ps(vacc${M}x${IDLETTERS[N:N+4]}, vmin);
+
+ if XNN_LIKELY(nc >= ${NR}) {
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, NR, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+
+ a = (const float*) ((uintptr_t) a - kc * ${MR});
+
+ $for M in reversed(range(MR)):
+ c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
+
+ nc -= ${NR};
+ } else {
+ $for LOG2N in reversed(range(NR.bit_length())):
+ $if NR != 1 << LOG2N:
+ if (nc & ${1 << LOG2N}) {
+ $if LOG2N >= 2:
+ $for M in reversed(range(MR)):
+ _mm_storeu_ps(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ $for N in range(4, 1 << LOG2N, 4):
+ _mm_storeu_ps(c${M} + ${N}, vacc${M}x${IDLETTERS[N:N+4]});
+
+ $for M in reversed(range(MR)):
+ $for N in range(0, 1 << (LOG2N - 1), 4):
+ vacc${M}x${IDLETTERS[N:N+4]} = vacc${M}x${IDLETTERS[N + (1 << LOG2N):N + (1 << LOG2N)+4]};
+
+ $for M in reversed(range(MR)):
+ c${M} += ${1 << LOG2N};
+ $elif LOG2N == 1:
+ $for M in reversed(range(MR)):
+ _mm_storel_pi((__m64*) c${M}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ vacc${M}x${IDLETTERS[0:4]} = _mm_movehl_ps(vacc${M}x${IDLETTERS[0:4]}, vacc${M}x${IDLETTERS[0:4]});
+
+ $for M in reversed(range(MR)):
+ c${M} += 2;
+ $elif LOG2N == 0:
+ $for M in reversed(range(MR)):
+ _mm_store_ss(c${M}, vacc${M}x${IDLETTERS[0:4]});
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/f32-prelu/x4-psimd.c b/src/f32-prelu/x4-psimd.c
new file mode 100644
index 0000000..3c0772d
--- /dev/null
+++ b/src/f32-prelu/x4-psimd.c
@@ -0,0 +1,131 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/prelu.h>
+
+
+void xnn_f32_prelu_ukernel_x4__psimd(
+ size_t mr,
+ size_t n,
+ const float* x,
+ size_t x_stride,
+ const float* w,
+ float* y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float* x0 = x;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ if (mr < 2) {
+ x1 = x0;
+ }
+ const float* x2 = (const float*) ((uintptr_t) x1 + x_stride);
+ if (mr <= 2) {
+ x2 = x1;
+ }
+ const float* x3 = (const float*) ((uintptr_t) x2 + x_stride);
+ if (mr != 4) {
+ x3 = x2;
+ }
+
+ float* y0 = y;
+ float* y1 = (float*) ((uintptr_t) y0 + y_stride);
+ if (mr < 2) {
+ y1 = y0;
+ }
+ float* y2 = (float*) ((uintptr_t) y1 + y_stride);
+ if (mr <= 2) {
+ y2 = y1;
+ }
+ float* y3 = (float*) ((uintptr_t) y2 + y_stride);
+ if (mr != 4) {
+ y3 = y2;
+ }
+
+ const psimd_f32 vy_min = psimd_load_splat_f32(¶ms->scalar.min);
+ const psimd_f32 vy_max = psimd_load_splat_f32(¶ms->scalar.max);
+ for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
+ const psimd_f32 vw = psimd_load_f32(w);
+ w += 4;
+ const psimd_f32 vx0 = psimd_load_f32(x0);
+ x0 += 4;
+ const psimd_f32 vx1 = psimd_load_f32(x1);
+ x1 += 4;
+ const psimd_f32 vx2 = psimd_load_f32(x2);
+ x2 += 4;
+ const psimd_f32 vx3 = psimd_load_f32(x3);
+ x3 += 4;
+
+ const psimd_f32 vacc0 = psimd_signblend_f32(vx0, vx0 * vw, vx0);
+ const psimd_f32 vacc1 = psimd_signblend_f32(vx1, vx1 * vw, vx1);
+ const psimd_f32 vacc2 = psimd_signblend_f32(vx2, vx2 * vw, vx2);
+ const psimd_f32 vacc3 = psimd_signblend_f32(vx3, vx3 * vw, vx3);
+
+ const psimd_f32 vy0 = psimd_min_f32(psimd_max_f32(vacc0, vy_min), vy_max);
+ const psimd_f32 vy1 = psimd_min_f32(psimd_max_f32(vacc1, vy_min), vy_max);
+ const psimd_f32 vy2 = psimd_min_f32(psimd_max_f32(vacc2, vy_min), vy_max);
+ const psimd_f32 vy3 = psimd_min_f32(psimd_max_f32(vacc3, vy_min), vy_max);
+
+ psimd_store_f32(y0, vy0);
+ y0 += 4;
+ psimd_store_f32(y1, vy1);
+ y1 += 4;
+ psimd_store_f32(y2, vy2);
+ y2 += 4;
+ psimd_store_f32(y3, vy3);
+ y3 += 4;
+ }
+ if (n != 0) {
+ const psimd_f32 vw = psimd_load_f32(w);
+ const psimd_f32 vx0 = psimd_load_f32(x0);
+ const psimd_f32 vx1 = psimd_load_f32(x1);
+ const psimd_f32 vx2 = psimd_load_f32(x2);
+ const psimd_f32 vx3 = psimd_load_f32(x3);
+
+ const psimd_f32 vacc0 = psimd_signblend_f32(vx0, vx0 * vw, vx0);
+ const psimd_f32 vacc1 = psimd_signblend_f32(vx1, vx1 * vw, vx1);
+ const psimd_f32 vacc2 = psimd_signblend_f32(vx2, vx2 * vw, vx2);
+ const psimd_f32 vacc3 = psimd_signblend_f32(vx3, vx3 * vw, vx3);
+
+ psimd_f32 vy0 = psimd_min_f32(psimd_max_f32(vacc0, vy_min), vy_max);
+ psimd_f32 vy1 = psimd_min_f32(psimd_max_f32(vacc1, vy_min), vy_max);
+ psimd_f32 vy2 = psimd_min_f32(psimd_max_f32(vacc2, vy_min), vy_max);
+ psimd_f32 vy3 = psimd_min_f32(psimd_max_f32(vacc3, vy_min), vy_max);
+
+ if (n & 2 * sizeof(float)) {
+ psimd_store2_f32(y0, vy0);
+ y0 += 2;
+ psimd_store2_f32(y1, vy1);
+ y1 += 2;
+ psimd_store2_f32(y2, vy2);
+ y2 += 2;
+ psimd_store2_f32(y3, vy3);
+ y3 += 2;
+
+ vy0 = psimd_concat_hi_f32(vy0, vy0);
+ vy1 = psimd_concat_hi_f32(vy1, vy1);
+ vy2 = psimd_concat_hi_f32(vy2, vy2);
+ vy3 = psimd_concat_hi_f32(vy3, vy3);
+ }
+ if (n & 1 * sizeof(float)) {
+ psimd_store1_f32(y0, vy0);
+ psimd_store1_f32(y1, vy1);
+ psimd_store1_f32(y2, vy2);
+ psimd_store1_f32(y3, vy3);
+ }
+ }
+}
diff --git a/src/f32-prelu/x4-scalar.c b/src/f32-prelu/x4-scalar.c
new file mode 100644
index 0000000..168b86b
--- /dev/null
+++ b/src/f32-prelu/x4-scalar.c
@@ -0,0 +1,90 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <math.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/prelu.h>
+#include <xnnpack/math.h>
+
+
+void xnn_f32_prelu_ukernel_x4__scalar(
+ size_t mr,
+ size_t n,
+ const float* x,
+ size_t x_stride,
+ const float* w,
+ float* y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float* x0 = x;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ if (mr < 2) {
+ x1 = x0;
+ }
+ const float* x2 = (const float*) ((uintptr_t) x1 + x_stride);
+ if (mr <= 2) {
+ x2 = x1;
+ }
+ const float* x3 = (const float*) ((uintptr_t) x2 + x_stride);
+ if (mr != 4) {
+ x3 = x2;
+ }
+
+ float* y0 = y;
+ float* y1 = (float*) ((uintptr_t) y0 + y_stride);
+ if (mr < 2) {
+ y1 = y0;
+ }
+ float* y2 = (float*) ((uintptr_t) y1 + y_stride);
+ if (mr <= 2) {
+ y2 = y1;
+ }
+ float* y3 = (float*) ((uintptr_t) y2 + y_stride);
+ if (mr != 4) {
+ y3 = y2;
+ }
+
+ const float vy_min = params->scalar.min;
+ const float vy_max = params->scalar.max;
+ do {
+ const float vw = *w++;
+ const float vx0 = *x0++;
+ const float vx1 = *x1++;
+ const float vx2 = *x2++;
+ const float vx3 = *x3++;
+
+ float vy0 = signbit(vx0) ? vx0 * vw : vx0;
+ float vy1 = signbit(vx1) ? vx1 * vw : vx1;
+ float vy2 = signbit(vx2) ? vx2 * vw : vx2;
+ float vy3 = signbit(vx3) ? vx3 * vw : vx3;
+
+ vy0 = math_max_f32(vy0, vy_min);
+ vy1 = math_max_f32(vy1, vy_min);
+ vy2 = math_max_f32(vy2, vy_min);
+ vy3 = math_max_f32(vy3, vy_min);
+
+ vy0 = math_min_f32(vy0, vy_max);
+ vy1 = math_min_f32(vy1, vy_max);
+ vy2 = math_min_f32(vy2, vy_max);
+ vy3 = math_min_f32(vy3, vy_max);
+
+ *y0++ = vy0;
+ *y1++ = vy1;
+ *y2++ = vy2;
+ *y3++ = vy3;
+
+ n -= sizeof(float);
+ } while (n != 0);
+}
diff --git a/src/f32-prelu/x4-sse.c b/src/f32-prelu/x4-sse.c
new file mode 100644
index 0000000..2c0038a
--- /dev/null
+++ b/src/f32-prelu/x4-sse.c
@@ -0,0 +1,152 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/prelu.h>
+
+
+void xnn_f32_prelu_ukernel_x4__sse(
+ size_t mr,
+ size_t n,
+ const float* x,
+ size_t x_stride,
+ const float* w,
+ float* y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float* x0 = x;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ if (mr < 2) {
+ x1 = x0;
+ }
+ const float* x2 = (const float*) ((uintptr_t) x1 + x_stride);
+ if (mr <= 2) {
+ x2 = x1;
+ }
+ const float* x3 = (const float*) ((uintptr_t) x2 + x_stride);
+ if (mr != 4) {
+ x3 = x2;
+ }
+
+ float* y0 = y;
+ float* y1 = (float*) ((uintptr_t) y0 + y_stride);
+ if (mr < 2) {
+ y1 = y0;
+ }
+ float* y2 = (float*) ((uintptr_t) y1 + y_stride);
+ if (mr <= 2) {
+ y2 = y1;
+ }
+ float* y3 = (float*) ((uintptr_t) y2 + y_stride);
+ if (mr != 4) {
+ y3 = y2;
+ }
+
+ const __m128 vy_min = _mm_load_ps(params->sse.min);
+ const __m128 vy_max = _mm_load_ps(params->sse.max);
+ for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
+ const __m128 vw = _mm_loadu_ps(w);
+ w += 4;
+ const __m128 vx0 = _mm_loadu_ps(x0);
+ x0 += 4;
+ const __m128 vx1 = _mm_loadu_ps(x1);
+ x1 += 4;
+ const __m128 vx2 = _mm_loadu_ps(x2);
+ x2 += 4;
+ const __m128 vx3 = _mm_loadu_ps(x3);
+ x3 += 4;
+
+ const __m128 vwx0 = _mm_mul_ps(vx0, vw);
+ const __m128 vwx1 = _mm_mul_ps(vx1, vw);
+ const __m128 vwx2 = _mm_mul_ps(vx2, vw);
+ const __m128 vwx3 = _mm_mul_ps(vx3, vw);
+
+ const __m128i vmask0 = _mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx0));
+ const __m128i vmask1 = _mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx1));
+ const __m128i vmask2 = _mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx2));
+ const __m128i vmask3 = _mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx3));
+
+ const __m128i vacc0 = _mm_or_si128(_mm_andnot_si128(vmask0, _mm_castps_si128(vx0)), _mm_and_si128(vmask0, _mm_castps_si128(vwx0)));
+ const __m128i vacc1 = _mm_or_si128(_mm_andnot_si128(vmask1, _mm_castps_si128(vx1)), _mm_and_si128(vmask1, _mm_castps_si128(vwx1)));
+ const __m128i vacc2 = _mm_or_si128(_mm_andnot_si128(vmask2, _mm_castps_si128(vx2)), _mm_and_si128(vmask2, _mm_castps_si128(vwx2)));
+ const __m128i vacc3 = _mm_or_si128(_mm_andnot_si128(vmask3, _mm_castps_si128(vx3)), _mm_and_si128(vmask3, _mm_castps_si128(vwx3)));
+
+ const __m128 vy0 = _mm_min_ps(_mm_max_ps(_mm_castsi128_ps(vacc0), vy_min), vy_max);
+ const __m128 vy1 = _mm_min_ps(_mm_max_ps(_mm_castsi128_ps(vacc1), vy_min), vy_max);
+ const __m128 vy2 = _mm_min_ps(_mm_max_ps(_mm_castsi128_ps(vacc2), vy_min), vy_max);
+ const __m128 vy3 = _mm_min_ps(_mm_max_ps(_mm_castsi128_ps(vacc3), vy_min), vy_max);
+
+ _mm_storeu_ps(y0, vy0);
+ y0 += 4;
+ _mm_storeu_ps(y1, vy1);
+ y1 += 4;
+ _mm_storeu_ps(y2, vy2);
+ y2 += 4;
+ _mm_storeu_ps(y3, vy3);
+ y3 += 4;
+ }
+ if (n != 0) {
+ const __m128 vw = _mm_loadu_ps(w);
+ const __m128 vx0 = _mm_loadu_ps(x0);
+ const __m128 vx1 = _mm_loadu_ps(x1);
+ const __m128 vx2 = _mm_loadu_ps(x2);
+ const __m128 vx3 = _mm_loadu_ps(x3);
+
+ const __m128 vwx0 = _mm_mul_ps(vx0, vw);
+ const __m128 vwx1 = _mm_mul_ps(vx1, vw);
+ const __m128 vwx2 = _mm_mul_ps(vx2, vw);
+ const __m128 vwx3 = _mm_mul_ps(vx3, vw);
+
+ const __m128i vmask0 = _mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx0));
+ const __m128i vmask1 = _mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx1));
+ const __m128i vmask2 = _mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx2));
+ const __m128i vmask3 = _mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx3));
+
+ const __m128i vacc0 = _mm_or_si128(_mm_andnot_si128(vmask0, _mm_castps_si128(vx0)), _mm_and_si128(vmask0, _mm_castps_si128(vwx0)));
+ const __m128i vacc1 = _mm_or_si128(_mm_andnot_si128(vmask1, _mm_castps_si128(vx1)), _mm_and_si128(vmask1, _mm_castps_si128(vwx1)));
+ const __m128i vacc2 = _mm_or_si128(_mm_andnot_si128(vmask2, _mm_castps_si128(vx2)), _mm_and_si128(vmask2, _mm_castps_si128(vwx2)));
+ const __m128i vacc3 = _mm_or_si128(_mm_andnot_si128(vmask3, _mm_castps_si128(vx3)), _mm_and_si128(vmask3, _mm_castps_si128(vwx3)));
+
+ __m128 vy0 = _mm_min_ps(_mm_max_ps(vacc0, vy_min), vy_max);
+ __m128 vy1 = _mm_min_ps(_mm_max_ps(vacc1, vy_min), vy_max);
+ __m128 vy2 = _mm_min_ps(_mm_max_ps(vacc2, vy_min), vy_max);
+ __m128 vy3 = _mm_min_ps(_mm_max_ps(vacc3, vy_min), vy_max);
+
+ if (n & 2 * sizeof(float)) {
+ _mm_storel_pi((__m64*) y0, vy0);
+ _mm_storel_pi((__m64*) y1, vy1);
+ _mm_storel_pi((__m64*) y2, vy2);
+ _mm_storel_pi((__m64*) y3, vy3);
+
+ vy0 = _mm_movehl_ps(vy0, vy0);
+ vy1 = _mm_movehl_ps(vy1, vy1);
+ vy2 = _mm_movehl_ps(vy2, vy2);
+ vy3 = _mm_movehl_ps(vy3, vy3);
+
+ y0 += 2;
+ y1 += 2;
+ y2 += 2;
+ y3 += 2;
+ }
+ if (n & 1 * sizeof(float)) {
+ _mm_store_ss(y0, vy0);
+ _mm_store_ss(y1, vy1);
+ _mm_store_ss(y2, vy2);
+ _mm_store_ss(y3, vy3);
+ }
+ }
+}
diff --git a/src/f32-rmax/avx.c b/src/f32-rmax/avx.c
new file mode 100644
index 0000000..e242283
--- /dev/null
+++ b/src/f32-rmax/avx.c
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/rmax.h>
+
+
+void xnn_f32_rmax_ukernel__avx(
+ size_t n,
+ const float* x,
+ float* y)
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ __m256 vmax0 = _mm256_broadcast_ss(x);
+ __m256 vmax1 = vmax0;
+ __m256 vmax2 = vmax0;
+ __m256 vmax3 = vmax0;
+ for (; n >= 128; n -= 128) {
+ const __m256 vx0 = _mm256_loadu_ps(x);
+ const __m256 vx1 = _mm256_loadu_ps(x + 8);
+ const __m256 vx2 = _mm256_loadu_ps(x + 16);
+ const __m256 vx3 = _mm256_loadu_ps(x + 24);
+ x += 32;
+
+ vmax0 = _mm256_max_ps(vmax0, vx0);
+ vmax1 = _mm256_max_ps(vmax1, vx1);
+ vmax2 = _mm256_max_ps(vmax2, vx2);
+ vmax3 = _mm256_max_ps(vmax3, vx3);
+ }
+ __m256 vmax = _mm256_max_ps(_mm256_max_ps(vmax0, vmax1), _mm256_max_ps(vmax2, vmax3));
+ for (; n >= 32; n -= 32) {
+ const __m256 vx = _mm256_loadu_ps(x);
+ vmax = _mm256_max_ps(vmax, vx);
+ x += 8;
+ }
+ __m128 vmax_lo = _mm_max_ps(_mm256_castps256_ps128(vmax), _mm256_extractf128_ps(vmax, 1));
+ vmax_lo = _mm_max_ps(vmax_lo, _mm_movehl_ps(vmax_lo, vmax_lo));
+ vmax_lo = _mm_max_ss(vmax_lo, _mm_shuffle_ps(vmax_lo, vmax_lo, _MM_SHUFFLE(3, 3, 1, 1)));
+ if XNN_UNLIKELY(n != 0) {
+ do {
+ vmax_lo = _mm_max_ss(vmax_lo, _mm_load_ss(x));
+ x += 1;
+ n -= 4;
+ } while (n != 0);
+ }
+ _mm_store_ss(y, vmax_lo);
+}
diff --git a/src/f32-rmax/avx512f.c b/src/f32-rmax/avx512f.c
new file mode 100644
index 0000000..ce96155
--- /dev/null
+++ b/src/f32-rmax/avx512f.c
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/rmax.h>
+
+
+void xnn_f32_rmax_ukernel__avx512f(
+ size_t n,
+ const float* x,
+ float* y)
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ __m512 vmax0 = _mm512_broadcastss_ps(_mm_load_ss(x));
+ __m512 vmax1 = vmax0;
+ __m512 vmax2 = vmax0;
+ __m512 vmax3 = vmax0;
+ for (; n >= 256; n -= 256) {
+ const __m512 vx0 = _mm512_loadu_ps(x);
+ const __m512 vx1 = _mm512_loadu_ps(x + 16);
+ const __m512 vx2 = _mm512_loadu_ps(x + 32);
+ const __m512 vx3 = _mm512_loadu_ps(x + 48);
+ x += 64;
+
+ vmax0 = _mm512_max_ps(vmax0, vx0);
+ vmax1 = _mm512_max_ps(vmax1, vx1);
+ vmax2 = _mm512_max_ps(vmax2, vx2);
+ vmax3 = _mm512_max_ps(vmax3, vx3);
+ }
+ __m512 vmax = _mm512_max_ps(_mm512_max_ps(vmax0, vmax1), _mm512_max_ps(vmax2, vmax3));
+ for (; n >= 64; n -= 64) {
+ const __m512 vx = _mm512_loadu_ps(x);
+ vmax = _mm512_max_ps(vmax, vx);
+ x += 16;
+ }
+ __m256 vmax_lo = _mm256_max_ps(_mm512_castps512_ps256(vmax), _mm512_castps512_ps256(_mm512_shuffle_f32x4(vmax, vmax, _MM_SHUFFLE(3, 2, 3, 2))));
+ __m128 vmax_ll = _mm_max_ps(_mm256_castps256_ps128(vmax_lo), _mm256_extractf128_ps(vmax_lo, 1));
+ for (; n >= 16; n -= 16) {
+ const __m128 vx = _mm_loadu_ps(x);
+ vmax_ll = _mm_max_ps(vmax_ll, vx);
+ x += 4;
+ }
+ vmax_ll = _mm_max_ps(vmax_ll, _mm_movehl_ps(vmax_ll, vmax_ll));
+ vmax_ll = _mm_max_ss(vmax_ll, _mm_shuffle_ps(vmax_ll, vmax_ll, _MM_SHUFFLE(3, 3, 1, 1)));
+ if XNN_UNLIKELY(n != 0) {
+ do {
+ vmax_ll = _mm_max_ss(vmax_ll, _mm_load_ss(x));
+ x += 1;
+ n -= 4;
+ } while (n != 0);
+ }
+ _mm_store_ss(y, vmax_ll);
+}
diff --git a/src/f32-rmax/neon.c b/src/f32-rmax/neon.c
new file mode 100644
index 0000000..a0cb076
--- /dev/null
+++ b/src/f32-rmax/neon.c
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/rmax.h>
+
+
+void xnn_f32_rmax_ukernel__neon(
+ size_t n,
+ const float* x,
+ float* y)
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ float32x4_t vmax0 = vld1q_dup_f32(x);
+ float32x4_t vmax1 = vmax0;
+ float32x4_t vmax2 = vmax0;
+ float32x4_t vmax3 = vmax0;
+ for (; n >= 64; n -= 64) {
+ const float32x4_t vx0 = vld1q_f32(x); x += 4;
+ const float32x4_t vx1 = vld1q_f32(x); x += 4;
+ const float32x4_t vx2 = vld1q_f32(x); x += 4;
+ const float32x4_t vx3 = vld1q_f32(x); x += 4;
+
+ vmax0 = vmaxq_f32(vmax0, vx0);
+ vmax1 = vmaxq_f32(vmax1, vx1);
+ vmax2 = vmaxq_f32(vmax2, vx2);
+ vmax3 = vmaxq_f32(vmax3, vx3);
+ }
+ float32x4_t vmax = vmaxq_f32(vmaxq_f32(vmax0, vmax1), vmaxq_f32(vmax2, vmax3));
+ for (; n >= 16; n -= 16) {
+ const float32x4_t vx = vld1q_f32(x); x += 4;
+ vmax = vmaxq_f32(vmax, vx);
+ }
+#ifdef __aarch64__
+ float32x2_t vmax_lo = vget_low_f32(vpmaxq_f32(vmax, vmax));
+#else
+ float32x2_t vmax_lo = vmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
+#endif
+ if XNN_UNLIKELY(n != 0) {
+ do {
+ const float32x2_t vx = vld1_dup_f32(x); x += 1;
+ vmax_lo = vmax_f32(vmax_lo, vx);
+ n -= 4;
+ } while (n != 0);
+ }
+#ifdef __aarch64__
+ *y = vmaxv_f32(vmax_lo);
+#else
+ vst1_lane_f32(y, vpmax_f32(vmax_lo, vmax_lo), 0);
+#endif
+}
diff --git a/src/f32-rmax/scalar.c b/src/f32-rmax/scalar.c
new file mode 100644
index 0000000..bc8d5f3
--- /dev/null
+++ b/src/f32-rmax/scalar.c
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/rmax.h>
+
+
+void xnn_f32_rmax_ukernel__scalar(
+ size_t n,
+ const float* x,
+ float* y)
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ float vmax0 = *x;
+ float vmax1 = vmax0;
+ float vmax2 = vmax0;
+ float vmax3 = vmax0;
+ for (; n >= 16; n -= 16) {
+ const float vx0 = x[0];
+ const float vx1 = x[1];
+ const float vx2 = x[2];
+ const float vx3 = x[3];
+ x += 4;
+
+ vmax0 = math_max_f32(vx0, vmax0);
+ vmax1 = math_max_f32(vx1, vmax1);
+ vmax2 = math_max_f32(vx2, vmax2);
+ vmax3 = math_max_f32(vx3, vmax3);
+ }
+ const float vmax01 = math_max_f32(vmax0, vmax1);
+ const float vmax23 = math_max_f32(vmax2, vmax3);
+ float vmax = math_max_f32(vmax01, vmax23);
+ if XNN_UNLIKELY(n != 0) {
+ do {
+ const float vx = *x++;
+ vmax = math_max_f32(vx, vmax);
+ n -= 4;
+ } while (n != 0);
+ }
+ *y = vmax;
+}
diff --git a/src/f32-rmax/sse.c b/src/f32-rmax/sse.c
new file mode 100644
index 0000000..8968565
--- /dev/null
+++ b/src/f32-rmax/sse.c
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/rmax.h>
+
+
+void xnn_f32_rmax_ukernel__sse(
+ size_t n,
+ const float* x,
+ float* y)
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ __m128 vmax0 = _mm_load_ss(x);
+ vmax0 = _mm_shuffle_ps(vmax0, vmax0, _MM_SHUFFLE(0, 0, 0, 0));
+ __m128 vmax1 = vmax0;
+ __m128 vmax2 = vmax0;
+ __m128 vmax3 = vmax0;
+ for (; n >= 64; n -= 64) {
+ const __m128 vx0 = _mm_loadu_ps(x);
+ const __m128 vx1 = _mm_loadu_ps(x + 4);
+ const __m128 vx2 = _mm_loadu_ps(x + 8);
+ const __m128 vx3 = _mm_loadu_ps(x + 12);
+ x += 16;
+
+ vmax0 = _mm_max_ps(vmax0, vx0);
+ vmax1 = _mm_max_ps(vmax1, vx1);
+ vmax2 = _mm_max_ps(vmax2, vx2);
+ vmax3 = _mm_max_ps(vmax3, vx3);
+ }
+ __m128 vmax = _mm_max_ps(_mm_max_ps(vmax0, vmax1), _mm_max_ps(vmax2, vmax3));
+ for (; n >= 16; n -= 16) {
+ const __m128 vx = _mm_loadu_ps(x);
+ vmax = _mm_max_ps(vmax, vx);
+ x += 4;
+ }
+ __m128 vmax_lo = _mm_max_ps(vmax, _mm_movehl_ps(vmax, vmax));
+ vmax_lo = _mm_max_ss(vmax_lo, _mm_shuffle_ps(vmax_lo, vmax_lo, _MM_SHUFFLE(3, 3, 1, 1)));
+ if XNN_UNLIKELY(n != 0) {
+ do {
+ vmax_lo = _mm_max_ss(vmax_lo, _mm_load_ss(x));
+ x += 1;
+ n -= 4;
+ } while (n != 0);
+ }
+ _mm_store_ss(y, vmax_lo);
+}
diff --git a/src/f32-spmm/12x1-neonfma.c b/src/f32-spmm/12x1-neonfma.c
new file mode 100644
index 0000000..13f2954
--- /dev/null
+++ b/src/f32-spmm/12x1-neonfma.c
@@ -0,0 +1,183 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_12x1__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 12) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ float32x4_t vacc89AB = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vout89AB = vmaxq_f32(vout89AB, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ vst1q_f32(c + 8, vout89AB);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 12;
+ a += 12;
+ i -= 12;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/12x2-neonfma.c b/src/f32-spmm/12x2-neonfma.c
new file mode 100644
index 0000000..f2a1c12
--- /dev/null
+++ b/src/f32-spmm/12x2-neonfma.c
@@ -0,0 +1,379 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-blocked.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_12x2__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 12) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc89ABc0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ float32x4_t vacc89ABc1 = vacc0123c1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
+ vacc89ABc0 = vfmaq_lane_f32(vacc89ABc0, va89AB, vb, 0);
+ vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
+ vacc89ABc1 = vfmaq_lane_f32(vacc89ABc1, va89AB, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout89ABc0 = vminq_f32(vacc89ABc0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+ float32x4_t vout89ABc1 = vminq_f32(vacc89ABc1, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout89ABc0 = vmaxq_f32(vout89ABc0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+ vout89ABc1 = vmaxq_f32(vout89ABc1, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 0 * m + 8, vout89ABc0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ vst1q_f32(c + 1 * m + 8, vout89ABc1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ float32x4_t vacc89AB = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vout89AB = vmaxq_f32(vout89AB, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ vst1q_f32(c + 8, vout89AB);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 12;
+ a += 12;
+ i -= 12;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
+ vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
+ vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc01c0 = vfma_lane_f32(vacc01c0, va01, vb, 0);
+ vacc01c1 = vfma_lane_f32(vacc01c1, va01, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
+ float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
+
+ vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
+ vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
+
+ vst1_f32(c + 0 * m + 0, vout01c0);
+ vst1_f32(c + 1 * m + 0, vout01c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+
+ vst1_f32(c, vout01);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0c0 = vfma_lane_f32(vacc0c0, va0, vb, 0);
+ vacc0c1 = vfma_lane_f32(vacc0c1, va0, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
+ float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
+
+ vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
+ vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
+
+ vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
+ vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+
+ vst1_lane_f32(c, vout0, 1);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/12x4-neonfma.c b/src/f32-spmm/12x4-neonfma.c
new file mode 100644
index 0000000..f89f1ea
--- /dev/null
+++ b/src/f32-spmm/12x4-neonfma.c
@@ -0,0 +1,459 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-blocked.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_12x4__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 12) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc89ABc0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ float32x4_t vacc89ABc1 = vacc0123c1;
+ float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c2 = vacc0123c2;
+ float32x4_t vacc89ABc2 = vacc0123c2;
+ float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c3 = vacc0123c3;
+ float32x4_t vacc89ABc3 = vacc0123c3;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_laneq_f32(vacc4567c0, va4567, vb, 0);
+ vacc89ABc0 = vfmaq_laneq_f32(vacc89ABc0, va89AB, vb, 0);
+ vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_laneq_f32(vacc4567c1, va4567, vb, 1);
+ vacc89ABc1 = vfmaq_laneq_f32(vacc89ABc1, va89AB, vb, 1);
+ vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
+ vacc4567c2 = vfmaq_laneq_f32(vacc4567c2, va4567, vb, 2);
+ vacc89ABc2 = vfmaq_laneq_f32(vacc89ABc2, va89AB, vb, 2);
+ vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
+ vacc4567c3 = vfmaq_laneq_f32(vacc4567c3, va4567, vb, 3);
+ vacc89ABc3 = vfmaq_laneq_f32(vacc89ABc3, va89AB, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout89ABc0 = vminq_f32(vacc89ABc0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+ float32x4_t vout89ABc1 = vminq_f32(vacc89ABc1, vmax);
+ float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
+ float32x4_t vout4567c2 = vminq_f32(vacc4567c2, vmax);
+ float32x4_t vout89ABc2 = vminq_f32(vacc89ABc2, vmax);
+ float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
+ float32x4_t vout4567c3 = vminq_f32(vacc4567c3, vmax);
+ float32x4_t vout89ABc3 = vminq_f32(vacc89ABc3, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout89ABc0 = vmaxq_f32(vout89ABc0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+ vout89ABc1 = vmaxq_f32(vout89ABc1, vmin);
+ vout0123c2 = vmaxq_f32(vout0123c2, vmin);
+ vout4567c2 = vmaxq_f32(vout4567c2, vmin);
+ vout89ABc2 = vmaxq_f32(vout89ABc2, vmin);
+ vout0123c3 = vmaxq_f32(vout0123c3, vmin);
+ vout4567c3 = vmaxq_f32(vout4567c3, vmin);
+ vout89ABc3 = vmaxq_f32(vout89ABc3, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 0 * m + 8, vout89ABc0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ vst1q_f32(c + 1 * m + 8, vout89ABc1);
+ vst1q_f32(c + 2 * m + 0, vout0123c2);
+ vst1q_f32(c + 2 * m + 4, vout4567c2);
+ vst1q_f32(c + 2 * m + 8, vout89ABc2);
+ vst1q_f32(c + 3 * m + 0, vout0123c3);
+ vst1q_f32(c + 3 * m + 4, vout4567c3);
+ vst1q_f32(c + 3 * m + 8, vout89ABc3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ float32x4_t vacc89AB = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vout89AB = vmaxq_f32(vout89AB, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ vst1q_f32(c + 8, vout89AB);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 12;
+ a += 12;
+ i -= 12;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c2 = vacc0123c2;
+ float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c3 = vacc0123c3;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_laneq_f32(vacc4567c0, va4567, vb, 0);
+ vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_laneq_f32(vacc4567c1, va4567, vb, 1);
+ vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
+ vacc4567c2 = vfmaq_laneq_f32(vacc4567c2, va4567, vb, 2);
+ vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
+ vacc4567c3 = vfmaq_laneq_f32(vacc4567c3, va4567, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+ float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
+ float32x4_t vout4567c2 = vminq_f32(vacc4567c2, vmax);
+ float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
+ float32x4_t vout4567c3 = vminq_f32(vacc4567c3, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+ vout0123c2 = vmaxq_f32(vout0123c2, vmin);
+ vout4567c2 = vmaxq_f32(vout4567c2, vmin);
+ vout0123c3 = vmaxq_f32(vout0123c3, vmin);
+ vout4567c3 = vmaxq_f32(vout4567c3, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ vst1q_f32(c + 2 * m + 0, vout0123c2);
+ vst1q_f32(c + 2 * m + 4, vout4567c2);
+ vst1q_f32(c + 3 * m + 0, vout0123c3);
+ vst1q_f32(c + 3 * m + 4, vout4567c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
+ vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
+ vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
+ vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
+ float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout0123c2 = vmaxq_f32(vout0123c2, vmin);
+ vout0123c3 = vmaxq_f32(vout0123c3, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 2 * m + 0, vout0123c2);
+ vst1q_f32(c + 3 * m + 0, vout0123c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c2 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c3 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc01c0 = vfma_laneq_f32(vacc01c0, va01, vb, 0);
+ vacc01c1 = vfma_laneq_f32(vacc01c1, va01, vb, 1);
+ vacc01c2 = vfma_laneq_f32(vacc01c2, va01, vb, 2);
+ vacc01c3 = vfma_laneq_f32(vacc01c3, va01, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
+ float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
+ float32x2_t vout01c2 = vmin_f32(vacc01c2, vget_low_f32(vmax));
+ float32x2_t vout01c3 = vmin_f32(vacc01c3, vget_low_f32(vmax));
+
+ vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
+ vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
+ vout01c2 = vmax_f32(vout01c2, vget_low_f32(vmin));
+ vout01c3 = vmax_f32(vout01c3, vget_low_f32(vmin));
+
+ vst1_f32(c + 0 * m + 0, vout01c0);
+ vst1_f32(c + 1 * m + 0, vout01c1);
+ vst1_f32(c + 2 * m + 0, vout01c2);
+ vst1_f32(c + 3 * m + 0, vout01c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+
+ vst1_f32(c, vout01);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c2 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c3 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0c0 = vfma_laneq_f32(vacc0c0, va0, vb, 0);
+ vacc0c1 = vfma_laneq_f32(vacc0c1, va0, vb, 1);
+ vacc0c2 = vfma_laneq_f32(vacc0c2, va0, vb, 2);
+ vacc0c3 = vfma_laneq_f32(vacc0c3, va0, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
+ float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
+ float32x2_t vout0c2 = vmin_f32(vacc0c2, vget_low_f32(vmax));
+ float32x2_t vout0c3 = vmin_f32(vacc0c3, vget_low_f32(vmax));
+
+ vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
+ vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
+ vout0c2 = vmax_f32(vout0c2, vget_low_f32(vmin));
+ vout0c3 = vmax_f32(vout0c3, vget_low_f32(vmin));
+
+ vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
+ vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
+ vst1_lane_f32(c + 2 * m + 0, vout0c2, 0);
+ vst1_lane_f32(c + 3 * m + 0, vout0c3, 0);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+
+ vst1_lane_f32(c, vout0, 1);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/16x1-neonfma-pipelined.c b/src/f32-spmm/16x1-neonfma-pipelined.c
new file mode 100644
index 0000000..75b11d5
--- /dev/null
+++ b/src/f32-spmm/16x1-neonfma-pipelined.c
@@ -0,0 +1,197 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-pipelined.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_16x1__neonfma_pipelined(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 16) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float32x4_t vw = vld1q_dup_f32(w); w += 1;
+ intptr_t diff = *dmap++;
+ float32x4_t va0123 = vld1q_f32(a);
+ float32x4_t va4567 = vld1q_f32(a + 4);
+ float32x4_t va89AB = vld1q_f32(a + 8);
+ float32x4_t vaCDEF = vld1q_f32(a + 12);
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vw;
+ float32x4_t vacc4567 = vw;
+ float32x4_t vacc89AB = vw;
+ float32x4_t vaccCDEF = vw;
+ vw = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vw);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vw);
+ vacc89AB = vfmaq_f32(vacc89AB, va89AB, vw);
+ vaccCDEF = vfmaq_f32(vaccCDEF, vaCDEF, vw);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = vld1q_dup_f32(w); w += 1;
+ va0123 = vld1q_f32(a);
+ va4567 = vld1q_f32(a + 4);
+ va89AB = vld1q_f32(a + 8);
+ vaCDEF = vld1q_f32(a + 12);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
+ float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vout89AB = vmaxq_f32(vout89AB, vmin);
+ voutCDEF = vmaxq_f32(voutCDEF, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ vst1q_f32(c + 8, vout89AB);
+ vst1q_f32(c + 12, voutCDEF);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 16;
+ a += 16;
+ i -= 16;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/16x1-neonfma-unroll2.c b/src/f32-spmm/16x1-neonfma-unroll2.c
new file mode 100644
index 0000000..f4a062b
--- /dev/null
+++ b/src/f32-spmm/16x1-neonfma-unroll2.c
@@ -0,0 +1,226 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_16x1__neonfma_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 16) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123x0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123x1 = vmovq_n_f32(0.0f);
+ float32x4_t vacc4567x0 = vacc0123x0;
+ float32x4_t vacc4567x1 = vmovq_n_f32(0.0f);
+ float32x4_t vacc89ABx0 = vacc0123x0;
+ float32x4_t vacc89ABx1 = vmovq_n_f32(0.0f);
+ float32x4_t vaccCDEFx0 = vacc0123x0;
+ float32x4_t vaccCDEFx1 = vmovq_n_f32(0.0f);
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float32x4_t va0123x0 = vld1q_f32(a);
+ const float32x4_t va4567x0 = vld1q_f32(a + 4);
+ const float32x4_t va89ABx0 = vld1q_f32(a + 8);
+ const float32x4_t vaCDEFx0 = vld1q_f32(a + 12);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float32x4_t vb0 = vld1q_dup_f32(w); w += 1;
+ vacc0123x0 = vfmaq_f32(vacc0123x0, va0123x0, vb0);
+ vacc4567x0 = vfmaq_f32(vacc4567x0, va4567x0, vb0);
+ vacc89ABx0 = vfmaq_f32(vacc89ABx0, va89ABx0, vb0);
+ vaccCDEFx0 = vfmaq_f32(vaccCDEFx0, vaCDEFx0, vb0);
+ const float32x4_t va0123x1 = vld1q_f32(a);
+ const float32x4_t va4567x1 = vld1q_f32(a + 4);
+ const float32x4_t va89ABx1 = vld1q_f32(a + 8);
+ const float32x4_t vaCDEFx1 = vld1q_f32(a + 12);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float32x4_t vb1 = vld1q_dup_f32(w); w += 1;
+ vacc0123x1 = vfmaq_f32(vacc0123x1, va0123x1, vb1);
+ vacc4567x1 = vfmaq_f32(vacc4567x1, va4567x1, vb1);
+ vacc89ABx1 = vfmaq_f32(vacc89ABx1, va89ABx1, vb1);
+ vaccCDEFx1 = vfmaq_f32(vaccCDEFx1, vaCDEFx1, vb1);
+ }
+ float32x4_t vacc0123 = vacc0123x0;
+ float32x4_t vacc4567 = vacc4567x0;
+ float32x4_t vacc89AB = vacc89ABx0;
+ float32x4_t vaccCDEF = vaccCDEFx0;
+ vacc0123 = vaddq_f32(vacc0123, vacc0123x1);
+ vacc4567 = vaddq_f32(vacc4567, vacc4567x1);
+ vacc89AB = vaddq_f32(vacc89AB, vacc89ABx1);
+ vaccCDEF = vaddq_f32(vaccCDEF, vaccCDEFx1);
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ const float32x4_t vaCDEF = vld1q_f32(a + 12);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
+ vaccCDEF = vfmaq_f32(vaccCDEF, vaCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
+ float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vout89AB = vmaxq_f32(vout89AB, vmin);
+ voutCDEF = vmaxq_f32(voutCDEF, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ vst1q_f32(c + 8, vout89AB);
+ vst1q_f32(c + 12, voutCDEF);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 16;
+ a += 16;
+ i -= 16;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/16x1-neonfma.c b/src/f32-spmm/16x1-neonfma.c
new file mode 100644
index 0000000..c92af5b
--- /dev/null
+++ b/src/f32-spmm/16x1-neonfma.c
@@ -0,0 +1,189 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_16x1__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 16) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ float32x4_t vacc89AB = vacc0123;
+ float32x4_t vaccCDEF = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ const float32x4_t vaCDEF = vld1q_f32(a + 12);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
+ vaccCDEF = vfmaq_f32(vaccCDEF, vaCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
+ float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vout89AB = vmaxq_f32(vout89AB, vmin);
+ voutCDEF = vmaxq_f32(voutCDEF, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ vst1q_f32(c + 8, vout89AB);
+ vst1q_f32(c + 12, voutCDEF);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 16;
+ a += 16;
+ i -= 16;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/16x2-neonfma.c b/src/f32-spmm/16x2-neonfma.c
new file mode 100644
index 0000000..b74966b
--- /dev/null
+++ b/src/f32-spmm/16x2-neonfma.c
@@ -0,0 +1,396 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-blocked.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_16x2__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 16) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc89ABc0 = vacc0123c0;
+ float32x4_t vaccCDEFc0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ float32x4_t vacc89ABc1 = vacc0123c1;
+ float32x4_t vaccCDEFc1 = vacc0123c1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ const float32x4_t vaCDEF = vld1q_f32(a + 12);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
+ vacc89ABc0 = vfmaq_lane_f32(vacc89ABc0, va89AB, vb, 0);
+ vaccCDEFc0 = vfmaq_lane_f32(vaccCDEFc0, vaCDEF, vb, 0);
+ vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
+ vacc89ABc1 = vfmaq_lane_f32(vacc89ABc1, va89AB, vb, 1);
+ vaccCDEFc1 = vfmaq_lane_f32(vaccCDEFc1, vaCDEF, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout89ABc0 = vminq_f32(vacc89ABc0, vmax);
+ float32x4_t voutCDEFc0 = vminq_f32(vaccCDEFc0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+ float32x4_t vout89ABc1 = vminq_f32(vacc89ABc1, vmax);
+ float32x4_t voutCDEFc1 = vminq_f32(vaccCDEFc1, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout89ABc0 = vmaxq_f32(vout89ABc0, vmin);
+ voutCDEFc0 = vmaxq_f32(voutCDEFc0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+ vout89ABc1 = vmaxq_f32(vout89ABc1, vmin);
+ voutCDEFc1 = vmaxq_f32(voutCDEFc1, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 0 * m + 8, vout89ABc0);
+ vst1q_f32(c + 0 * m + 12, voutCDEFc0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ vst1q_f32(c + 1 * m + 8, vout89ABc1);
+ vst1q_f32(c + 1 * m + 12, voutCDEFc1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ float32x4_t vacc89AB = vacc0123;
+ float32x4_t vaccCDEF = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ const float32x4_t vaCDEF = vld1q_f32(a + 12);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
+ vaccCDEF = vfmaq_f32(vaccCDEF, vaCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
+ float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vout89AB = vmaxq_f32(vout89AB, vmin);
+ voutCDEF = vmaxq_f32(voutCDEF, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ vst1q_f32(c + 8, vout89AB);
+ vst1q_f32(c + 12, voutCDEF);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 16;
+ a += 16;
+ i -= 16;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
+ vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
+ vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc01c0 = vfma_lane_f32(vacc01c0, va01, vb, 0);
+ vacc01c1 = vfma_lane_f32(vacc01c1, va01, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
+ float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
+
+ vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
+ vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
+
+ vst1_f32(c + 0 * m + 0, vout01c0);
+ vst1_f32(c + 1 * m + 0, vout01c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+
+ vst1_f32(c, vout01);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0c0 = vfma_lane_f32(vacc0c0, va0, vb, 0);
+ vacc0c1 = vfma_lane_f32(vacc0c1, va0, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
+ float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
+
+ vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
+ vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
+
+ vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
+ vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+
+ vst1_lane_f32(c, vout0, 1);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/16x4-neonfma.c b/src/f32-spmm/16x4-neonfma.c
new file mode 100644
index 0000000..8a927eb
--- /dev/null
+++ b/src/f32-spmm/16x4-neonfma.c
@@ -0,0 +1,486 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-blocked.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_16x4__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 16) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc89ABc0 = vacc0123c0;
+ float32x4_t vaccCDEFc0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ float32x4_t vacc89ABc1 = vacc0123c1;
+ float32x4_t vaccCDEFc1 = vacc0123c1;
+ float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c2 = vacc0123c2;
+ float32x4_t vacc89ABc2 = vacc0123c2;
+ float32x4_t vaccCDEFc2 = vacc0123c2;
+ float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c3 = vacc0123c3;
+ float32x4_t vacc89ABc3 = vacc0123c3;
+ float32x4_t vaccCDEFc3 = vacc0123c3;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ const float32x4_t vaCDEF = vld1q_f32(a + 12);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_laneq_f32(vacc4567c0, va4567, vb, 0);
+ vacc89ABc0 = vfmaq_laneq_f32(vacc89ABc0, va89AB, vb, 0);
+ vaccCDEFc0 = vfmaq_laneq_f32(vaccCDEFc0, vaCDEF, vb, 0);
+ vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_laneq_f32(vacc4567c1, va4567, vb, 1);
+ vacc89ABc1 = vfmaq_laneq_f32(vacc89ABc1, va89AB, vb, 1);
+ vaccCDEFc1 = vfmaq_laneq_f32(vaccCDEFc1, vaCDEF, vb, 1);
+ vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
+ vacc4567c2 = vfmaq_laneq_f32(vacc4567c2, va4567, vb, 2);
+ vacc89ABc2 = vfmaq_laneq_f32(vacc89ABc2, va89AB, vb, 2);
+ vaccCDEFc2 = vfmaq_laneq_f32(vaccCDEFc2, vaCDEF, vb, 2);
+ vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
+ vacc4567c3 = vfmaq_laneq_f32(vacc4567c3, va4567, vb, 3);
+ vacc89ABc3 = vfmaq_laneq_f32(vacc89ABc3, va89AB, vb, 3);
+ vaccCDEFc3 = vfmaq_laneq_f32(vaccCDEFc3, vaCDEF, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout89ABc0 = vminq_f32(vacc89ABc0, vmax);
+ float32x4_t voutCDEFc0 = vminq_f32(vaccCDEFc0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+ float32x4_t vout89ABc1 = vminq_f32(vacc89ABc1, vmax);
+ float32x4_t voutCDEFc1 = vminq_f32(vaccCDEFc1, vmax);
+ float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
+ float32x4_t vout4567c2 = vminq_f32(vacc4567c2, vmax);
+ float32x4_t vout89ABc2 = vminq_f32(vacc89ABc2, vmax);
+ float32x4_t voutCDEFc2 = vminq_f32(vaccCDEFc2, vmax);
+ float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
+ float32x4_t vout4567c3 = vminq_f32(vacc4567c3, vmax);
+ float32x4_t vout89ABc3 = vminq_f32(vacc89ABc3, vmax);
+ float32x4_t voutCDEFc3 = vminq_f32(vaccCDEFc3, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout89ABc0 = vmaxq_f32(vout89ABc0, vmin);
+ voutCDEFc0 = vmaxq_f32(voutCDEFc0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+ vout89ABc1 = vmaxq_f32(vout89ABc1, vmin);
+ voutCDEFc1 = vmaxq_f32(voutCDEFc1, vmin);
+ vout0123c2 = vmaxq_f32(vout0123c2, vmin);
+ vout4567c2 = vmaxq_f32(vout4567c2, vmin);
+ vout89ABc2 = vmaxq_f32(vout89ABc2, vmin);
+ voutCDEFc2 = vmaxq_f32(voutCDEFc2, vmin);
+ vout0123c3 = vmaxq_f32(vout0123c3, vmin);
+ vout4567c3 = vmaxq_f32(vout4567c3, vmin);
+ vout89ABc3 = vmaxq_f32(vout89ABc3, vmin);
+ voutCDEFc3 = vmaxq_f32(voutCDEFc3, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 0 * m + 8, vout89ABc0);
+ vst1q_f32(c + 0 * m + 12, voutCDEFc0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ vst1q_f32(c + 1 * m + 8, vout89ABc1);
+ vst1q_f32(c + 1 * m + 12, voutCDEFc1);
+ vst1q_f32(c + 2 * m + 0, vout0123c2);
+ vst1q_f32(c + 2 * m + 4, vout4567c2);
+ vst1q_f32(c + 2 * m + 8, vout89ABc2);
+ vst1q_f32(c + 2 * m + 12, voutCDEFc2);
+ vst1q_f32(c + 3 * m + 0, vout0123c3);
+ vst1q_f32(c + 3 * m + 4, vout4567c3);
+ vst1q_f32(c + 3 * m + 8, vout89ABc3);
+ vst1q_f32(c + 3 * m + 12, voutCDEFc3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ float32x4_t vacc89AB = vacc0123;
+ float32x4_t vaccCDEF = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ const float32x4_t va89AB = vld1q_f32(a + 8);
+ const float32x4_t vaCDEF = vld1q_f32(a + 12);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ vacc89AB = vfmaq_f32(vacc89AB, va89AB, vb);
+ vaccCDEF = vfmaq_f32(vaccCDEF, vaCDEF, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
+ float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vout89AB = vmaxq_f32(vout89AB, vmin);
+ voutCDEF = vmaxq_f32(voutCDEF, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ vst1q_f32(c + 8, vout89AB);
+ vst1q_f32(c + 12, voutCDEF);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 16;
+ a += 16;
+ i -= 16;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c2 = vacc0123c2;
+ float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c3 = vacc0123c3;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_laneq_f32(vacc4567c0, va4567, vb, 0);
+ vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_laneq_f32(vacc4567c1, va4567, vb, 1);
+ vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
+ vacc4567c2 = vfmaq_laneq_f32(vacc4567c2, va4567, vb, 2);
+ vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
+ vacc4567c3 = vfmaq_laneq_f32(vacc4567c3, va4567, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+ float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
+ float32x4_t vout4567c2 = vminq_f32(vacc4567c2, vmax);
+ float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
+ float32x4_t vout4567c3 = vminq_f32(vacc4567c3, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+ vout0123c2 = vmaxq_f32(vout0123c2, vmin);
+ vout4567c2 = vmaxq_f32(vout4567c2, vmin);
+ vout0123c3 = vmaxq_f32(vout0123c3, vmin);
+ vout4567c3 = vmaxq_f32(vout4567c3, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ vst1q_f32(c + 2 * m + 0, vout0123c2);
+ vst1q_f32(c + 2 * m + 4, vout4567c2);
+ vst1q_f32(c + 3 * m + 0, vout0123c3);
+ vst1q_f32(c + 3 * m + 4, vout4567c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 8;
+ a += 8;
+ }
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
+ vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
+ vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
+ vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
+ float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout0123c2 = vmaxq_f32(vout0123c2, vmin);
+ vout0123c3 = vmaxq_f32(vout0123c3, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 2 * m + 0, vout0123c2);
+ vst1q_f32(c + 3 * m + 0, vout0123c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c2 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c3 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc01c0 = vfma_laneq_f32(vacc01c0, va01, vb, 0);
+ vacc01c1 = vfma_laneq_f32(vacc01c1, va01, vb, 1);
+ vacc01c2 = vfma_laneq_f32(vacc01c2, va01, vb, 2);
+ vacc01c3 = vfma_laneq_f32(vacc01c3, va01, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
+ float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
+ float32x2_t vout01c2 = vmin_f32(vacc01c2, vget_low_f32(vmax));
+ float32x2_t vout01c3 = vmin_f32(vacc01c3, vget_low_f32(vmax));
+
+ vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
+ vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
+ vout01c2 = vmax_f32(vout01c2, vget_low_f32(vmin));
+ vout01c3 = vmax_f32(vout01c3, vget_low_f32(vmin));
+
+ vst1_f32(c + 0 * m + 0, vout01c0);
+ vst1_f32(c + 1 * m + 0, vout01c1);
+ vst1_f32(c + 2 * m + 0, vout01c2);
+ vst1_f32(c + 3 * m + 0, vout01c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+
+ vst1_f32(c, vout01);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c2 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c3 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0c0 = vfma_laneq_f32(vacc0c0, va0, vb, 0);
+ vacc0c1 = vfma_laneq_f32(vacc0c1, va0, vb, 1);
+ vacc0c2 = vfma_laneq_f32(vacc0c2, va0, vb, 2);
+ vacc0c3 = vfma_laneq_f32(vacc0c3, va0, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
+ float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
+ float32x2_t vout0c2 = vmin_f32(vacc0c2, vget_low_f32(vmax));
+ float32x2_t vout0c3 = vmin_f32(vacc0c3, vget_low_f32(vmax));
+
+ vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
+ vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
+ vout0c2 = vmax_f32(vout0c2, vget_low_f32(vmin));
+ vout0c3 = vmax_f32(vout0c3, vget_low_f32(vmin));
+
+ vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
+ vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
+ vst1_lane_f32(c + 2 * m + 0, vout0c2, 0);
+ vst1_lane_f32(c + 3 * m + 0, vout0c3, 0);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+
+ vst1_lane_f32(c, vout0, 1);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/1x1-scalar-pipelined.c b/src/f32-spmm/1x1-scalar-pipelined.c
new file mode 100644
index 0000000..a0e3843
--- /dev/null
+++ b/src/f32-spmm/1x1-scalar-pipelined.c
@@ -0,0 +1,65 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar-pipelined.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_1x1__scalar_pipelined(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while XNN_LIKELY(i >= 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ i -= 1;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ }
+}
diff --git a/src/f32-spmm/1x1-scalar-unroll2.c b/src/f32-spmm/1x1-scalar-unroll2.c
new file mode 100644
index 0000000..dc60437
--- /dev/null
+++ b/src/f32-spmm/1x1-scalar-unroll2.c
@@ -0,0 +1,76 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_1x1__scalar_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while (i >= 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ const float va0x1 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ vacc0 += vacc0x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ i -= 1;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ }
+}
diff --git a/src/f32-spmm/1x1-scalar.c b/src/f32-spmm/1x1-scalar.c
new file mode 100644
index 0000000..5d5752f
--- /dev/null
+++ b/src/f32-spmm/1x1-scalar.c
@@ -0,0 +1,60 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_1x1__scalar(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while (i >= 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ i -= 1;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ }
+}
diff --git a/src/f32-spmm/2x1-scalar-pipelined.c b/src/f32-spmm/2x1-scalar-pipelined.c
new file mode 100644
index 0000000..fd185f3
--- /dev/null
+++ b/src/f32-spmm/2x1-scalar-pipelined.c
@@ -0,0 +1,103 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar-pipelined.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_2x1__scalar_pipelined(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while XNN_LIKELY(i >= 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ float va1 = a[1];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ float vacc1 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ vacc1 += va1 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ va1 = a[1];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ i -= 2;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/2x1-scalar-unroll2.c b/src/f32-spmm/2x1-scalar-unroll2.c
new file mode 100644
index 0000000..e9e1abb
--- /dev/null
+++ b/src/f32-spmm/2x1-scalar-unroll2.c
@@ -0,0 +1,131 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_2x1__scalar_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while (i >= 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ const float va1x0 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ vacc1x0 += va1x0 * vb0;
+ const float va0x1 = a[0];
+ const float va1x1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ vacc1x1 += va1x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ float vacc1 = vacc1x0;
+ vacc0 += vacc0x1;
+ vacc1 += vacc1x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ i -= 2;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ const float va0x1 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ vacc0 += vacc0x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/2x1-scalar.c b/src/f32-spmm/2x1-scalar.c
new file mode 100644
index 0000000..474c2ea
--- /dev/null
+++ b/src/f32-spmm/2x1-scalar.c
@@ -0,0 +1,92 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_2x1__scalar(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while (i >= 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ i -= 2;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/4x1-neonfma-pipelined.c b/src/f32-spmm/4x1-neonfma-pipelined.c
new file mode 100644
index 0000000..c24b261
--- /dev/null
+++ b/src/f32-spmm/4x1-neonfma-pipelined.c
@@ -0,0 +1,118 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-pipelined.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_4x1__neonfma_pipelined(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float32x4_t vw = vld1q_dup_f32(w); w += 1;
+ intptr_t diff = *dmap++;
+ float32x4_t va0123 = vld1q_f32(a);
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vw;
+ vw = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vw);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = vld1q_dup_f32(w); w += 1;
+ va0123 = vld1q_f32(a);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ i -= 4;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/4x1-neonfma-unroll2.c b/src/f32-spmm/4x1-neonfma-unroll2.c
new file mode 100644
index 0000000..81bcd85
--- /dev/null
+++ b/src/f32-spmm/4x1-neonfma-unroll2.c
@@ -0,0 +1,129 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_4x1__neonfma_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123x0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123x1 = vmovq_n_f32(0.0f);
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float32x4_t va0123x0 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float32x4_t vb0 = vld1q_dup_f32(w); w += 1;
+ vacc0123x0 = vfmaq_f32(vacc0123x0, va0123x0, vb0);
+ const float32x4_t va0123x1 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float32x4_t vb1 = vld1q_dup_f32(w); w += 1;
+ vacc0123x1 = vfmaq_f32(vacc0123x1, va0123x1, vb1);
+ }
+ float32x4_t vacc0123 = vacc0123x0;
+ vacc0123 = vaddq_f32(vacc0123, vacc0123x1);
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ i -= 4;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/4x1-neonfma.c b/src/f32-spmm/4x1-neonfma.c
new file mode 100644
index 0000000..7e40a60
--- /dev/null
+++ b/src/f32-spmm/4x1-neonfma.c
@@ -0,0 +1,113 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_4x1__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ i -= 4;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/4x1-scalar-pipelined.c b/src/f32-spmm/4x1-scalar-pipelined.c
new file mode 100644
index 0000000..4fff03b
--- /dev/null
+++ b/src/f32-spmm/4x1-scalar-pipelined.c
@@ -0,0 +1,155 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar-pipelined.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_4x1__scalar_pipelined(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while XNN_LIKELY(i >= 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ float va1 = a[1];
+ float va2 = a[2];
+ float va3 = a[3];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ float vacc1 = vw;
+ float vacc2 = vw;
+ float vacc3 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ vacc1 += va1 * vw;
+ vacc2 += va2 * vw;
+ vacc3 += va3 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ va1 = a[1];
+ va2 = a[2];
+ va3 = a[3];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ i -= 4;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ float va1 = a[1];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ float vacc1 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ vacc1 += va1 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ va1 = a[1];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/4x1-scalar-unroll2.c b/src/f32-spmm/4x1-scalar-unroll2.c
new file mode 100644
index 0000000..6685b17
--- /dev/null
+++ b/src/f32-spmm/4x1-scalar-unroll2.c
@@ -0,0 +1,212 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_4x1__scalar_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while (i >= 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = 0.0f;
+ float vacc2x0 = vacc0x0;
+ float vacc2x1 = 0.0f;
+ float vacc3x0 = vacc0x0;
+ float vacc3x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ const float va1x0 = a[1];
+ const float va2x0 = a[2];
+ const float va3x0 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ vacc1x0 += va1x0 * vb0;
+ vacc2x0 += va2x0 * vb0;
+ vacc3x0 += va3x0 * vb0;
+ const float va0x1 = a[0];
+ const float va1x1 = a[1];
+ const float va2x1 = a[2];
+ const float va3x1 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ vacc1x1 += va1x1 * vb1;
+ vacc2x1 += va2x1 * vb1;
+ vacc3x1 += va3x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ float vacc1 = vacc1x0;
+ float vacc2 = vacc2x0;
+ float vacc3 = vacc3x0;
+ vacc0 += vacc0x1;
+ vacc1 += vacc1x1;
+ vacc2 += vacc2x1;
+ vacc3 += vacc3x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ vacc2 += va2 * vb;
+ vacc3 += va3 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ i -= 4;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ const float va1x0 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ vacc1x0 += va1x0 * vb0;
+ const float va0x1 = a[0];
+ const float va1x1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ vacc1x1 += va1x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ float vacc1 = vacc1x0;
+ vacc0 += vacc0x1;
+ vacc1 += vacc1x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ const float va0x1 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ vacc0 += vacc0x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/4x1-scalar.c b/src/f32-spmm/4x1-scalar.c
new file mode 100644
index 0000000..353f633
--- /dev/null
+++ b/src/f32-spmm/4x1-scalar.c
@@ -0,0 +1,136 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_4x1__scalar(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while (i >= 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ float vacc2 = vacc0;
+ float vacc3 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ vacc2 += va2 * vb;
+ vacc3 += va3 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ i -= 4;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/4x1-sse.c b/src/f32-spmm/4x1-sse.c
new file mode 100644
index 0000000..85eaa26
--- /dev/null
+++ b/src/f32-spmm/4x1-sse.c
@@ -0,0 +1,115 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/sse.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_4x1__sse(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ __m128 vacc0123 = _mm_load1_ps(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const __m128 va0123 = _mm_loadu_ps(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const __m128 vb = _mm_load1_ps(w); w += 1;
+ vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(va0123, vb));
+ } while (--nnz != 0);
+ }
+ __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
+ vout0123 = _mm_max_ps(vout0123, vmin);
+ _mm_storeu_ps(c, vout0123);
+ c += 1 * m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ i -= 4;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ __m128 vacc01 = _mm_load_ss(w); w += 1;
+ vacc01 = _mm_unpacklo_ps(vacc01, vacc01);
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const __m128 va01 = _mm_loadl_pi(_mm_undefined_ps(), (const __m64*) a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ __m128 vb = _mm_load_ss(w); w += 1;
+ vb = _mm_unpacklo_ps(vb, vb);
+ vacc01 = _mm_add_ps(vacc01, _mm_mul_ps(va01, vb));
+ } while (--nnz != 0);
+ }
+ __m128 vout01 = _mm_min_ps(vacc01, vmax);
+ vout01 = _mm_max_ps(vout01, vmin);
+ _mm_storel_pi((__m64*) c, vout01);
+ c += 1 * m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ __m128 vacc0 = _mm_load_ss(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const __m128 va0 = _mm_load_ss(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const __m128 vb = _mm_load_ss(w); w += 1;
+ vacc0 = _mm_add_ss(vacc0, _mm_mul_ss(va0, vb));
+ } while (--nnz != 0);
+ }
+ __m128 vout0 = _mm_min_ss(vacc0, vmax);
+ vout0 = _mm_max_ss(vout0, vmin);
+ _mm_store_ss(c, vout0);
+ c += 1 * m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/4x2-neonfma.c b/src/f32-spmm/4x2-neonfma.c
new file mode 100644
index 0000000..56809df
--- /dev/null
+++ b/src/f32-spmm/4x2-neonfma.c
@@ -0,0 +1,210 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-blocked.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_4x2__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
+ vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 4;
+ a += 4;
+ i -= 4;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc01c0 = vfma_lane_f32(vacc01c0, va01, vb, 0);
+ vacc01c1 = vfma_lane_f32(vacc01c1, va01, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
+ float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
+
+ vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
+ vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
+
+ vst1_f32(c + 0 * m + 0, vout01c0);
+ vst1_f32(c + 1 * m + 0, vout01c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+
+ vst1_f32(c, vout01);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0c0 = vfma_lane_f32(vacc0c0, va0, vb, 0);
+ vacc0c1 = vfma_lane_f32(vacc0c1, va0, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
+ float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
+
+ vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
+ vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
+
+ vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
+ vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+
+ vst1_lane_f32(c, vout0, 1);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/4x4-neonfma.c b/src/f32-spmm/4x4-neonfma.c
new file mode 100644
index 0000000..f506166
--- /dev/null
+++ b/src/f32-spmm/4x4-neonfma.c
@@ -0,0 +1,240 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-blocked.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_4x4__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
+ vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
+ vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
+ vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
+ float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout0123c2 = vmaxq_f32(vout0123c2, vmin);
+ vout0123c3 = vmaxq_f32(vout0123c3, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 2 * m + 0, vout0123c2);
+ vst1q_f32(c + 3 * m + 0, vout0123c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 4;
+ a += 4;
+ i -= 4;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c2 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c3 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc01c0 = vfma_laneq_f32(vacc01c0, va01, vb, 0);
+ vacc01c1 = vfma_laneq_f32(vacc01c1, va01, vb, 1);
+ vacc01c2 = vfma_laneq_f32(vacc01c2, va01, vb, 2);
+ vacc01c3 = vfma_laneq_f32(vacc01c3, va01, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
+ float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
+ float32x2_t vout01c2 = vmin_f32(vacc01c2, vget_low_f32(vmax));
+ float32x2_t vout01c3 = vmin_f32(vacc01c3, vget_low_f32(vmax));
+
+ vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
+ vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
+ vout01c2 = vmax_f32(vout01c2, vget_low_f32(vmin));
+ vout01c3 = vmax_f32(vout01c3, vget_low_f32(vmin));
+
+ vst1_f32(c + 0 * m + 0, vout01c0);
+ vst1_f32(c + 1 * m + 0, vout01c1);
+ vst1_f32(c + 2 * m + 0, vout01c2);
+ vst1_f32(c + 3 * m + 0, vout01c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+
+ vst1_f32(c, vout01);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c2 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c3 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0c0 = vfma_laneq_f32(vacc0c0, va0, vb, 0);
+ vacc0c1 = vfma_laneq_f32(vacc0c1, va0, vb, 1);
+ vacc0c2 = vfma_laneq_f32(vacc0c2, va0, vb, 2);
+ vacc0c3 = vfma_laneq_f32(vacc0c3, va0, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
+ float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
+ float32x2_t vout0c2 = vmin_f32(vacc0c2, vget_low_f32(vmax));
+ float32x2_t vout0c3 = vmin_f32(vacc0c3, vget_low_f32(vmax));
+
+ vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
+ vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
+ vout0c2 = vmax_f32(vout0c2, vget_low_f32(vmin));
+ vout0c3 = vmax_f32(vout0c3, vget_low_f32(vmin));
+
+ vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
+ vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
+ vst1_lane_f32(c + 2 * m + 0, vout0c2, 0);
+ vst1_lane_f32(c + 3 * m + 0, vout0c3, 0);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+
+ vst1_lane_f32(c, vout0, 1);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/8x1-neonfma-pipelined.c b/src/f32-spmm/8x1-neonfma-pipelined.c
new file mode 100644
index 0000000..ece6ad7
--- /dev/null
+++ b/src/f32-spmm/8x1-neonfma-pipelined.c
@@ -0,0 +1,151 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-pipelined.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_8x1__neonfma_pipelined(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float32x4_t vw = vld1q_dup_f32(w); w += 1;
+ intptr_t diff = *dmap++;
+ float32x4_t va0123 = vld1q_f32(a);
+ float32x4_t va4567 = vld1q_f32(a + 4);
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vw;
+ float32x4_t vacc4567 = vw;
+ vw = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vw);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vw);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = vld1q_dup_f32(w); w += 1;
+ va0123 = vld1q_f32(a);
+ va4567 = vld1q_f32(a + 4);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/8x1-neonfma-unroll2.c b/src/f32-spmm/8x1-neonfma-unroll2.c
new file mode 100644
index 0000000..a68d586
--- /dev/null
+++ b/src/f32-spmm/8x1-neonfma-unroll2.c
@@ -0,0 +1,168 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_8x1__neonfma_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123x0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123x1 = vmovq_n_f32(0.0f);
+ float32x4_t vacc4567x0 = vacc0123x0;
+ float32x4_t vacc4567x1 = vmovq_n_f32(0.0f);
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float32x4_t va0123x0 = vld1q_f32(a);
+ const float32x4_t va4567x0 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float32x4_t vb0 = vld1q_dup_f32(w); w += 1;
+ vacc0123x0 = vfmaq_f32(vacc0123x0, va0123x0, vb0);
+ vacc4567x0 = vfmaq_f32(vacc4567x0, va4567x0, vb0);
+ const float32x4_t va0123x1 = vld1q_f32(a);
+ const float32x4_t va4567x1 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float32x4_t vb1 = vld1q_dup_f32(w); w += 1;
+ vacc0123x1 = vfmaq_f32(vacc0123x1, va0123x1, vb1);
+ vacc4567x1 = vfmaq_f32(vacc4567x1, va4567x1, vb1);
+ }
+ float32x4_t vacc0123 = vacc0123x0;
+ float32x4_t vacc4567 = vacc4567x0;
+ vacc0123 = vaddq_f32(vacc0123, vacc0123x1);
+ vacc4567 = vaddq_f32(vacc4567, vacc4567x1);
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/8x1-neonfma.c b/src/f32-spmm/8x1-neonfma.c
new file mode 100644
index 0000000..d851769
--- /dev/null
+++ b/src/f32-spmm/8x1-neonfma.c
@@ -0,0 +1,145 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_8x1__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+ vst1q_f32(c, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vst1q_f32(c, vout0123);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+ vst1_f32(c, vout01);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+ vst1_lane_f32(c, vout0, 0);
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/8x1-scalar-pipelined.c b/src/f32-spmm/8x1-scalar-pipelined.c
new file mode 100644
index 0000000..e84a32c
--- /dev/null
+++ b/src/f32-spmm/8x1-scalar-pipelined.c
@@ -0,0 +1,235 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar-pipelined.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_8x1__scalar_pipelined(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while XNN_LIKELY(i >= 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ float va1 = a[1];
+ float va2 = a[2];
+ float va3 = a[3];
+ float va4 = a[4];
+ float va5 = a[5];
+ float va6 = a[6];
+ float va7 = a[7];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ float vacc1 = vw;
+ float vacc2 = vw;
+ float vacc3 = vw;
+ float vacc4 = vw;
+ float vacc5 = vw;
+ float vacc6 = vw;
+ float vacc7 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ vacc1 += va1 * vw;
+ vacc2 += va2 * vw;
+ vacc3 += va3 * vw;
+ vacc4 += va4 * vw;
+ vacc5 += va5 * vw;
+ vacc6 += va6 * vw;
+ vacc7 += va7 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ va1 = a[1];
+ va2 = a[2];
+ va3 = a[3];
+ va4 = a[4];
+ va5 = a[5];
+ va6 = a[6];
+ va7 = a[7];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ float vout4 = math_min_f32(vacc4, vmax);
+ float vout5 = math_min_f32(vacc5, vmax);
+ float vout6 = math_min_f32(vacc6, vmax);
+ float vout7 = math_min_f32(vacc7, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ vout4 = math_max_f32(vout4, vmin);
+ vout5 = math_max_f32(vout5, vmin);
+ vout6 = math_max_f32(vout6, vmin);
+ vout7 = math_max_f32(vout7, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c[4] = vout4;
+ c[5] = vout5;
+ c[6] = vout6;
+ c[7] = vout7;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ float va1 = a[1];
+ float va2 = a[2];
+ float va3 = a[3];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ float vacc1 = vw;
+ float vacc2 = vw;
+ float vacc3 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ vacc1 += va1 * vw;
+ vacc2 += va2 * vw;
+ vacc3 += va3 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ va1 = a[1];
+ va2 = a[2];
+ va3 = a[3];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ float va1 = a[1];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ float vacc1 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ vacc1 += va1 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ va1 = a[1];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ float va0 = a[0];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ vacc0 += va0 * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ va0 = a[0];
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/8x1-scalar-unroll2.c b/src/f32-spmm/8x1-scalar-unroll2.c
new file mode 100644
index 0000000..73d0ad4
--- /dev/null
+++ b/src/f32-spmm/8x1-scalar-unroll2.c
@@ -0,0 +1,345 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_8x1__scalar_unroll2(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while (i >= 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = 0.0f;
+ float vacc2x0 = vacc0x0;
+ float vacc2x1 = 0.0f;
+ float vacc3x0 = vacc0x0;
+ float vacc3x1 = 0.0f;
+ float vacc4x0 = vacc0x0;
+ float vacc4x1 = 0.0f;
+ float vacc5x0 = vacc0x0;
+ float vacc5x1 = 0.0f;
+ float vacc6x0 = vacc0x0;
+ float vacc6x1 = 0.0f;
+ float vacc7x0 = vacc0x0;
+ float vacc7x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ const float va1x0 = a[1];
+ const float va2x0 = a[2];
+ const float va3x0 = a[3];
+ const float va4x0 = a[4];
+ const float va5x0 = a[5];
+ const float va6x0 = a[6];
+ const float va7x0 = a[7];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ vacc1x0 += va1x0 * vb0;
+ vacc2x0 += va2x0 * vb0;
+ vacc3x0 += va3x0 * vb0;
+ vacc4x0 += va4x0 * vb0;
+ vacc5x0 += va5x0 * vb0;
+ vacc6x0 += va6x0 * vb0;
+ vacc7x0 += va7x0 * vb0;
+ const float va0x1 = a[0];
+ const float va1x1 = a[1];
+ const float va2x1 = a[2];
+ const float va3x1 = a[3];
+ const float va4x1 = a[4];
+ const float va5x1 = a[5];
+ const float va6x1 = a[6];
+ const float va7x1 = a[7];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ vacc1x1 += va1x1 * vb1;
+ vacc2x1 += va2x1 * vb1;
+ vacc3x1 += va3x1 * vb1;
+ vacc4x1 += va4x1 * vb1;
+ vacc5x1 += va5x1 * vb1;
+ vacc6x1 += va6x1 * vb1;
+ vacc7x1 += va7x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ float vacc1 = vacc1x0;
+ float vacc2 = vacc2x0;
+ float vacc3 = vacc3x0;
+ float vacc4 = vacc4x0;
+ float vacc5 = vacc5x0;
+ float vacc6 = vacc6x0;
+ float vacc7 = vacc7x0;
+ vacc0 += vacc0x1;
+ vacc1 += vacc1x1;
+ vacc2 += vacc2x1;
+ vacc3 += vacc3x1;
+ vacc4 += vacc4x1;
+ vacc5 += vacc5x1;
+ vacc6 += vacc6x1;
+ vacc7 += vacc7x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ const float va4 = a[4];
+ const float va5 = a[5];
+ const float va6 = a[6];
+ const float va7 = a[7];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ vacc2 += va2 * vb;
+ vacc3 += va3 * vb;
+ vacc4 += va4 * vb;
+ vacc5 += va5 * vb;
+ vacc6 += va6 * vb;
+ vacc7 += va7 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ float vout4 = math_min_f32(vacc4, vmax);
+ float vout5 = math_min_f32(vacc5, vmax);
+ float vout6 = math_min_f32(vacc6, vmax);
+ float vout7 = math_min_f32(vacc7, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ vout4 = math_max_f32(vout4, vmin);
+ vout5 = math_max_f32(vout5, vmin);
+ vout6 = math_max_f32(vout6, vmin);
+ vout7 = math_max_f32(vout7, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c[4] = vout4;
+ c[5] = vout5;
+ c[6] = vout6;
+ c[7] = vout7;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = 0.0f;
+ float vacc2x0 = vacc0x0;
+ float vacc2x1 = 0.0f;
+ float vacc3x0 = vacc0x0;
+ float vacc3x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ const float va1x0 = a[1];
+ const float va2x0 = a[2];
+ const float va3x0 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ vacc1x0 += va1x0 * vb0;
+ vacc2x0 += va2x0 * vb0;
+ vacc3x0 += va3x0 * vb0;
+ const float va0x1 = a[0];
+ const float va1x1 = a[1];
+ const float va2x1 = a[2];
+ const float va3x1 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ vacc1x1 += va1x1 * vb1;
+ vacc2x1 += va2x1 * vb1;
+ vacc3x1 += va3x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ float vacc1 = vacc1x0;
+ float vacc2 = vacc2x0;
+ float vacc3 = vacc3x0;
+ vacc0 += vacc0x1;
+ vacc1 += vacc1x1;
+ vacc2 += vacc2x1;
+ vacc3 += vacc3x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ vacc2 += va2 * vb;
+ vacc3 += va3 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ float vacc1x0 = vacc0x0;
+ float vacc1x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ const float va1x0 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ vacc1x0 += va1x0 * vb0;
+ const float va0x1 = a[0];
+ const float va1x1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ vacc1x1 += va1x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ float vacc1 = vacc1x0;
+ vacc0 += vacc0x1;
+ vacc1 += vacc1x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0x0 = *w++;
+ float vacc0x1 = 0.0f;
+ for (; nnz >= 2; nnz -= 2) {
+ const intptr_t diff0 = dmap[0];
+ const intptr_t diff1 = dmap[1];
+ dmap += 2;
+ const float va0x0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff0);
+ const float vb0 = *w++;
+ vacc0x0 += va0x0 * vb0;
+ const float va0x1 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff1);
+ const float vb1 = *w++;
+ vacc0x1 += va0x1 * vb1;
+ }
+ float vacc0 = vacc0x0;
+ vacc0 += vacc0x1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/8x1-scalar.c b/src/f32-spmm/8x1-scalar.c
new file mode 100644
index 0000000..2a1ac08
--- /dev/null
+++ b/src/f32-spmm/8x1-scalar.c
@@ -0,0 +1,204 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/scalar.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_8x1__scalar(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while (i >= 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ float vacc2 = vacc0;
+ float vacc3 = vacc0;
+ float vacc4 = vacc0;
+ float vacc5 = vacc0;
+ float vacc6 = vacc0;
+ float vacc7 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ const float va4 = a[4];
+ const float va5 = a[5];
+ const float va6 = a[6];
+ const float va7 = a[7];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ vacc2 += va2 * vb;
+ vacc3 += va3 * vb;
+ vacc4 += va4 * vb;
+ vacc5 += va5 * vb;
+ vacc6 += va6 * vb;
+ vacc7 += va7 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ float vout4 = math_min_f32(vacc4, vmax);
+ float vout5 = math_min_f32(vacc5, vmax);
+ float vout6 = math_min_f32(vacc6, vmax);
+ float vout7 = math_min_f32(vacc7, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ vout4 = math_max_f32(vout4, vmin);
+ vout5 = math_max_f32(vout5, vmin);
+ vout6 = math_max_f32(vout6, vmin);
+ vout7 = math_max_f32(vout7, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c[4] = vout4;
+ c[5] = vout5;
+ c[6] = vout6;
+ c[7] = vout7;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ float vacc2 = vacc0;
+ float vacc3 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ const float va2 = a[2];
+ const float va3 = a[3];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ vacc2 += va2 * vb;
+ vacc3 += va3 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ float vout2 = math_min_f32(vacc2, vmax);
+ float vout3 = math_min_f32(vacc3, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ vout2 = math_max_f32(vout2, vmin);
+ vout3 = math_max_f32(vout3, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c[2] = vout2;
+ c[3] = vout3;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ float vacc1 = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ vacc1 += va1 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ float vout1 = math_min_f32(vacc1, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ vout1 = math_max_f32(vout1, vmin);
+ c[0] = vout0;
+ c[1] = vout1;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ float vacc0 = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float va0 = a[0];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ vacc0 += va0 * vb;
+ } while (--nnz != 0);
+ }
+ float vout0 = math_min_f32(vacc0, vmax);
+ vout0 = math_max_f32(vout0, vmin);
+ c[0] = vout0;
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/8x1-sse.c b/src/f32-spmm/8x1-sse.c
new file mode 100644
index 0000000..a411b94
--- /dev/null
+++ b/src/f32-spmm/8x1-sse.c
@@ -0,0 +1,147 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/sse.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_8x1__sse(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ __m128 vacc0123 = _mm_load1_ps(w); w += 1;
+ __m128 vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const __m128 va0123 = _mm_loadu_ps(a);
+ const __m128 va4567 = _mm_loadu_ps(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const __m128 vb = _mm_load1_ps(w); w += 1;
+ vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(va0123, vb));
+ vacc4567 = _mm_add_ps(vacc4567, _mm_mul_ps(va4567, vb));
+ } while (--nnz != 0);
+ }
+ __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
+ __m128 vout4567 = _mm_min_ps(vacc4567, vmax);
+ vout0123 = _mm_max_ps(vout0123, vmin);
+ vout4567 = _mm_max_ps(vout4567, vmin);
+ _mm_storeu_ps(c, vout0123);
+ _mm_storeu_ps(c + 4, vout4567);
+ c += 1 * m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ __m128 vacc0123 = _mm_load1_ps(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const __m128 va0123 = _mm_loadu_ps(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const __m128 vb = _mm_load1_ps(w); w += 1;
+ vacc0123 = _mm_add_ps(vacc0123, _mm_mul_ps(va0123, vb));
+ } while (--nnz != 0);
+ }
+ __m128 vout0123 = _mm_min_ps(vacc0123, vmax);
+ vout0123 = _mm_max_ps(vout0123, vmin);
+ _mm_storeu_ps(c, vout0123);
+ c += 1 * m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ __m128 vacc01 = _mm_load_ss(w); w += 1;
+ vacc01 = _mm_unpacklo_ps(vacc01, vacc01);
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const __m128 va01 = _mm_loadl_pi(_mm_undefined_ps(), (const __m64*) a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ __m128 vb = _mm_load_ss(w); w += 1;
+ vb = _mm_unpacklo_ps(vb, vb);
+ vacc01 = _mm_add_ps(vacc01, _mm_mul_ps(va01, vb));
+ } while (--nnz != 0);
+ }
+ __m128 vout01 = _mm_min_ps(vacc01, vmax);
+ vout01 = _mm_max_ps(vout01, vmin);
+ _mm_storel_pi((__m64*) c, vout01);
+ c += 1 * m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ __m128 vacc0 = _mm_load_ss(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const __m128 va0 = _mm_load_ss(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const __m128 vb = _mm_load_ss(w); w += 1;
+ vacc0 = _mm_add_ss(vacc0, _mm_mul_ss(va0, vb));
+ } while (--nnz != 0);
+ }
+ __m128 vout0 = _mm_min_ss(vacc0, vmax);
+ vout0 = _mm_max_ss(vout0, vmin);
+ _mm_store_ss(c, vout0);
+ c += 1 * m;
+ } while (--j != 0);
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/8x2-neonfma.c b/src/f32-spmm/8x2-neonfma.c
new file mode 100644
index 0000000..07fcc6a
--- /dev/null
+++ b/src/f32-spmm/8x2-neonfma.c
@@ -0,0 +1,286 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-blocked.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_8x2__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_lane_f32(vacc4567c0, va4567, vb, 0);
+ vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_lane_f32(vacc4567c1, va4567, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0123c0 = vfmaq_lane_f32(vacc0123c0, va0123, vb, 0);
+ vacc0123c1 = vfmaq_lane_f32(vacc0123c1, va0123, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc01c0 = vfma_lane_f32(vacc01c0, va01, vb, 0);
+ vacc01c1 = vfma_lane_f32(vacc01c1, va01, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
+ float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
+
+ vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
+ vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
+
+ vst1_f32(c + 0 * m + 0, vout01c0);
+ vst1_f32(c + 1 * m + 0, vout01c1);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+
+ vst1_f32(c, vout01);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 2) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_f32(w); w += 2;
+
+ vacc0c0 = vfma_lane_f32(vacc0c0, va0, vb, 0);
+ vacc0c1 = vfma_lane_f32(vacc0c1, va0, vb, 1);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
+ float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
+
+ vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
+ vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
+
+ vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
+ vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
+ c += 2 * m;
+ j -= 2;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+
+ vst1_lane_f32(c, vout0, 1);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/8x4-neonfma.c b/src/f32-spmm/8x4-neonfma.c
new file mode 100644
index 0000000..c1fc602
--- /dev/null
+++ b/src/f32-spmm/8x4-neonfma.c
@@ -0,0 +1,336 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-spmm/neon-blocked.c.in
+// Generator: tools/xngen
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_8x4__neonfma(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= 8) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c0 = vacc0123c0;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c1 = vacc0123c1;
+ float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c2 = vacc0123c2;
+ float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567c3 = vacc0123c3;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
+ vacc4567c0 = vfmaq_laneq_f32(vacc4567c0, va4567, vb, 0);
+ vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
+ vacc4567c1 = vfmaq_laneq_f32(vacc4567c1, va4567, vb, 1);
+ vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
+ vacc4567c2 = vfmaq_laneq_f32(vacc4567c2, va4567, vb, 2);
+ vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
+ vacc4567c3 = vfmaq_laneq_f32(vacc4567c3, va4567, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout4567c0 = vminq_f32(vacc4567c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout4567c1 = vminq_f32(vacc4567c1, vmax);
+ float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
+ float32x4_t vout4567c2 = vminq_f32(vacc4567c2, vmax);
+ float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
+ float32x4_t vout4567c3 = vminq_f32(vacc4567c3, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout4567c0 = vmaxq_f32(vout4567c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout4567c1 = vmaxq_f32(vout4567c1, vmin);
+ vout0123c2 = vmaxq_f32(vout0123c2, vmin);
+ vout4567c2 = vmaxq_f32(vout4567c2, vmin);
+ vout0123c3 = vmaxq_f32(vout0123c3, vmin);
+ vout4567c3 = vmaxq_f32(vout4567c3, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 0 * m + 4, vout4567c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 1 * m + 4, vout4567c1);
+ vst1q_f32(c + 2 * m + 0, vout0123c2);
+ vst1q_f32(c + 2 * m + 4, vout4567c2);
+ vst1q_f32(c + 3 * m + 0, vout0123c3);
+ vst1q_f32(c + 3 * m + 4, vout4567c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc4567 = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ const float32x4_t va4567 = vld1q_f32(a + 4);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ vacc4567 = vfmaq_f32(vacc4567, va4567, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+ float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+ vout4567 = vmaxq_f32(vout4567, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ vst1q_f32(c + 4, vout4567);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 8;
+ a += 8;
+ i -= 8;
+ }
+ if XNN_UNLIKELY(i != 0) {
+ if (i & 4) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123c0 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c1 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c2 = vld1q_dup_f32(w); w += 1;
+ float32x4_t vacc0123c3 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0123c0 = vfmaq_laneq_f32(vacc0123c0, va0123, vb, 0);
+ vacc0123c1 = vfmaq_laneq_f32(vacc0123c1, va0123, vb, 1);
+ vacc0123c2 = vfmaq_laneq_f32(vacc0123c2, va0123, vb, 2);
+ vacc0123c3 = vfmaq_laneq_f32(vacc0123c3, va0123, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123c0 = vminq_f32(vacc0123c0, vmax);
+ float32x4_t vout0123c1 = vminq_f32(vacc0123c1, vmax);
+ float32x4_t vout0123c2 = vminq_f32(vacc0123c2, vmax);
+ float32x4_t vout0123c3 = vminq_f32(vacc0123c3, vmax);
+
+ vout0123c0 = vmaxq_f32(vout0123c0, vmin);
+ vout0123c1 = vmaxq_f32(vout0123c1, vmin);
+ vout0123c2 = vmaxq_f32(vout0123c2, vmin);
+ vout0123c3 = vmaxq_f32(vout0123c3, vmin);
+
+ vst1q_f32(c + 0 * m + 0, vout0123c0);
+ vst1q_f32(c + 1 * m + 0, vout0123c1);
+ vst1q_f32(c + 2 * m + 0, vout0123c2);
+ vst1q_f32(c + 3 * m + 0, vout0123c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ vacc0123 = vfmaq_f32(vacc0123, va0123, vb);
+ } while (--nnz != 0);
+ }
+ float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
+
+ vout0123 = vmaxq_f32(vout0123, vmin);
+
+ vst1q_f32(c + 0, vout0123);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 4;
+ a += 4;
+ }
+ if (i & 2) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c1 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c2 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc01c3 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc01c0 = vfma_laneq_f32(vacc01c0, va01, vb, 0);
+ vacc01c1 = vfma_laneq_f32(vacc01c1, va01, vb, 1);
+ vacc01c2 = vfma_laneq_f32(vacc01c2, va01, vb, 2);
+ vacc01c3 = vfma_laneq_f32(vacc01c3, va01, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01c0 = vmin_f32(vacc01c0, vget_low_f32(vmax));
+ float32x2_t vout01c1 = vmin_f32(vacc01c1, vget_low_f32(vmax));
+ float32x2_t vout01c2 = vmin_f32(vacc01c2, vget_low_f32(vmax));
+ float32x2_t vout01c3 = vmin_f32(vacc01c3, vget_low_f32(vmax));
+
+ vout01c0 = vmax_f32(vout01c0, vget_low_f32(vmin));
+ vout01c1 = vmax_f32(vout01c1, vget_low_f32(vmin));
+ vout01c2 = vmax_f32(vout01c2, vget_low_f32(vmin));
+ vout01c3 = vmax_f32(vout01c3, vget_low_f32(vmin));
+
+ vst1_f32(c + 0 * m + 0, vout01c0);
+ vst1_f32(c + 1 * m + 0, vout01c1);
+ vst1_f32(c + 2 * m + 0, vout01c2);
+ vst1_f32(c + 3 * m + 0, vout01c3);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va01 = vld1_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc01 = vfma_f32(vacc01, va01, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
+ vout01 = vmax_f32(vout01, vget_low_f32(vmin));
+
+ vst1_f32(c, vout01);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 2;
+ a += 2;
+ }
+ if (i & 1) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= 4) {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0c0 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c1 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c2 = vld1_dup_f32(w); w += 1;
+ float32x2_t vacc0c3 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ vacc0c0 = vfma_laneq_f32(vacc0c0, va0, vb, 0);
+ vacc0c1 = vfma_laneq_f32(vacc0c1, va0, vb, 1);
+ vacc0c2 = vfma_laneq_f32(vacc0c2, va0, vb, 2);
+ vacc0c3 = vfma_laneq_f32(vacc0c3, va0, vb, 3);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0c0 = vmin_f32(vacc0c0, vget_low_f32(vmax));
+ float32x2_t vout0c1 = vmin_f32(vacc0c1, vget_low_f32(vmax));
+ float32x2_t vout0c2 = vmin_f32(vacc0c2, vget_low_f32(vmax));
+ float32x2_t vout0c3 = vmin_f32(vacc0c3, vget_low_f32(vmax));
+
+ vout0c0 = vmax_f32(vout0c0, vget_low_f32(vmin));
+ vout0c1 = vmax_f32(vout0c1, vget_low_f32(vmin));
+ vout0c2 = vmax_f32(vout0c2, vget_low_f32(vmin));
+ vout0c3 = vmax_f32(vout0c3, vget_low_f32(vmin));
+
+ vst1_lane_f32(c + 0 * m + 0, vout0c0, 0);
+ vst1_lane_f32(c + 1 * m + 0, vout0c1, 0);
+ vst1_lane_f32(c + 2 * m + 0, vout0c2, 0);
+ vst1_lane_f32(c + 3 * m + 0, vout0c3, 0);
+ c += 4 * m;
+ j -= 4;
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x2_t va0 = vld1_dup_f32(a);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc0 = vfma_f32(vacc0, va0, vb);
+ } while (--nnz != 0);
+ }
+ float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
+ vout0 = vmax_f32(vout0, vget_low_f32(vmin));
+
+ vst1_lane_f32(c, vout0, 1);
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += 1;
+ a += 1;
+ }
+ }
+}
diff --git a/src/f32-spmm/neon-blocked.c.in b/src/f32-spmm/neon-blocked.c.in
new file mode 100644
index 0000000..104816d
--- /dev/null
+++ b/src/f32-spmm/neon-blocked.c.in
@@ -0,0 +1,253 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert MR % 4 == 0
+$assert NR in [1, 2, 4]
+$ABC = "0123456789ABCDEFGHIJK"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= ${MR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= ${NR}) {
+ uint32_t nnz = *nnzmap++;
+ $for N in range(0, NR, 1):
+ float32x4_t vacc${ABC[0:4]}c${N} = vld1q_dup_f32(w); w += 1;
+ $for M in range(4, MR, 4):
+ float32x4_t vacc${ABC[M:M+4]}c${N} = vacc${ABC[0:4]}c${N};
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va${ABC[0:4]} = vld1q_f32(a);
+ $for M in range(4, MR, 4):
+ const float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ $if NR == 1:
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ $elif NR == 2:
+ const float32x2_t vb = vld1_f32(w); w += 2;
+ $elif NR == 4:
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ $if NR == 1:
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]}c0 = vfmaq_f32(vacc${ABC[M:M+4]}c0, va${ABC[M:M+4]}, vb);
+ $else:
+ $for N in range(NR):
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]}c${N} = vfmaq_lane${"q" if NR == 4 else ""}_f32(vacc${ABC[M:M+4]}c${N}, va${ABC[M:M+4]}, vb, ${N});
+ } while (--nnz != 0);
+ }
+ $for N in range(0, NR, 1):
+ $for M in range(0, MR, 4):
+ float32x4_t vout${ABC[M:M+4]}c${N} = vminq_f32(vacc${ABC[M:M+4]}c${N}, vmax);
+
+ $for N in range(0, NR, 1):
+ $for M in range(0, MR, 4):
+ vout${ABC[M:M+4]}c${N} = vmaxq_f32(vout${ABC[M:M+4]}c${N}, vmin);
+
+ $for N in range(0, NR, 1):
+ $for M in range(0, MR, 4):
+ vst1q_f32(c + ${N} * m + ${M}, vout${ABC[M:M+4]}c${N});
+ c += ${NR} * m;
+ j -= ${NR};
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ float32x4_t vacc${ABC[0:4]} = vld1q_dup_f32(w); w += 1;
+ $for M in range(4, MR, 4):
+ float32x4_t vacc${ABC[M:M+4]} = vacc${ABC[0:4]};
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va${ABC[0:4]} = vld1q_f32(a);
+ $for M in range(4, MR, 4):
+ const float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, va${ABC[M:M+4]}, vb);
+ } while (--nnz != 0);
+ }
+ $for M in range(0, MR, 4):
+ float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
+
+ $for M in range(0, MR, 4):
+ vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
+
+ $for M in range(0, MR, 4):
+ vst1q_f32(c + ${M}, vout${ABC[M:M+4]});
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += ${MR};
+ a += ${MR};
+ i -= ${MR};
+ }
+ if XNN_UNLIKELY(i != 0) {
+ $for LOG2M in reversed(range((MR - 1).bit_length())):
+ $SUBMR = 1 << LOG2M
+ if (i & ${SUBMR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ while (j >= ${NR}) {
+ uint32_t nnz = *nnzmap++;
+ $for N in range(0, NR, 1):
+ $if SUBMR < 4:
+ float32x2_t vacc${ABC[0:SUBMR]}c${N} = vld1_dup_f32(w); w += 1;
+ $else:
+ float32x4_t vacc${ABC[0:4]}c${N} = vld1q_dup_f32(w); w += 1;
+ $for M in range(4, SUBMR, 4):
+ float32x4_t vacc${ABC[M:M+4]}c${N} = vacc${ABC[0:4]}c${N};
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ $if SUBMR == 1:
+ const float32x2_t va${ABC[0]} = vld1_dup_f32(a);
+ $elif SUBMR == 2:
+ const float32x2_t va${ABC[0:2]} = vld1_f32(a);
+ $else:
+ const float32x4_t va${ABC[0:4]} = vld1q_f32(a);
+ $for M in range(4, SUBMR, 4):
+ const float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ $if NR == 1:
+ $if SUBMR < 4:
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ $else:
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ $elif NR == 2:
+ const float32x2_t vb = vld1_f32(w); w += 2;
+ $elif NR == 4:
+ const float32x4_t vb = vld1q_f32(w); w += 4;
+
+ $if NR == 1:
+ $if SUBMR < 4:
+ vacc${ABC[0:SUBMR]}c0 = vfmaq_f32(vacc${ABC[0:SUBMR]}c0, va${ABC[0:SUBMR]}, vb);
+ $else:
+ $for M in range(0, SUBMR, 4):
+ vacc${ABC[M:M+4]}c0 = vfmaq_f32(vacc${ABC[M:M+4]}c0, va${ABC[M:M+4]}, vb);
+ $else:
+ $for N in range(NR):
+ $if SUBMR < 4:
+ vacc${ABC[0:SUBMR]}c${N} = vfma_lane${"q" if NR == 4 else ""}_f32(vacc${ABC[0:SUBMR]}c${N}, va${ABC[0:SUBMR]}, vb, ${N});
+ $else:
+ $for M in range(0, SUBMR, 4):
+ vacc${ABC[M:M+4]}c${N} = vfmaq_lane${"q" if NR == 4 else ""}_f32(vacc${ABC[M:M+4]}c${N}, va${ABC[M:M+4]}, vb, ${N});
+ } while (--nnz != 0);
+ }
+ $for N in range(0, NR, 1):
+ $if SUBMR < 4:
+ float32x2_t vout${ABC[0:SUBMR]}c${N} = vmin_f32(vacc${ABC[0:SUBMR]}c${N}, vget_low_f32(vmax));
+ $else:
+ $for M in range(0, SUBMR, 4):
+ float32x4_t vout${ABC[M:M+4]}c${N} = vminq_f32(vacc${ABC[M:M+4]}c${N}, vmax);
+
+ $for N in range(0, NR, 1):
+ $if SUBMR < 4:
+ vout${ABC[0:SUBMR]}c${N} = vmax_f32(vout${ABC[0:SUBMR]}c${N}, vget_low_f32(vmin));
+ $else:
+ $for M in range(0, SUBMR, 4):
+ vout${ABC[M:M+4]}c${N} = vmaxq_f32(vout${ABC[M:M+4]}c${N}, vmin);
+
+ $for N in range(0, NR, 1):
+ $if SUBMR == 1:
+ vst1_lane_f32(c + ${N} * m + ${M}, vout${ABC[0:SUBMR]}c${N}, 0);
+ $elif SUBMR == 2:
+ vst1_f32(c + ${N} * m + ${M}, vout${ABC[0:SUBMR]}c${N});
+ $else:
+ $for M in range(0, SUBMR, 4):
+ vst1q_f32(c + ${N} * m + ${M}, vout${ABC[M:M+4]}c${N});
+ c += ${NR} * m;
+ j -= ${NR};
+ }
+
+ // clean up loop, fall back to nr=1
+ if XNN_UNLIKELY(j != 0) {
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if SUBMR < 4:
+ float32x2_t vacc${ABC[0:SUBMR]} = vld1_dup_f32(w); w += 1;
+ $else:
+ float32x4_t vacc${ABC[0:4]} = vld1q_dup_f32(w); w += 1;
+ $for M in range(4, SUBMR, 4):
+ float32x4_t vacc${ABC[M:M+4]} = vacc${ABC[0:4]};
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ $if SUBMR == 1:
+ const float32x2_t va${ABC[0:1]} = vld1_dup_f32(a);
+ $elif SUBMR == 2:
+ const float32x2_t va${ABC[0:2]} = vld1_f32(a);
+ $else:
+ const float32x4_t va${ABC[0:4]} = vld1q_f32(a);
+ $for M in range(4, SUBMR, 4):
+ const float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ $if SUBMR < 4:
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ vacc${ABC[0:SUBMR]} = vfma_f32(vacc${ABC[0:SUBMR]}, va${ABC[0:SUBMR]}, vb);
+ $else:
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ $for M in range(0, SUBMR, 4):
+ vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, va${ABC[M:M+4]}, vb);
+ } while (--nnz != 0);
+ }
+ $if SUBMR < 4:
+ float32x2_t vout${ABC[0:SUBMR]} = vmin_f32(vacc${ABC[0:SUBMR]}, vget_low_f32(vmax));
+ vout${ABC[0:SUBMR]} = vmax_f32(vout${ABC[0:SUBMR]}, vget_low_f32(vmin));
+ $else:
+ $for M in range(0, SUBMR, 4):
+ float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
+
+ $for M in range(0, SUBMR, 4):
+ vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
+
+ $if SUBMR == 1:
+ vst1_lane_f32(c, vout${ABC[0:1]}, 1);
+ $elif SUBMR == 2:
+ vst1_f32(c, vout${ABC[0:2]});
+ $else:
+ $for M in range(0, SUBMR, 4):
+ vst1q_f32(c + ${M}, vout${ABC[M:M+4]});
+ c += m;
+ j -= 1;
+ } while (j != 0);
+ }
+ c -= m * n;
+ c += ${SUBMR};
+ a += ${SUBMR};
+ }
+ }
+}
diff --git a/src/f32-spmm/neon-pipelined.c.in b/src/f32-spmm/neon-pipelined.c.in
new file mode 100644
index 0000000..99a5a9e
--- /dev/null
+++ b/src/f32-spmm/neon-pipelined.c.in
@@ -0,0 +1,133 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert MR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJK"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}_pipelined(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= ${MR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float32x4_t vw = vld1q_dup_f32(w); w += 1;
+ intptr_t diff = *dmap++;
+ float32x4_t va0123 = vld1q_f32(a);
+ $for M in range(4, MR, 4):
+ float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M});
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $for M in range(0, MR, 4):
+ float32x4_t vacc${ABC[M:M+4]} = vw;
+ vw = vld1q_dup_f32(w); w += 1;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, va${ABC[M:M+4]}, vw);
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = vld1q_dup_f32(w); w += 1;
+ va0123 = vld1q_f32(a);
+ $for M in range(4, MR, 4):
+ va${ABC[M:M+4]} = vld1q_f32(a + ${M});
+ } while (--nnz != 0);
+ }
+ $for M in range(0, MR, 4):
+ float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
+ $for M in range(0, MR, 4):
+ vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
+ vst1q_f32(c, vout0123);
+ $for M in range(4, MR, 4):
+ vst1q_f32(c + ${M}, vout${ABC[M:M+4]});
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${MR};
+ a += ${MR};
+ i -= ${MR};
+ }
+ if XNN_UNLIKELY(i != 0) {
+ $for LOG2M in reversed(range((MR - 1).bit_length())):
+ $SUBMR = 1 << LOG2M
+ if (i & ${SUBMR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if SUBMR <= 2:
+ float32x2_t vacc${ABC[0:SUBMR]} = vld1_dup_f32(w); w += 1;
+ $else:
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ $for M in range(4, SUBMR, 4):
+ float32x4_t vacc${ABC[M:M+4]} = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ $if SUBMR == 1:
+ const float32x2_t va0 = vld1_dup_f32(a);
+ $elif SUBMR == 2:
+ const float32x2_t va01 = vld1_f32(a);
+ $else:
+ const float32x4_t va0123 = vld1q_f32(a);
+ $for M in range(4, SUBMR, 4):
+ const float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ $if SUBMR <= 2:
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ $else:
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ $if SUBMR <= 2:
+ vacc${ABC[0:SUBMR]} = vfma_f32(vacc${ABC[0:SUBMR]}, va${ABC[0:SUBMR]}, vb);
+ $else:
+ $for M in range(0, SUBMR, 4):
+ vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, va${ABC[M:M+4]}, vb);
+ } while (--nnz != 0);
+ }
+ $if SUBMR <= 2:
+ float32x2_t vout${ABC[0:SUBMR]} = vmin_f32(vacc${ABC[0:SUBMR]}, vget_low_f32(vmax));
+ vout${ABC[0:SUBMR]} = vmax_f32(vout${ABC[0:SUBMR]}, vget_low_f32(vmin));
+ $if SUBMR == 1:
+ vst1_lane_f32(c, vout${ABC[0]}, 0);
+ $else:
+ vst1_f32(c, vout${ABC[0:SUBMR]});
+ $else:
+ $for M in range(0, SUBMR, 4):
+ float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
+ $for M in range(0, SUBMR, 4):
+ vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
+ vst1q_f32(c, vout0123);
+ $for M in range(4, SUBMR, 4):
+ vst1q_f32(c + ${M}, vout${ABC[M:M+4]});
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${SUBMR};
+ a += ${SUBMR};
+ }
+ }
+}
diff --git a/src/f32-spmm/neon.c.in b/src/f32-spmm/neon.c.in
new file mode 100644
index 0000000..aabc871
--- /dev/null
+++ b/src/f32-spmm/neon.c.in
@@ -0,0 +1,154 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert MR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJK"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}${"_unroll" + str(UNROLL) if UNROLL > 1 else ""}(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= ${MR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if UNROLL > 1:
+ float32x4_t vacc0123x0 = vld1q_dup_f32(w); w += 1;
+ $for K in range(1, UNROLL):
+ float32x4_t vacc0123x${K} = vmovq_n_f32(0.0f);
+ $for M in range(4, MR, 4):
+ float32x4_t vacc${ABC[M:M+4]}x0 = vacc0123x0;
+ $for K in range(1, UNROLL):
+ float32x4_t vacc${ABC[M:M+4]}x${K} = vmovq_n_f32(0.0f);
+ for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
+ $for K in range(UNROLL):
+ const intptr_t diff${K} = dmap[${K}];
+ dmap += ${UNROLL};
+ $for K in range(UNROLL):
+ const float32x4_t va0123x${K} = vld1q_f32(a);
+ $for M in range(4, MR, 4):
+ const float32x4_t va${ABC[M:M+4]}x${K} = vld1q_f32(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff${K});
+ const float32x4_t vb${K} = vld1q_dup_f32(w); w += 1;
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]}x${K} = vfmaq_f32(vacc${ABC[M:M+4]}x${K}, va${ABC[M:M+4]}x${K}, vb${K});
+ }
+ $for M in range(0, MR, 4):
+ float32x4_t vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0;
+ $for K in range(1, UNROLL):
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]} = vaddq_f32(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K});
+ $else:
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ $for M in range(4, MR, 4):
+ float32x4_t vacc${ABC[M:M+4]} = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const float32x4_t va0123 = vld1q_f32(a);
+ $for M in range(4, MR, 4):
+ const float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, va${ABC[M:M+4]}, vb);
+ } while (--nnz != 0);
+ }
+ $for M in range(0, MR, 4):
+ float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
+ $for M in range(0, MR, 4):
+ vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
+ vst1q_f32(c, vout0123);
+ $for M in range(4, MR, 4):
+ vst1q_f32(c + ${M}, vout${ABC[M:M+4]});
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${MR};
+ a += ${MR};
+ i -= ${MR};
+ }
+ if XNN_UNLIKELY(i != 0) {
+ $for LOG2M in reversed(range((MR - 1).bit_length())):
+ $SUBMR = 1 << LOG2M
+ if (i & ${SUBMR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if SUBMR <= 2:
+ float32x2_t vacc${ABC[0:SUBMR]} = vld1_dup_f32(w); w += 1;
+ $else:
+ float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
+ $for M in range(4, SUBMR, 4):
+ float32x4_t vacc${ABC[M:M+4]} = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ $if SUBMR == 1:
+ const float32x2_t va0 = vld1_dup_f32(a);
+ $elif SUBMR == 2:
+ const float32x2_t va01 = vld1_f32(a);
+ $else:
+ const float32x4_t va0123 = vld1q_f32(a);
+ $for M in range(4, SUBMR, 4):
+ const float32x4_t va${ABC[M:M+4]} = vld1q_f32(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ $if SUBMR <= 2:
+ const float32x2_t vb = vld1_dup_f32(w); w += 1;
+ $else:
+ const float32x4_t vb = vld1q_dup_f32(w); w += 1;
+ $if SUBMR <= 2:
+ vacc${ABC[0:SUBMR]} = vfma_f32(vacc${ABC[0:SUBMR]}, va${ABC[0:SUBMR]}, vb);
+ $else:
+ $for M in range(0, SUBMR, 4):
+ vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, va${ABC[M:M+4]}, vb);
+ } while (--nnz != 0);
+ }
+ $if SUBMR <= 2:
+ float32x2_t vout${ABC[0:SUBMR]} = vmin_f32(vacc${ABC[0:SUBMR]}, vget_low_f32(vmax));
+ vout${ABC[0:SUBMR]} = vmax_f32(vout${ABC[0:SUBMR]}, vget_low_f32(vmin));
+ $if SUBMR == 1:
+ vst1_lane_f32(c, vout${ABC[0]}, 0);
+ $else:
+ vst1_f32(c, vout${ABC[0:SUBMR]});
+ $else:
+ $for M in range(0, SUBMR, 4):
+ float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
+ $for M in range(0, SUBMR, 4):
+ vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
+ vst1q_f32(c, vout0123);
+ $for M in range(4, SUBMR, 4):
+ vst1q_f32(c + ${M}, vout${ABC[M:M+4]});
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${SUBMR};
+ a += ${SUBMR};
+ }
+ }
+}
diff --git a/src/f32-spmm/scalar-pipelined.c.in b/src/f32-spmm/scalar-pipelined.c.in
new file mode 100644
index 0000000..290227d
--- /dev/null
+++ b/src/f32-spmm/scalar-pipelined.c.in
@@ -0,0 +1,109 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$ABC = "0123456789ABCDEFGHIJK"
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_${MR}x${NR}__scalar_pipelined(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while XNN_LIKELY(i >= ${MR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ $for M in range(MR):
+ float va${ABC[M]} = a[${M}];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $for M in range(MR):
+ float vacc${ABC[M]} = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ $for M in range(MR):
+ vacc${ABC[M]} += va${ABC[M]} * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ $for M in range(MR):
+ va${ABC[M]} = a[${M}];
+ } while (--nnz != 0);
+ }
+ $for M in range(MR):
+ float vout${ABC[M]} = math_min_f32(vacc${ABC[M]}, vmax);
+ $for M in range(MR):
+ vout${ABC[M]} = math_max_f32(vout${ABC[M]}, vmin);
+ $for M in range(MR):
+ c[${M}] = vout${ABC[M]};
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${MR};
+ a += ${MR};
+ i -= ${MR};
+ }
+ if XNN_UNLIKELY(i != 0) {
+ $for LOG2M in reversed(range((MR - 1).bit_length())):
+ $SUBMR = 1 << LOG2M
+ if (i & ${SUBMR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ float vw = *w++;
+ intptr_t diff = *dmap++;
+ $for M in range(SUBMR):
+ float va${ABC[M]} = a[${M}];
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $for M in range(SUBMR):
+ float vacc${ABC[M]} = vw;
+ vw = *w++;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ $for M in range(SUBMR):
+ vacc${ABC[M]} += va${ABC[M]} * vw;
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+
+ diff = *dmap++;
+ vw = *w++;
+ $for M in range(SUBMR):
+ va${ABC[M]} = a[${M}];
+ } while (--nnz != 0);
+ }
+ $for M in range(SUBMR):
+ float vout${ABC[M]} = math_min_f32(vacc${ABC[M]}, vmax);
+ $for M in range(SUBMR):
+ vout${ABC[M]} = math_max_f32(vout${ABC[M]}, vmin);
+ $for M in range(SUBMR):
+ c[${M}] = vout${ABC[M]};
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${SUBMR};
+ a += ${SUBMR};
+ }
+ }
+}
diff --git a/src/f32-spmm/scalar.c.in b/src/f32-spmm/scalar.c.in
new file mode 100644
index 0000000..16aa159
--- /dev/null
+++ b/src/f32-spmm/scalar.c.in
@@ -0,0 +1,151 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$ABC = "0123456789ABCDEFGHIJK"
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_${MR}x${NR}__scalar${"_unroll" + str(UNROLL) if UNROLL > 1 else ""}(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ size_t i = m;
+ while (i >= ${MR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if UNROLL > 1:
+ float vacc0x0 = *w++;
+ $for K in range(1, UNROLL):
+ float vacc0x${K} = 0.0f;
+ $for M in range(1, MR):
+ float vacc${ABC[M]}x0 = vacc0x0;
+ $for K in range(1, UNROLL):
+ float vacc${ABC[M]}x${K} = 0.0f;
+ for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
+ $for K in range(UNROLL):
+ const intptr_t diff${K} = dmap[${K}];
+ dmap += ${UNROLL};
+ $for K in range(UNROLL):
+ $for M in range(MR):
+ const float va${ABC[M]}x${K} = a[${M}];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff${K});
+ const float vb${K} = *w++;
+ $for M in range(0, MR):
+ vacc${ABC[M]}x${K} += va${ABC[M]}x${K} * vb${K};
+ }
+ $for M in range(MR):
+ float vacc${ABC[M]} = vacc${ABC[M]}x0;
+ $for K in range(1, UNROLL):
+ $for M in range(MR):
+ vacc${ABC[M]} += vacc${ABC[M]}x${K};
+ $else:
+ float vacc0 = *w++;
+ $for M in range(1, MR):
+ float vacc${ABC[M]} = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ $for M in range(MR):
+ const float va${ABC[M]} = a[${M}];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ $for M in range(MR):
+ vacc${ABC[M]} += va${ABC[M]} * vb;
+ } while (--nnz != 0);
+ }
+ $for M in range(MR):
+ float vout${ABC[M]} = math_min_f32(vacc${ABC[M]}, vmax);
+ $for M in range(MR):
+ vout${ABC[M]} = math_max_f32(vout${ABC[M]}, vmin);
+ $for M in range(MR):
+ c[${M}] = vout${ABC[M]};
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${MR};
+ a += ${MR};
+ i -= ${MR};
+ }
+ if XNN_UNLIKELY(i != 0) {
+ $for LOG2M in reversed(range((MR - 1).bit_length())):
+ $SUBMR = 1 << LOG2M
+ if (i & ${SUBMR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if UNROLL > 1:
+ float vacc0x0 = *w++;
+ $for K in range(1, UNROLL):
+ float vacc0x${K} = 0.0f;
+ $for M in range(1, SUBMR):
+ float vacc${ABC[M]}x0 = vacc0x0;
+ $for K in range(1, UNROLL):
+ float vacc${ABC[M]}x${K} = 0.0f;
+ for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
+ $for K in range(UNROLL):
+ const intptr_t diff${K} = dmap[${K}];
+ dmap += ${UNROLL};
+ $for K in range(UNROLL):
+ $for M in range(SUBMR):
+ const float va${ABC[M]}x${K} = a[${M}];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff${K});
+ const float vb${K} = *w++;
+ $for M in range(0, SUBMR):
+ vacc${ABC[M]}x${K} += va${ABC[M]}x${K} * vb${K};
+ }
+ $for M in range(SUBMR):
+ float vacc${ABC[M]} = vacc${ABC[M]}x0;
+ $for K in range(1, UNROLL):
+ $for M in range(SUBMR):
+ vacc${ABC[M]} += vacc${ABC[M]}x${K};
+ $else:
+ float vacc0 = *w++;
+ $for M in range(1, SUBMR):
+ float vacc${ABC[M]} = vacc0;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ $for M in range(SUBMR):
+ const float va${ABC[M]} = a[${M}];
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const float vb = *w++;
+ $for M in range(SUBMR):
+ vacc${ABC[M]} += va${ABC[M]} * vb;
+ } while (--nnz != 0);
+ }
+ $for M in range(SUBMR):
+ float vout${ABC[M]} = math_min_f32(vacc${ABC[M]}, vmax);
+ $for M in range(SUBMR):
+ vout${ABC[M]} = math_max_f32(vout${ABC[M]}, vmin);
+ $for M in range(SUBMR):
+ c[${M}] = vout${ABC[M]};
+ c += m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${SUBMR};
+ a += ${SUBMR};
+ }
+ }
+}
diff --git a/src/f32-spmm/sse.c.in b/src/f32-spmm/sse.c.in
new file mode 100644
index 0000000..29689bb
--- /dev/null
+++ b/src/f32-spmm/sse.c.in
@@ -0,0 +1,163 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+$assert MR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJK"
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/spmm.h>
+
+
+void xnn_f32_spmm_ukernel_${MR}x${NR}__sse${"_unroll" + str(UNROLL) if UNROLL > 1 else ""}(
+ uint32_t m,
+ uint32_t n,
+ const float*restrict a,
+ const float*restrict weights,
+ const int32_t*restrict widx_dmap,
+ const uint32_t*restrict nidx_nnzmap,
+ float*restrict c,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ size_t i = m;
+ while XNN_LIKELY(i >= ${MR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if UNROLL > 1:
+ __m128 vacc0123x0 = _mm_load1_ps(w);
+ w += 1;
+ $for K in range(1, UNROLL):
+ __m128 vacc0123x${K} = _mm_setzero_ps();
+ $for M in range(4, MR, 4):
+ __m128 vacc${ABC[M:M+4]}x0 = vacc0123x0;
+ $for K in range(1, UNROLL):
+ __m128 vacc${ABC[M:M+4]}x${K} = _mm_setzero_ps();
+ for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
+ $for K in range(UNROLL):
+ const intptr_t diff${K} = dmap[${K}];
+ dmap += ${UNROLL};
+ $for K in range(UNROLL):
+ const __m128 va0123x${K} = _mm_loadu_ps(a);
+ $for M in range(4, MR, 4):
+ const __m128 va${ABC[M:M+4]}x${K} = _mm_loadu_ps(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff${K});
+ const __m128 vb${K} = _mm_load1_ps(w);
+ w += 1;
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]}x${K} = _mm_add_ps(vacc${ABC[M:M+4]}x${K}, _mm_mul_ps(va${ABC[M:M+4]}x${K}, vb${K}));
+ }
+ $for M in range(0, MR, 4):
+ __m128 vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0;
+ $for K in range(1, UNROLL):
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]} = _mm_add_ps(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K});
+ $else:
+ __m128 vacc0123 = _mm_load1_ps(w); w += 1;
+ $for M in range(4, MR, 4):
+ __m128 vacc${ABC[M:M+4]} = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ const __m128 va0123 = _mm_loadu_ps(a);
+ $for M in range(4, MR, 4):
+ const __m128 va${ABC[M:M+4]} = _mm_loadu_ps(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ const __m128 vb = _mm_load1_ps(w); w += 1;
+ $for M in range(0, MR, 4):
+ vacc${ABC[M:M+4]} = _mm_add_ps(vacc${ABC[M:M+4]}, _mm_mul_ps(va${ABC[M:M+4]}, vb));
+ } while (--nnz != 0);
+ }
+ $for M in range(0, MR, 4):
+ __m128 vout${ABC[M:M+4]} = _mm_min_ps(vacc${ABC[M:M+4]}, vmax);
+ $for M in range(0, MR, 4):
+ vout${ABC[M:M+4]} = _mm_max_ps(vout${ABC[M:M+4]}, vmin);
+ _mm_storeu_ps(c, vout0123);
+ $for M in range(4, MR, 4):
+ _mm_storeu_ps(c + ${M}, vout${ABC[M:M+4]});
+ c += ${NR} * m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${MR};
+ a += ${MR};
+ i -= ${MR};
+ }
+ if XNN_UNLIKELY(i != 0) {
+ $for LOG2M in reversed(range((MR - 1).bit_length())):
+ $SUBMR = 1 << LOG2M
+ if (i & ${SUBMR}) {
+ const float*restrict w = weights;
+ const int32_t* dmap = widx_dmap;
+ const uint32_t* nnzmap = nidx_nnzmap;
+ size_t j = n;
+ do {
+ uint32_t nnz = *nnzmap++;
+ $if SUBMR == 1:
+ __m128 vacc0 = _mm_load_ss(w); w += 1;
+ $elif SUBMR == 2:
+ __m128 vacc01 = _mm_load_ss(w); w += 1;
+ vacc01 = _mm_unpacklo_ps(vacc01, vacc01);
+ $else:
+ __m128 vacc0123 = _mm_load1_ps(w); w += 1;
+ $for M in range(4, SUBMR, 4):
+ __m128 vacc${ABC[M:M+4]} = vacc0123;
+ if XNN_LIKELY(nnz != 0) {
+ do {
+ const intptr_t diff = *dmap++;
+ $if SUBMR >= 4:
+ const __m128 va0123 = _mm_loadu_ps(a);
+ $elif SUBMR == 2:
+ const __m128 va01 = _mm_loadl_pi(_mm_undefined_ps(), (const __m64*) a);
+ $elif SUBMR == 1:
+ const __m128 va0 = _mm_load_ss(a);
+ $for M in range(4, SUBMR, 4):
+ const __m128 va${ABC[M:M+4]} = _mm_loadu_ps(a + ${M});
+ a = (const float*restrict) ((uintptr_t) a + (uintptr_t) diff);
+ $if SUBMR >= 4:
+ const __m128 vb = _mm_load1_ps(w); w += 1;
+ $elif SUBMR == 2:
+ __m128 vb = _mm_load_ss(w); w += 1;
+ vb = _mm_unpacklo_ps(vb, vb);
+ $else:
+ const __m128 vb = _mm_load_ss(w); w += 1;
+ $if SUBMR == 1:
+ vacc${ABC[0]} = _mm_add_ss(vacc${ABC[0]}, _mm_mul_ss(va${ABC[0]}, vb));
+ $else:
+ $for M in range(0, SUBMR, 4):
+ vacc${ABC[M:min(M+4,SUBMR)]} = _mm_add_ps(vacc${ABC[M:min(M+4,SUBMR)]}, _mm_mul_ps(va${ABC[M:min(M+4,SUBMR)]}, vb));
+ } while (--nnz != 0);
+ }
+ $if SUBMR == 1:
+ __m128 vout${ABC[0]} = _mm_min_ss(vacc${ABC[0]}, vmax);
+ vout${ABC[0]} = _mm_max_ss(vout${ABC[0]}, vmin);
+ $else:
+ $for M in range(0, SUBMR, 4):
+ __m128 vout${ABC[M:min(M+4,SUBMR)]} = _mm_min_ps(vacc${ABC[M:min(M+4,SUBMR)]}, vmax);
+ $for M in range(0, SUBMR, 4):
+ vout${ABC[M:min(M+4,SUBMR)]} = _mm_max_ps(vout${ABC[M:min(M+4,SUBMR)]}, vmin);
+ $if SUBMR >= 4:
+ _mm_storeu_ps(c, vout0123);
+ $elif SUBMR == 2:
+ _mm_storel_pi((__m64*) c, vout01);
+ $elif SUBMR == 1:
+ _mm_store_ss(c, vout0);
+ $for M in range(4, SUBMR, 4):
+ _mm_storeu_ps(c + ${M}, vout${ABC[M:M+4]});
+ c += ${NR} * m;
+ } while (--j != 0);
+ c -= m * n;
+ c += ${SUBMR};
+ a += ${SUBMR};
+ }
+ }
+}
diff --git a/src/f32-vadd/psimd.c b/src/f32-vadd/psimd.c
new file mode 100644
index 0000000..2e06ff5
--- /dev/null
+++ b/src/f32-vadd/psimd.c
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/vadd.h>
+
+
+void xnn_f32_vadd_ukernel__psimd(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const psimd_f32 vy_min = psimd_load_splat_f32(¶ms->scalar.min);
+ const psimd_f32 vy_max = psimd_load_splat_f32(¶ms->scalar.max);
+
+ for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a);
+ const psimd_f32 va1 = psimd_load_f32(a + 4);
+ a += 8;
+
+ const psimd_f32 vb0 = psimd_load_f32(b);
+ const psimd_f32 vb1 = psimd_load_f32(b + 4);
+ b += 8;
+
+ const psimd_f32 vacc0 = psimd_add_f32(va0, vb0);
+ const psimd_f32 vacc1 = psimd_add_f32(va1, vb1);
+ const psimd_f32 vy0 = psimd_min_f32(psimd_max_f32(vacc0, vy_min), vy_max);
+ const psimd_f32 vy1 = psimd_min_f32(psimd_max_f32(vacc1, vy_min), vy_max);
+
+ psimd_store_f32(y, vy0);
+ psimd_store_f32(y + 4, vy1);
+ y += 8;
+ }
+ if (n >= 4 * sizeof(float)) {
+ const psimd_f32 va = psimd_load_f32(a);
+ a += 4;
+ const psimd_f32 vb = psimd_load_f32(b);
+ b += 4;
+ const psimd_f32 vacc = psimd_add_f32(va, vb);
+ const psimd_f32 vy = psimd_min_f32(psimd_max_f32(vacc, vy_min), vy_max);
+ psimd_store_f32(y, vy);
+ y += 4;
+ n -= 4 * sizeof(float);
+ }
+ if (n != 0) {
+ const psimd_f32 va = psimd_load_f32(a);
+ const psimd_f32 vb = psimd_load_f32(b);
+ const psimd_f32 vacc = psimd_add_f32(va, vb);
+ psimd_f32 vy = psimd_min_f32(psimd_max_f32(vacc, vy_min), vy_max);
+ if (n & 2 * sizeof(float)) {
+ psimd_store2_f32(y, vy);
+ vy = psimd_concat_hi_f32(vy, vy);
+ y += 2;
+ }
+ if (n & 1 * sizeof(float)) {
+ psimd_store1_f32(y, vy);
+ }
+ }
+}
diff --git a/src/f32-vadd/scalar.c b/src/f32-vadd/scalar.c
new file mode 100644
index 0000000..075fcb9
--- /dev/null
+++ b/src/f32-vadd/scalar.c
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/vadd.h>
+
+
+void xnn_f32_vadd_ukernel__scalar(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float vy_min = params->scalar.min;
+ const float vy_max = params->scalar.max;
+
+ for (; n >= 2 * sizeof(float); n -= 2 * sizeof(float)) {
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a += 2;
+
+ const float vb0 = b[0];
+ const float vb1 = b[1];
+ b += 2;
+
+ float vy0 = va0 + vb0;
+ float vy1 = va1 + vb1;
+ vy0 = math_max_f32(vy0, vy_min);
+ vy1 = math_max_f32(vy1, vy_min);
+ vy0 = math_min_f32(vy0, vy_max);
+ vy1 = math_min_f32(vy1, vy_max);
+
+ y[0] = vy0;
+ y[1] = vy1;
+ y += 2;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ const float va = *a;
+ const float vb = *b;
+ float vy = va + vb;
+ vy = math_max_f32(vy, vy_min);
+ vy = math_min_f32(vy, vy_max);
+ *y = vy;
+ }
+}
diff --git a/src/f32-vadd/sse.c b/src/f32-vadd/sse.c
new file mode 100644
index 0000000..2f49638
--- /dev/null
+++ b/src/f32-vadd/sse.c
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/vadd.h>
+
+
+void xnn_f32_vadd_ukernel__sse(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const __m128 vy_min = _mm_load_ps(params->sse.min);
+ const __m128 vy_max = _mm_load_ps(params->sse.max);
+
+ for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a);
+ const __m128 va1 = _mm_loadu_ps(a + 4);
+ a += 8;
+
+ const __m128 vb0 = _mm_loadu_ps(b);
+ const __m128 vb1 = _mm_loadu_ps(b + 4);
+ b += 8;
+
+ const __m128 vacc0 = _mm_add_ps(va0, vb0);
+ const __m128 vacc1 = _mm_add_ps(va1, vb1);
+ const __m128 vy0 = _mm_min_ps(_mm_max_ps(vacc0, vy_min), vy_max);
+ const __m128 vy1 = _mm_min_ps(_mm_max_ps(vacc1, vy_min), vy_max);
+
+ _mm_storeu_ps(y, vy0);
+ _mm_storeu_ps(y + 4, vy1);
+ y += 8;
+ }
+ if (n >= 4 * sizeof(float)) {
+ const __m128 va = _mm_loadu_ps(a);
+ a += 4;
+ const __m128 vb = _mm_loadu_ps(b);
+ b += 4;
+ const __m128 vacc = _mm_add_ps(va, vb);
+ const __m128 vy = _mm_min_ps(_mm_max_ps(vacc, vy_min), vy_max);
+ _mm_storeu_ps(y, vy);
+ y += 4;
+ n -= 4 * sizeof(float);
+ }
+ if (n != 0) {
+ const __m128 va = _mm_loadu_ps(a);
+ const __m128 vb = _mm_loadu_ps(b);
+ const __m128 vacc = _mm_add_ps(va, vb);
+ __m128 vy = _mm_min_ps(_mm_max_ps(vacc, vy_min), vy_max);
+ if (n & 2 * sizeof(float)) {
+ _mm_storel_pi((__m64*) y, vy);
+ vy = _mm_movehl_ps(vy, vy);
+ y += 2;
+ }
+ if (n & 1 * sizeof(float)) {
+ _mm_store_ss(y, vy);
+ }
+ }
+}
diff --git a/src/f32-vmul/psimd.c b/src/f32-vmul/psimd.c
new file mode 100644
index 0000000..e42ff9f
--- /dev/null
+++ b/src/f32-vmul/psimd.c
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/vmul.h>
+
+
+void xnn_f32_vmul_ukernel__psimd(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const psimd_f32 vy_min = psimd_load_splat_f32(¶ms->scalar.min);
+ const psimd_f32 vy_max = psimd_load_splat_f32(¶ms->scalar.max);
+
+ for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a);
+ const psimd_f32 va1 = psimd_load_f32(a + 4);
+ a += 8;
+
+ const psimd_f32 vb0 = psimd_load_f32(b);
+ const psimd_f32 vb1 = psimd_load_f32(b + 4);
+ b += 8;
+
+ const psimd_f32 vprod0 = psimd_mul_f32(va0, vb0);
+ const psimd_f32 vprod1 = psimd_mul_f32(va1, vb1);
+ const psimd_f32 vy0 = psimd_min_f32(psimd_max_f32(vprod0, vy_min), vy_max);
+ const psimd_f32 vy1 = psimd_min_f32(psimd_max_f32(vprod1, vy_min), vy_max);
+
+ psimd_store_f32(y, vy0);
+ psimd_store_f32(y + 4, vy1);
+ y += 8;
+ }
+ if (n >= 4 * sizeof(float)) {
+ const psimd_f32 va = psimd_load_f32(a);
+ a += 4;
+ const psimd_f32 vb = psimd_load_f32(b);
+ b += 4;
+ const psimd_f32 vprod = psimd_mul_f32(va, vb);
+ const psimd_f32 vy = psimd_min_f32(psimd_max_f32(vprod, vy_min), vy_max);
+ psimd_store_f32(y, vy);
+ y += 4;
+ n -= 4 * sizeof(float);
+ }
+ if (n != 0) {
+ const psimd_f32 va = psimd_load_f32(a);
+ const psimd_f32 vb = psimd_load_f32(b);
+ const psimd_f32 vprod = psimd_mul_f32(va, vb);
+ psimd_f32 vy = psimd_min_f32(psimd_max_f32(vprod, vy_min), vy_max);
+ if (n & 2 * sizeof(float)) {
+ psimd_store2_f32(y, vy);
+ vy = psimd_concat_hi_f32(vy, vy);
+ y += 2;
+ }
+ if (n & 1 * sizeof(float)) {
+ psimd_store1_f32(y, vy);
+ }
+ }
+}
diff --git a/src/f32-vmul/scalar.c b/src/f32-vmul/scalar.c
new file mode 100644
index 0000000..8f5c2f3
--- /dev/null
+++ b/src/f32-vmul/scalar.c
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/vmul.h>
+
+
+void xnn_f32_vmul_ukernel__scalar(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float vy_min = params->scalar.min;
+ const float vy_max = params->scalar.max;
+
+ for (; n >= 2 * sizeof(float); n -= 2 * sizeof(float)) {
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a += 2;
+
+ const float vb0 = b[0];
+ const float vb1 = b[1];
+ b += 2;
+
+ float vy0 = va0 * vb0;
+ float vy1 = va1 * vb1;
+ vy0 = math_max_f32(vy0, vy_min);
+ vy1 = math_max_f32(vy1, vy_min);
+ vy0 = math_min_f32(vy0, vy_max);
+ vy1 = math_min_f32(vy1, vy_max);
+
+ y[0] = vy0;
+ y[1] = vy1;
+ y += 2;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ const float va = *a;
+ const float vb = *b;
+ float vy = va * vb;
+ vy = math_max_f32(vy, vy_min);
+ vy = math_min_f32(vy, vy_max);
+ *y = vy;
+ }
+}
diff --git a/src/f32-vmul/sse.c b/src/f32-vmul/sse.c
new file mode 100644
index 0000000..cb1087a
--- /dev/null
+++ b/src/f32-vmul/sse.c
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/vmul.h>
+
+
+void xnn_f32_vmul_ukernel__sse(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const __m128 vy_min = _mm_load_ps(params->sse.min);
+ const __m128 vy_max = _mm_load_ps(params->sse.max);
+
+ for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a);
+ const __m128 va1 = _mm_loadu_ps(a + 4);
+ a += 8;
+
+ const __m128 vb0 = _mm_loadu_ps(b);
+ const __m128 vb1 = _mm_loadu_ps(b + 4);
+ b += 8;
+
+ const __m128 vacc0 = _mm_mul_ps(va0, vb0);
+ const __m128 vacc1 = _mm_mul_ps(va1, vb1);
+ const __m128 vy0 = _mm_min_ps(_mm_max_ps(vacc0, vy_min), vy_max);
+ const __m128 vy1 = _mm_min_ps(_mm_max_ps(vacc1, vy_min), vy_max);
+
+ _mm_storeu_ps(y, vy0);
+ _mm_storeu_ps(y + 4, vy1);
+ y += 8;
+ }
+ if (n >= 4 * sizeof(float)) {
+ const __m128 va = _mm_loadu_ps(a);
+ a += 4;
+ const __m128 vb = _mm_loadu_ps(b);
+ b += 4;
+ const __m128 vacc = _mm_mul_ps(va, vb);
+ const __m128 vy = _mm_min_ps(_mm_max_ps(vacc, vy_min), vy_max);
+ _mm_storeu_ps(y, vy);
+ y += 4;
+ n -= 4 * sizeof(float);
+ }
+ if (n != 0) {
+ const __m128 va = _mm_loadu_ps(a);
+ const __m128 vb = _mm_loadu_ps(b);
+ const __m128 vacc = _mm_mul_ps(va, vb);
+ __m128 vy = _mm_min_ps(_mm_max_ps(vacc, vy_min), vy_max);
+ if (n & 2 * sizeof(float)) {
+ _mm_storel_pi((__m64*) y, vy);
+ vy = _mm_movehl_ps(vy, vy);
+ y += 2;
+ }
+ if (n & 1 * sizeof(float)) {
+ _mm_store_ss(y, vy);
+ }
+ }
+}
diff --git a/src/f32-vmulcaddc/c1-scalar-x2.c b/src/f32-vmulcaddc/c1-scalar-x2.c
new file mode 100644
index 0000000..1b7b7e9
--- /dev/null
+++ b/src/f32-vmulcaddc/c1-scalar-x2.c
@@ -0,0 +1,107 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-vmulcaddc/scalar.c.in
+// Generator: tools/xngen
+//
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vmulcaddc.h>
+
+
+void xnn_f32_vmulcaddc_ukernel_c1__scalar_x2(
+ size_t m,
+ size_t channels,
+ const float*restrict x,
+ size_t x_stride,
+ const float*restrict weights,
+ float*restrict y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(float) == 0);
+
+ const size_t x_increment = x_stride * 2 - (channels & -(1 * sizeof(float)));
+ const size_t y_increment = y_stride * 2 - channels;
+
+ const float* x0 = x;
+ float* y0 = y;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ float* y1 = (float*) ((uintptr_t) y0 + y_stride);
+ if XNN_UNPREDICTABLE(m < 2) {
+ x1 = x0;
+ y1 = y0;
+ }
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ do {
+ const float* w = weights;
+ size_t c = channels;
+ for (; c >= 1 * sizeof(float); c -= 1 * sizeof(float)) {
+ const float vscale0 = w[0];
+
+ const float vx0x0 = x0[0];
+ x0 += 1;
+ const float vx1x0 = x1[0];
+ x1 += 1;
+
+ const float vbias0 = w[1];
+
+ float vacc0x0 = vx0x0 * vscale0 + vbias0;
+ float vacc1x0 = vx1x0 * vscale0 + vbias0;
+
+ vacc0x0 = math_max_f32(vacc0x0, vmin);
+ vacc1x0 = math_max_f32(vacc1x0, vmin);
+
+ vacc0x0 = math_min_f32(vacc0x0, vmax);
+ vacc1x0 = math_min_f32(vacc1x0, vmax);
+
+ y0[0] = vacc0x0;
+ y0 += 1;
+ y1[0] = vacc1x0;
+ y1 += 1;
+
+ w += 2;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ const float vscale0 = w[0];
+
+ const float vx0x0 = x0[0];
+ x0 += 1;
+ const float vx1x0 = x1[0];
+ x1 += 1;
+
+ const float vbias0 = w[1];
+
+ float vacc0x0 = vx0x0 * vscale0 + vbias0;
+ float vacc1x0 = vx1x0 * vscale0 + vbias0;
+
+ vacc0x0 = math_max_f32(vacc0x0, vmin);
+ vacc1x0 = math_max_f32(vacc1x0, vmin);
+
+ vacc0x0 = math_min_f32(vacc0x0, vmax);
+ vacc1x0 = math_min_f32(vacc1x0, vmax);
+
+ w += 2;
+
+ }
+ x0 = (const float*) ((uintptr_t) x0 + x_increment);
+ y0 = (float*) ((uintptr_t) y0 + y_increment);
+ x1 = (const float*) ((uintptr_t) x1 + x_increment);
+ y1 = (float*) ((uintptr_t) y1 + y_increment);
+ if XNN_UNPREDICTABLE(m < 4) {
+ x1 = x0;
+ y1 = y0;
+ }
+ m = doz(m, 2);
+ } while (m != 0);
+}
diff --git a/src/f32-vmulcaddc/c4-neon-x2.c b/src/f32-vmulcaddc/c4-neon-x2.c
new file mode 100644
index 0000000..cee2f83
--- /dev/null
+++ b/src/f32-vmulcaddc/c4-neon-x2.c
@@ -0,0 +1,118 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-vmulcaddc/neon.c.in
+// Generator: tools/xngen
+//
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vmulcaddc.h>
+
+
+void xnn_f32_vmulcaddc_ukernel_c4__neon_x2(
+ size_t m,
+ size_t channels,
+ const float*restrict x,
+ size_t x_stride,
+ const float*restrict weights,
+ float*restrict y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(float) == 0);
+
+ const size_t x_increment = x_stride * 2 - channels;
+ const size_t y_increment = y_stride * 2 - channels;
+
+ const float* x0 = x;
+ float* y0 = y;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ float* y1 = (float*) ((uintptr_t) y0 + y_stride);
+ if XNN_UNPREDICTABLE(m < 2) {
+ x1 = x0;
+ y1 = y0;
+ }
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ do {
+ const float* w = weights;
+ size_t c = channels;
+ for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
+ const float32x4_t vscale0123 = vld1q_f32(w); w += 4;
+
+ const float32x4_t vx0x0123 = vld1q_f32(x0); x0 += 4;
+ const float32x4_t vx1x0123 = vld1q_f32(x1); x1 += 4;
+
+ float32x4_t vacc0x0123 = vmulq_f32(vx0x0123, vscale0123);
+ float32x4_t vacc1x0123 = vmulq_f32(vx1x0123, vscale0123);
+
+ const float32x4_t vbias0123 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vaddq_f32(vacc0x0123, vbias0123);
+ vacc1x0123 = vaddq_f32(vacc1x0123, vbias0123);
+
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+
+ vst1q_f32(y0, vacc0x0123); y0 += 4;
+ vst1q_f32(y1, vacc1x0123); y1 += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ const float32x4_t vscale0123 = vld1q_f32(w); w += 4;
+
+ const float32x4_t vx0x0123 = vld1q_f32(x0); x0 = (const float*) ((uintptr_t) x0 + c);
+ const float32x4_t vx1x0123 = vld1q_f32(x1); x1 = (const float*) ((uintptr_t) x1 + c);
+
+ float32x4_t vacc0x0123 = vmulq_f32(vx0x0123, vscale0123);
+ float32x4_t vacc1x0123 = vmulq_f32(vx1x0123, vscale0123);
+
+ const float32x4_t vbias0123 = vld1q_f32(w); w += 4;
+
+ vacc0x0123 = vaddq_f32(vacc0x0123, vbias0123);
+ vacc1x0123 = vaddq_f32(vacc1x0123, vbias0123);
+
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ if (c & (2 * sizeof(float))) {
+ vst1_f32(y0, vacc0x01); y0 += 2;
+ vst1_f32(y1, vacc1x01); y1 += 2;
+
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ }
+ if (c & (1 * sizeof(float))) {
+ vst1_lane_f32(y0, vacc0x01, 0); y0 += 1;
+ vst1_lane_f32(y1, vacc1x01, 0); y1 += 1;
+ }
+ }
+ x0 = (const float*) ((uintptr_t) x0 + x_increment);
+ y0 = (float*) ((uintptr_t) y0 + y_increment);
+ x1 = (const float*) ((uintptr_t) x1 + x_increment);
+ y1 = (float*) ((uintptr_t) y1 + y_increment);
+ if XNN_UNPREDICTABLE(m < 4) {
+ x1 = x0;
+ y1 = y0;
+ }
+ m = doz(m, 2);
+ } while (m != 0);
+}
diff --git a/src/f32-vmulcaddc/c4-neonfma-x2.c b/src/f32-vmulcaddc/c4-neonfma-x2.c
new file mode 100644
index 0000000..f395592
--- /dev/null
+++ b/src/f32-vmulcaddc/c4-neonfma-x2.c
@@ -0,0 +1,114 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-vmulcaddc/neon.c.in
+// Generator: tools/xngen
+//
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vmulcaddc.h>
+
+
+void xnn_f32_vmulcaddc_ukernel_c4__neonfma_x2(
+ size_t m,
+ size_t channels,
+ const float*restrict x,
+ size_t x_stride,
+ const float*restrict weights,
+ float*restrict y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(float) == 0);
+
+ const size_t x_increment = x_stride * 2 - channels;
+ const size_t y_increment = y_stride * 2 - channels;
+
+ const float* x0 = x;
+ float* y0 = y;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ float* y1 = (float*) ((uintptr_t) y0 + y_stride);
+ if XNN_UNPREDICTABLE(m < 2) {
+ x1 = x0;
+ y1 = y0;
+ }
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ do {
+ const float* w = weights;
+ size_t c = channels;
+ for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
+ const float32x4_t vscale0123 = vld1q_f32(w); w += 4;
+
+ const float32x4_t vx0x0123 = vld1q_f32(x0); x0 += 4;
+ const float32x4_t vx1x0123 = vld1q_f32(x1); x1 += 4;
+
+
+ const float32x4_t vbias0123 = vld1q_f32(w); w += 4;
+
+ float32x4_t vacc0x0123 = vfmaq_f32(vbias0123, vx0x0123, vscale0123);
+ float32x4_t vacc1x0123 = vfmaq_f32(vbias0123, vx1x0123, vscale0123);
+
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+
+ vst1q_f32(y0, vacc0x0123); y0 += 4;
+ vst1q_f32(y1, vacc1x0123); y1 += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ const float32x4_t vscale0123 = vld1q_f32(w); w += 4;
+
+ const float32x4_t vx0x0123 = vld1q_f32(x0); x0 = (const float*) ((uintptr_t) x0 + c);
+ const float32x4_t vx1x0123 = vld1q_f32(x1); x1 = (const float*) ((uintptr_t) x1 + c);
+
+
+ const float32x4_t vbias0123 = vld1q_f32(w); w += 4;
+
+ float32x4_t vacc0x0123 = vfmaq_f32(vbias0123, vx0x0123, vscale0123);
+ float32x4_t vacc1x0123 = vfmaq_f32(vbias0123, vx1x0123, vscale0123);
+
+ vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
+ vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
+
+ vacc0x0123 = vminq_f32(vacc0x0123, vmax);
+ vacc1x0123 = vminq_f32(vacc1x0123, vmax);
+
+ float32x2_t vacc0x01 = vget_low_f32(vacc0x0123);
+ float32x2_t vacc1x01 = vget_low_f32(vacc1x0123);
+ if (c & (2 * sizeof(float))) {
+ vst1_f32(y0, vacc0x01); y0 += 2;
+ vst1_f32(y1, vacc1x01); y1 += 2;
+
+ vacc0x01 = vget_high_f32(vacc0x0123);
+ vacc1x01 = vget_high_f32(vacc1x0123);
+ }
+ if (c & (1 * sizeof(float))) {
+ vst1_lane_f32(y0, vacc0x01, 0); y0 += 1;
+ vst1_lane_f32(y1, vacc1x01, 0); y1 += 1;
+ }
+ }
+ x0 = (const float*) ((uintptr_t) x0 + x_increment);
+ y0 = (float*) ((uintptr_t) y0 + y_increment);
+ x1 = (const float*) ((uintptr_t) x1 + x_increment);
+ y1 = (float*) ((uintptr_t) y1 + y_increment);
+ if XNN_UNPREDICTABLE(m < 4) {
+ x1 = x0;
+ y1 = y0;
+ }
+ m = doz(m, 2);
+ } while (m != 0);
+}
diff --git a/src/f32-vmulcaddc/c4-psimd-x2.c b/src/f32-vmulcaddc/c4-psimd-x2.c
new file mode 100644
index 0000000..dc9344c
--- /dev/null
+++ b/src/f32-vmulcaddc/c4-psimd-x2.c
@@ -0,0 +1,124 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-vmulcaddc/psimd.c.in
+// Generator: tools/xngen
+//
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vmulcaddc.h>
+
+
+void xnn_f32_vmulcaddc_ukernel_c4__psimd_x2(
+ size_t m,
+ size_t channels,
+ const float*restrict x,
+ size_t x_stride,
+ const float*restrict weights,
+ float*restrict y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(float) == 0);
+
+ const size_t x_increment = x_stride * 2 - (channels & -(4 * sizeof(float)));
+ const size_t y_increment = y_stride * 2 - channels;
+
+ const float* x0 = x;
+ float* y0 = y;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ float* y1 = (float*) ((uintptr_t) y0 + y_stride);
+ if XNN_UNPREDICTABLE(m < 2) {
+ x1 = x0;
+ y1 = y0;
+ }
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ do {
+ const float* w = weights;
+ size_t c = channels;
+ for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
+ const psimd_f32 vscale0123 = psimd_load_f32(w);
+
+ const psimd_f32 vx0x0123 = psimd_load_f32(x0);
+ x0 += 4;
+ const psimd_f32 vx1x0123 = psimd_load_f32(x1);
+ x1 += 4;
+
+ const psimd_f32 vbias0123 = psimd_load_f32(w + 4);
+
+ psimd_f32 vacc0x0123 = psimd_qfma_f32(vbias0123, vx0x0123, vscale0123);
+ psimd_f32 vacc1x0123 = psimd_qfma_f32(vbias0123, vx1x0123, vscale0123);
+
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+
+ psimd_store_f32(y0, vacc0x0123);
+ y0 += 4;
+ psimd_store_f32(y1, vacc1x0123);
+ y1 += 4;
+
+ w += 8;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ const psimd_f32 vscale0123 = psimd_load_f32(w);
+
+ const psimd_f32 vx0x0123 = psimd_load_f32(x0);
+ const psimd_f32 vx1x0123 = psimd_load_f32(x1);
+
+ const psimd_f32 vbias0123 = psimd_load_f32(w + 4);
+
+ psimd_f32 vacc0x0123 = psimd_qfma_f32(vbias0123, vx0x0123, vscale0123);
+ psimd_f32 vacc1x0123 = psimd_qfma_f32(vbias0123, vx1x0123, vscale0123);
+
+ vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
+ vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
+
+ vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
+ vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
+
+ w += 8;
+
+ if (c & (2 * sizeof(float))) {
+ psimd_store2_f32(y0, vacc0x0123);
+ psimd_store2_f32(y1, vacc1x0123);
+
+ y0 += 2;
+ y1 += 2;
+
+ vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
+ vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
+ }
+ if (c & (1 * sizeof(float))) {
+ psimd_store1_f32(y0, vacc0x0123);
+ psimd_store1_f32(y1, vacc1x0123);
+
+ y0 += 1;
+ y1 += 1;
+ }
+ }
+ x0 = (const float*) ((uintptr_t) x0 + x_increment);
+ y0 = (float*) ((uintptr_t) y0 + y_increment);
+ x1 = (const float*) ((uintptr_t) x1 + x_increment);
+ y1 = (float*) ((uintptr_t) y1 + y_increment);
+ if XNN_UNPREDICTABLE(m < 4) {
+ x1 = x0;
+ y1 = y0;
+ }
+ m = doz(m, 2);
+ } while (m != 0);
+}
diff --git a/src/f32-vmulcaddc/c4-sse-x2.c b/src/f32-vmulcaddc/c4-sse-x2.c
new file mode 100644
index 0000000..b48acf1
--- /dev/null
+++ b/src/f32-vmulcaddc/c4-sse-x2.c
@@ -0,0 +1,130 @@
+// Auto-generated file. Do not edit!
+// Template: src/f32-vmulcaddc/sse.c.in
+// Generator: tools/xngen
+//
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vmulcaddc.h>
+
+
+void xnn_f32_vmulcaddc_ukernel_c4__sse_x2(
+ size_t m,
+ size_t channels,
+ const float*restrict x,
+ size_t x_stride,
+ const float*restrict weights,
+ float*restrict y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(float) == 0);
+
+ const size_t x_increment = x_stride * 2 - (channels & -(4 * sizeof(float)));
+ const size_t y_increment = y_stride * 2 - channels;
+
+ const float* x0 = x;
+ float* y0 = y;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ float* y1 = (float*) ((uintptr_t) y0 + y_stride);
+ if XNN_UNPREDICTABLE(m < 2) {
+ x1 = x0;
+ y1 = y0;
+ }
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ do {
+ const float* w = weights;
+ size_t c = channels;
+ for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
+ const __m128 vscale0123 = _mm_load_ps(w);
+
+ const __m128 vx0x0123 = _mm_loadu_ps(x0);
+ x0 += 4;
+ const __m128 vx1x0123 = _mm_loadu_ps(x1);
+ x1 += 4;
+
+ __m128 vacc0x0123 = _mm_mul_ps(vx0x0123, vscale0123);
+ __m128 vacc1x0123 = _mm_mul_ps(vx1x0123, vscale0123);
+
+ const __m128 vbias0123 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, vbias0123);
+ vacc1x0123 = _mm_add_ps(vacc1x0123, vbias0123);
+
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+
+ _mm_storeu_ps(y0, vacc0x0123);
+ y0 += 4;
+ _mm_storeu_ps(y1, vacc1x0123);
+ y1 += 4;
+
+ w += 8;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ const __m128 vscale0123 = _mm_load_ps(w);
+
+ const __m128 vx0x0123 = _mm_loadu_ps(x0);
+ const __m128 vx1x0123 = _mm_loadu_ps(x1);
+
+ __m128 vacc0x0123 = _mm_mul_ps(vx0x0123, vscale0123);
+ __m128 vacc1x0123 = _mm_mul_ps(vx1x0123, vscale0123);
+
+ const __m128 vbias0123 = _mm_load_ps(w + 4);
+
+ vacc0x0123 = _mm_add_ps(vacc0x0123, vbias0123);
+ vacc1x0123 = _mm_add_ps(vacc1x0123, vbias0123);
+
+ vacc0x0123 = _mm_max_ps(vacc0x0123, vmin);
+ vacc1x0123 = _mm_max_ps(vacc1x0123, vmin);
+
+ vacc0x0123 = _mm_min_ps(vacc0x0123, vmax);
+ vacc1x0123 = _mm_min_ps(vacc1x0123, vmax);
+
+ w += 8;
+
+ if (c & (2 * sizeof(float))) {
+ _mm_storel_pi((__m64*) y0, vacc0x0123);
+ _mm_storel_pi((__m64*) y1, vacc1x0123);
+
+ y0 += 2;
+ y1 += 2;
+
+ vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
+ vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
+ }
+ if (c & (1 * sizeof(float))) {
+ _mm_store_ss(y0, vacc0x0123);
+ _mm_store_ss(y1, vacc1x0123);
+
+ y0 += 1;
+ y1 += 1;
+ }
+ }
+ x0 = (const float*) ((uintptr_t) x0 + x_increment);
+ y0 = (float*) ((uintptr_t) y0 + y_increment);
+ x1 = (const float*) ((uintptr_t) x1 + x_increment);
+ y1 = (float*) ((uintptr_t) y1 + y_increment);
+ if XNN_UNPREDICTABLE(m < 4) {
+ x1 = x0;
+ y1 = y0;
+ }
+ m = doz(m, 2);
+ } while (m != 0);
+}
diff --git a/src/f32-vmulcaddc/neon.c.in b/src/f32-vmulcaddc/neon.c.in
new file mode 100644
index 0000000..dd7b824
--- /dev/null
+++ b/src/f32-vmulcaddc/neon.c.in
@@ -0,0 +1,166 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+$assert CR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vmulcaddc.h>
+
+
+void xnn_f32_vmulcaddc_ukernel_c${CR}__${"neonfma" if FMA else "neon"}_x${MR}(
+ size_t m,
+ size_t channels,
+ const float*restrict x,
+ size_t x_stride,
+ const float*restrict weights,
+ float*restrict y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(float) == 0);
+
+ const size_t x_increment = x_stride * ${MR} - channels;
+ const size_t y_increment = y_stride * ${MR} - channels;
+
+ const float* x0 = x;
+ float* y0 = y;
+ $for M in range(1, MR):
+ const float* x${M} = (const float*) ((uintptr_t) x${M-1} + x_stride);
+ float* y${M} = (float*) ((uintptr_t) y${M-1} + y_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(m <= ${M}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(m < ${M+1}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+
+ const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
+ const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
+ do {
+ const float* w = weights;
+ size_t c = channels;
+ for (; c >= ${CR} * sizeof(float); c -= ${CR} * sizeof(float)) {
+ $for C in range(0, CR, 4):
+ const float32x4_t vscale${ABC[C:C+4]} = vld1q_f32(w); w += 4;
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ const float32x4_t vx${M}x${ABC[C:C+4]} = vld1q_f32(x${M}); x${M} += 4;
+
+ $if not FMA:
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ float32x4_t vacc${M}x${ABC[C:C+4]} = vmulq_f32(vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
+
+ $for C in range(0, CR, 4):
+ const float32x4_t vbias${ABC[C:C+4]} = vld1q_f32(w); w += 4;
+
+ $if not FMA:
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = vaddq_f32(vacc${M}x${ABC[C:C+4]}, vbias${ABC[C:C+4]});
+ $else:
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ float32x4_t vacc${M}x${ABC[C:C+4]} = vfmaq_f32(vbias${ABC[C:C+4]}, vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = vmaxq_f32(vacc${M}x${ABC[C:C+4]}, vmin);
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = vminq_f32(vacc${M}x${ABC[C:C+4]}, vmax);
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vst1q_f32(y${M}, vacc${M}x${ABC[C:C+4]}); y${M} += 4;
+ }
+ if XNN_UNLIKELY(c != 0) {
+ $for C in range(0, CR, 4):
+ const float32x4_t vscale${ABC[C:C+4]} = vld1q_f32(w); w += 4;
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ const float32x4_t vx${M}x${ABC[C:C+4]} = vld1q_f32(x${M}); x${M} = (const float*) ((uintptr_t) x${M} + c);
+
+ $if not FMA:
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ float32x4_t vacc${M}x${ABC[C:C+4]} = vmulq_f32(vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
+
+ $for C in range(0, CR, 4):
+ const float32x4_t vbias${ABC[C:C+4]} = vld1q_f32(w); w += 4;
+
+ $if not FMA:
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = vaddq_f32(vacc${M}x${ABC[C:C+4]}, vbias${ABC[C:C+4]});
+ $else:
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ float32x4_t vacc${M}x${ABC[C:C+4]} = vfmaq_f32(vbias${ABC[C:C+4]}, vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = vmaxq_f32(vacc${M}x${ABC[C:C+4]}, vmin);
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = vminq_f32(vacc${M}x${ABC[C:C+4]}, vmax);
+
+ $for LOG2C in reversed(range(CR.bit_length())):
+ $if CR != 1 << LOG2C:
+ if (c & (${1 << LOG2C} * sizeof(float))) {
+ $if LOG2C >= 2:
+ $for M in range(MR):
+ $for C in range(0, 1 << LOG2C, 4):
+ vst1q_f32(y${M}, vacc${M}x${ABC[C:C+4]}); y${M} += 4;
+
+ $for M in range(MR):
+ $for C in range(0, 1 << (LOG2C - 1), 4):
+ vacc${M}x${ABC[C:C+4]} = vacc${M}x${ABC[C+(1<<LOG2C):C+(1<<LOG2C)+4]}
+ $elif LOG2C == 1:
+ $for M in range(MR):
+ vst1_f32(y${M}, vacc${M}x${ABC[0:2]}); y${M} += 2;
+
+ $for M in range(MR):
+ vacc${M}x${ABC[0:2]} = vget_high_f32(vacc${M}x${ABC[0:4]});
+ $elif LOG2C == 0:
+ $for M in range(MR):
+ vst1_lane_f32(y${M}, vacc${M}x${ABC[0:2]}, 0); y${M} += 1;
+ }
+ $if LOG2C == 2:
+ $for M in range(MR):
+ float32x2_t vacc${M}x${ABC[0:2]} = vget_low_f32(vacc${M}x${ABC[0:4]});
+ }
+ $for M in range(MR):
+ x${M} = (const float*) ((uintptr_t) x${M} + x_increment);
+ y${M} = (float*) ((uintptr_t) y${M} + y_increment);
+ $if M % 2 == 1:
+ if XNN_UNPREDICTABLE(m < ${MR+M+1}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ $elif M != 0:
+ if XNN_UNPREDICTABLE(m <= ${MR+M}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ m = doz(m, ${MR});
+ } while (m != 0);
+}
diff --git a/src/f32-vmulcaddc/psimd.c.in b/src/f32-vmulcaddc/psimd.c.in
new file mode 100644
index 0000000..51c27b1
--- /dev/null
+++ b/src/f32-vmulcaddc/psimd.c.in
@@ -0,0 +1,164 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+$assert CR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vmulcaddc.h>
+
+
+void xnn_f32_vmulcaddc_ukernel_c${CR}__psimd_x${MR}(
+ size_t m,
+ size_t channels,
+ const float*restrict x,
+ size_t x_stride,
+ const float*restrict weights,
+ float*restrict y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(float) == 0);
+
+ const size_t x_increment = x_stride * ${MR} - (channels & -(${CR} * sizeof(float)));
+ const size_t y_increment = y_stride * ${MR} - channels;
+
+ const float* x0 = x;
+ float* y0 = y;
+ $for M in range(1, MR):
+ const float* x${M} = (const float*) ((uintptr_t) x${M-1} + x_stride);
+ float* y${M} = (float*) ((uintptr_t) y${M-1} + y_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(m <= ${M}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(m < ${M+1}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+
+ const psimd_f32 vmin = psimd_load_splat_f32(¶ms->scalar.min);
+ const psimd_f32 vmax = psimd_load_splat_f32(¶ms->scalar.max);
+ do {
+ const float* w = weights;
+ size_t c = channels;
+ for (; c >= ${CR} * sizeof(float); c -= ${CR} * sizeof(float)) {
+ const psimd_f32 vscale${ABC[0:4]} = psimd_load_f32(w);
+ $for C in range(4, CR, 4):
+ const psimd_f32 vscale${ABC[C:C+4]} = psimd_load_f32(w + ${C});
+
+ $for M in range(MR):
+ const psimd_f32 vx${M}x${ABC[0:4]} = psimd_load_f32(x${M});
+ $for C in range(4, CR, 4):
+ const psimd_f32 vx${M}x${ABC[C:C+4]} = psimd_load_f32(x${M} + ${C});
+ x${M} += ${CR};
+
+ $for C in range(0, CR, 4):
+ const psimd_f32 vbias${ABC[C:C+4]} = psimd_load_f32(w + ${C + CR});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ psimd_f32 vacc${M}x${ABC[C:C+4]} = psimd_qfma_f32(vbias${ABC[C:C+4]}, vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = psimd_max_f32(vacc${M}x${ABC[C:C+4]}, vmin);
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = psimd_min_f32(vacc${M}x${ABC[C:C+4]}, vmax);
+
+ $for M in range(MR):
+ psimd_store_f32(y${M}, vacc${M}x${ABC[0:4]});
+ $for C in range(4, CR, 4):
+ psimd_store_f32(y${M} + ${C}, vacc${M}x${ABC[C:C+4]});
+ y${M} += ${CR};
+
+ w += ${CR * 2};
+ }
+ if XNN_UNLIKELY(c != 0) {
+ const psimd_f32 vscale${ABC[0:4]} = psimd_load_f32(w);
+ $for C in range(4, CR, 4):
+ const psimd_f32 vscale${ABC[C:C+4]} = psimd_load_f32(w + ${C});
+
+ $for M in range(MR):
+ const psimd_f32 vx${M}x${ABC[0:4]} = psimd_load_f32(x${M});
+ $for C in range(4, CR, 4):
+ const psimd_f32 vx${M}x${ABC[C:C+4]} = psimd_load_f32(x${M} + ${C});
+
+ $for C in range(0, CR, 4):
+ const psimd_f32 vbias${ABC[C:C+4]} = psimd_load_f32(w + ${C + CR});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ psimd_f32 vacc${M}x${ABC[C:C+4]} = psimd_qfma_f32(vbias${ABC[C:C+4]}, vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = psimd_max_f32(vacc${M}x${ABC[C:C+4]}, vmin);
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = psimd_min_f32(vacc${M}x${ABC[C:C+4]}, vmax);
+
+ w += ${CR * 2};
+
+ $for LOG2C in reversed(range(CR.bit_length())):
+ $if CR != 1 << LOG2C:
+ if (c & (${1 << LOG2C} * sizeof(float))) {
+ $if LOG2C >= 2:
+ $for M in range(MR):
+ psimd_store_f32(y${M}, vacc${M}x${ABC[C:C+4]});
+ $for C in range(4, 1 << LOG2C, 4):
+ psimd_store_f32(y${M} + ${C}, vacc${M}x${ABC[C:C+4]});
+
+ $for M in range(MR):
+ y${M} += ${1 << LOG2C};
+
+ $for M in range(MR):
+ $for C in range(0, 1 << (LOG2C - 1), 4):
+ vacc${M}x${ABC[C:C+4]} = vacc${M}x${ABC[C+(1<<LOG2C):C+(1<<LOG2C)+4]}
+ $elif LOG2C == 1:
+ $for M in range(MR):
+ psimd_store2_f32(y${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in range(MR):
+ y${M} += 2;
+
+ $for M in range(MR):
+ vacc${M}x${ABC[0:4]} = psimd_concat_hi_f32(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+ $elif LOG2C == 0:
+ $for M in range(MR):
+ psimd_store1_f32(y${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in range(MR):
+ y${M} += 1;
+ }
+ }
+ $for M in range(MR):
+ x${M} = (const float*) ((uintptr_t) x${M} + x_increment);
+ y${M} = (float*) ((uintptr_t) y${M} + y_increment);
+ $if M % 2 == 1:
+ if XNN_UNPREDICTABLE(m < ${MR + M+1}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ $elif M != 0:
+ if XNN_UNPREDICTABLE(m <= ${MR + M}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ m = doz(m, ${MR});
+ } while (m != 0);
+}
diff --git a/src/f32-vmulcaddc/scalar.c.in b/src/f32-vmulcaddc/scalar.c.in
new file mode 100644
index 0000000..df86ff2
--- /dev/null
+++ b/src/f32-vmulcaddc/scalar.c.in
@@ -0,0 +1,137 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+$assert CR > 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vmulcaddc.h>
+
+
+void xnn_f32_vmulcaddc_ukernel_c${CR}__scalar_x${MR}(
+ size_t m,
+ size_t channels,
+ const float*restrict x,
+ size_t x_stride,
+ const float*restrict weights,
+ float*restrict y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(float) == 0);
+
+ const size_t x_increment = x_stride * ${MR} - (channels & -(${CR} * sizeof(float)));
+ const size_t y_increment = y_stride * ${MR} - channels;
+
+ const float* x0 = x;
+ float* y0 = y;
+ $for M in range(1, MR):
+ const float* x${M} = (const float*) ((uintptr_t) x${M-1} + x_stride);
+ float* y${M} = (float*) ((uintptr_t) y${M-1} + y_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(m <= ${M}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(m < ${M+1}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+
+ const float vmin = params->scalar.min;
+ const float vmax = params->scalar.max;
+ do {
+ const float* w = weights;
+ size_t c = channels;
+ for (; c >= ${CR} * sizeof(float); c -= ${CR} * sizeof(float)) {
+ $for C in range(CR):
+ const float vscale${ABC[C]} = w[${C}];
+
+ $for M in range(MR):
+ $for C in range(CR):
+ const float vx${M}x${ABC[C]} = x${M}[${C}];
+ x${M} += ${CR};
+
+ $for C in range(CR):
+ const float vbias${ABC[C]} = w[${C + CR}];
+
+ $for M in range(MR):
+ $for C in range(CR):
+ float vacc${M}x${ABC[C]} = vx${M}x${ABC[C]} * vscale${ABC[C]} + vbias${ABC[C]};
+
+ $for M in range(MR):
+ $for C in range(CR):
+ vacc${M}x${ABC[C]} = math_max_f32(vacc${M}x${ABC[C]}, vmin);
+
+ $for M in range(MR):
+ $for C in range(CR):
+ vacc${M}x${ABC[C]} = math_min_f32(vacc${M}x${ABC[C]}, vmax);
+
+ $for M in range(MR):
+ $for C in range(CR):
+ y${M}[${C}] = vacc${M}x${ABC[C]};
+ y${M} += ${CR};
+
+ w += ${CR * 2};
+ }
+ if XNN_UNLIKELY(c != 0) {
+ $for C in range(CR):
+ const float vscale${ABC[C]} = w[${C}];
+
+ $for M in range(MR):
+ $for C in range(CR):
+ const float vx${M}x${ABC[C]} = x${M}[${C}];
+ x${M} += ${CR};
+
+ $for C in range(CR):
+ const float vbias${ABC[C]} = w[${C + CR}];
+
+ $for M in range(MR):
+ $for C in range(CR):
+ float vacc${M}x${ABC[C]} = vx${M}x${ABC[C]} * vscale${ABC[C]} + vbias${ABC[C]};
+
+ $for M in range(MR):
+ $for C in range(CR):
+ vacc${M}x${ABC[C]} = math_max_f32(vacc${M}x${ABC[C]}, vmin);
+
+ $for M in range(MR):
+ $for C in range(CR):
+ vacc${M}x${ABC[C]} = math_min_f32(vacc${M}x${ABC[C]}, vmax);
+
+ w += ${CR * 2};
+
+ $for LOG2C in reversed(range(CR.bit_length() - 1)):
+ if (c & ${1 << LOG2C}) {
+ $for M in range(MR):
+ $for C in range(1 << LOG2C):
+ c${M}[${C}] = vacc${M}x${C};
+ $if LOG2C != 0:
+ $for C in range(1 << (LOG2C - 1)):
+ vacc${M}${C} = vacc${M}x${C + (1 << LOG2C)};
+ c${M} += ${1 << LOG2C};
+ }
+ }
+ $for M in range(MR):
+ x${M} = (const float*) ((uintptr_t) x${M} + x_increment);
+ y${M} = (float*) ((uintptr_t) y${M} + y_increment);
+ $if M % 2 == 1:
+ if XNN_UNPREDICTABLE(m < ${MR + M+1}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ $elif M != 0:
+ if XNN_UNPREDICTABLE(m <= ${MR + M}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ m = doz(m, ${MR});
+ } while (m != 0);
+}
diff --git a/src/f32-vmulcaddc/sse.c.in b/src/f32-vmulcaddc/sse.c.in
new file mode 100644
index 0000000..e358ec8
--- /dev/null
+++ b/src/f32-vmulcaddc/sse.c.in
@@ -0,0 +1,172 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+$assert CR % 4 == 0
+$ABC = "0123456789ABCDEFGHIJKLMN"
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/math.h>
+#include <xnnpack/vmulcaddc.h>
+
+
+void xnn_f32_vmulcaddc_ukernel_c${CR}__sse_x${MR}(
+ size_t m,
+ size_t channels,
+ const float*restrict x,
+ size_t x_stride,
+ const float*restrict weights,
+ float*restrict y,
+ size_t y_stride,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(channels != 0);
+ assert(channels % sizeof(float) == 0);
+
+ const size_t x_increment = x_stride * ${MR} - (channels & -(${CR} * sizeof(float)));
+ const size_t y_increment = y_stride * ${MR} - channels;
+
+ const float* x0 = x;
+ float* y0 = y;
+ $for M in range(1, MR):
+ const float* x${M} = (const float*) ((uintptr_t) x${M-1} + x_stride);
+ float* y${M} = (float*) ((uintptr_t) y${M-1} + y_stride);
+ $if M % 2 == 0:
+ if XNN_UNPREDICTABLE(m <= ${M}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ $else:
+ if XNN_UNPREDICTABLE(m < ${M+1}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+
+ const __m128 vmin = _mm_load_ps(params->sse.min);
+ const __m128 vmax = _mm_load_ps(params->sse.max);
+ do {
+ const float* w = weights;
+ size_t c = channels;
+ for (; c >= ${CR} * sizeof(float); c -= ${CR} * sizeof(float)) {
+ const __m128 vscale${ABC[0:4]} = _mm_load_ps(w);
+ $for C in range(4, CR, 4):
+ const __m128 vscale${ABC[C:C+4]} = _mm_load_ps(w + ${C});
+
+ $for M in range(MR):
+ const __m128 vx${M}x${ABC[0:4]} = _mm_loadu_ps(x${M});
+ $for C in range(4, CR, 4):
+ const __m128 vx${M}x${ABC[C:C+4]} = _mm_loadu_ps(x${M} + ${C});
+ x${M} += ${CR};
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ __m128 vacc${M}x${ABC[C:C+4]} = _mm_mul_ps(vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
+
+ $for C in range(0, CR, 4):
+ const __m128 vbias${ABC[C:C+4]} = _mm_load_ps(w + ${C + CR});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = _mm_add_ps(vacc${M}x${ABC[C:C+4]}, vbias${ABC[C:C+4]});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = _mm_max_ps(vacc${M}x${ABC[C:C+4]}, vmin);
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = _mm_min_ps(vacc${M}x${ABC[C:C+4]}, vmax);
+
+ $for M in range(MR):
+ _mm_storeu_ps(y${M}, vacc${M}x${ABC[0:4]});
+ $for C in range(4, CR, 4):
+ _mm_storeu_ps(y${M} + ${C}, vacc${M}x${ABC[C:C+4]});
+ y${M} += ${CR};
+
+ w += ${CR * 2};
+ }
+ if XNN_UNLIKELY(c != 0) {
+ const __m128 vscale${ABC[0:4]} = _mm_load_ps(w);
+ $for C in range(4, CR, 4):
+ const __m128 vscale${ABC[C:C+4]} = _mm_load_ps(w + ${C});
+
+ $for M in range(MR):
+ const __m128 vx${M}x${ABC[0:4]} = _mm_loadu_ps(x${M});
+ $for C in range(4, CR, 4):
+ const __m128 vx${M}x${ABC[C:C+4]} = _mm_loadu_ps(x${M} + ${C});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ __m128 vacc${M}x${ABC[C:C+4]} = _mm_mul_ps(vx${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]});
+
+ $for C in range(0, CR, 4):
+ const __m128 vbias${ABC[C:C+4]} = _mm_load_ps(w + ${C + CR});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = _mm_add_ps(vacc${M}x${ABC[C:C+4]}, vbias${ABC[C:C+4]});
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = _mm_max_ps(vacc${M}x${ABC[C:C+4]}, vmin);
+
+ $for M in range(MR):
+ $for C in range(0, CR, 4):
+ vacc${M}x${ABC[C:C+4]} = _mm_min_ps(vacc${M}x${ABC[C:C+4]}, vmax);
+
+ w += ${CR * 2};
+
+ $for LOG2C in reversed(range(CR.bit_length())):
+ $if CR != 1 << LOG2C:
+ if (c & (${1 << LOG2C} * sizeof(float))) {
+ $if LOG2C >= 2:
+ $for M in range(MR):
+ _mm_storeu_ps(y${M}, vacc${M}x${ABC[0:4]});
+ $for C in range(4, 1 << LOG2C, 4):
+ _mm_storeu_ps(y${M} + ${C}, vacc${M}x${ABC[C:C+4]});
+
+ $for M in range(MR):
+ y${M} += ${1 << LOG2C};
+
+ $for M in range(MR):
+ $for C in range(0, 1 << (LOG2C - 1), 4):
+ vacc${M}x${ABC[C:C+4]} = vacc${M}x${ABC[C+(1<<LOG2C):C+(1<<LOG2C)+4]}
+ $elif LOG2C == 1:
+ $for M in range(MR):
+ _mm_storel_pi((__m64*) y${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in range(MR):
+ y${M} += 2;
+
+ $for M in range(MR):
+ vacc${M}x${ABC[0:4]} = _mm_movehl_ps(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
+ $elif LOG2C == 0:
+ $for M in range(MR):
+ _mm_store_ss(y${M}, vacc${M}x${ABC[0:4]});
+
+ $for M in range(MR):
+ y${M} += 1;
+ }
+ }
+ $for M in range(MR):
+ x${M} = (const float*) ((uintptr_t) x${M} + x_increment);
+ y${M} = (float*) ((uintptr_t) y${M} + y_increment);
+ $if M % 2 == 1:
+ if XNN_UNPREDICTABLE(m < ${MR + M+1}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ $elif M != 0:
+ if XNN_UNPREDICTABLE(m <= ${MR + M}) {
+ x${M} = x${M-1};
+ y${M} = y${M-1};
+ }
+ m = doz(m, ${MR});
+ } while (m != 0);
+}
diff --git a/src/f32-vsub/psimd.c b/src/f32-vsub/psimd.c
new file mode 100644
index 0000000..bdf701a
--- /dev/null
+++ b/src/f32-vsub/psimd.c
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/vsub.h>
+
+
+void xnn_f32_vsub_ukernel__psimd(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const psimd_f32 vy_min = psimd_load_splat_f32(¶ms->scalar.min);
+ const psimd_f32 vy_max = psimd_load_splat_f32(¶ms->scalar.max);
+
+ for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
+ const psimd_f32 va0 = psimd_load_f32(a);
+ const psimd_f32 va1 = psimd_load_f32(a + 4);
+ a += 8;
+
+ const psimd_f32 vb0 = psimd_load_f32(b);
+ const psimd_f32 vb1 = psimd_load_f32(b + 4);
+ b += 8;
+
+ const psimd_f32 vacc0 = psimd_sub_f32(va0, vb0);
+ const psimd_f32 vacc1 = psimd_sub_f32(va1, vb1);
+ const psimd_f32 vy0 = psimd_min_f32(psimd_max_f32(vacc0, vy_min), vy_max);
+ const psimd_f32 vy1 = psimd_min_f32(psimd_max_f32(vacc1, vy_min), vy_max);
+
+ psimd_store_f32(y, vy0);
+ psimd_store_f32(y + 4, vy1);
+ y += 8;
+ }
+ if (n >= 4 * sizeof(float)) {
+ const psimd_f32 va = psimd_load_f32(a);
+ a += 4;
+ const psimd_f32 vb = psimd_load_f32(b);
+ b += 4;
+ const psimd_f32 vacc = psimd_sub_f32(va, vb);
+ const psimd_f32 vy = psimd_min_f32(psimd_max_f32(vacc, vy_min), vy_max);
+ psimd_store_f32(y, vy);
+ y += 4;
+ n -= 4 * sizeof(float);
+ }
+ if (n != 0) {
+ const psimd_f32 va = psimd_load_f32(a);
+ const psimd_f32 vb = psimd_load_f32(b);
+ const psimd_f32 vacc = psimd_sub_f32(va, vb);
+ psimd_f32 vy = psimd_min_f32(psimd_max_f32(vacc, vy_min), vy_max);
+ if (n & 2 * sizeof(float)) {
+ psimd_store2_f32(y, vy);
+ vy = psimd_concat_hi_f32(vy, vy);
+ y += 2;
+ }
+ if (n & 1 * sizeof(float)) {
+ psimd_store1_f32(y, vy);
+ }
+ }
+}
diff --git a/src/f32-vsub/scalar.c b/src/f32-vsub/scalar.c
new file mode 100644
index 0000000..a53b419
--- /dev/null
+++ b/src/f32-vsub/scalar.c
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/vsub.h>
+
+
+void xnn_f32_vsub_ukernel__scalar(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const float vy_min = params->scalar.min;
+ const float vy_max = params->scalar.max;
+
+ for (; n >= 2 * sizeof(float); n -= 2 * sizeof(float)) {
+ const float va0 = a[0];
+ const float va1 = a[1];
+ a += 2;
+
+ const float vb0 = b[0];
+ const float vb1 = b[1];
+ b += 2;
+
+ float vy0 = va0 - vb0;
+ float vy1 = va1 - vb1;
+ vy0 = math_max_f32(vy0, vy_min);
+ vy1 = math_max_f32(vy1, vy_min);
+ vy0 = math_min_f32(vy0, vy_max);
+ vy1 = math_min_f32(vy1, vy_max);
+
+ y[0] = vy0;
+ y[1] = vy1;
+ y += 2;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ const float va = *a;
+ const float vb = *b;
+ float vy = va - vb;
+ vy = math_max_f32(vy, vy_min);
+ vy = math_min_f32(vy, vy_max);
+ *y = vy;
+ }
+}
diff --git a/src/f32-vsub/sse.c b/src/f32-vsub/sse.c
new file mode 100644
index 0000000..0722622
--- /dev/null
+++ b/src/f32-vsub/sse.c
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/vsub.h>
+
+
+void xnn_f32_vsub_ukernel__sse(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(n % sizeof(float) == 0);
+
+ const __m128 vy_min = _mm_load_ps(params->sse.min);
+ const __m128 vy_max = _mm_load_ps(params->sse.max);
+
+ for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
+ const __m128 va0 = _mm_loadu_ps(a);
+ const __m128 va1 = _mm_loadu_ps(a + 4);
+ a += 8;
+
+ const __m128 vb0 = _mm_loadu_ps(b);
+ const __m128 vb1 = _mm_loadu_ps(b + 4);
+ b += 8;
+
+ const __m128 vacc0 = _mm_sub_ps(va0, vb0);
+ const __m128 vacc1 = _mm_sub_ps(va1, vb1);
+ const __m128 vy0 = _mm_min_ps(_mm_max_ps(vacc0, vy_min), vy_max);
+ const __m128 vy1 = _mm_min_ps(_mm_max_ps(vacc1, vy_min), vy_max);
+
+ _mm_storeu_ps(y, vy0);
+ _mm_storeu_ps(y + 4, vy1);
+ y += 8;
+ }
+ if (n >= 4 * sizeof(float)) {
+ const __m128 va = _mm_loadu_ps(a);
+ a += 4;
+ const __m128 vb = _mm_loadu_ps(b);
+ b += 4;
+ const __m128 vacc = _mm_sub_ps(va, vb);
+ const __m128 vy = _mm_min_ps(_mm_max_ps(vacc, vy_min), vy_max);
+ _mm_storeu_ps(y, vy);
+ y += 4;
+ n -= 4 * sizeof(float);
+ }
+ if (n != 0) {
+ const __m128 va = _mm_loadu_ps(a);
+ const __m128 vb = _mm_loadu_ps(b);
+ const __m128 vacc = _mm_sub_ps(va, vb);
+ __m128 vy = _mm_min_ps(_mm_max_ps(vacc, vy_min), vy_max);
+ if (n & 2 * sizeof(float)) {
+ _mm_storel_pi((__m64*) y, vy);
+ vy = _mm_movehl_ps(vy, vy);
+ y += 2;
+ }
+ if (n & 1 * sizeof(float)) {
+ _mm_store_ss(y, vy);
+ }
+ }
+}
diff --git a/src/fully-connected.c b/src/fully-connected.c
new file mode 100644
index 0000000..b1e9dc9
--- /dev/null
+++ b/src/fully-connected.c
@@ -0,0 +1,437 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <string.h>
+#include <math.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/requantization.h>
+#include <xnnpack/log.h>
+#include <xnnpack/math.h>
+#include <xnnpack/pack.h>
+#include <xnnpack/params.h>
+
+
+enum xnn_status xnn_create_fully_connected_nc_q8(
+ size_t input_channels,
+ size_t output_channels,
+ size_t input_stride,
+ size_t output_stride,
+ uint8_t input_zero_point,
+ float input_scale,
+ uint8_t kernel_zero_point,
+ float kernel_scale,
+ const uint8_t* kernel,
+ const int32_t* bias,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* fully_connected_op_out)
+{
+ xnn_operator_t fully_connected_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Fully Connected operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (input_channels == 0) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with %zu input channels: number of channels must be non-zero",
+ input_channels);
+ goto error;
+ }
+
+ if (output_channels == 0) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with %zu output channels: number of channels must be non-zero",
+ output_channels);
+ goto error;
+ }
+
+ if (input_stride < input_channels) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with input element stride of %zu: "
+ "stride must be at least as large as the number of input channels (%zu)",
+ input_stride, input_channels);
+ goto error;
+ }
+
+ if (output_stride < output_channels) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with output element stride of %zu: "
+ "stride must be at least as large as the number of output channels (%zu)",
+ output_stride, output_channels);
+ goto error;
+ }
+
+ if (input_scale <= 0.0f || !isnormal(input_scale)) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with %.7g input scale: scale must be finite, normalized, and positive",
+ input_scale);
+ goto error;
+ }
+
+ if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with %.7g kernel scale: scale must be finite, normalized, and positive",
+ kernel_scale);
+ goto error;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with %.7g output scale: scale must be finite, normalized, and positive",
+ output_scale);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with [%" PRIu8 ", %" PRIu8 "] output range: "
+ "range min must be below range max",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ const float requantization_scale = input_scale * kernel_scale / output_scale;
+ if (requantization_scale >= 1.0f) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
+ "requantization scale %.7g is greater or equal to 1.0",
+ input_scale, kernel_scale, output_scale, requantization_scale);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ fully_connected_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (fully_connected_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Fully Connected operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ const uint32_t nr = xnn_params.q8.gemm.nr;
+ const uint32_t kr = UINT32_C(1) << xnn_params.q8.gemm.log2_kr;
+
+ const uint32_t n_stride = round_up(output_channels, nr);
+ const uint32_t k_stride = round_up_po2(input_channels, kr);
+
+ fully_connected_op->packed_weights = xnn_allocate_memory(n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
+ if (fully_connected_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights",
+ n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
+ goto error;
+ }
+ memset(fully_connected_op->packed_weights, kernel_zero_point, n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
+
+ xnn_pack_q8_gemm_goi_w(
+ 1, output_channels, input_channels,
+ nr, kr,
+ input_zero_point, kernel_zero_point,
+ kernel, bias,
+ fully_connected_op->packed_weights);
+
+ fully_connected_op->group_input_channels = input_channels;
+ fully_connected_op->group_output_channels = output_channels;
+ fully_connected_op->input_pixel_stride = input_stride;
+ fully_connected_op->output_pixel_stride = output_stride;
+
+ fully_connected_op->kernel_zero_point = kernel_zero_point;
+
+ fully_connected_op->q8_gemm_params =
+ xnn_compute_q8_gemm_params(
+ input_zero_point, kernel_zero_point,
+ requantization_scale, output_zero_point, output_min, output_max);
+
+ fully_connected_op->type = xnn_operator_type_fully_connected_q8;
+
+ fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
+ fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
+ .default_function = xnn_params.q8.gemm.gemm,
+ .mr = xnn_params.q8.gemm.mr,
+ .nr = nr,
+ .kr = kr,
+ };
+
+ fully_connected_op->state = xnn_run_state_invalid;
+
+ *fully_connected_op_out = fully_connected_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(fully_connected_op);
+ return status;
+}
+
+enum xnn_status xnn_create_fully_connected_nc_f32(
+ size_t input_channels,
+ size_t output_channels,
+ size_t input_stride,
+ size_t output_stride,
+ const float* kernel,
+ const float* bias,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* fully_connected_op_out)
+{
+ xnn_operator_t fully_connected_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Fully Connected operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (input_channels == 0) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with %zu input channels: number of channels must be non-zero",
+ input_channels);
+ goto error;
+ }
+
+ if (output_channels == 0) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with %zu output channels: number of channels must be non-zero",
+ output_channels);
+ goto error;
+ }
+
+ if (input_stride < input_channels) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with input element stride of %zu: "
+ "stride must be at least as large as the number of input channels (%zu)",
+ input_stride, input_channels);
+ goto error;
+ }
+
+ if (output_stride < output_channels) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with output element stride of %zu: "
+ "stride must be at least as large as the number of output channels (%zu)",
+ output_stride, output_channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Fully Connected operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ fully_connected_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (fully_connected_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Fully Connected operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ const uint32_t nr = xnn_params.f32.gemm.nr;
+ const uint32_t kr = UINT32_C(1) << xnn_params.f32.gemm.log2_kr;
+
+ const uint32_t n_stride = round_up(output_channels, nr);
+ const uint32_t k_stride = round_up_po2(input_channels, kr);
+
+ fully_connected_op->packed_weights = xnn_allocate_memory(n_stride * (k_stride * sizeof(float) + sizeof(float)));
+ if (fully_connected_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed weights",
+ n_stride * (k_stride * sizeof(float) + sizeof(float)));
+ goto error;
+ }
+ memset(fully_connected_op->packed_weights, 0, n_stride * (k_stride * sizeof(float) + sizeof(float)));
+
+ xnn_pack_f32_gemm_goi_w(
+ 1, output_channels, input_channels,
+ nr, kr, 1 /* sr */,
+ kernel, bias,
+ fully_connected_op->packed_weights);
+
+ fully_connected_op->group_input_channels = input_channels;
+ fully_connected_op->group_output_channels = output_channels;
+ fully_connected_op->input_pixel_stride = input_stride;
+ fully_connected_op->output_pixel_stride = output_stride;
+
+ fully_connected_op->f32_output_params = xnn_compute_f32_output_params(output_min, output_max);
+
+ fully_connected_op->type = xnn_operator_type_fully_connected_f32;
+
+ fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
+ fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
+ .default_function = xnn_params.f32.gemm.gemm,
+ .mr1_function = xnn_params.f32.gemm.gemm1,
+ .mr = xnn_params.f32.gemm.mr,
+ .nr = nr,
+ .kr = kr,
+ };
+
+ fully_connected_op->state = xnn_run_state_invalid;
+
+ *fully_connected_op_out = fully_connected_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(fully_connected_op);
+ return status;
+}
+
+static enum xnn_status setup_fully_connected_nc(
+ xnn_operator_t fully_connected_op,
+ size_t batch_size,
+ const void* input,
+ void* output,
+ uint32_t log2_input_element_size,
+ uint32_t log2_filter_element_size,
+ uint32_t bias_element_size,
+ uint32_t log2_output_element_size,
+ const void* params,
+ size_t num_threads)
+{
+ fully_connected_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Fully Connected operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ fully_connected_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ fully_connected_op->batch_size = 1;
+ fully_connected_op->input_height = batch_size;
+ fully_connected_op->input_width = 1;
+ fully_connected_op->input = input;
+
+ fully_connected_op->output_height = batch_size;
+ fully_connected_op->output_width = 1;
+ fully_connected_op->output = output;
+
+ const size_t input_channels = fully_connected_op->group_input_channels;
+ const size_t output_channels = fully_connected_op->group_output_channels;
+
+ uint32_t mr = fully_connected_op->ukernel.gemm.mr;
+ const uint32_t nr = fully_connected_op->ukernel.gemm.nr;
+
+ xnn_gemm_ukernel_function gemm_ukernel = fully_connected_op->ukernel.gemm.default_function;
+ if (batch_size == 1 && fully_connected_op->ukernel.gemm.mr1_function != NULL) {
+ gemm_ukernel = fully_connected_op->ukernel.gemm.mr1_function;
+ mr = 1;
+ }
+
+ fully_connected_op->context.gemm = (struct gemm_context) {
+ .k_scaled = input_channels << log2_input_element_size,
+ .w_stride = (round_up_po2(input_channels, fully_connected_op->ukernel.gemm.kr) << log2_input_element_size) + bias_element_size,
+ .a = input,
+ .a_stride = fully_connected_op->input_pixel_stride << log2_input_element_size,
+ .packed_w = fully_connected_op->packed_weights,
+ .c = output,
+ .cm_stride = fully_connected_op->output_pixel_stride << log2_output_element_size,
+ .cn_stride = nr << log2_output_element_size,
+ .log2_csize = log2_output_element_size,
+ .ukernel = gemm_ukernel,
+ };
+ memcpy(&fully_connected_op->context.gemm.params, params, sizeof(fully_connected_op->context.gemm.params));
+
+ size_t nc = output_channels;
+ if (num_threads > 1) {
+ const size_t num_other_tiles = divide_round_up(batch_size, mr);
+ const size_t target_tiles_per_thread = 5;
+ const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
+ if (max_nc < nc) {
+ nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
+ }
+ }
+ fully_connected_op->compute.type = xnn_parallelization_type_2d_tile_2d;
+ fully_connected_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
+ fully_connected_op->compute.range[0] = batch_size;
+ fully_connected_op->compute.range[1] = output_channels;
+ fully_connected_op->compute.tile[0] = mr;
+ fully_connected_op->compute.tile[1] = nc;
+ fully_connected_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_setup_fully_connected_nc_q8(
+ xnn_operator_t fully_connected_op,
+ size_t batch_size,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (fully_connected_op->type != xnn_operator_type_fully_connected_q8) {
+ xnn_log_error("failed to setup Fully Connected (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_fully_connected_nc(
+ fully_connected_op,
+ batch_size,
+ input, output,
+ 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
+ 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
+ sizeof(int32_t) /* sizeof(bias element) */,
+ 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
+ &fully_connected_op->q8_gemm_params,
+ pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_fully_connected_nc_f32(
+ xnn_operator_t fully_connected_op,
+ size_t batch_size,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (fully_connected_op->type != xnn_operator_type_fully_connected_f32) {
+ xnn_log_error("failed to setup Fully Connected (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+
+ return setup_fully_connected_nc(
+ fully_connected_op,
+ batch_size,
+ input, output,
+ 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
+ 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
+ sizeof(float) /* sizeof(bias element) */,
+ 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
+ &fully_connected_op->f32_output_params,
+ pthreadpool_get_threads_count(threadpool));
+}
diff --git a/src/global-average-pooling-spnchw.c b/src/global-average-pooling-spnchw.c
new file mode 100644
index 0000000..d85f9e2
--- /dev/null
+++ b/src/global-average-pooling-spnchw.c
@@ -0,0 +1,163 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/requantization.h>
+#include <xnnpack/log.h>
+#include <xnnpack/params.h>
+
+
+enum xnn_status xnn_create_global_average_pooling_spnchw_f32(
+ size_t channels,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* global_average_pooling_op_out)
+{
+ xnn_operator_t global_average_pooling_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Global Average Pooling operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with %zu channels: number of channels must be non-zero",
+ channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with [%.7g, %.7g] output range: "
+ "lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+ if (xnn_params.f32.spchw_gavgpool.ukernel == NULL) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator: "
+ "only selected configurations parameters are supported");
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ global_average_pooling_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (global_average_pooling_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Global Average Pooling operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ global_average_pooling_op->channels = channels;
+ global_average_pooling_op->f32_gavgpool_params =
+ xnn_compute_f32_gavgpool_params(nanf(""), output_min, output_max, 0);
+
+ global_average_pooling_op->type = xnn_operator_type_global_average_pooling_spnchw_f32;
+ global_average_pooling_op->ukernel.type = xnn_ukernel_type_global_average_pooling;
+
+ global_average_pooling_op->state = xnn_run_state_invalid;
+
+ *global_average_pooling_op_out = global_average_pooling_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(global_average_pooling_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_global_average_pooling_spnchw_f32(
+ xnn_operator_t global_average_pooling_op,
+ size_t batch_size,
+ size_t height,
+ size_t width,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (global_average_pooling_op->type != xnn_operator_type_global_average_pooling_spnchw_f32) {
+ xnn_log_error("failed to setup Global Average Pooling (F32, SpNCHW) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ global_average_pooling_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Global Average Pooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (height == 0) {
+ xnn_log_error("failed to setup Global Average Pooling operator with height %zu: height must be non-zero", height);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (width == 0) {
+ xnn_log_error("failed to setup Global Average Pooling operator with width %zu: width must be non-zero", width);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ global_average_pooling_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ global_average_pooling_op->batch_size = batch_size;
+ global_average_pooling_op->input_height = height;
+ global_average_pooling_op->input_width = width;
+ global_average_pooling_op->input = input;
+ global_average_pooling_op->output = output;
+
+ xnn_update_f32_gavgpool_params(&global_average_pooling_op->f32_gavgpool_params,
+ 1.0f / (float) (width * height), width * height);
+
+ global_average_pooling_op->context.global_average_pooling_spnchw = (struct global_average_pooling_spnchw_context) {
+ .input_elements = width * height * sizeof(float),
+ .input = input,
+ .input_channel_stride = width * height * sizeof(float),
+ .input_batch_stride = global_average_pooling_op->channels * width * height * sizeof(float),
+ .output = output,
+ .output_channel_stride = sizeof(float),
+ .output_batch_stride = global_average_pooling_op->channels * sizeof(float),
+ .ukernel = xnn_params.f32.spchw_gavgpool.ukernel,
+ .params.f32 = global_average_pooling_op->f32_gavgpool_params,
+ };
+
+ global_average_pooling_op->compute.type = xnn_parallelization_type_2d_tile_1d;
+ global_average_pooling_op->compute.task_2d_tile_1d =
+ (pthreadpool_task_2d_tile_1d_t) xnn_compute_global_average_pooling_spnchw;
+ global_average_pooling_op->compute.range[0] = batch_size;
+ global_average_pooling_op->compute.range[1] = global_average_pooling_op->channels;
+ global_average_pooling_op->compute.tile[0] = global_average_pooling_op->channels; //xnn_params.f32.spchw_gavgpool.channel_tile;
+
+ global_average_pooling_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/global-average-pooling.c b/src/global-average-pooling.c
new file mode 100644
index 0000000..79b35d1
--- /dev/null
+++ b/src/global-average-pooling.c
@@ -0,0 +1,372 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/requantization.h>
+#include <xnnpack/log.h>
+#include <xnnpack/params.h>
+
+
+enum xnn_status xnn_create_global_average_pooling_nwc_q8(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ uint8_t input_zero_point,
+ float input_scale,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* global_average_pooling_op_out)
+{
+ xnn_operator_t global_average_pooling_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Global Average Pooling operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with %zu channels: number of channels must be non-zero",
+ channels);
+ goto error;
+ }
+
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_stride, channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_stride, channels);
+ goto error;
+ }
+
+ if (input_scale <= 0.0f || !isnormal(input_scale)) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with %.7g input scale: "
+ "scale must be finite, normalized, and positive",
+ input_scale);
+ goto error;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with %.7g output scale: "
+ "scale must be finite, normalized, and positive",
+ output_scale);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with [%" PRIu8 ", %" PRIu8 "] output range: "
+ "range min must be below range max",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ const float input_output_scale = input_scale / output_scale;
+ if (input_output_scale < 0x1.0p-8f || input_output_scale >= 0x1.0p+8f) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with %.7g input-to-output scale ratio: "
+ "scale ratio must be in [2**-8, 2**8) range",
+ input_output_scale);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ global_average_pooling_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (global_average_pooling_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Global Average Pooling operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ void* zero_buffer = xnn_allocate_zero_memory(channels * sizeof(uint8_t) + XNN_EXTRA_BYTES);
+ if (zero_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Global Average Pooling zero padding",
+ channels * sizeof(uint8_t) + XNN_EXTRA_BYTES);
+ goto error;
+ }
+ global_average_pooling_op->zero_buffer = zero_buffer;
+
+ global_average_pooling_op->channels = channels;
+ global_average_pooling_op->input_pixel_stride = input_stride;
+ global_average_pooling_op->output_pixel_stride = output_stride;
+ global_average_pooling_op->input_zero_point = input_zero_point;
+ global_average_pooling_op->output_zero_point = output_zero_point;
+ global_average_pooling_op->input_scale = input_scale;
+ global_average_pooling_op->output_scale = output_scale;
+ global_average_pooling_op->output_min = output_min;
+ global_average_pooling_op->output_max = output_max;
+
+ global_average_pooling_op->type = xnn_operator_type_global_average_pooling_q8;
+ global_average_pooling_op->ukernel.type = xnn_ukernel_type_global_average_pooling;
+
+ global_average_pooling_op->state = xnn_run_state_invalid;
+
+ *global_average_pooling_op_out = global_average_pooling_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(global_average_pooling_op);
+ return status;
+}
+
+enum xnn_status xnn_create_global_average_pooling_nwc_f32(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* global_average_pooling_op_out)
+{
+ xnn_operator_t global_average_pooling_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Global Average Pooling operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with %zu channels: number of channels must be non-zero",
+ channels);
+ goto error;
+ }
+
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_stride, channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_stride, channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Global Average Pooling operator with [%.7g, %.7g] output range: "
+ "lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ global_average_pooling_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (global_average_pooling_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Global Average Pooling operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ void* zero_buffer = xnn_allocate_zero_memory(channels * sizeof(float) + XNN_EXTRA_BYTES);
+ if (zero_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Global Average Pooling zero padding",
+ channels * sizeof(float) + XNN_EXTRA_BYTES);
+ goto error;
+ }
+ global_average_pooling_op->zero_buffer = zero_buffer;
+
+ global_average_pooling_op->channels = channels;
+ global_average_pooling_op->input_pixel_stride = input_stride;
+ global_average_pooling_op->output_pixel_stride = output_stride;
+ global_average_pooling_op->f32_avgpool_params =
+ xnn_compute_f32_avgpool_params(nanf(""), output_min, output_max);
+
+ global_average_pooling_op->type = xnn_operator_type_global_average_pooling_f32;
+ global_average_pooling_op->ukernel.type = xnn_ukernel_type_global_average_pooling;
+
+ global_average_pooling_op->state = xnn_run_state_invalid;
+
+ *global_average_pooling_op_out = global_average_pooling_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(global_average_pooling_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_global_average_pooling_nwc_q8(
+ xnn_operator_t global_average_pooling_op,
+ size_t batch_size,
+ size_t width,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (global_average_pooling_op->type != xnn_operator_type_global_average_pooling_q8) {
+ xnn_log_error("failed to setup Global Average Pooling (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ global_average_pooling_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Global Average Pooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (width == 0) {
+ xnn_log_error("failed to setup Global Average Pooling operator with width %zu: width must be non-zero", width);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ global_average_pooling_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ global_average_pooling_op->batch_size = batch_size;
+ global_average_pooling_op->input_width = width;
+ global_average_pooling_op->input = input;
+ global_average_pooling_op->output = output;
+
+ global_average_pooling_op->q8_avgpool_params =
+ xnn_compute_q8_avgpool_params(
+ -(int32_t) width * (int32_t) (uint32_t) global_average_pooling_op->input_zero_point,
+ global_average_pooling_op->input_scale / (global_average_pooling_op->output_scale * (float) width),
+ global_average_pooling_op->output_zero_point,
+ global_average_pooling_op->output_min,
+ global_average_pooling_op->output_max);
+
+ const size_t input_stride_in_bytes = global_average_pooling_op->input_pixel_stride * sizeof(uint8_t);
+ const size_t channels = global_average_pooling_op->channels;
+ global_average_pooling_op->context.global_average_pooling = (struct global_average_pooling_context) {
+ .input = input,
+ .zero = global_average_pooling_op->zero_buffer,
+ .input_pixel_stride = input_stride_in_bytes,
+ .input_batch_stride = input_stride_in_bytes * width,
+ .input_elements = width,
+ .channels = channels,
+ .output = output,
+ .output_batch_stride = global_average_pooling_op->output_pixel_stride * sizeof(uint8_t),
+ .params.q8 = global_average_pooling_op->q8_avgpool_params,
+ };
+ global_average_pooling_op->compute.type = xnn_parallelization_type_1d;
+ global_average_pooling_op->compute.range[0] = batch_size;
+
+ if (width <= xnn_params.q8.gavgpool.mr) {
+ global_average_pooling_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_global_average_pooling_unipass;
+ global_average_pooling_op->context.global_average_pooling.unipass_ukernel = xnn_params.q8.gavgpool.up;
+ } else {
+ global_average_pooling_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_global_average_pooling_multipass;
+ global_average_pooling_op->context.global_average_pooling.multipass_ukernel = xnn_params.q8.gavgpool.mp;
+ }
+ global_average_pooling_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_setup_global_average_pooling_nwc_f32(
+ xnn_operator_t global_average_pooling_op,
+ size_t batch_size,
+ size_t width,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (global_average_pooling_op->type != xnn_operator_type_global_average_pooling_f32) {
+ xnn_log_error("failed to setup Global Average Pooling (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ global_average_pooling_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Global Average Pooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (width == 0) {
+ xnn_log_error("failed to setup Global Average Pooling operator with width %zu: width must be non-zero", width);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ global_average_pooling_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ global_average_pooling_op->batch_size = batch_size;
+ global_average_pooling_op->input_width = width;
+ global_average_pooling_op->input = input;
+ global_average_pooling_op->output = output;
+
+ xnn_update_f32_avgpool_params(&global_average_pooling_op->f32_avgpool_params, 1.0f / (float) width);
+
+ const size_t input_stride_in_bytes = global_average_pooling_op->input_pixel_stride * sizeof(float);
+ const size_t channels = global_average_pooling_op->channels;
+ global_average_pooling_op->context.global_average_pooling = (struct global_average_pooling_context) {
+ .input = input,
+ .zero = global_average_pooling_op->zero_buffer,
+ .input_pixel_stride = input_stride_in_bytes,
+ .input_batch_stride = input_stride_in_bytes * width,
+ .input_elements = width,
+ .channels = channels,
+ .output = output,
+ .output_batch_stride = global_average_pooling_op->output_pixel_stride * sizeof(float),
+ .params.f32 = global_average_pooling_op->f32_avgpool_params,
+ };
+ global_average_pooling_op->compute.type = xnn_parallelization_type_1d;
+ global_average_pooling_op->compute.range[0] = batch_size;
+
+ if (width <= xnn_params.f32.gavgpool.mr) {
+ global_average_pooling_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_global_average_pooling_unipass;
+ global_average_pooling_op->context.global_average_pooling.unipass_ukernel = xnn_params.f32.gavgpool.up;
+ } else {
+ global_average_pooling_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_global_average_pooling_multipass;
+ global_average_pooling_op->context.global_average_pooling.multipass_ukernel = xnn_params.f32.gavgpool.mp;
+ }
+ global_average_pooling_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/hardswish.c b/src/hardswish.c
new file mode 100644
index 0000000..64b151a
--- /dev/null
+++ b/src/hardswish.c
@@ -0,0 +1,140 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/requantization.h>
+#include <xnnpack/log.h>
+
+
+enum xnn_status xnn_create_hardswish_nc_f32(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ uint32_t flags,
+ xnn_operator_t* hardswish_op_out)
+{
+ xnn_operator_t hardswish_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create HardSwish operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create HardSwish operator with %zu channels: number of channels must be non-zero", channels);
+ goto error;
+ }
+
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create HardSwish operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_stride, channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create HardSwish operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_stride, channels);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ hardswish_op = calloc(1, sizeof(struct xnn_operator));
+ if (hardswish_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for xnn_operator structure", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ hardswish_op->channels = channels;
+ hardswish_op->input_pixel_stride = input_stride;
+ hardswish_op->output_pixel_stride = output_stride;
+ hardswish_op->f32_hswish_params = xnn_compute_f32_hswish_params();
+
+ hardswish_op->type = xnn_operator_type_hswish_f32;
+ hardswish_op->ukernel.type = xnn_ukernel_type_hswish;
+
+ hardswish_op->state = xnn_run_state_invalid;
+
+ *hardswish_op_out = hardswish_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(hardswish_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_hardswish_nc_f32(
+ xnn_operator_t hardswish_op,
+ size_t batch_size,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (hardswish_op->type != xnn_operator_type_hswish_f32) {
+ xnn_log_error("failed to setup HardSwish (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ hardswish_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup HardSwish operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ hardswish_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ const size_t channels = hardswish_op->channels;
+ const size_t input_stride = hardswish_op->input_pixel_stride;
+ const size_t output_stride = hardswish_op->output_pixel_stride;
+ if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
+ const size_t block_size = 4096;
+ hardswish_op->context.univector_contiguous = (struct univector_contiguous_context) {
+ .x = input,
+ .x_stride = input_stride * sizeof(float),
+ .y = output,
+ .y_stride = output_stride * sizeof(float),
+ .ukernel = xnn_params.f32.hswish,
+ .params.f32_hswish = hardswish_op->f32_hswish_params,
+ };
+ hardswish_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ hardswish_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
+ hardswish_op->compute.range[0] = batch_size * channels * sizeof(float);
+ hardswish_op->compute.tile[0] = block_size;
+ } else {
+ hardswish_op->context.univector_strided = (struct univector_strided_context) {
+ .n = channels * sizeof(float),
+ .x = input,
+ .x_stride = input_stride * sizeof(float),
+ .y = output,
+ .y_stride = output_stride * sizeof(float),
+ .ukernel = xnn_params.f32.hswish,
+ .params.f32_hswish = hardswish_op->f32_hswish_params,
+ };
+ hardswish_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ hardswish_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
+ hardswish_op->compute.range[0] = batch_size;
+ hardswish_op->compute.tile[0] = 1;
+ }
+ hardswish_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/im2col.c b/src/im2col.c
new file mode 100644
index 0000000..ca7a639
--- /dev/null
+++ b/src/im2col.c
@@ -0,0 +1,52 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <stddef.h>
+#include <string.h>
+
+#include <xnnpack/im2col.h>
+
+
+void xnn_im2col_conv2d(
+ size_t output_height,
+ size_t output_width,
+ size_t kernel_height,
+ size_t kernel_width,
+ size_t subsampling_height,
+ size_t subsampling_width,
+ size_t dilation_height,
+ size_t dilation_width,
+ size_t input_width,
+ size_t input_padding_top,
+ size_t input_padding_left,
+ size_t group_input_channels_in_bytes,
+ size_t input_pixel_stride_in_bytes,
+ const void* input,
+ void* output)
+{
+ for (size_t output_y = 0; output_y < output_height; output_y++) {
+ for (size_t output_x = 0; output_x < output_width; output_x++) {
+ for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) {
+ const size_t input_y = output_y * subsampling_height + kernel_y * dilation_height - input_padding_top;
+ if (input_y < output_height) {
+ for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) {
+ const size_t input_x = output_x * subsampling_width + kernel_x * dilation_width - input_padding_left;
+ if (input_x < output_width) {
+ memcpy(output,
+ (const void*) ((uintptr_t) input + (input_y * input_width + input_x) * input_pixel_stride_in_bytes),
+ group_input_channels_in_bytes);
+ } else {
+ memset(output, 0, group_input_channels_in_bytes);
+ }
+ output = (void*) ((uintptr_t) output + group_input_channels_in_bytes);
+ }
+ } else {
+ memset(output, 0, kernel_width * group_input_channels_in_bytes);
+ output = (void*) ((uintptr_t) output + kernel_width * group_input_channels_in_bytes);
+ }
+ }
+ }
+ }
+}
diff --git a/src/indirection.c b/src/indirection.c
new file mode 100644
index 0000000..c1e7dfc
--- /dev/null
+++ b/src/indirection.c
@@ -0,0 +1,327 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <stddef.h>
+
+#include <fxdiv.h>
+
+#include <xnnpack/indirection.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/math.h>
+
+
+void xnn_indirection_init_conv2d(
+ xnn_operator_t op,
+ size_t output_tile_size,
+ uint32_t log2_element_size)
+{
+ const void** indirection_buffer = op->indirection_buffer;
+ const void* input = op->input;
+ const void* zero = op->zero_buffer;
+ const size_t input_pixel_stride = op->input_pixel_stride << log2_element_size;
+ const size_t input_height = op->input_height;
+ const size_t input_width = op->input_width;
+ const size_t output_height = op->output_height;
+ const size_t output_width = op->output_width;
+ const size_t kernel_height = op->kernel_height;
+ const size_t kernel_width = op->kernel_width;
+ const size_t stride_height = op->stride_height;
+ const size_t stride_width = op->stride_width;
+ const size_t dilation_height = op->dilation_height;
+ const size_t dilation_width = op->dilation_width;
+ const size_t input_padding_top = op->padding_top;
+ const size_t input_padding_left = op->padding_left;
+
+ const size_t output_size = output_height * output_width;
+ const size_t tiled_output_size = round_up(output_size, output_tile_size);
+ const size_t kernel_size = kernel_height * kernel_width;
+
+ const struct fxdiv_divisor_size_t output_width_divisor = fxdiv_init_size_t(output_width);
+
+ for (size_t output_tile_start = 0; output_tile_start < tiled_output_size; output_tile_start += output_tile_size) {
+ for (size_t output_tile_offset = 0; output_tile_offset < output_tile_size; output_tile_offset++) {
+ const size_t output_index = min(output_tile_start + output_tile_offset, output_size - 1);
+ const struct fxdiv_result_size_t output_y_x = fxdiv_divide_size_t(output_index, output_width_divisor);
+ const size_t output_x = output_y_x.remainder;
+ const size_t output_y = output_y_x.quotient;
+ for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) {
+ const size_t input_y = output_y * stride_height + kernel_y * dilation_height - input_padding_top;
+ if (input_y < input_height) {
+ for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) {
+ const size_t input_x = output_x * stride_width + kernel_x * dilation_width - input_padding_left;
+ const size_t kernel_index = kernel_y * kernel_width + kernel_x;
+ const size_t index = output_tile_start * kernel_size + kernel_index * output_tile_size + output_tile_offset;
+ if (input_x < input_width) {
+ indirection_buffer[index] = (const void*)
+ ((uintptr_t) input + (input_y * input_width + input_x) * input_pixel_stride);
+ } else {
+ indirection_buffer[index] = zero;
+ }
+ }
+ } else {
+ for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) {
+ const size_t kernel_index = kernel_y * kernel_width + kernel_x;
+ const size_t index = output_tile_start * kernel_size + kernel_index * output_tile_size + output_tile_offset;
+ indirection_buffer[index] = zero;
+ }
+ }
+ }
+ }
+ }
+}
+
+void xnn_indirection_init_dwconv2d(
+ xnn_operator_t op,
+ size_t batch_start,
+ size_t step_height,
+ size_t step_width,
+ uint32_t log2_element_size)
+{
+ const void** indirection_buffer = op->indirection_buffer;
+ const void* input = op->input;
+ const size_t input_pixel_stride = op->input_pixel_stride << log2_element_size;
+ const void* zero = op->zero_buffer;
+ const size_t batch_size = op->batch_size;
+ const size_t input_height = op->input_height;
+ const size_t input_width = op->input_width;
+ const size_t output_height = op->output_height;
+ const size_t output_width = op->output_width;
+ const size_t kernel_height = op->kernel_height;
+ const size_t kernel_width = op->kernel_width;
+ const size_t stride_height = op->stride_height;
+ const size_t stride_width = op->stride_width;
+ const size_t dilation_height = op->dilation_height;
+ const size_t dilation_width = op->dilation_width;
+ const size_t input_padding_top = op->padding_top;
+ const size_t input_padding_left = op->padding_left;
+
+ for (size_t batch_index = batch_start; batch_index < batch_size; batch_index++) {
+ for (size_t output_y = 0; output_y < output_height; output_y++) {
+ for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) {
+ const size_t input_y = output_y * stride_height + kernel_y * dilation_height - input_padding_top;
+ if (input_y < input_height) {
+ for (size_t output_x = 0; output_x < output_width; output_x++) {
+ for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) {
+ const size_t input_x = output_x * stride_width + kernel_x * dilation_width - input_padding_left;
+ const size_t index = (batch_index * output_height + output_y) * step_height + output_x * step_width * kernel_height + kernel_x * kernel_height + kernel_y;
+ if (input_x < input_width) {
+ indirection_buffer[index] =
+ (const void*) ((uintptr_t) input + ((batch_index * input_height + input_y) * input_width + input_x) * input_pixel_stride);
+ } else {
+ indirection_buffer[index] = zero;
+ }
+ }
+ }
+ } else {
+ for (size_t output_x = 0; output_x < output_width; output_x++) {
+ for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) {
+ const size_t index = (batch_index * output_height + output_y) * step_height + output_x * step_width * kernel_height + kernel_x * kernel_height + kernel_y;
+ indirection_buffer[index] = zero;
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+void xnn_indirection_init_deconv2d(
+ xnn_operator_t op,
+ size_t output_tile_size,
+ uint32_t log2_element_size)
+{
+ const void** indirection_buffer = op->indirection_buffer;
+ const void* input = op->input;
+ const size_t input_pixel_stride = op->input_pixel_stride << log2_element_size;
+ const void* zero = op->zero_buffer;
+ const size_t input_height = op->input_height;
+ const size_t input_width = op->input_width;
+ const size_t output_height = op->output_height;
+ const size_t output_width = op->output_width;
+ const size_t kernel_height = op->kernel_height;
+ const size_t kernel_width = op->kernel_width;
+ const size_t stride_height = op->stride_height;
+ const size_t stride_width = op->stride_width;
+ const size_t dilation_height = op->dilation_height;
+ const size_t dilation_width = op->dilation_width;
+ const size_t padding_top = op->padding_top;
+ const size_t padding_left = op->padding_left;
+
+ const size_t output_size = output_height * output_width;
+ const size_t tiled_output_size = round_up(output_size, output_tile_size);
+ const size_t kernel_size = kernel_height * kernel_width;
+
+ const struct fxdiv_divisor_size_t output_width_divisor = fxdiv_init_size_t(output_width);
+ const struct fxdiv_divisor_size_t stride_height_divisor = fxdiv_init_size_t(stride_height);
+ const struct fxdiv_divisor_size_t stride_width_divisor = fxdiv_init_size_t(stride_width);
+
+ for (size_t output_tile_start = 0; output_tile_start < tiled_output_size; output_tile_start += output_tile_size) {
+ for (size_t output_tile_offset = 0; output_tile_offset < output_tile_size; output_tile_offset++) {
+ const size_t output_index = min(output_tile_start + output_tile_offset, output_size - 1);
+ const struct fxdiv_result_size_t output_y_x = fxdiv_divide_size_t(output_index, output_width_divisor);
+ const size_t output_x = output_y_x.remainder;
+ const size_t output_y = output_y_x.quotient;
+ for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) {
+ const size_t y = output_y + padding_top - kernel_y * dilation_height;
+ const size_t input_y = fxdiv_quotient_size_t(y, stride_height_divisor);
+ for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) {
+ const size_t x = output_x + padding_left - kernel_x * dilation_width;
+ const size_t input_x = fxdiv_quotient_size_t(x, stride_width_divisor);
+ const size_t kernel_index = kernel_y * kernel_width + kernel_x;
+ const size_t index = output_tile_start * kernel_size + kernel_index * output_tile_size + output_tile_offset;
+ if (input_y * stride_height == y && input_y < input_height && input_x * stride_width == x && input_x < input_width) {
+ indirection_buffer[index] = (const void*) ((uintptr_t) input + (input_y * input_width + input_x) * input_pixel_stride);
+ } else {
+ indirection_buffer[index] = zero;
+ }
+ }
+ }
+ }
+ }
+}
+
+void xnn_indirection_init_subconv2d(
+ xnn_operator_t op,
+ size_t output_tile_size,
+ uint32_t log2_element_size)
+{
+ const void** indirection_buffer = op->indirection_buffer;
+ struct subconvolution_params* subconvolution_params = op->subconvolution_buffer;
+ const void* input = op->input;
+ const size_t input_pixel_stride = op->input_pixel_stride << log2_element_size;
+ const void* zero = op->zero_buffer;
+ const size_t input_height = op->input_height;
+ const size_t input_width = op->input_width;
+ const size_t output_height = op->output_height;
+ const size_t output_width = op->output_width;
+ const size_t kernel_height = op->kernel_height;
+ const size_t kernel_width = op->kernel_width;
+ const size_t stride_height = op->stride_height;
+ const size_t stride_width = op->stride_width;
+ const size_t padding_top = op->padding_top;
+ const size_t padding_left = op->padding_left;
+
+ const size_t modulo_padding_top = padding_top % stride_height;
+ const size_t modulo_padding_left = padding_left % stride_width;
+ for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
+ const size_t output_y_start = subtract_modulo(offset_y, modulo_padding_top, stride_height);
+ for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
+ const size_t output_x_start = subtract_modulo(offset_x, modulo_padding_left, stride_width);
+ const size_t sliced_output_width = divide_round_up(output_width - output_x_start, stride_width);
+
+ subconvolution_params->indirection_buffer = indirection_buffer;
+ subconvolution_params->indirection_y_stride =
+ subconvolution_params->indirection_x_stride * round_up(sliced_output_width, output_tile_size);
+ ++subconvolution_params;
+
+ for (size_t output_y = output_y_start; output_y < output_height; output_y += stride_height) {
+ for (size_t output_tile_start = 0; output_tile_start < sliced_output_width; output_tile_start += output_tile_size) {
+ for (size_t kernel_y = offset_y; kernel_y < kernel_height; kernel_y += stride_height) {
+ assert(doz(output_y + padding_top, kernel_y) % stride_height == 0);
+ const size_t y = output_y + padding_top - kernel_y;
+ const size_t input_y = y / stride_height;
+
+ for (size_t kernel_x = offset_x; kernel_x < kernel_width; kernel_x += stride_width) {
+ for (size_t output_tile_offset = 0; output_tile_offset < output_tile_size; output_tile_offset++) {
+ const size_t sliced_output_x = min(output_tile_start + output_tile_offset, sliced_output_width - 1);
+ const size_t output_x = output_x_start + sliced_output_x * stride_width;
+
+ assert(doz(output_x + padding_left, kernel_x) % stride_width == 0);
+ const size_t x = output_x + padding_left - kernel_x;
+ const size_t input_x = x / stride_width;
+
+ if (input_y < input_height && input_x < input_width) {
+ *indirection_buffer++ =
+ (const void*) ((uintptr_t) input + (input_y * input_width + input_x) * input_pixel_stride);
+ } else {
+ *indirection_buffer++ = zero;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+void xnn_indirection_init_maxpool2d(
+ xnn_operator_t op,
+ size_t batch_start,
+ size_t step_height,
+ size_t step_width,
+ uint32_t log2_element_size)
+{
+ const void** indirection_buffer = op->indirection_buffer;
+ const void* input = op->input;
+ const size_t input_pixel_stride = op->input_pixel_stride << log2_element_size;
+ const size_t batch_size = op->batch_size;
+ const size_t input_height = op->input_height;
+ const size_t input_width = op->input_width;
+ const size_t output_height = op->output_height;
+ const size_t output_width = op->output_width;
+ const size_t pooling_height = op->kernel_height;
+ const size_t pooling_width = op->kernel_width;
+ const size_t stride_height = op->stride_height;
+ const size_t stride_width = op->stride_width;
+ const size_t dilation_height = op->dilation_height;
+ const size_t dilation_width = op->dilation_width;
+ const size_t input_padding_top = op->padding_top;
+ const size_t input_padding_left = op->padding_left;
+
+ for (size_t image = batch_start; image < batch_size; image++) {
+ for (size_t output_y = 0; output_y < output_height; output_y++) {
+ for (size_t pooling_y = 0; pooling_y < pooling_height; pooling_y++) {
+ const size_t input_y = doz(output_y * stride_height + pooling_y * dilation_height, input_padding_top);
+ const size_t clamped_input_y = min(input_y, input_height - 1);
+ for (size_t output_x = 0; output_x < output_width; output_x++) {
+ for (size_t pooling_x = 0; pooling_x < pooling_width; pooling_x++) {
+ const size_t input_x = doz(output_x * stride_width + pooling_x * dilation_width, input_padding_left);
+ const size_t clamped_input_x = min(input_x, input_width - 1);
+ const size_t index = (image * output_height + output_y) * step_height + output_x * step_width * pooling_height + pooling_x * pooling_height + pooling_y;
+ indirection_buffer[index] = input + ((image * input_height + clamped_input_y) * input_width + clamped_input_x) * input_pixel_stride;
+ }
+ }
+ }
+ }
+ }
+}
+
+void xnn_indirection_init_unpool2d(
+ xnn_operator_t op,
+ size_t batch_start,
+ uint32_t log2_element_size)
+{
+ const void** indirection_buffer = op->indirection_buffer;
+ const void* output = op->output;
+ const size_t output_pixel_stride = op->output_pixel_stride << log2_element_size;
+ const size_t batch_size = op->batch_size;
+ const size_t input_height = op->input_height;
+ const size_t input_width = op->input_width;
+ const size_t output_height = op->output_height;
+ const size_t output_width = op->output_width;
+ const size_t pooling_height = op->kernel_height;
+ const size_t pooling_width = op->kernel_width;
+ const size_t output_padding_top = op->padding_top;
+ const size_t output_padding_left = op->padding_left;
+
+ for (size_t image = batch_start; image < batch_size; image++) {
+ for (size_t input_y = 0; input_y < input_height; input_y++) {
+ for (size_t pooling_y = 0; pooling_y < pooling_height; pooling_y++) {
+ const size_t output_y = min(doz(input_y * pooling_height + pooling_y, output_padding_top), output_height - 1);
+ for (size_t input_x = 0; input_x < input_width; input_x++) {
+ for (size_t pooling_x = 0; pooling_x < pooling_width; pooling_x++) {
+ const size_t output_x = min(doz(input_x * pooling_width + pooling_x, output_padding_left), output_width - 1);
+ indirection_buffer[(((image * input_height + input_y) * input_width + input_x) * pooling_width + pooling_x) * pooling_height + pooling_y] =
+ output + ((image * output_height + output_y) * output_width + output_x) * output_pixel_stride;
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/src/init.c b/src/init.c
new file mode 100644
index 0000000..b0ad23d
--- /dev/null
+++ b/src/init.c
@@ -0,0 +1,969 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include <pthread.h>
+
+#include <cpuinfo.h>
+
+#include <xnnpack.h>
+#include <xnnpack/argmaxpool.h>
+#include <xnnpack/avgpool.h>
+#include <xnnpack/clamp.h>
+#include <xnnpack/conv.h>
+#include <xnnpack/dwconv.h>
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/gemm.h>
+#include <xnnpack/hswish.h>
+#include <xnnpack/igemm.h>
+#include <xnnpack/log.h>
+#include <xnnpack/lut.h>
+#include <xnnpack/maxpool.h>
+#include <xnnpack/pad.h>
+#include <xnnpack/params.h>
+#include <xnnpack/pavgpool.h>
+#include <xnnpack/prelu.h>
+#include <xnnpack/rmax.h>
+#include <xnnpack/spmm.h>
+#include <xnnpack/unpool.h>
+#include <xnnpack/vadd.h>
+#include <xnnpack/vmulcaddc.h>
+#include <xnnpack/zip.h>
+
+#ifndef XNN_ENABLE_ASSEMBLY
+ #define XNN_ENABLE_ASSEMBLY 1
+#endif
+
+static pthread_once_t init_guard = PTHREAD_ONCE_INIT;
+
+struct xnn_parameters xnn_params = {
+ .initialized = false
+};
+
+#if CPUINFO_ARCH_PNACL || CPUINFO_ARCH_ASMJS || CPUINFO_ARCH_WASM || CPUINFO_ARCH_WASMSIMD
+ extern uint32_t xnn_stub_wasm_f32_sub(uint32_t a, uint32_t b);
+#endif
+#if CPUINFO_ARCH_PNACL || CPUINFO_ARCH_WASM || CPUINFO_ARCH_WASMSIMD
+ extern uint32_t xnn_stub_wasm_f32_min(uint32_t a, uint32_t b);
+#endif
+
+static void init(void) {
+#if CPUINFO_ARCH_ARM
+ if (!cpuinfo_has_arm_neon()) {
+ xnn_log_error("XNNPACK initialization failed: NEON is not supported");
+ return;
+ }
+
+ /**************************** Q8 micro-kernels ****************************/
+ xnn_params.q8.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_4x8__neon,
+ .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_4x8__neon,
+ .mr = 4,
+ .nr = 8,
+ };
+
+#if XNN_ENABLE_ASSEMBLY
+ xnn_params.q8.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_q8_dwconv_ukernel_up8x9__aarch32_neon,
+ .cr = 8,
+ .mr = 9,
+ };
+#else
+ xnn_params.q8.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_q8_dwconv_ukernel_up8x9__neon,
+ .cr = 8,
+ .mr = 9,
+ };
+#endif
+ xnn_params.q8.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_q8_avgpool_ukernel_up9__neon,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_q8_avgpool_ukernel_mp9p8q__neon,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.q8.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_q8_gavgpool_ukernel_up7__neon,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_q8_gavgpool_ukernel_mp7p7q__neon,
+ .mr = 7,
+ };
+ xnn_params.q8.vadd = (xnn_vadd_ukernel_function) xnn_q8_vadd_ukernel__neon;
+
+ /**************************** U8 micro-kernels ****************************/
+ xnn_params.u8.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__neon,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.u8.clamp = (xnn_univector_ukernel_function) xnn_u8_clamp_ukernel__neon;
+ xnn_params.u8.rmax = xnn_u8_rmax_ukernel__neon;
+ xnn_params.u8.lut32norm = xnn_u8_lut32norm_ukernel__scalar;
+
+ /**************************** X8 micro-kernels ****************************/
+ xnn_params.x8.lut = xnn_x8_lut_ukernel__scalar;
+ xnn_params.x8.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x8_zip_x2_ukernel__neon,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x8_zip_x3_ukernel__neon,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x8_zip_x4_ukernel__neon,
+ .xm = (xnn_zipv_ukernel_function) xnn_x8_zip_xm_ukernel__neon,
+ };
+
+ /**************************** F32 micro-kernels ****************************/
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__neon_ld128,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neon_ld128,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_ld64,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_ld64,
+ .mr = 4,
+ .nr = 8,
+ };
+ xnn_params.f32.gemm2 = (struct gemm_parameters) {
+ .gemm = NULL,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__neon_ld64,
+ .mr = 4,
+ .nr = 2,
+ };
+ xnn_params.f32.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x4__psimd,
+ .cr = 4,
+ .mr = 4,
+ };
+ xnn_params.f32.dwconv[1] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x9__neon,
+ .cr = 4,
+ .mr = 9,
+ };
+ xnn_params.f32.dwconv[2] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x25__psimd,
+ .cr = 4,
+ .mr = 25,
+ };
+ xnn_params.f32.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_f32_avgpool_ukernel_up9__neon,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_f32_avgpool_ukernel_mp9p8q__neon,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.pavgpool = (struct pavgpool_parameters) {
+ .up = (xnn_pavgpool_up_ukernel_function) xnn_f32_pavgpool_ukernel_up9__neon,
+ .mp = (xnn_pavgpool_mp_ukernel_function) xnn_f32_pavgpool_ukernel_mp9p8q__neon,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_f32_gavgpool_ukernel_up7__neon,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_f32_gavgpool_ukernel_mp7p7q__neon,
+ .mr = 7,
+ };
+ xnn_params.f32.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__psimd,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__psimd,
+ .mr = 4,
+ };
+ xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__psimd,
+ .mr = 9,
+ };
+ xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__psimd,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__neon;
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__neon;
+ xnn_params.f32.prelu = (struct prelu_parameters) {
+ .ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel_x4__psimd,
+ .mr = 4,
+ };
+ xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__psimd;
+ xnn_params.f32.vmulcaddc = (struct vmulcaddc_parameters) {
+ .ukernel = (xnn_vmulcaddc_ukernel_function) xnn_f32_vmulcaddc_ukernel_c4__neon_x2,
+ .cr = 4,
+ .mr = 2,
+ };
+
+ /**************************** X32 micro-kernels ****************************/
+ xnn_params.x32.pad = (struct pad_parameters) {
+ .ukernel = xnn_x32_pad_x2__neon,
+ .mr = 2,
+ };
+ xnn_params.x32.unpool = (xnn_unpool_ukernel_function) xnn_x32_unpool_ukernel__psimd;
+ xnn_params.x32.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x32_zip_x2_ukernel__neon,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x32_zip_x3_ukernel__neon,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x32_zip_x4_ukernel__neon,
+ .xm = (xnn_zipv_ukernel_function) xnn_x32_zip_xm_ukernel__neon,
+ };
+
+#elif CPUINFO_ARCH_ARM64
+
+ /**************************** Q8 micro-kernels ****************************/
+ xnn_params.q8.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_8x8__neon,
+ .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_8x8__neon,
+ .mr = 8,
+ .nr = 8,
+ };
+ xnn_params.q8.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_q8_dwconv_ukernel_up8x9__neon,
+ .cr = 8,
+ .mr = 9,
+ };
+ xnn_params.q8.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_q8_avgpool_ukernel_up9__neon,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_q8_avgpool_ukernel_mp9p8q__neon,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.q8.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_q8_gavgpool_ukernel_up7__neon,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_q8_gavgpool_ukernel_mp7p7q__neon,
+ .mr = 7,
+ };
+ xnn_params.q8.vadd = (xnn_vadd_ukernel_function) xnn_q8_vadd_ukernel__neon;
+
+ /**************************** U8 micro-kernels ****************************/
+ xnn_params.u8.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__neon,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.u8.clamp = (xnn_univector_ukernel_function) xnn_u8_clamp_ukernel__neon;
+ xnn_params.u8.lut32norm = xnn_u8_lut32norm_ukernel__scalar;
+ xnn_params.u8.rmax = xnn_u8_rmax_ukernel__neon;
+
+ /**************************** X8 micro-kernels ****************************/
+ xnn_params.x8.lut = xnn_x8_lut_ukernel__scalar;
+ xnn_params.x8.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x8_zip_x2_ukernel__neon,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x8_zip_x3_ukernel__neon,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x8_zip_x4_ukernel__neon,
+ .xm = (xnn_zipv_ukernel_function) xnn_x8_zip_xm_ukernel__neon,
+ };
+
+ /**************************** F32 micro-kernels ****************************/
+#if XNN_ENABLE_ASSEMBLY
+ switch (cpuinfo_get_core(0)->uarch) {
+ case cpuinfo_uarch_kryo:
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a57,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .mr = 4,
+ .nr = 8,
+ };
+ break;
+ case cpuinfo_uarch_cortex_a57:
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a57,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a57,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57,
+ .mr = 6,
+ .nr = 8,
+ };
+ break;
+ case cpuinfo_uarch_cortex_a72:
+ case cpuinfo_uarch_cortex_a76:
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .mr = 4,
+ .nr = 8,
+ };
+ break;
+ case cpuinfo_uarch_cortex_a75:
+ case cpuinfo_uarch_exynos_m1:
+ case cpuinfo_uarch_exynos_m2:
+ case cpuinfo_uarch_exynos_m3:
+ case cpuinfo_uarch_exynos_m4:
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a75,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .mr = 6,
+ .nr = 8,
+ };
+ break;
+ case cpuinfo_uarch_cortex_a53:
+ case cpuinfo_uarch_cortex_a55:
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53,
+ .mr = 4,
+ .nr = 12,
+ };
+ break;
+ case cpuinfo_uarch_cortex_a73:
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a73,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a73,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .mr = 6,
+ .nr = 8,
+ };
+ break;
+ default:
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__neonfma_ld64,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neonfma_ld64,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .mr = 4,
+ .nr = 8,
+ };
+ break;
+ }
+#else // XNN_ENABLE_ASSEMBLY
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__neonfma_ld64,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neonfma_ld64,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neonfma_ld64,
+ // TODO(b/140592595): xnn_f32_igemm_ukernel_1x8__neonfma_ld64
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .mr = 4,
+ .nr = 8,
+ };
+#endif
+
+ xnn_params.f32.gemm2 = (struct gemm_parameters) {
+ .gemm = NULL,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__neonfma_ld64,
+ .mr = 4,
+ .nr = 2,
+ };
+ xnn_params.f32.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x4__psimd,
+ .cr = 4,
+ .mr = 4,
+ };
+ switch (cpuinfo_get_core(0)->uarch) {
+ case cpuinfo_uarch_kryo:
+ xnn_params.f32.dwconv[1] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x9__neonfma,
+ .cr = 4,
+ .mr = 9,
+ };
+ break;
+#if XNN_ENABLE_ASSEMBLY
+ case cpuinfo_uarch_cortex_a53:
+ case cpuinfo_uarch_cortex_a55:
+ xnn_params.f32.dwconv[1] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x9__aarch64_neonfma_cortex_a55,
+ .cr = 4,
+ .mr = 9,
+ };
+ break;
+#endif
+ default:
+ xnn_params.f32.dwconv[1] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up8x9__neonfma,
+ .cr = 8,
+ .mr = 9,
+ };
+ break;
+ }
+ xnn_params.f32.dwconv[2] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x25__psimd,
+ .cr = 4,
+ .mr = 25,
+ };
+ xnn_params.f32.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_f32_avgpool_ukernel_up9__neon,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_f32_avgpool_ukernel_mp9p8q__neon,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.pavgpool = (struct pavgpool_parameters) {
+ .up = (xnn_pavgpool_up_ukernel_function) xnn_f32_pavgpool_ukernel_up9__neon,
+ .mp = (xnn_pavgpool_mp_ukernel_function) xnn_f32_pavgpool_ukernel_mp9p8q__neon,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_f32_gavgpool_ukernel_up7__neon,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_f32_gavgpool_ukernel_mp7p7q__neon,
+ .mr = 7,
+ };
+ xnn_params.f32.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__psimd,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__psimd,
+ .mr = 4,
+ };
+ xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__psimd,
+ .mr = 9,
+ };
+ xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__psimd,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__neon;
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__neonfma;
+ xnn_params.f32.prelu = (struct prelu_parameters) {
+ .ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel_x4__psimd,
+ .mr = 4,
+ };
+ xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__psimd;
+ xnn_params.f32.vmulcaddc = (struct vmulcaddc_parameters) {
+ .ukernel = (xnn_vmulcaddc_ukernel_function) xnn_f32_vmulcaddc_ukernel_c4__neonfma_x2,
+ .cr = 4,
+ .mr = 2,
+ };
+ xnn_params.f32.spmm = (struct spmm_parameters) {
+ .ukernel = (xnn_spmm_ukernel_function) xnn_f32_spmm_ukernel_16x1__neonfma,
+ .mr = 16,
+ .nr = 1,
+ };
+ xnn_params.f32.spmm2 = (struct spmm_parameters) {
+ .ukernel = (xnn_spmm_ukernel_function) xnn_f32_spmm_ukernel_16x2__neonfma,
+ .mr = 16,
+ .nr = 2,
+ };
+ xnn_params.f32.spmm4 = (struct spmm_parameters) {
+ .ukernel = (xnn_spmm_ukernel_function) xnn_f32_spmm_ukernel_16x4__neonfma,
+ .mr = 16,
+ .nr = 4,
+ };
+ xnn_params.f32.hwc2spchw_dconv3x3c3s2 = (struct hwc2spchw_dconv_parameters) {
+ .ukernel_with_symm_padding =
+ (xnn_conv_hwc2spchw_ukernel_function) xnn_f32_conv_hwc2spchw_ukernel_3x3s2p1c3x4__neonfma_2x2,
+ .output_channel_tile = 4,
+ .output_height_tile = 2,
+ .output_width_tile = 2,
+ };
+ xnn_params.f32.spchw_dwconv3x3 = (struct spchw_dwconv_parameters) {
+ .ukernel = (xnn_dwconv_spchw_ukernel_function) xnn_f32_dwconv_spchw_ukernel_3x3p1__neonfma,
+ .input_width_tile = 4,
+ .output_width_tile = 4,
+ .output_height_tile = 3,
+ };
+ xnn_params.f32.spchw_dwconv3x3s2 = (struct spchw_dwconv_parameters) {
+ .ukernel = (xnn_dwconv_spchw_ukernel_function) xnn_f32_dwconv_spchw_ukernel_3x3s2p1__neonfma,
+ .input_width_tile = 4,
+ .output_width_tile = 4,
+ .output_height_tile = 1,
+ };
+ xnn_params.f32.spchw_gavgpool = (struct spchw_gavgpool_parameters) {
+ .ukernel = (xnn_gavgpool_spchw_ukernel_function) xnn_f32_gavgpool_spchw_ukernel__neon_x4,
+ .channel_tile = 4,
+ };
+
+ /**************************** X32 micro-kernels ****************************/
+ xnn_params.x32.pad = (struct pad_parameters) {
+ .ukernel = xnn_x32_pad_x2__neon,
+ .mr = 2,
+ };
+ xnn_params.x32.unpool = (xnn_unpool_ukernel_function) xnn_x32_unpool_ukernel__psimd;
+ xnn_params.x32.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x32_zip_x2_ukernel__neon,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x32_zip_x3_ukernel__neon,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x32_zip_x4_ukernel__neon,
+ .xm = (xnn_zipv_ukernel_function) xnn_x32_zip_xm_ukernel__neon,
+ };
+
+#elif CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ if (!cpuinfo_has_x86_sse2()) {
+ xnn_log_error("XNNPACK initialization failed: SSE2 is not supported");
+ return;
+ }
+
+ /**************************** Q8 micro-kernels ****************************/
+ xnn_params.q8.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_4x4c2__sse2,
+ .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_4x4c2__sse2,
+ .mr = 4,
+ .nr = 4,
+ .log2_kr = 1,
+ };
+ xnn_params.q8.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_q8_dwconv_ukernel_up8x9__sse2,
+ .cr = 8,
+ .mr = 9,
+ };
+ xnn_params.q8.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_q8_avgpool_ukernel_up9__sse2,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_q8_avgpool_ukernel_mp9p8q__sse2,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.q8.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_q8_gavgpool_ukernel_up7__sse2,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_q8_gavgpool_ukernel_mp7p7q__sse2,
+ .mr = 7,
+ };
+ xnn_params.q8.vadd = (xnn_vadd_ukernel_function) xnn_q8_vadd_ukernel__sse2;
+
+ /**************************** U8 micro-kernels ****************************/
+ xnn_params.u8.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__sse2,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.u8.clamp = (xnn_univector_ukernel_function) xnn_u8_clamp_ukernel__sse2;
+ xnn_params.u8.lut32norm = xnn_u8_lut32norm_ukernel__scalar;
+ xnn_params.u8.rmax = xnn_u8_rmax_ukernel__sse2;
+
+ /**************************** X8 micro-kernels ****************************/
+ xnn_params.x8.lut = xnn_x8_lut_ukernel__scalar;
+ xnn_params.x8.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x8_zip_x2_ukernel__sse2,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x8_zip_x3_ukernel__sse2,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x8_zip_x4_ukernel__sse2,
+ .xm = (xnn_zipv_ukernel_function) xnn_x8_zip_xm_ukernel__sse2,
+ };
+
+ /**************************** F32 micro-kernels ****************************/
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__sse_load1,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__sse_load1,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__sse_load1,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__sse_load1,
+ .mr = 4,
+ .nr = 8,
+ };
+ xnn_params.f32.gemm2 = (struct gemm_parameters) {
+ .gemm = NULL,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2c4__sse,
+ .mr = 4,
+ .nr = 2,
+ .log2_kr = 2,
+ };
+ xnn_params.f32.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x4__sse,
+ .cr = 4,
+ .mr = 4,
+ };
+ xnn_params.f32.dwconv[1] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x9__sse,
+ .cr = 4,
+ .mr = 9,
+ };
+ xnn_params.f32.dwconv[2] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x25__sse,
+ .cr = 4,
+ .mr = 25,
+ };
+ xnn_params.f32.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_f32_avgpool_ukernel_up9__sse,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_f32_avgpool_ukernel_mp9p8q__sse,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.pavgpool = (struct pavgpool_parameters) {
+ .up = (xnn_pavgpool_up_ukernel_function) xnn_f32_pavgpool_ukernel_up9__sse,
+ .mp = (xnn_pavgpool_mp_ukernel_function) xnn_f32_pavgpool_ukernel_mp9p8q__sse,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_f32_gavgpool_ukernel_up7__sse,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_f32_gavgpool_ukernel_mp7p7q__sse,
+ .mr = 7,
+ };
+ xnn_params.f32.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__sse,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__sse2,
+ .mr = 4,
+ };
+ xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__sse2,
+ .mr = 9,
+ };
+ xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__sse2,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__sse;
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__sse;
+ xnn_params.f32.prelu = (struct prelu_parameters) {
+ .ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel_x4__sse,
+ .mr = 4,
+ };
+ xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__sse;
+ xnn_params.f32.vmulcaddc = (struct vmulcaddc_parameters) {
+ .ukernel = (xnn_vmulcaddc_ukernel_function) xnn_f32_vmulcaddc_ukernel_c4__sse_x2,
+ .cr = 4,
+ .mr = 2,
+ };
+ xnn_params.f32.spmm = (struct spmm_parameters) {
+ .ukernel = (xnn_spmm_ukernel_function) xnn_f32_spmm_ukernel_4x1__sse,
+ .mr = 4,
+ .nr = 1,
+ };
+ xnn_params.f32.spchw_dwconv3x3 = (struct spchw_dwconv_parameters) {
+ .ukernel = (xnn_dwconv_spchw_ukernel_function) xnn_f32_dwconv_spchw_ukernel_3x3p1__sse,
+ .input_width_tile = 4,
+ .output_width_tile = 4,
+ .output_height_tile = 1,
+ };
+ xnn_params.f32.spchw_dwconv3x3s2 = (struct spchw_dwconv_parameters) {
+ .ukernel = (xnn_dwconv_spchw_ukernel_function) xnn_f32_dwconv_spchw_ukernel_3x3s2p1__sse,
+ .input_width_tile = 4,
+ .output_width_tile = 4,
+ .output_height_tile = 1,
+ };
+ xnn_params.f32.spchw_gavgpool = (struct spchw_gavgpool_parameters) {
+ .ukernel = (xnn_gavgpool_spchw_ukernel_function) xnn_f32_gavgpool_spchw_ukernel__sse_x4,
+ .channel_tile = 4,
+ };
+
+ /**************************** X32 micro-kernels ****************************/
+ xnn_params.x32.pad = (struct pad_parameters) {
+ .ukernel = xnn_x32_pad_x2__sse2,
+ .mr = 2,
+ };
+ xnn_params.x32.unpool = (xnn_unpool_ukernel_function) xnn_x32_unpool_ukernel__psimd;
+ xnn_params.x32.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x32_zip_x2_ukernel__sse2,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x32_zip_x3_ukernel__sse2,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x32_zip_x4_ukernel__sse2,
+ .xm = (xnn_zipv_ukernel_function) xnn_x32_zip_xm_ukernel__sse2,
+ };
+
+#elif CPUINFO_ARCH_PNACL || CPUINFO_ARCH_WASMSIMD
+ /**************************** Q8 micro-kernels ****************************/
+ xnn_params.q8.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_2x2__scalar,
+ .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_2x2__scalar,
+ .mr = 2,
+ .nr = 2,
+ };
+ xnn_params.q8.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_q8_dwconv_ukernel_up1x9__scalar,
+ .cr = 1,
+ .mr = 9,
+ };
+ xnn_params.q8.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_q8_avgpool_ukernel_up9__scalar,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_q8_avgpool_ukernel_mp9p8q__scalar,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.q8.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_q8_gavgpool_ukernel_up7__scalar,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_q8_gavgpool_ukernel_mp7p7q__scalar,
+ .mr = 7,
+ };
+ xnn_params.q8.vadd = (xnn_vadd_ukernel_function) xnn_q8_vadd_ukernel__scalar;
+
+ /**************************** U8 micro-kernels ****************************/
+ xnn_params.u8.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__scalar,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.u8.clamp = (xnn_univector_ukernel_function) xnn_u8_clamp_ukernel__scalar;
+ xnn_params.u8.lut32norm = xnn_u8_lut32norm_ukernel__scalar;
+ xnn_params.u8.rmax = xnn_u8_rmax_ukernel__scalar;
+
+ /**************************** X8 micro-kernels ****************************/
+ xnn_params.x8.lut = xnn_x8_lut_ukernel__scalar;
+ xnn_params.x8.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x8_zip_x2_ukernel__scalar,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x8_zip_x3_ukernel__scalar,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x8_zip_x4_ukernel__scalar,
+ .xm = (xnn_zipv_ukernel_function) xnn_x8_zip_xm_ukernel__scalar,
+ };
+
+ /**************************** F32 micro-kernels ****************************/
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__psimd_splat,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__psimd_splat,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__psimd_loadsplat,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__psimd_loadsplat,
+ .mr = 4,
+ .nr = 8,
+ };
+ xnn_params.f32.gemm2 = (struct gemm_parameters) {
+ .gemm = NULL,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2c4__psimd,
+ .mr = 4,
+ .nr = 2,
+ .log2_kr = 2,
+ };
+ xnn_params.f32.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x4__psimd,
+ .cr = 4,
+ .mr = 4,
+ };
+ xnn_params.f32.dwconv[1] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x9__psimd,
+ .cr = 4,
+ .mr = 9,
+ };
+ xnn_params.f32.dwconv[2] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up4x25__psimd,
+ .cr = 4,
+ .mr = 25,
+ };
+ xnn_params.f32.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_f32_avgpool_ukernel_up9__psimd,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_f32_avgpool_ukernel_mp9p8q__psimd,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.pavgpool = (struct pavgpool_parameters) {
+ .up = (xnn_pavgpool_up_ukernel_function) xnn_f32_pavgpool_ukernel_up9__psimd,
+ .mp = (xnn_pavgpool_mp_ukernel_function) xnn_f32_pavgpool_ukernel_mp9p8q__psimd,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_f32_gavgpool_ukernel_up7__psimd,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_f32_gavgpool_ukernel_mp7p7q__psimd,
+ .mr = 7,
+ };
+ xnn_params.f32.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__psimd,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__psimd,
+ .mr = 4,
+ };
+ xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__psimd,
+ .mr = 9,
+ };
+ xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__psimd,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__psimd;
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__psimd;
+ xnn_params.f32.prelu = (struct prelu_parameters) {
+ .ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel_x4__psimd,
+ .mr = 4,
+ };
+ xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__psimd;
+ xnn_params.f32.vmulcaddc = (struct vmulcaddc_parameters) {
+ .ukernel = (xnn_vmulcaddc_ukernel_function) xnn_f32_vmulcaddc_ukernel_c4__psimd_x2,
+ .cr = 4,
+ .mr = 2,
+ };
+
+ /**************************** X32 micro-kernels ****************************/
+ xnn_params.x32.pad = (struct pad_parameters) {
+ .ukernel = xnn_x32_pad_x2__psimd,
+ .mr = 2,
+ };
+ xnn_params.x32.unpool = (xnn_unpool_ukernel_function) xnn_x32_unpool_ukernel__psimd;
+ xnn_params.x32.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x32_zip_x2_ukernel__psimd,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x32_zip_x3_ukernel__psimd,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x32_zip_x4_ukernel__psimd,
+ .xm = (xnn_zipv_ukernel_function) xnn_x32_zip_xm_ukernel__psimd,
+ };
+
+#elif CPUINFO_ARCH_WASM || CPUINFO_ARCH_ASMJS
+ // Unlike most other architectures, on x86/x86-64 when floating-point instructions
+ // have no NaN arguments, but produce NaN output, the output NaN has sign bit set.
+ // We use it to distinguish x86/x86-64 from other architectures, by doing subtraction
+ // of two infinities (must produce NaN per IEEE 754 standard).
+ static volatile uint32_t minus_inf = UINT32_C(0xFF800000);
+ const bool is_wasm_x86 = (int32_t) xnn_stub_wasm_f32_sub(minus_inf, minus_inf) < 0;
+
+ /**************************** Q8 micro-kernels ****************************/
+ xnn_params.q8.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_2x2__scalar,
+ .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_2x2__scalar,
+ .mr = 2,
+ .nr = 2,
+ };
+ xnn_params.q8.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_q8_dwconv_ukernel_up1x9__scalar,
+ .cr = 1,
+ .mr = 9,
+ };
+ xnn_params.q8.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_q8_avgpool_ukernel_up9__scalar,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_q8_avgpool_ukernel_mp9p8q__scalar,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.q8.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_q8_gavgpool_ukernel_up7__scalar,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_q8_gavgpool_ukernel_mp7p7q__scalar,
+ .mr = 7,
+ };
+ xnn_params.q8.vadd = (xnn_vadd_ukernel_function) xnn_q8_vadd_ukernel__scalar;
+
+ /**************************** U8 micro-kernels ****************************/
+ xnn_params.u8.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_ukernel_9p8q__scalar,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.u8.clamp = (xnn_univector_ukernel_function) xnn_u8_clamp_ukernel__scalar;
+ xnn_params.u8.lut32norm = xnn_u8_lut32norm_ukernel__scalar;
+ xnn_params.u8.rmax = xnn_u8_rmax_ukernel__scalar;
+
+ /**************************** X8 micro-kernels ****************************/
+ xnn_params.x8.lut = xnn_x8_lut_ukernel__scalar;
+ xnn_params.x8.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x8_zip_x2_ukernel__scalar,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x8_zip_x3_ukernel__scalar,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x8_zip_x4_ukernel__scalar,
+ .xm = (xnn_zipv_ukernel_function) xnn_x8_zip_xm_ukernel__scalar,
+ };
+
+ /**************************** F32 micro-kernels ****************************/
+ if (is_wasm_x86) {
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_2x4__scalar,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_2x4__scalar,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x4__scalar,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x4__scalar,
+ .mr = 2,
+ .nr = 4,
+ };
+ } else {
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x4__scalar,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x4__scalar,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x4__scalar,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x4__scalar,
+ .mr = 4,
+ .nr = 4,
+ };
+ }
+ xnn_params.f32.gemm2 = (struct gemm_parameters) {
+ .gemm = NULL,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__scalar,
+ .mr = 4,
+ .nr = 2,
+ };
+ xnn_params.f32.dwconv[0] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up1x4__scalar,
+ .cr = 1,
+ .mr = 4,
+ };
+ xnn_params.f32.dwconv[1] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up1x9__scalar,
+ .cr = 1,
+ .mr = 9,
+ };
+ xnn_params.f32.dwconv[2] = (struct dwconv_parameters) {
+ .up = (xnn_dwconv_up_ukernel_function) xnn_f32_dwconv_ukernel_up1x25__scalar,
+ .cr = 1,
+ .mr = 25,
+ };
+ xnn_params.f32.avgpool = (struct avgpool_parameters) {
+ .up = (xnn_avgpool_up_ukernel_function) xnn_f32_avgpool_ukernel_up9__scalar,
+ .mp = (xnn_avgpool_mp_ukernel_function) xnn_f32_avgpool_ukernel_mp9p8q__scalar,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.pavgpool = (struct pavgpool_parameters) {
+ .up = (xnn_pavgpool_up_ukernel_function) xnn_f32_pavgpool_ukernel_up9__scalar,
+ .mp = (xnn_pavgpool_mp_ukernel_function) xnn_f32_pavgpool_ukernel_mp9p8q__scalar,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.gavgpool = (struct gavgpool_parameters) {
+ .up = (xnn_gavgpool_up_ukernel_function) xnn_f32_gavgpool_ukernel_up7__scalar,
+ .mp = (xnn_gavgpool_mp_ukernel_function) xnn_f32_gavgpool_ukernel_mp7p7q__scalar,
+ .mr = 7,
+ };
+ xnn_params.f32.maxpool = (struct maxpool_parameters) {
+ .ukernel = (xnn_maxpool_ukernel_function) xnn_f32_maxpool_ukernel_9p8q__scalar,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.argmaxpool[0] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up4__scalar,
+ .mr = 4,
+ };
+ xnn_params.f32.argmaxpool[1] = (struct argmaxpool_parameters) {
+ .up = (xnn_argmaxpool_up_ukernel_function) xnn_f32_argmaxpool_ukernel_up9__scalar,
+ .mr = 9,
+ };
+ xnn_params.f32.argmaxpool[2] = (struct argmaxpool_parameters) {
+ .mp = (xnn_argmaxpool_mp_ukernel_function) xnn_f32_argmaxpool_ukernel_mp9p8q__scalar,
+ .mr = 9,
+ .qr = 8,
+ };
+ xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__scalar;
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__scalar;
+ xnn_params.f32.prelu = (struct prelu_parameters) {
+ .ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel_x4__scalar,
+ .mr = 4,
+ };
+ xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__scalar;
+ xnn_params.f32.vmulcaddc = (struct vmulcaddc_parameters) {
+ .ukernel = (xnn_vmulcaddc_ukernel_function) xnn_f32_vmulcaddc_ukernel_c1__scalar_x2,
+ .cr = 1,
+ .mr = 2,
+ };
+ xnn_params.f32.spmm = (struct spmm_parameters) {
+ .ukernel = (xnn_spmm_ukernel_function) xnn_f32_spmm_ukernel_4x1__scalar,
+ .mr = 4,
+ .nr = 1,
+ };
+
+ /**************************** X32 micro-kernels ****************************/
+ xnn_params.x32.pad = (struct pad_parameters) {
+ .ukernel = xnn_x32_pad_x2__scalar,
+ .mr = 2,
+ };
+ xnn_params.x32.unpool = (xnn_unpool_ukernel_function) xnn_x32_unpool_ukernel__scalar;
+ xnn_params.x32.zip = (struct zip_parameters) {
+ .x2 = (xnn_zipc_ukernel_function) xnn_x32_zip_x2_ukernel__scalar,
+ .x3 = (xnn_zipc_ukernel_function) xnn_x32_zip_x3_ukernel__scalar,
+ .x4 = (xnn_zipc_ukernel_function) xnn_x32_zip_x4_ukernel__scalar,
+ .xm = (xnn_zipv_ukernel_function) xnn_x32_zip_xm_ukernel__scalar,
+ };
+
+#else
+ #error "Unsupported architecture"
+#endif
+ xnn_params.initialized = true;
+}
+
+enum xnn_status xnn_initialize(void) {
+ if (!cpuinfo_initialize()) {
+ return xnn_status_out_of_memory;
+ }
+ pthread_once(&init_guard, &init);
+ if (xnn_params.initialized) {
+ return xnn_status_success;
+ } else {
+ return xnn_status_unsupported_hardware;
+ }
+}
+
+enum xnn_status xnn_deinitialize(void) {
+ cpuinfo_deinitialize();
+ return xnn_status_success;
+}
diff --git a/src/leaky-relu.c b/src/leaky-relu.c
new file mode 100644
index 0000000..74ca098
--- /dev/null
+++ b/src/leaky-relu.c
@@ -0,0 +1,217 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+
+
+enum xnn_status xnn_create_leaky_relu_nc_q8(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ float negative_slope,
+ uint8_t input_zero_point,
+ float input_scale,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* leaky_relu_op_out)
+{
+ xnn_operator_t leaky_relu_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Leaky ReLU operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Leaky ReLU operator with %zu channels: number of channels must be non-zero", channels);
+ goto error;
+ }
+
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create Leaky ReLU operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_stride, channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create Leaky ReLU operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_stride, channels);
+ goto error;
+ }
+
+ if (negative_slope <= 0.0f || !isnormal(negative_slope)) {
+ xnn_log_error(
+ "failed to create Leaky ReLU operator with %.7g negative slope: slope must be finite, normalized, and positive",
+ negative_slope);
+ goto error;
+ }
+
+ if (negative_slope > 1.0f) {
+ xnn_log_error(
+ "failed to create Leaky ReLU operator with %.7g negative slope: slope must not exceed 1.0", negative_slope);
+ goto error;
+ }
+
+ if (input_scale <= 0.0f || !isnormal(input_scale)) {
+ xnn_log_error(
+ "failed to create Leaky ReLU operator with %.7g input scale: scale must be finite, normalized, and positive",
+ input_scale);
+ goto error;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create Leaky ReLU operator with %.7g output scale: scale must be finite, normalized, and positive",
+ output_scale);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Leaky ReLU operator with [%" PRIu8 ", %" PRIu8 "] output range: "
+ "range min must be below range max",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ const float input_output_scale = input_scale / output_scale;
+ if (input_output_scale < 0x1.0p-8f || input_output_scale >= 0x1.0p+8f) {
+ xnn_log_error(
+ "failed to create Leaky ReLU operator with %.7g input-to-output scale ratio: "
+ "scale ratio must be in [2**-8, 2**8) range",
+ input_output_scale);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ leaky_relu_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (leaky_relu_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Leaky ReLU operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ leaky_relu_op->lookup_table = xnn_allocate_memory(256 * sizeof(uint8_t));
+ if (leaky_relu_op->lookup_table == NULL) {
+ xnn_log_error("failed to allocate 256 bytes for Leaky ReLU lookup table");
+ goto error;
+ }
+
+ uint8_t* lookup_table = leaky_relu_op->lookup_table;
+ const float scaled_min_less_zero_point = (float) ((int32_t) output_min - (int32_t) output_zero_point);
+ const float scaled_max_less_zero_point = (float) ((int32_t) output_max - (int32_t) output_zero_point);
+ for (int32_t i = 0; i < 256; i++) {
+ const float x = input_output_scale * (float) (i - (int32_t) (uint32_t) input_zero_point);
+ float y = x < 0.0f ? x * negative_slope : x;
+ if (y < scaled_min_less_zero_point) {
+ y = scaled_min_less_zero_point;
+ }
+ if (y > scaled_max_less_zero_point) {
+ y = scaled_max_less_zero_point;
+ }
+ lookup_table[(uint32_t) i] = (uint8_t) (lrintf(y) + (long) output_zero_point);
+ }
+
+ leaky_relu_op->channels = channels;
+ leaky_relu_op->input_pixel_stride = input_stride;
+ leaky_relu_op->output_pixel_stride = output_stride;
+
+ leaky_relu_op->type = xnn_operator_type_leaky_relu_q8;
+ leaky_relu_op->ukernel.type = xnn_ukernel_type_lut;
+
+ leaky_relu_op->state = xnn_run_state_invalid;
+
+ *leaky_relu_op_out = leaky_relu_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(leaky_relu_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_leaky_relu_nc_q8(
+ xnn_operator_t leaky_relu_op,
+ size_t batch_size,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (leaky_relu_op->type != xnn_operator_type_leaky_relu_q8) {
+ xnn_log_error("failed to setup Leaky ReLU (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ leaky_relu_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Leaky ReLU operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ leaky_relu_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ const size_t channels = leaky_relu_op->channels;
+ const size_t input_stride = leaky_relu_op->input_pixel_stride;
+ const size_t output_stride = leaky_relu_op->output_pixel_stride;
+ if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
+ const size_t block_size = 1024;
+ leaky_relu_op->context.lut_contiguous = (struct lut_contiguous_context) {
+ .x = input,
+ .x_stride = input_stride * sizeof(uint8_t),
+ .t = leaky_relu_op->lookup_table,
+ .y = output,
+ .y_stride = output_stride * sizeof(uint8_t),
+ .ukernel = xnn_params.x8.lut,
+ };
+ leaky_relu_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ leaky_relu_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_lut_contiguous;
+ leaky_relu_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
+ leaky_relu_op->compute.tile[0] = block_size;
+ } else {
+ leaky_relu_op->context.lut_strided = (struct lut_strided_context) {
+ .n = channels,
+ .x = input,
+ .x_stride = input_stride * sizeof(uint8_t),
+ .t = leaky_relu_op->lookup_table,
+ .y = output,
+ .y_stride = output_stride * sizeof(uint8_t),
+ .ukernel = xnn_params.x8.lut,
+ };
+ leaky_relu_op->compute.type = xnn_parallelization_type_1d;
+ leaky_relu_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_lut_strided;
+ leaky_relu_op->compute.range[0] = batch_size;
+ leaky_relu_op->compute.tile[0] = 0;
+ }
+ leaky_relu_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/max-pooling.c b/src/max-pooling.c
new file mode 100644
index 0000000..c8a4d67
--- /dev/null
+++ b/src/max-pooling.c
@@ -0,0 +1,548 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/params.h>
+#include <xnnpack/indirection.h>
+
+
+static inline size_t compute_output_dimension(
+ size_t padded_input_dimension,
+ size_t kernel_dimension,
+ size_t dilation_dimension,
+ size_t stride_dimension)
+{
+ const size_t effective_kernel_dimension = (kernel_dimension - 1) * dilation_dimension + 1;
+ return (padded_input_dimension - effective_kernel_dimension) / stride_dimension + 1;
+}
+
+enum xnn_status xnn_create_max_pooling2d_nhwc_u8(
+ uint32_t input_padding_top,
+ uint32_t input_padding_right,
+ uint32_t input_padding_bottom,
+ uint32_t input_padding_left,
+ uint32_t pooling_height,
+ uint32_t pooling_width,
+ uint32_t stride_height,
+ uint32_t stride_width,
+ uint32_t dilation_height,
+ uint32_t dilation_width,
+ size_t channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* max_pooling_op_out)
+{
+ xnn_operator_t max_pooling_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Max Pooling operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ const uint32_t pooling_size = pooling_height * pooling_width;
+ if (pooling_size == 0) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with %" PRIu32 "x%" PRIu32 " pooling size: "
+ "pooling size dimensions must be non-zero",
+ pooling_width, pooling_height);
+ goto error;
+ }
+
+ if (pooling_size == 1) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with 1 pooling element: 1x1 pooling is meaningless");
+ goto error;
+ }
+
+ if (stride_height == 0 || stride_width == 0) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with %" PRIu32 "x%" PRIu32 " stride: "
+ "stride dimensions must be non-zero",
+ stride_width, stride_height);
+ goto error;
+ }
+
+ if (dilation_height == 0 || dilation_width == 0) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with %" PRIu32 "x%" PRIu32 " dilation: "
+ "dilation dimensions must be non-zero",
+ dilation_width, dilation_height);
+ goto error;
+ }
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with %zu channels: number of channels must be non-zero",
+ channels);
+ goto error;
+ }
+
+ if (input_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_pixel_stride, channels);
+ goto error;
+ }
+
+ if (output_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_pixel_stride, channels);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with [%" PRIu8 ", %" PRIu8 "] output range: "
+ "range min must be below range max",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ max_pooling_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (max_pooling_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Max Pooling operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ max_pooling_op->padding_top = input_padding_top;
+ max_pooling_op->padding_right = input_padding_right;
+ max_pooling_op->padding_bottom = input_padding_bottom;
+ max_pooling_op->padding_left = input_padding_left;
+
+ max_pooling_op->kernel_height = pooling_height;
+ max_pooling_op->kernel_width = pooling_width;
+ max_pooling_op->stride_height = stride_height;
+ max_pooling_op->stride_width = stride_width;
+ max_pooling_op->dilation_height = dilation_height;
+ max_pooling_op->dilation_width = dilation_width;
+ max_pooling_op->channels = channels;
+ max_pooling_op->input_pixel_stride = input_pixel_stride;
+ max_pooling_op->output_pixel_stride = output_pixel_stride;
+
+ max_pooling_op->u8_output_params = xnn_compute_u8_output_params(output_min, output_max);
+
+ max_pooling_op->type = xnn_operator_type_max_pooling_u8;
+ max_pooling_op->ukernel.type = xnn_ukernel_type_max_pooling;
+
+ max_pooling_op->state = xnn_run_state_invalid;
+
+ *max_pooling_op_out = max_pooling_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(max_pooling_op);
+ return status;
+}
+
+enum xnn_status xnn_create_max_pooling2d_nhwc_f32(
+ uint32_t input_padding_top,
+ uint32_t input_padding_right,
+ uint32_t input_padding_bottom,
+ uint32_t input_padding_left,
+ uint32_t pooling_height,
+ uint32_t pooling_width,
+ uint32_t stride_height,
+ uint32_t stride_width,
+ uint32_t dilation_height,
+ uint32_t dilation_width,
+ size_t channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* max_pooling_op_out)
+{
+ xnn_operator_t max_pooling_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Max Pooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ const uint32_t pooling_size = pooling_height * pooling_width;
+ if (pooling_size == 0) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with %" PRIu32 "x%" PRIu32 " pooling size: "
+ "pooling size dimensions must be non-zero",
+ pooling_width, pooling_height);
+ goto error;
+ }
+
+ if (pooling_size == 1) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with 1 pooling element: "
+ "1x1 pooling is meaningless");
+ goto error;
+ }
+
+ if (stride_height == 0 || stride_width == 0) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with %" PRIu32 "x%" PRIu32 " stride: "
+ "stride dimensions must be non-zero",
+ stride_width, stride_height);
+ goto error;
+ }
+
+ if (dilation_height == 0 || dilation_width == 0) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with %" PRIu32 "x%" PRIu32 " dilation: "
+ "dilation dimensions must be non-zero",
+ dilation_width, dilation_height);
+ goto error;
+ }
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with %zu channels: number of channels must be non-zero",
+ channels);
+ goto error;
+ }
+
+ if (input_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_pixel_stride, channels);
+ goto error;
+ }
+
+ if (output_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Max Pooling operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_pixel_stride, channels);
+ goto error;
+ }
+
+ if (isnan(output_min)) {
+ xnn_log_error(
+ "failed to create Max Pooling with NaN output lower bound: lower bound must be non-NaN");
+ goto error;
+ }
+
+ if (isnan(output_max)) {
+ xnn_log_error(
+ "failed to create Max Pooling with NaN output upper bound: upper bound must be non-NaN");
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Max Pooling with [%.7g, %.7g] output range: lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ max_pooling_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (max_pooling_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Max Pooling operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ max_pooling_op->padding_top = input_padding_top;
+ max_pooling_op->padding_right = input_padding_right;
+ max_pooling_op->padding_bottom = input_padding_bottom;
+ max_pooling_op->padding_left = input_padding_left;
+
+ max_pooling_op->kernel_height = pooling_height;
+ max_pooling_op->kernel_width = pooling_width;
+ max_pooling_op->stride_height = stride_height;
+ max_pooling_op->stride_width = stride_width;
+ max_pooling_op->dilation_height = dilation_height;
+ max_pooling_op->dilation_width = dilation_width;
+ max_pooling_op->channels = channels;
+ max_pooling_op->input_pixel_stride = input_pixel_stride;
+ max_pooling_op->output_pixel_stride = output_pixel_stride;
+
+ max_pooling_op->f32_output_params = xnn_compute_f32_output_params(output_min, output_max);
+
+ max_pooling_op->type = xnn_operator_type_max_pooling_f32;
+ max_pooling_op->ukernel.type = xnn_ukernel_type_max_pooling;
+
+ max_pooling_op->state = xnn_run_state_invalid;
+
+ *max_pooling_op_out = max_pooling_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(max_pooling_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
+ xnn_operator_t max_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (max_pooling_op->type != xnn_operator_type_max_pooling_u8) {
+ xnn_log_error("failed to setup Max Pooling (U8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ max_pooling_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Max Pooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_width == 0 || input_height == 0) {
+ xnn_log_error(
+ "failed to setup Max Pooling operator with %zux%zu input: input dimensions must be non-zero",
+ input_width, input_height);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ max_pooling_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ max_pooling_op->batch_size = batch_size;
+ max_pooling_op->input_height = input_height;
+ max_pooling_op->input_width = input_width;
+ max_pooling_op->input = input;
+
+ max_pooling_op->output_height = compute_output_dimension(
+ max_pooling_op->padding_top + input_height + max_pooling_op->padding_bottom,
+ max_pooling_op->kernel_height,
+ max_pooling_op->dilation_height,
+ max_pooling_op->stride_height);
+ max_pooling_op->output_width = compute_output_dimension(
+ max_pooling_op->padding_left + input_width + max_pooling_op->padding_right,
+ max_pooling_op->kernel_width,
+ max_pooling_op->dilation_width,
+ max_pooling_op->stride_width);
+ max_pooling_op->output = output;
+
+ size_t valid_batch_size = 0;
+ if (input == max_pooling_op->last_input &&
+ input_height == max_pooling_op->last_input_height &&
+ input_width == max_pooling_op->last_input_width)
+ {
+ valid_batch_size = max_pooling_op->valid_batch_size;
+ if (batch_size <= valid_batch_size) {
+ max_pooling_op->compute.range[0] = batch_size;
+ max_pooling_op->state = xnn_run_state_ready;
+ return xnn_status_success;
+ }
+ }
+
+ const size_t pooling_height = max_pooling_op->kernel_height;
+ const size_t pooling_width = max_pooling_op->kernel_width;
+ const size_t pooling_size = pooling_height * pooling_width;
+ const size_t output_height = max_pooling_op->output_height;
+ const size_t output_width = max_pooling_op->output_width;
+ // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
+ const uint32_t mr = xnn_params.u8.maxpool.mr;
+
+ const size_t step_width =
+ max_pooling_op->dilation_width > 1 ? pooling_width : min(max_pooling_op->stride_width, pooling_width);
+ const size_t step_height = pooling_size + (output_width * step_width - 1) * pooling_height;
+ const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
+
+ const void** indirection_buffer = (const void**) realloc(max_pooling_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ max_pooling_op->indirection_buffer = indirection_buffer;
+
+ xnn_indirection_init_maxpool2d(max_pooling_op, valid_batch_size, step_height, step_width, 0 /* log2(sizeof(uint8_t)) */);
+
+ const uint32_t qr = xnn_params.u8.maxpool.qr;
+ const size_t channels = max_pooling_op->channels;
+
+ const size_t indirect_input_height_stride = step_height * sizeof(void*);
+ const size_t output_width_stride = max_pooling_op->output_pixel_stride * sizeof(uint8_t);
+ const size_t output_height_stride = output_width * output_width_stride;
+ const size_t multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
+
+ max_pooling_op->context.max_pooling = (struct max_pooling_context) {
+ .indirect_input = indirection_buffer,
+ .indirect_input_batch_stride = output_height * indirect_input_height_stride,
+ .indirect_input_height_stride = indirect_input_height_stride,
+ .output = output,
+ .output_batch_stride = output_height * output_height_stride,
+ .output_height_stride = output_height_stride,
+ .output_width = output_width,
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+ .output_increment = output_width_stride - channels * sizeof(uint8_t),
+ .params.u8 = max_pooling_op->u8_output_params,
+ .ukernel = xnn_params.u8.maxpool.ukernel,
+ };
+ max_pooling_op->compute.type = xnn_parallelization_type_2d;
+ max_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_max_pooling;
+ max_pooling_op->compute.range[0] = batch_size;
+ max_pooling_op->compute.range[1] = output_height;
+ max_pooling_op->state = xnn_run_state_ready;
+
+ max_pooling_op->last_input = input;
+ max_pooling_op->last_input_height = input_height;
+ max_pooling_op->last_input_width = input_width;
+ max_pooling_op->valid_batch_size = max(valid_batch_size, batch_size);
+
+ return xnn_status_success;
+}
+
+enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
+ xnn_operator_t max_pooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (max_pooling_op->type != xnn_operator_type_max_pooling_f32) {
+ xnn_log_error("failed to setup Max Pooling (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ max_pooling_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error(
+ "failed to setup Max Pooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_width == 0 || input_height == 0) {
+ xnn_log_error(
+ "failed to setup Max Pooling operator with %zux%zu input: input dimensions must be non-zero",
+ input_width, input_height);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ max_pooling_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ max_pooling_op->batch_size = batch_size;
+ max_pooling_op->input_height = input_height;
+ max_pooling_op->input_width = input_width;
+ max_pooling_op->input = input;
+
+ max_pooling_op->output_height = compute_output_dimension(
+ max_pooling_op->padding_top + input_height + max_pooling_op->padding_bottom,
+ max_pooling_op->kernel_height,
+ max_pooling_op->dilation_height,
+ max_pooling_op->stride_height);
+ max_pooling_op->output_width = compute_output_dimension(
+ max_pooling_op->padding_left + input_width + max_pooling_op->padding_right,
+ max_pooling_op->kernel_width,
+ max_pooling_op->dilation_width,
+ max_pooling_op->stride_width);
+ max_pooling_op->output = output;
+
+ size_t valid_batch_size = 0;
+ if (input == max_pooling_op->last_input &&
+ input_height == max_pooling_op->last_input_height &&
+ input_width == max_pooling_op->last_input_width)
+ {
+ valid_batch_size = max_pooling_op->valid_batch_size;
+ if (batch_size <= valid_batch_size) {
+ max_pooling_op->compute.range[0] = batch_size;
+ max_pooling_op->state = xnn_run_state_ready;
+ return xnn_status_success;
+ }
+ }
+
+ const size_t pooling_height = max_pooling_op->kernel_height;
+ const size_t pooling_width = max_pooling_op->kernel_width;
+ const size_t pooling_size = pooling_height * pooling_width;
+ const size_t output_height = max_pooling_op->output_height;
+ const size_t output_width = max_pooling_op->output_width;
+ /* Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer */
+ const uint32_t mr = xnn_params.f32.maxpool.mr;
+
+ const size_t step_width =
+ max_pooling_op->dilation_width > 1 ? pooling_width : min(max_pooling_op->stride_width, pooling_width);
+ const size_t step_height = pooling_size + (output_width * step_width - 1) * pooling_height;
+ const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height);
+
+ const void** indirection_buffer = (const void**) realloc(max_pooling_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ max_pooling_op->indirection_buffer = indirection_buffer;
+
+ xnn_indirection_init_maxpool2d(max_pooling_op, valid_batch_size, step_height, step_width, 2 /* log2(sizeof(float)) */);
+
+ const uint32_t qr = xnn_params.f32.maxpool.qr;
+ const size_t channels = max_pooling_op->channels;
+
+ const size_t indirect_input_height_stride = step_height * sizeof(void*);
+ const size_t output_width_stride = max_pooling_op->output_pixel_stride * sizeof(float);
+ const size_t output_height_stride = output_width * output_width_stride;
+ const size_t multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
+
+ max_pooling_op->context.max_pooling = (struct max_pooling_context) {
+ .indirect_input = indirection_buffer,
+ .indirect_input_batch_stride = output_height * indirect_input_height_stride,
+ .indirect_input_height_stride = indirect_input_height_stride,
+ .output = output,
+ .output_batch_stride = output_height * output_height_stride,
+ .output_height_stride = output_height_stride,
+ .output_width = output_width,
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
+ .output_increment = output_width_stride - channels * sizeof(float),
+ .params.f32 = max_pooling_op->f32_output_params,
+ .ukernel = xnn_params.f32.maxpool.ukernel,
+ };
+ max_pooling_op->compute.type = xnn_parallelization_type_2d;
+ max_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_max_pooling;
+ max_pooling_op->compute.range[0] = batch_size;
+ max_pooling_op->compute.range[1] = output_height;
+ max_pooling_op->state = xnn_run_state_ready;
+
+ max_pooling_op->last_input = input;
+ max_pooling_op->last_input_height = input_height;
+ max_pooling_op->last_input_width = input_width;
+ max_pooling_op->valid_batch_size = max(valid_batch_size, batch_size);
+
+ return xnn_status_success;
+}
diff --git a/src/operator-delete.c b/src/operator-delete.c
new file mode 100644
index 0000000..1a8e73e
--- /dev/null
+++ b/src/operator-delete.c
@@ -0,0 +1,38 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/params.h>
+
+
+enum xnn_status xnn_delete_operator(xnn_operator_t op)
+{
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to delete operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (op == NULL) {
+ return xnn_status_invalid_parameter;
+ }
+
+ free(op->indirection_buffer);
+ xnn_release_memory(op->packed_weights);
+ free(op->a_sum);
+ xnn_release_memory(op->zero_buffer);
+ free(op->pixelwise_buffer);
+ free(op->subconvolution_buffer);
+ xnn_release_memory(op->lookup_table);
+ xnn_release_memory(op);
+ return xnn_status_success;
+}
diff --git a/src/operator-run.c b/src/operator-run.c
new file mode 100644
index 0000000..0c35481
--- /dev/null
+++ b/src/operator-run.c
@@ -0,0 +1,784 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <string.h>
+
+#include <xnnpack.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/params.h>
+#include <xnnpack/compute.h>
+
+
+void xnn_compute_ggemm(
+ const struct gemm_context context[restrict static 1],
+ size_t group_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size)
+{
+ const size_t k_scaled = context->k_scaled;
+ const size_t a_stride = context->a_stride;
+ const size_t cm_stride = context->cm_stride;
+
+ context->ukernel(
+ mr_block_size,
+ nr_block_size,
+ k_scaled,
+ (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
+ a_stride,
+ (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
+ (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
+ cm_stride,
+ context->cn_stride,
+ &context->params);
+}
+
+void xnn_compute_gemm(
+ const struct gemm_context context[restrict static 1],
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size)
+{
+ const size_t a_stride = context->a_stride;
+ const size_t cm_stride = context->cm_stride;
+
+ context->ukernel(
+ mr_block_size,
+ nr_block_size,
+ context->k_scaled,
+ (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
+ a_stride,
+ (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
+ (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
+ cm_stride,
+ context->cn_stride,
+ &context->params);
+}
+
+void xnn_compute_spmm(
+ const struct spmm_context context[restrict static 1],
+ size_t batch_index,
+ size_t mr_block_start,
+ size_t mr_block_size)
+{
+ context->ukernel(
+ mr_block_size,
+ context->n,
+ (const void*) ((uintptr_t) context->a + batch_index * context->batched_a_stride + mr_block_start * sizeof(float)),
+ context->packed_weights,
+ context->input_increments,
+ context->output_channel_nonzeros,
+ (void*) ((uintptr_t) context->c + batch_index * context->batched_c_stride + mr_block_start * sizeof(float)),
+ &context->params);
+}
+
+void xnn_compute_gigemm(
+ const struct igemm_context context[restrict static 1],
+ size_t batch_index,
+ size_t group_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size)
+{
+ const size_t ks = context->ks;
+ const size_t cm_stride = context->cm_stride;
+
+ context->ukernel(
+ mr_block_size,
+ nr_block_size,
+ context->kc,
+ context->ks_scaled,
+ (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
+ (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
+ (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
+ cm_stride,
+ context->cn_stride,
+ context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
+ context->zero,
+ &context->params);
+}
+
+void xnn_compute_igemm(
+ const struct igemm_context context[restrict static 1],
+ size_t batch_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size)
+{
+ const size_t ks = context->ks;
+ const size_t cm_stride = context->cm_stride;
+
+ context->ukernel(
+ mr_block_size,
+ nr_block_size,
+ context->kc,
+ context->ks_scaled,
+ (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
+ (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
+ (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
+ cm_stride,
+ context->cn_stride,
+ context->a_offset + batch_index * context->ba_stride,
+ context->zero,
+ &context->params);
+}
+
+void xnn_compute_gsubconv2d(
+ const struct subconv_context context[restrict static 1],
+ size_t batch_index,
+ size_t group_index,
+ size_t subkernel_index,
+ size_t slice_y,
+ size_t slice_x_start,
+ size_t nc_block_start,
+ size_t slice_x_max,
+ size_t nc_block_size)
+{
+ const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
+
+ if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
+ return;
+ }
+
+ const size_t slice_width = subconvolution_params->slice_width;
+ if XNN_UNLIKELY(slice_x_start >= slice_width) {
+ return;
+ }
+ const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
+
+ const size_t cx_stride = context->cx_stride;
+ context->ukernel(
+ slice_x_size,
+ nc_block_size,
+ context->kc,
+ subconvolution_params->scaled_kernel_size,
+ (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
+ (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
+ (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
+ cx_stride,
+ context->cn_stride,
+ context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
+ context->zero,
+ &context->params);
+}
+
+void xnn_compute_subconv2d(
+ const struct subconv_context context[restrict static 1],
+ size_t batch_index,
+ size_t subkernel_index,
+ size_t slice_y,
+ size_t slice_x_start,
+ size_t nc_block_start,
+ size_t slice_x_max,
+ size_t nc_block_size)
+{
+ const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
+
+ if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
+ return;
+ }
+
+ const size_t slice_width = subconvolution_params->slice_width;
+ if XNN_UNLIKELY(slice_x_start >= slice_width) {
+ return;
+ }
+ const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
+
+ const size_t cx_stride = context->cx_stride;
+ context->ukernel(
+ slice_x_size,
+ nc_block_size,
+ context->kc,
+ subconvolution_params->scaled_kernel_size,
+ (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
+ (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
+ (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
+ cx_stride,
+ context->cn_stride,
+ context->a_offset + batch_index * context->ba_stride,
+ context->zero,
+ &context->params);
+}
+
+void xnn_compute_dconv2d_hwc2spchw(
+ const struct dconv2d_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y_start,
+ size_t output_y_slice)
+{
+ context->hwc2spchw_ukernel(
+ context->input_height,
+ context->input_width,
+ output_y_start,
+ output_y_start + output_y_slice,
+ (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
+ context->zero,
+ context->packed_weights,
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
+ context->input_padding_top,
+ context->output_channels,
+ context->output_height_stride,
+ context->output_channel_stride,
+ &context->params);
+}
+
+void xnn_compute_dwconv_unipass(
+ const struct dwconv_context context[restrict static 1],
+ size_t output_y)
+{
+ context->unipass_ukernel(
+ context->groups,
+ context->output_width,
+ context->indirection_buffer + output_y * context->indirection_buffer_row_stride,
+ context->packed_weights,
+ context->output + output_y * context->output_row_stride,
+ context->indirection_buffer_col_stride,
+ context->output_col_increment,
+ &context->params);
+}
+
+void xnn_compute_dwconv2d_spchw(
+ const struct dwconv2d_context context[restrict static 1],
+ size_t batch_index,
+ size_t channel)
+{
+ context->spchw_ukernel(
+ context->output_height,
+ context->input_width,
+ (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
+ (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
+ (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
+ context->input_tuple_stride,
+ context->output_tuple_stride,
+ context->input_pixel_stride,
+ context->output_pixel_stride,
+ &context->params);
+}
+
+void xnn_compute_argmax_pooling_unipass(
+ const struct argmax_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y)
+{
+ const void** indirect_input =
+ (const void**) ((uintptr_t) context->indirect_input +
+ batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+ uint32_t* index =
+ (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
+
+ context->unipass_ukernel(
+ context->output_width, context->pooling_size, context->channels,
+ indirect_input, output, index,
+ context->input_increment, context->output_increment,
+ &context->params);
+}
+
+void xnn_compute_argmax_pooling_multipass(
+ const struct argmax_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y)
+{
+ const void** indirect_input =
+ (const void**) ((uintptr_t) context->indirect_input +
+ batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+ uint32_t* index =
+ (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
+
+ XNN_ALIGN(16) float multipass_output_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(float)];
+ XNN_ALIGN(16) uint32_t multipass_index_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint32_t)];
+
+ context->multipass_ukernel(
+ context->output_width, context->pooling_size, context->channels,
+ indirect_input, multipass_output_buffer, multipass_index_buffer, output, index,
+ context->input_increment, context->output_increment,
+ &context->params);
+}
+
+void xnn_compute_max_pooling(
+ const struct max_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y)
+{
+ const void** indirect_input =
+ (const void**) ((uintptr_t) context->indirect_input +
+ batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+
+ context->ukernel(
+ context->output_width, context->pooling_size, context->channels,
+ indirect_input, output,
+ context->input_increment, context->output_increment,
+ &context->params);
+}
+
+void xnn_compute_unpooling(
+ const struct unpooling_context context[restrict static 1],
+ size_t input_y,
+ size_t input_x)
+{
+ const void* input = (const void*) ((uintptr_t) context->input +
+ input_y * context->input_height_stride + input_x * context->input_width_stride);
+ const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
+ input_y * context->index_height_stride + input_x * context->index_width_stride);
+ void** indirect_output =
+ (void**) ((uintptr_t) context->indirect_output +
+ input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
+
+ context->ukernel(
+ context->pooling_size,
+ context->channels,
+ context->fill_value,
+ input, index, indirect_output);
+}
+
+void xnn_compute_average_pooling_unipass(
+ const struct average_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y)
+{
+ const void** indirect_input =
+ (const void**) ((uintptr_t) context->indirect_input +
+ batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+
+ context->unipass_ukernel(
+ context->output_width, context->pooling_size, context->channels,
+ indirect_input, context->zero, output,
+ context->input_increment, context->output_increment,
+ &context->params);
+}
+
+void xnn_compute_average_pooling_multipass(
+ const struct average_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y)
+{
+ const void** indirect_input =
+ (const void**) ((uintptr_t) context->indirect_input +
+ batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+ XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
+
+ context->multipass_ukernel(
+ context->output_width, context->pooling_size, context->channels,
+ indirect_input, context->zero, multipass_buffer, output,
+ context->input_increment, context->output_increment,
+ &context->params);
+}
+
+void xnn_compute_pixelwise_average_pooling_unipass(
+ const struct pixelwise_average_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y)
+{
+ const void** indirect_input =
+ (const void**) ((uintptr_t) context->indirect_input +
+ batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+ const void* pixelwise_buffer =
+ (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+
+ context->unipass_ukernel(
+ context->output_width, context->pooling_size, context->channels,
+ indirect_input, context->zero, pixelwise_buffer, output,
+ context->input_increment, context->output_increment,
+ &context->params);
+}
+
+void xnn_compute_pixelwise_average_pooling_multipass(
+ const struct pixelwise_average_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y)
+{
+ const void** indirect_input =
+ (const void**) ((uintptr_t) context->indirect_input +
+ batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
+ const void* pixelwise_buffer =
+ (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
+ XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
+
+ context->multipass_ukernel(
+ context->output_width, context->pooling_size, context->channels,
+ indirect_input, context->zero, pixelwise_buffer, multipass_buffer, output,
+ context->input_increment, context->output_increment,
+ &context->params);
+}
+
+void xnn_compute_global_average_pooling_unipass(
+ const struct global_average_pooling_context context[restrict static 1],
+ size_t batch_index)
+{
+ const void* input =
+ (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
+
+ context->unipass_ukernel(
+ context->input_elements,
+ context->channels,
+ input,
+ context->input_pixel_stride,
+ context->zero,
+ output,
+ &context->params);
+}
+
+void xnn_compute_global_average_pooling_multipass(
+ const struct global_average_pooling_context context[restrict static 1],
+ size_t batch_index)
+{
+ const void* input =
+ (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
+ XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
+
+ context->multipass_ukernel(
+ context->input_elements,
+ context->channels,
+ input,
+ context->input_pixel_stride,
+ context->zero,
+ multipass_buffer,
+ output,
+ &context->params);
+}
+
+void xnn_compute_global_average_pooling_spnchw(
+ const struct global_average_pooling_spnchw_context context[restrict static 1],
+ size_t batch_index,
+ size_t channels_start,
+ size_t channels_slice)
+{
+ const void* input =
+ (const void*) ((uintptr_t) context->input + channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
+ void* output =
+ (void*) ((uintptr_t) context->output + channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
+
+ context->ukernel(
+ context->input_elements,
+ channels_slice,
+ input,
+ output,
+ &context->params);
+}
+
+void xnn_compute_prelu(
+ const struct prelu_context context[restrict static 1],
+ size_t batch_start,
+ size_t batch_range)
+{
+ const size_t x_stride = context->x_stride;
+ const size_t y_stride = context->y_stride;
+ const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
+ void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
+
+ context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride, &context->params);
+}
+
+void xnn_compute_channel_pad(
+ const struct channel_pad_context context[restrict static 1],
+ size_t batch_start,
+ size_t batch_range)
+{
+ const size_t x_stride = context->x_stride;
+ const size_t y_stride = context->y_stride;
+ const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
+ void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
+
+ context->ukernel(batch_range, context->n, context->l, context->r, context->c, x, x_stride, y, y_stride);
+}
+
+void xnn_compute_add_strided(
+ const struct add_strided_context context[restrict static 1],
+ size_t batch_index,
+ size_t batch_range /* always 1 */)
+{
+ assert(batch_range == 1);
+
+ const size_t n = context->n;
+ const size_t a_stride = context->a_stride;
+ const size_t b_stride = context->b_stride;
+ const size_t y_stride = context->y_stride;
+ const void* a = (const void*) ((uintptr_t) context->a + a_stride * batch_index);
+ const void* b = (const void*) ((uintptr_t) context->b + b_stride * batch_index);
+ void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
+
+ context->ukernel(n, a, b, y, &context->params);
+}
+
+void xnn_compute_add_contiguous(
+ const struct add_contiguous_context context[restrict static 1],
+ size_t offset,
+ size_t size)
+{
+ const void* a = (const void*) ((uintptr_t) context->a + offset);
+ const void* b = (const void*) ((uintptr_t) context->b + offset);
+ void* y = (void*) ((uintptr_t) context->y + offset);
+ context->ukernel(size, a, b, y, &context->params);
+}
+
+void xnn_compute_channel_shuffle_fixed(
+ const struct channel_shuffle_context context[restrict static 1],
+ size_t index)
+{
+ const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
+ void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
+
+ context->fixed_ukernel(context->n, x, y);
+}
+
+void xnn_compute_channel_shuffle_variable(
+ const struct channel_shuffle_context context[restrict static 1],
+ size_t index)
+{
+ const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
+ void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
+
+ context->variable_ukernel(context->n, context->m, x, y);
+}
+
+void xnn_compute_lut_strided(
+ const struct lut_strided_context context[restrict static 1],
+ size_t batch_index)
+{
+ const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
+ void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
+
+ context->ukernel(context->n, x, context->t, y);
+}
+
+void xnn_compute_lut_contiguous(
+ const struct lut_contiguous_context context[restrict static 1],
+ size_t offset,
+ size_t size)
+{
+ const void* x = (const void*) ((uintptr_t) context->x + offset);
+ void* y = (void*) ((uintptr_t) context->y + offset);
+
+ context->ukernel(size, x, context->t, y);
+}
+
+void xnn_compute_univector_strided(
+ const struct univector_strided_context context[restrict static 1],
+ size_t batch_index,
+ size_t batch_range /* always 1 */)
+{
+ assert(batch_range == 1);
+
+ const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
+ void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
+ context->ukernel(context->n, x, y, &context->params);
+}
+
+void xnn_compute_univector_contiguous(
+ const struct univector_contiguous_context context[restrict static 1],
+ size_t offset,
+ size_t size)
+{
+ const void* x = (const void*) ((uintptr_t) context->x + offset);
+ void* y = (void*) ((uintptr_t) context->y + offset);
+ context->ukernel(size, x, y, &context->params);
+}
+
+void xnn_compute_u8_softargmax(
+ const struct u8_softargmax_context context[restrict static 1],
+ size_t batch_index)
+{
+ const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
+ uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
+ const size_t n = context->n;
+
+ uint8_t x_max = 0;
+ context->rmax_ukernel(n, x, &x_max);
+ const size_t adjustment = x_max ^ 255;
+ const uint32_t* t = (const uint32_t*) context->t + adjustment;
+ context->lut_norm_ukernel(n, x, t, y);
+}
+
+void xnn_compute_vmulcaddc(
+ const struct vmulcaddc_context context[restrict static 1],
+ size_t batch_start,
+ size_t batch_size)
+{
+ const size_t x_stride = context->x_stride;
+ const size_t y_stride = context->y_stride;
+
+ const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
+ void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
+
+ context->ukernel(
+ batch_size,
+ context->n,
+ x, x_stride,
+ context->w,
+ y, y_stride,
+ &context->params);
+}
+
+enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
+{
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to run operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+ switch (op->state) {
+ case xnn_run_state_invalid:
+ xnn_log_error("failed to run operator: operator was not successfully setup");
+ return xnn_status_invalid_state;
+ case xnn_run_state_ready:
+ break;
+ case xnn_run_state_skip:
+ return xnn_status_success;
+ }
+
+ switch (op->compute.type) {
+ case xnn_parallelization_type_invalid:
+ break;
+ case xnn_parallelization_type_1d:
+ assert(op->compute.range[0] != 0);
+ pthreadpool_parallelize_1d(
+ threadpool,
+ op->compute.task_1d,
+ &op->context,
+ op->compute.range[0],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_1d_tile_1d:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.tile[0] != 0);
+ pthreadpool_parallelize_1d_tile_1d(
+ threadpool,
+ op->compute.task_1d_tile_1d,
+ &op->context,
+ op->compute.range[0],
+ op->compute.tile[0],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_2d:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ pthreadpool_parallelize_2d(
+ threadpool,
+ op->compute.task_2d,
+ &op->context,
+ op->compute.range[0], op->compute.range[1],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_2d_tile_1d:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ assert(op->compute.tile[0] != 0);
+ pthreadpool_parallelize_2d_tile_1d(
+ threadpool,
+ op->compute.task_2d_tile_1d,
+ &op->context,
+ op->compute.range[0], op->compute.range[1],
+ op->compute.tile[0],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_2d_tile_2d:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ assert(op->compute.tile[0] != 0);
+ assert(op->compute.tile[1] != 0);
+ pthreadpool_parallelize_2d_tile_2d(
+ threadpool,
+ op->compute.task_2d_tile_2d,
+ &op->context,
+ op->compute.range[0], op->compute.range[1],
+ op->compute.tile[0], op->compute.tile[1],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_3d_tile_2d:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ assert(op->compute.range[2] != 0);
+ assert(op->compute.tile[0] != 0);
+ assert(op->compute.tile[1] != 0);
+ pthreadpool_parallelize_3d_tile_2d(
+ threadpool,
+ op->compute.task_3d_tile_2d,
+ &op->context,
+ op->compute.range[0], op->compute.range[1], op->compute.range[2],
+ op->compute.tile[0], op->compute.tile[1],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_4d_tile_2d:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ assert(op->compute.range[2] != 0);
+ assert(op->compute.range[3] != 0);
+ assert(op->compute.tile[0] != 0);
+ assert(op->compute.tile[1] != 0);
+ pthreadpool_parallelize_4d_tile_2d(
+ threadpool,
+ op->compute.task_4d_tile_2d,
+ &op->context,
+ op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
+ op->compute.tile[0], op->compute.tile[1],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_5d_tile_2d:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ assert(op->compute.range[2] != 0);
+ assert(op->compute.range[3] != 0);
+ assert(op->compute.range[4] != 0);
+ assert(op->compute.tile[0] != 0);
+ assert(op->compute.tile[1] != 0);
+ pthreadpool_parallelize_5d_tile_2d(
+ threadpool,
+ op->compute.task_5d_tile_2d,
+ &op->context,
+ op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
+ op->compute.tile[0], op->compute.tile[1],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_6d_tile_2d:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ assert(op->compute.range[2] != 0);
+ assert(op->compute.range[3] != 0);
+ assert(op->compute.range[4] != 0);
+ assert(op->compute.range[5] != 0);
+ assert(op->compute.tile[0] != 0);
+ assert(op->compute.tile[1] != 0);
+ pthreadpool_parallelize_6d_tile_2d(
+ threadpool,
+ op->compute.task_6d_tile_2d,
+ &op->context,
+ op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
+ op->compute.tile[0], op->compute.tile[1],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ default:
+ XNN_UNREACHABLE;
+ }
+ return xnn_status_success;
+}
diff --git a/src/prelu.c b/src/prelu.c
new file mode 100644
index 0000000..8c64e5d
--- /dev/null
+++ b/src/prelu.c
@@ -0,0 +1,143 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+
+
+enum xnn_status xnn_create_prelu_nc_f32(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ const float* negative_slope,
+ float output_min,
+ float output_max,
+ uint32_t flags,
+ xnn_operator_t* prelu_op_out)
+{
+ xnn_operator_t prelu_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create PReLU operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create PReLU operator with %zu channels: number of channels must be non-zero", channels);
+ goto error;
+ }
+
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create PReLU operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_stride, channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create PReLU operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_stride, channels);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create PReLU operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ prelu_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (prelu_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for PReLU operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ const size_t packed_channels = round_up(channels, XNN_EXTRA_BYTES / sizeof(float));
+ prelu_op->packed_weights = xnn_allocate_memory(packed_channels * sizeof(float));
+ if (prelu_op->packed_weights == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for packed slope data",
+ packed_channels * sizeof(float));
+ goto error;
+ }
+ memcpy(prelu_op->packed_weights, negative_slope, channels * sizeof(float));
+
+ prelu_op->channels = channels;
+ prelu_op->input_pixel_stride = input_stride;
+ prelu_op->output_pixel_stride = output_stride;
+ prelu_op->f32_output_params = xnn_compute_f32_output_params(output_min, output_max);
+
+ prelu_op->type = xnn_operator_type_prelu_f32;
+ prelu_op->ukernel.type = xnn_ukernel_type_prelu;
+
+ prelu_op->state = xnn_run_state_invalid;
+
+ *prelu_op_out = prelu_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(prelu_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_prelu_nc_f32(
+ xnn_operator_t prelu_op,
+ size_t batch_size,
+ const float* input,
+ float* output,
+ pthreadpool_t threadpool)
+{
+ if (prelu_op->type != xnn_operator_type_prelu_f32) {
+ xnn_log_error("failed to setup PReLU (F32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ prelu_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup PReLU operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ prelu_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ const size_t channels = prelu_op->channels;
+ prelu_op->context.prelu = (struct prelu_context) {
+ .n = channels * sizeof(float),
+ .x = input,
+ .x_stride = prelu_op->input_pixel_stride * sizeof(float),
+ .w = prelu_op->packed_weights,
+ .y = output,
+ .y_stride = prelu_op->output_pixel_stride * sizeof(float),
+ .ukernel = xnn_params.f32.prelu.ukernel,
+ .params = prelu_op->f32_output_params,
+ };
+ prelu_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ prelu_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_prelu;
+ prelu_op->compute.range[0] = batch_size;
+ prelu_op->compute.tile[0] = xnn_params.f32.prelu.mr;
+ prelu_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/q8-avgpool/mp9p8q-neon.c b/src/q8-avgpool/mp9p8q-neon.c
new file mode 100644
index 0000000..08a0b49
--- /dev/null
+++ b/src/q8-avgpool/mp9p8q-neon.c
@@ -0,0 +1,320 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_q8_avgpool_ukernel_mp9p8q__neon(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** input,
+ const uint8_t* zero,
+ int32_t* buffer,
+ uint8_t* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const int32x4_t vbias = vld1q_dup_s32(¶ms->neon.bias);
+#ifdef __aarch64__
+ const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
+#else
+ const int32x2_t vmultiplier = vld1_dup_s32(¶ms->neon.multiplier);
+#endif
+ const int64x2_t vleft_shift = vld1q_dup_s64(¶ms->neon.left_shift);
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
+ const uint8x8_t voutput_min = vld1_dup_u8(¶ms->neon.output_min);
+ const uint8x8_t voutput_max = vld1_dup_u8(¶ms->neon.output_max);
+
+ do {
+ {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+ const uint8_t* i8 = *input++;
+
+ int32_t* acc = buffer;
+ for (size_t k = 0; k < kc; k += 8) {
+ const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
+ const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
+ const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
+ const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
+ const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
+ const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
+ const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
+ const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
+ const uint8x8_t vi8 = vld1_u8(i8); i8 += 8;
+
+ const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+ const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
+
+ const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
+ const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
+ const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
+
+ const int32x4_t vacc_lo = vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
+ const int32x4_t vacc_hi = vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
+
+ vst1q_s32(acc, vacc_lo); acc += 4;
+ vst1q_s32(acc, vacc_hi); acc += 4;
+ }
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+
+ int32_t* acc = buffer;
+ for (size_t k = 0; k < kc; k += 8) {
+ const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
+ const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
+ const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
+ const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
+ const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
+ const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
+ const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
+ const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
+ int32x4_t vacc_lo = vld1q_s32(acc);
+ int32x4_t vacc_hi = vld1q_s32(acc + 4);
+
+ const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+ const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
+
+ const uint16x8_t vsum0123 = vaddq_u16(vsum01, vsum23);
+ const uint16x8_t vsum4567 = vaddq_u16(vsum45, vsum67);
+ const uint16x8_t vsum = vaddq_u16(vsum0123, vsum4567);
+
+ vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vsum)));
+ vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vsum)));
+
+ vst1q_s32(acc, vacc_lo); acc += 4;
+ vst1q_s32(acc, vacc_hi); acc += 4;
+ }
+ }
+
+ {
+ const uint8_t* i0 = input[0];
+ const uint8_t* i1 = input[1];
+ const uint8_t* i2 = input[2];
+ const uint8_t* i3 = input[3];
+ const uint8_t* i4 = input[4];
+ const uint8_t* i5 = input[5];
+ const uint8_t* i6 = input[6];
+ const uint8_t* i7 = input[7];
+ input = (const uint8_t**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ size_t k = kc;
+ int32_t* acc = buffer;
+ while (k >= 8) {
+ const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
+ const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
+ const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
+ const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
+ const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
+ const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
+ const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
+ const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
+ int32x4_t vacc_lo = vld1q_s32(acc); acc += 4;
+ int32x4_t vacc_hi = vld1q_s32(acc); acc += 4;
+
+ const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
+ const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
+ const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
+ const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
+
+ const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
+ const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
+ const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
+
+ vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
+ vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
+
+ const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
+ const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
+
+#if defined(__aarch64__)
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
+#else
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
+ const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
+#endif
+
+ const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
+ const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
+ const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
+ const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
+
+#ifdef __aarch64__
+ vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
+ vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
+ vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ vst1_u8(output, vout); output += 8;
+
+ k -= 8;
+ }
+ if (k != 0) {
+ const uint8x8_t vi0 = vld1_u8(i0);
+ const uint8x8_t vi1 = vld1_u8(i1);
+ const uint8x8_t vi2 = vld1_u8(i2);
+ const uint8x8_t vi3 = vld1_u8(i3);
+ const uint8x8_t vi4 = vld1_u8(i4);
+ const uint8x8_t vi5 = vld1_u8(i5);
+ const uint8x8_t vi6 = vld1_u8(i6);
+ const uint8x8_t vi7 = vld1_u8(i7);
+ int32x4_t vacc_lo = vld1q_s32(acc); acc += 4;
+ int32x4_t vacc_hi = vld1q_s32(acc);
+
+ const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
+ const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
+ const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
+ const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
+
+ const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
+ const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
+ const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
+
+ vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
+ vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
+
+ const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
+ const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
+
+#if defined(__aarch64__)
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
+#else
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
+ const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
+#endif
+
+ const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
+ const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
+ const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
+ const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
+
+#ifdef __aarch64__
+ vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
+ vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
+ vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ if (k & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
+ vout = vext_u8(vout, vout, 4);
+ }
+ if (k & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
+ vout = vext_u8(vout, vout, 2);
+ }
+ if (k & 1) {
+ vst1_lane_u8(output, vout, 0); output += 1;
+ }
+ }
+ }
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/q8-avgpool/mp9p8q-scalar.c b/src/q8-avgpool/mp9p8q-scalar.c
new file mode 100644
index 0000000..6971756
--- /dev/null
+++ b/src/q8-avgpool/mp9p8q-scalar.c
@@ -0,0 +1,184 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/avgpool.h>
+
+
+void xnn_q8_avgpool_ukernel_mp9p8q__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** input,
+ const uint8_t* zero,
+ int32_t* buffer,
+ uint8_t* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const int32_t vbias = params->scalar.bias;
+ const int32_t vmultiplier = params->scalar.multiplier;
+ const int64_t vrounding = params->scalar.rounding;
+ const uint32_t vshift = params->scalar.right_shift;
+ const int32_t voutput_min = params->scalar.output_min_less_zero_point;
+ const int32_t voutput_max = params->scalar.output_max_less_zero_point;
+ const int32_t voutput_zero_point = params->scalar.output_zero_point;
+ do {
+ /* First pass */ {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+ const uint8_t* i8 = *input++;
+
+ int32_t* b = buffer;
+ size_t k = kc;
+ do {
+ const uint32_t vi0 = (uint32_t) *i0++;
+ const uint32_t vi1 = (uint32_t) *i1++;
+ const uint32_t vi2 = (uint32_t) *i2++;
+ const uint32_t vi3 = (uint32_t) *i3++;
+ const uint32_t vi4 = (uint32_t) *i4++;
+ const uint32_t vi5 = (uint32_t) *i5++;
+ const uint32_t vi6 = (uint32_t) *i6++;
+ const uint32_t vi7 = (uint32_t) *i7++;
+ const uint32_t vi8 = (uint32_t) *i8++;
+
+ const uint32_t vsum01 = vi0 + vi1;
+ const uint32_t vsum23 = vi2 + vi3;
+ const uint32_t vsum45 = vi4 + vi5;
+ const uint32_t vsum67 = vi6 + vi7;
+ const uint32_t vsum018 = vsum01 + vi8;
+ const uint32_t vsum2345 = vsum23 + vsum45;
+ const uint32_t vsum01678 = vsum018 + vsum67;
+ int32_t vacc = vbias + (int32_t) vsum2345;
+ vacc += (int32_t) vsum01678;
+ *b++ = vacc;
+ } while (--k != 0);
+ }
+
+ size_t m = ks;
+ /* Intermediate passes */
+ for (m -= 9; m > 8; m -= 8) {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+
+ int32_t* b = buffer;
+ size_t k = kc;
+ do {
+ int32_t vacc = *b;
+
+ const uint32_t vi0 = (uint32_t) *i0++;
+ const uint32_t vi1 = (uint32_t) *i1++;
+ const uint32_t vi2 = (uint32_t) *i2++;
+ const uint32_t vi3 = (uint32_t) *i3++;
+ const uint32_t vi4 = (uint32_t) *i4++;
+ const uint32_t vi5 = (uint32_t) *i5++;
+ const uint32_t vi6 = (uint32_t) *i6++;
+ const uint32_t vi7 = (uint32_t) *i7++;
+
+ const uint32_t vsum01 = vi0 + vi1;
+ const uint32_t vsum23 = vi2 + vi3;
+ const uint32_t vsum45 = vi4 + vi5;
+ const uint32_t vsum67 = vi6 + vi7;
+ const uint32_t vsum0123 = vsum01 + vsum23;
+ const uint32_t vsum4567 = vsum45 + vsum67;
+ vacc += (int32_t) vsum0123;
+ vacc += (int32_t) vsum4567;
+
+ *b++ = vacc;
+ } while (--k != 0);
+ }
+
+ /* Last pass */ {
+ const uint8_t* i0 = input[0];
+ const uint8_t* i1 = input[1];
+ const uint8_t* i2 = input[2];
+ const uint8_t* i3 = input[3];
+ const uint8_t* i4 = input[4];
+ const uint8_t* i5 = input[5];
+ const uint8_t* i6 = input[6];
+ const uint8_t* i7 = input[7];
+ input = (const uint8_t**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ size_t k = kc;
+ int32_t* b = buffer;
+ do {
+ int32_t vacc = *b++;
+
+ const uint32_t vi0 = (uint32_t) *i0++;
+ const uint32_t vi1 = (uint32_t) *i1++;
+ const uint32_t vi2 = (uint32_t) *i2++;
+ const uint32_t vi3 = (uint32_t) *i3++;
+ const uint32_t vi4 = (uint32_t) *i4++;
+ const uint32_t vi5 = (uint32_t) *i5++;
+ const uint32_t vi6 = (uint32_t) *i6++;
+ const uint32_t vi7 = (uint32_t) *i7++;
+
+ const uint32_t vsum01 = vi0 + vi1;
+ const uint32_t vsum23 = vi2 + vi3;
+ const uint32_t vsum45 = vi4 + vi5;
+ const uint32_t vsum67 = vi6 + vi7;
+ const uint32_t vsum0123 = vsum01 + vsum23;
+ const uint32_t vsum4567 = vsum45 + vsum67;
+ vacc += (int32_t) vsum0123;
+ vacc += (int32_t) vsum4567;
+
+ const int64_t vproduct = (int64_t) vacc * (int64_t) vmultiplier;
+ const int64_t vadjusted_product = vproduct - (int64_t) (vacc < 0);
+ int32_t vout = (int32_t) asr_s64(vadjusted_product + vrounding, vshift);
+ vout = vout < voutput_min ? voutput_min : vout;
+ vout = vout > voutput_max ? voutput_max : vout;
+ vout += voutput_zero_point;
+
+ *output++ = (uint8_t) vout;
+ } while (--k != 0);
+ }
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/q8-avgpool/mp9p8q-sse2.c b/src/q8-avgpool/mp9p8q-sse2.c
new file mode 100644
index 0000000..a27074a
--- /dev/null
+++ b/src/q8-avgpool/mp9p8q-sse2.c
@@ -0,0 +1,341 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_q8_avgpool_ukernel_mp9p8q__sse2(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** input,
+ const uint8_t* zero,
+ int32_t* buffer,
+ uint8_t* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks > 9);
+ assert(kc != 0);
+
+ const __m128i vbias = _mm_load_si128((const __m128i*) ¶ms->sse2.bias);
+ const __m128i vzero = _mm_setzero_si128();
+ const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
+ const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
+ const __m128i vright_shift = _mm_loadl_epi64((const __m128i*) params->sse2.right_shift);
+
+ do {
+ {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+ const uint8_t* i8 = *input++;
+
+ int32_t* acc = buffer;
+ for (size_t k = 0; k < kc; k += 8) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+ const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7); i7 += 8;
+ const __m128i vi8 = _mm_loadl_epi64((const __m128i*) i8); i8 += 8;
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+ const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
+ const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
+
+ const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+ const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
+
+ const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
+ const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67);
+ const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678);
+
+ const __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
+ const __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
+
+ _mm_store_si128((__m128i*) acc, vacc_lo);
+ _mm_store_si128((__m128i*) acc + 1, vacc_hi);
+ acc += 8;
+ }
+ }
+
+ size_t m = ks;
+ for (m -= 9; m > 8; m -= 8) {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+
+ int32_t* acc = buffer;
+ for (size_t k = 0; k < kc; k += 8) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+ const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7); i7 += 8;
+ __m128i vacc_lo = _mm_load_si128((const __m128i*) acc);
+ __m128i vacc_hi = _mm_load_si128((const __m128i*) acc + 1);
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+ const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
+
+ const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+ const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
+
+ const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23);
+ const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67);
+ const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567);
+
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero));
+
+ _mm_store_si128((__m128i*) acc, vacc_lo);
+ _mm_store_si128((__m128i*) acc + 1, vacc_hi);
+ acc += 8;
+ }
+ }
+
+ {
+ const uint8_t* i0 = input[0];
+ const uint8_t* i1 = input[1];
+ const uint8_t* i2 = input[2];
+ const uint8_t* i3 = input[3];
+ const uint8_t* i4 = input[4];
+ const uint8_t* i5 = input[5];
+ const uint8_t* i6 = input[6];
+ const uint8_t* i7 = input[7];
+ input = (const uint8_t**) ((uintptr_t) input + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ if (m <= 2) {
+ i2 = zero;
+ }
+ if (m < 4) {
+ i3 = zero;
+ }
+ if (m <= 4) {
+ i4 = zero;
+ }
+ if (m < 6) {
+ i5 = zero;
+ }
+ if (m <= 6) {
+ i6 = zero;
+ }
+ if (m != 8) {
+ i7 = zero;
+ }
+
+ size_t k = kc;
+ int32_t* acc = buffer;
+ while (k >= 8) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+ const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7); i7 += 8;
+ __m128i vacc_lo = _mm_load_si128((const __m128i*) acc);
+ __m128i vacc_hi = _mm_load_si128((const __m128i*) acc + 1);
+ acc += 8;
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+ const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
+
+ const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+ const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
+
+ const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23);
+ const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67);
+ const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567);
+
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero));
+
+ const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
+
+ const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
+ const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
+
+ const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
+ const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
+
+ const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
+ const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
+
+ const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
+
+ __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
+ vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_zero_point));
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_min));
+
+ _mm_storel_epi64((__m128i*) output, vout);
+ output += 8;
+
+ k -= 8;
+ }
+ if (k != 0) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0);
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1);
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2);
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3);
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4);
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5);
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6);
+ const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7);
+ __m128i vacc_lo = _mm_load_si128((const __m128i*) acc);
+ __m128i vacc_hi = _mm_load_si128((const __m128i*) acc + 1);
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+ const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
+
+ const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+ const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
+
+ const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23);
+ const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67);
+ const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567);
+
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero));
+
+ const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
+
+ const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
+ const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
+
+ const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
+ const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
+
+ const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
+ const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
+
+ const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
+
+ __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
+ vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_zero_point));
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_min));
+
+ if (k & 4) {
+ *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout);
+ output += 4;
+ vout = _mm_srli_epi64(vout, 32);
+ }
+ if (k & 2) {
+ *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout, 0);
+ output += 2;
+ vout = _mm_srli_epi32(vout, 16);
+ }
+ if (k & 1) {
+ *((uint8_t*) output) = (uint8_t) _mm_cvtsi128_si32(vout);
+ output += 1;
+ }
+ }
+ }
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/q8-avgpool/up9-neon.c b/src/q8-avgpool/up9-neon.c
new file mode 100644
index 0000000..d228186
--- /dev/null
+++ b/src/q8-avgpool/up9-neon.c
@@ -0,0 +1,238 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_q8_avgpool_ukernel_up9__neon(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** input,
+ const uint8_t* zero,
+ uint8_t* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const int32x4_t vbias = vld1q_dup_s32(¶ms->neon.bias);
+#ifdef __aarch64__
+ const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
+#else
+ const int32x2_t vmultiplier = vld1_dup_s32(¶ms->neon.multiplier);
+#endif
+ const int64x2_t vleft_shift = vld1q_dup_s64(¶ms->neon.left_shift);
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
+ const uint8x8_t voutput_min = vld1_dup_u8(¶ms->neon.output_min);
+ const uint8x8_t voutput_max = vld1_dup_u8(¶ms->neon.output_max);
+
+ do {
+ const uint8_t* i0 = input[0];
+ const uint8_t* i1 = input[1];
+ const uint8_t* i2 = input[2];
+ const uint8_t* i3 = input[3];
+ const uint8_t* i4 = input[4];
+ const uint8_t* i5 = input[5];
+ const uint8_t* i6 = input[6];
+ const uint8_t* i7 = input[7];
+ const uint8_t* i8 = input[8];
+ input = (const uint8_t**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ size_t k = kc;
+ while (k >= 8) {
+ const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
+ const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
+ const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
+ const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
+ const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
+ const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
+ const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
+ const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
+ const uint8x8_t vi8 = vld1_u8(i8); i8 += 8;
+
+ const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+ const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
+
+ const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
+ const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
+ const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
+
+ int32x4_t vacc_lo = vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
+ int32x4_t vacc_hi = vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
+
+ const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
+ const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
+
+#if defined(__aarch64__)
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
+#else
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
+ const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
+#endif
+
+ const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
+ const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
+ const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
+ const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
+
+#ifdef __aarch64__
+ vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
+ vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
+ vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ vst1_u8(output, vout); output += 8;
+
+ k -= 8;
+ }
+ if (k != 0) {
+ const uint8x8_t vi0 = vld1_u8(i0);
+ const uint8x8_t vi1 = vld1_u8(i1);
+ const uint8x8_t vi2 = vld1_u8(i2);
+ const uint8x8_t vi3 = vld1_u8(i3);
+ const uint8x8_t vi4 = vld1_u8(i4);
+ const uint8x8_t vi5 = vld1_u8(i5);
+ const uint8x8_t vi6 = vld1_u8(i6);
+ const uint8x8_t vi7 = vld1_u8(i7);
+ const uint8x8_t vi8 = vld1_u8(i8);
+
+ const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+ const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
+
+ const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
+ const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
+ const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
+
+ int32x4_t vacc_lo = vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
+ int32x4_t vacc_hi = vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
+
+ const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
+ const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
+
+#if defined(__aarch64__)
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
+#else
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
+ const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
+#endif
+
+ const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
+ const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
+ const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
+ const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
+
+#ifdef __aarch64__
+ vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
+ vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
+ vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ if (k & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
+ vout = vext_u8(vout, vout, 4);
+ }
+ if (k & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
+ vout = vext_u8(vout, vout, 2);
+ }
+ if (k & 1) {
+ vst1_lane_u8(output, vout, 0); output += 1;
+ }
+ }
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/q8-avgpool/up9-scalar.c b/src/q8-avgpool/up9-scalar.c
new file mode 100644
index 0000000..b8f2fa4
--- /dev/null
+++ b/src/q8-avgpool/up9-scalar.c
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/avgpool.h>
+
+
+void xnn_q8_avgpool_ukernel_up9__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** input,
+ const uint8_t* zero,
+ uint8_t* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const int32_t vbias = params->scalar.bias;
+ const int32_t vmultiplier = params->scalar.multiplier;
+ const int64_t vrounding = params->scalar.rounding;
+ const uint32_t vshift = params->scalar.right_shift;
+ const int32_t voutput_min = params->scalar.output_min_less_zero_point;
+ const int32_t voutput_max = params->scalar.output_max_less_zero_point;
+ const int32_t voutput_zero_point = params->scalar.output_zero_point;
+ do {
+ const uint8_t* i0 = input[0];
+ const uint8_t* i1 = input[1];
+ const uint8_t* i2 = input[2];
+ const uint8_t* i3 = input[3];
+ const uint8_t* i4 = input[4];
+ const uint8_t* i5 = input[5];
+ const uint8_t* i6 = input[6];
+ const uint8_t* i7 = input[7];
+ const uint8_t* i8 = input[8];
+ input = (const uint8_t**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ size_t k = kc;
+ do {
+ const uint32_t vi0 = (uint32_t) *i0++;
+ const uint32_t vi1 = (uint32_t) *i1++;
+ const uint32_t vi2 = (uint32_t) *i2++;
+ const uint32_t vi3 = (uint32_t) *i3++;
+ const uint32_t vi4 = (uint32_t) *i4++;
+ const uint32_t vi5 = (uint32_t) *i5++;
+ const uint32_t vi6 = (uint32_t) *i6++;
+ const uint32_t vi7 = (uint32_t) *i7++;
+ const uint32_t vi8 = (uint32_t) *i8++;
+
+ const uint32_t vsum01 = vi0 + vi1;
+ const uint32_t vsum23 = vi2 + vi3;
+ const uint32_t vsum45 = vi4 + vi5;
+ const uint32_t vsum67 = vi6 + vi7;
+ const uint32_t vsum018 = vsum01 + vi8;
+ const uint32_t vsum2345 = vsum23 + vsum45;
+ const uint32_t vsum01678 = vsum018 + vsum67;
+ const uint32_t vsum = vsum2345 + vsum01678;
+
+ const int32_t vacc = vbias + (int32_t) vsum;
+
+ const int64_t vproduct = (int64_t) vacc * (int64_t) vmultiplier;
+ const int64_t vadjusted_product = vproduct - (int64_t) (vacc < 0);
+ int32_t vout = (int32_t) asr_s64(vadjusted_product + vrounding, vshift);
+ vout = vout < voutput_min ? voutput_min : vout;
+ vout = vout > voutput_max ? voutput_max : vout;
+ vout += voutput_zero_point;
+
+ *output++ = (uint8_t) vout;
+ } while (--k != 0);
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/q8-avgpool/up9-sse2.c b/src/q8-avgpool/up9-sse2.c
new file mode 100644
index 0000000..5998790
--- /dev/null
+++ b/src/q8-avgpool/up9-sse2.c
@@ -0,0 +1,239 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/avgpool.h>
+
+
+void xnn_q8_avgpool_ukernel_up9__sse2(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** input,
+ const uint8_t* zero,
+ uint8_t* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(ks <= 9);
+ assert(kc != 0);
+
+ const __m128i vbias = _mm_load_si128((const __m128i*) ¶ms->sse2.bias);
+ const __m128i vzero = _mm_setzero_si128();
+ const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
+ const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
+ const __m128i vright_shift = _mm_loadl_epi64((const __m128i*) params->sse2.right_shift);
+
+ do {
+ const uint8_t* i0 = input[0];
+ const uint8_t* i1 = input[1];
+ const uint8_t* i2 = input[2];
+ const uint8_t* i3 = input[3];
+ const uint8_t* i4 = input[4];
+ const uint8_t* i5 = input[5];
+ const uint8_t* i6 = input[6];
+ const uint8_t* i7 = input[7];
+ const uint8_t* i8 = input[8];
+ input = (const uint8_t**) ((uintptr_t) input + input_increment);
+ if (ks < 2) {
+ i1 = zero;
+ }
+ if (ks <= 2) {
+ i2 = zero;
+ }
+ if (ks < 4) {
+ i3 = zero;
+ }
+ if (ks <= 4) {
+ i4 = zero;
+ }
+ if (ks < 6) {
+ i5 = zero;
+ }
+ if (ks <= 6) {
+ i6 = zero;
+ }
+ if (ks < 8) {
+ i7 = zero;
+ }
+ if (ks <= 8) {
+ i8 = zero;
+ }
+
+ size_t k = kc;
+ while (k >= 8) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+ const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7); i7 += 8;
+ const __m128i vi8 = _mm_loadl_epi64((const __m128i*) i8); i8 += 8;
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+ const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
+ const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
+
+ const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+ const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
+
+ const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
+ const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67);
+ const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678);
+
+ const __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
+ const __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
+
+ const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
+
+ const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
+ const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
+
+ const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
+ const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
+
+ const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
+ const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
+
+ const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
+
+ __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
+ vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_zero_point));
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_min));
+
+ _mm_storel_epi64((__m128i*) output, vout);
+ output += 8;
+
+ k -= 8;
+ }
+ if (k != 0) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0);
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1);
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2);
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3);
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4);
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5);
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6);
+ const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7);
+ const __m128i vi8 = _mm_loadl_epi64((const __m128i*) i8);
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+ const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
+ const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
+
+ const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+ const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7);
+
+ const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
+ const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67);
+ const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678);
+
+ const __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
+ const __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
+
+ const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
+
+ const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
+ const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
+
+ const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
+ const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
+
+ const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
+ const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
+
+ const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
+
+ __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
+ vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_zero_point));
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) ¶ms->sse2.output_min));
+
+ if (k & 4) {
+ *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout);
+ output += 4;
+ vout = _mm_srli_epi64(vout, 32);
+ }
+ if (k & 2) {
+ *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout, 0);
+ output += 2;
+ vout = _mm_srli_epi32(vout, 16);
+ }
+ if (k & 1) {
+ *((uint8_t*) output) = (uint8_t) _mm_cvtsi128_si32(vout);
+ output += 1;
+ }
+ }
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/q8-dwconv/up1x9-scalar.c b/src/q8-dwconv/up1x9-scalar.c
new file mode 100644
index 0000000..2014e86
--- /dev/null
+++ b/src/q8-dwconv/up1x9-scalar.c
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/dwconv.h>
+
+
+void xnn_q8_dwconv_ukernel_up1x9__scalar(
+ size_t channels,
+ size_t output_width,
+ const uint8_t** input,
+ const void* weights,
+ uint8_t* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ const int32_t vkernel_zero_point = params->scalar.kernel_zero_point;
+ const int32_t vmultiplier = params->scalar.multiplier;
+ const int32_t vq31rounding = INT32_C(0x40000000);
+ const int32_t vremainder_mask = params->scalar.remainder_mask;
+ const uint32_t vshift = params->scalar.shift;
+ const int32_t vremainder_threshold = params->scalar.remainder_threshold;
+ const int32_t vout_min = params->scalar.output_min_less_zero_point;
+ const int32_t vout_max = params->scalar.output_max_less_zero_point;
+ const int32_t voutput_zero_point = params->scalar.output_zero_point;
+ do {
+ const uint8_t* i0 = input[0];
+ const uint8_t* i1 = input[1];
+ const uint8_t* i2 = input[2];
+ const uint8_t* i3 = input[3];
+ const uint8_t* i4 = input[4];
+ const uint8_t* i5 = input[5];
+ const uint8_t* i6 = input[6];
+ const uint8_t* i7 = input[7];
+ const uint8_t* i8 = input[8];
+
+ input = (const uint8_t**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const void* w = weights;
+ do {
+ int32_t vacc = *((const int32_t*) w);
+
+ const int32_t vi0 = (int32_t) (uint32_t) *i0++;
+ const uint32_t vk0 = (uint32_t) ((const uint8_t*) w)[4];
+ const int32_t vxk0 = (int32_t) vk0 - vkernel_zero_point;
+ vacc += vi0 * vxk0;
+
+ const int32_t vi1 = (int32_t) (uint32_t) *i1++;
+ const uint32_t vk1 = (uint32_t) ((const uint8_t*) w)[5];
+ const int32_t vxk1 = (int32_t) vk1 - vkernel_zero_point;
+ vacc += vi1 * vxk1;
+
+ const int32_t vi2 = (int32_t) (uint32_t) *i2++;
+ const uint32_t vk2 = (uint32_t) ((const uint8_t*) w)[6];
+ const int32_t vxk2 = (int32_t) vk2 - vkernel_zero_point;
+ vacc += vi2 * vxk2;
+
+ const int32_t vi3 = (int32_t) (uint32_t) *i3++;
+ const uint32_t vk3 = (uint32_t) ((const uint8_t*) w)[7];
+ const int32_t vxk3 = (int32_t) vk3 - vkernel_zero_point;
+ vacc += vi3 * vxk3;
+
+ const int32_t vi4 = (int32_t) (uint32_t) *i4++;
+ const uint32_t vk4 = (uint32_t) ((const uint8_t*) w)[8];
+ const int32_t vxk4 = (int32_t) vk4 - vkernel_zero_point;
+ vacc += vi4 * vxk4;
+
+ const int32_t vi5 = (int32_t) (uint32_t) *i5++;
+ const uint32_t vk5 = (uint32_t) ((const uint8_t*) w)[9];
+ const int32_t vxk5 = (int32_t) vk5 - vkernel_zero_point;
+ vacc += vi5 * vxk5;
+
+ const int32_t vi6 = (int32_t) (uint32_t) *i6++;
+ const uint32_t vk6 = (uint32_t) ((const uint8_t*) w)[10];
+ const int32_t vxk6 = (int32_t) vk6 - vkernel_zero_point;
+ vacc += vi6 * vxk6;
+
+ const int32_t vi7 = (int32_t) (uint32_t) *i7++;
+ const uint32_t vk7 = (uint32_t) ((const uint8_t*) w)[11];
+ const int32_t vxk7 = (int32_t) vk7 - vkernel_zero_point;
+ vacc += vi7 * vxk7;
+
+ const int32_t vi8 = (int32_t) (uint32_t) *i8++;
+ const uint32_t vk8 = (uint32_t) ((const uint8_t*) w)[12];
+ const int32_t vxk8 = (int32_t) vk8 - vkernel_zero_point;
+ vacc += vi8 * vxk8;
+
+ w = (const void*) ((uintptr_t) w + sizeof(int32_t) + 9 * sizeof(uint8_t));
+
+ const int64_t vproduct = (int64_t) vacc * (int64_t) vmultiplier;
+ const int32_t vq31product = (int32_t) (uint32_t) ((uint64_t) (vproduct + (int64_t) vq31rounding) >> 31);
+ const int32_t vremainder = (vq31product & vremainder_mask) - (int32_t) (vq31product < 0);
+ int32_t vout = asr_s32(vq31product, vshift) + (int32_t) (vremainder > vremainder_threshold);
+ vout = vout < vout_min ? vout_min : vout;
+ vout = vout > vout_max ? vout_max : vout;
+ vout += voutput_zero_point;
+
+ *output++ = vout;
+ } while (--c != 0);
+
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/q8-dwconv/up8x9-aarch32-neon.S b/src/q8-dwconv/up8x9-aarch32-neon.S
new file mode 100644
index 0000000..aceabf9
--- /dev/null
+++ b/src/q8-dwconv/up8x9-aarch32-neon.S
@@ -0,0 +1,362 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <xnnpack/assembly.h>
+
+.syntax unified
+
+# void xnn_q8_dwconv_ukernel_up8x9__aarch32_neon(
+# size_t channels,
+# size_t output_width,
+# const uint8_t** input,
+# const void* weights,
+# uint8_t* output,
+# size_t input_stride,
+# size_t output_increment,
+# const union xnn_q8_gemm_params params[restrict static 1])
+BEGIN_FUNCTION xnn_q8_dwconv_ukernel_up8x9__aarch32_neon
+ .arm
+#ifndef __APPLE__
+ .arch armv7-a
+ .fpu neon
+#endif
+
+ # Load params
+ # - r12 = params
+ LDR r12, [sp, 12]
+
+ PUSH {r4, r5, r6, r7, r8, r9, r10, r11, lr}
+ VPUSH {d8-d15}
+
+ STR r0, [sp, #-8]
+ STR r3, [sp, #-4]
+
+ MOV r4, 4
+
+ # Load o:
+ # - lr = o = output
+ LDR lr, [sp, 100]
+
+ # Load kernel zero point:
+ # - d31 = vkernel_zero_point
+ VLD1.8 {d31[]}, [r12], r4
+
+ # Load multiplier:
+ # - q14 = d28:d29 = vmultiplier
+ VLD1.32 {d28[], d29[]}, [r12]!
+
+ # Load right shift:
+ # - q13 = d26:d27 = vright_shift
+ VLD1.32 {d26[], d27[]}, [r12]!
+
+ # Load output zero point:
+ # - q12 = d24:d25 = voutput_zero_point
+ VLD1.16 {d24[], d25[]}, [r12]!
+
+ # Compute vzero_shift_mask
+ # - q11 = vzero_shift_mask
+ VCEQ.S32 q11, q13, 0
+
+ # Load output max:
+ # - d20 = voutput_max
+ VLD1.8 {d20[]}, [r12]!
+
+ # Load output min:
+ # - d21 = voutput_min
+ VLD1.8 {d21[]}, [r12]
+
+ .p2align 3
+0:
+ # Load input stride
+ # - r3 = input_stride
+ LDR r3, [sp, 104]
+
+ # Load c:
+ # - r0 = c = channels
+ LDR r0, [sp, #-8]
+
+ # Load i0, i1, i2, i3, i4, i5, i6, i7, i8
+ # - r4 = i0
+ # - r5 = i1
+ # - r6 = i2
+ # - r7 = i3
+ # - r8 = i4
+ # - r9 = i5
+ # - r10 = i6
+ # - r11 = i7
+ # - r12 = i8
+ LDM r2, {r4, r5, r6, r7, r8, r9, r10, r11, r12}
+
+ # Pre-decrement c
+ SUBS r0, r0, 8
+
+ # Increment input by input stride
+ # - input = r2 := input + input_stride
+ ADD r2, r2, r3
+
+ # Load w:
+ # - r3 = w = weights
+ LDR r3, [sp, #-4]
+
+ BLO 2f
+
+ .p2align 4
+1:
+ VLDM r3!, {d0-d3}
+
+ VLD1.8 {d4}, [r4]!
+ VLD1.8 {d6}, [r3]!
+
+ VLD1.8 {d8}, [r5]!
+ VLD1.8 {d10}, [r3]!
+
+ VMOVL.U8 q2, d4
+ VSUBL.U8 q3, d6, d31
+
+ VLD1.8 {d12}, [r6]!
+ VLD1.8 {d14}, [r3]!
+
+ VMOVL.U8 q4, d8
+ VSUBL.U8 q5, d10, d31
+
+ VMLAL.S16 q0, d4, d6
+ VMLAL.S16 q1, d5, d7
+
+ VLD1.8 {d4}, [r7]!
+ VLD1.8 {d6}, [r3]!
+
+ VMOVL.U8 q6, d12
+ VSUBL.U8 q7, d14, d31
+
+ VMLAL.S16 q0, d8, d10
+ VMLAL.S16 q1, d9, d11
+
+ VLD1.8 {d8}, [r8]!
+ VLD1.8 {d10}, [r3]!
+
+ VMOVL.U8 q2, d4
+ VSUBL.U8 q3, d6, d31
+
+ VMLAL.S16 q0, d12, d14
+ VMLAL.S16 q1, d13, d15
+
+ VLD1.8 {d12}, [r9]!
+ VLD1.8 {d14}, [r3]!
+
+ VMOVL.U8 q4, d8
+ VSUBL.U8 q5, d10, d31
+
+ VMLAL.S16 q0, d4, d6
+ VMLAL.S16 q1, d5, d7
+
+ VLD1.8 {d4}, [r10]!
+ VLD1.8 {d6}, [r3]!
+
+ VMOVL.U8 q6, d12
+ VSUBL.U8 q7, d14, d31
+
+ VMLAL.S16 q0, d8, d10
+ VMLAL.S16 q1, d9, d11
+
+ VLD1.8 {d8}, [r11]!
+ VLD1.8 {d10}, [r3]!
+
+ VMOVL.U8 q2, d4
+ VSUBL.U8 q3, d6, d31
+
+ VMLAL.S16 q0, d12, d14
+ VMLAL.S16 q1, d13, d15
+
+ VLD1.8 {d12}, [r12]!
+ VLD1.8 {d14}, [r3]!
+
+ VMOVL.U8 q4, d8
+ VSUBL.U8 q5, d10, d31
+
+ VMLAL.S16 q0, d4, d6
+ VMLAL.S16 q1, d5, d7
+
+ VMOVL.U8 q6, d12
+ VSUBL.U8 q7, d14, d31
+
+ VMLAL.S16 q0, d8, d10
+ VMLAL.S16 q1, d9, d11
+
+ VMLAL.S16 q0, d12, d14
+ VMLAL.S16 q1, d13, d15
+
+ VQRDMULH.S32 q0, q0, q14
+ VQRDMULH.S32 q1, q1, q14
+
+ VBIC q2, q0, q11
+ VBIC q3, q1, q11
+
+ VSRA.S32 q0, q2, 31
+ VSRA.S32 q1, q3, 31
+
+ VRSHL.S32 q0, q0, q13
+ VRSHL.S32 q1, q1, q13
+
+ VQMOVN.S32 d0, q0
+ VQMOVN.S32 d1, q1
+
+ VQADD.S16 q0, q12
+ VQMOVUN.S16 d0, q0
+ VMIN.U8 d0, d0, d20
+ VMAX.U8 d0, d0, d21
+
+ VST1.8 {d0}, [lr]!
+ SUBS r0, r0, 8
+ BHS 1b
+
+2:
+ CMP r0, -8
+ BEQ 5f
+
+ VLDM r3!, {d0-d3}
+
+ VLD1.8 {d4}, [r4]!
+ VLD1.8 {d6}, [r3]!
+
+ VLD1.8 {d8}, [r5]!
+ VLD1.8 {d10}, [r3]!
+
+ VMOVL.U8 q2, d4
+ VSUBL.U8 q3, d6, d31
+
+ VLD1.8 {d12}, [r6]!
+ VLD1.8 {d14}, [r3]!
+
+ VMOVL.U8 q4, d8
+ VSUBL.U8 q5, d10, d31
+
+ VMLAL.S16 q0, d4, d6
+ VMLAL.S16 q1, d5, d7
+
+ VLD1.8 {d4}, [r7]!
+ VLD1.8 {d6}, [r3]!
+
+ VMOVL.U8 q6, d12
+ VSUBL.U8 q7, d14, d31
+
+ VMLAL.S16 q0, d8, d10
+ VMLAL.S16 q1, d9, d11
+
+ VLD1.8 {d8}, [r8]!
+ VLD1.8 {d10}, [r3]!
+
+ VMOVL.U8 q2, d4
+ VSUBL.U8 q3, d6, d31
+
+ VMLAL.S16 q0, d12, d14
+ VMLAL.S16 q1, d13, d15
+
+ VLD1.8 {d12}, [r9]!
+ VLD1.8 {d14}, [r3]!
+
+ VMOVL.U8 q4, d8
+ VSUBL.U8 q5, d10, d31
+
+ VMLAL.S16 q0, d4, d6
+ VMLAL.S16 q1, d5, d7
+
+ VLD1.8 {d4}, [r10]!
+ VLD1.8 {d6}, [r3]!
+
+ VMOVL.U8 q6, d12
+ VSUBL.U8 q7, d14, d31
+
+ VMLAL.S16 q0, d8, d10
+ VMLAL.S16 q1, d9, d11
+
+ VLD1.8 {d8}, [r11]!
+ VLD1.8 {d10}, [r3]!
+
+ VMOVL.U8 q2, d4
+ VSUBL.U8 q3, d6, d31
+
+ VMLAL.S16 q0, d12, d14
+ VMLAL.S16 q1, d13, d15
+
+ VLD1.8 {d12}, [r12]!
+ VLD1.8 {d14}, [r3]!
+
+ VMOVL.U8 q4, d8
+ VSUBL.U8 q5, d10, d31
+
+ VMLAL.S16 q0, d4, d6
+ VMLAL.S16 q1, d5, d7
+
+ VMOVL.U8 q6, d12
+ VSUBL.U8 q7, d14, d31
+
+ VMLAL.S16 q0, d8, d10
+ VMLAL.S16 q1, d9, d11
+
+ VMLAL.S16 q0, d12, d14
+ VMLAL.S16 q1, d13, d15
+
+ VQRDMULH.S32 q0, q0, q14
+ VQRDMULH.S32 q1, q1, q14
+
+ VBIC q2, q0, q11
+ VBIC q3, q1, q11
+
+ VSRA.S32 q0, q2, 31
+ VSRA.S32 q1, q3, 31
+
+ VRSHL.S32 q0, q0, q13
+ VRSHL.S32 q1, q1, q13
+
+ VQMOVN.S32 d0, q0
+ VQMOVN.S32 d1, q1
+
+ VQADD.S16 q0, q12
+ VQMOVUN.S16 d0, q0
+ VMIN.U8 d0, d0, d20
+ VMAX.U8 d0, d0, d21
+
+ TST r0, 4
+ BEQ 3f
+ VST1.32 {d0[0]}, [lr]!
+ VEXT.8 d0, d0, 4
+
+3:
+ TST r0, 2
+ BEQ 4f
+ VST1.16 {d0[0]}, [lr]!
+ VEXT.8 d0, d0, 2
+
+4:
+ TST r0, 1
+ BEQ 5f
+ VST1.8 {d0[0]}, [lr]!
+
+5:
+ # Load output increment
+ # - r3 = output_increment
+ LDR r3, [sp, 108]
+
+ # Decrement output width
+ SUBS r1, r1, 1
+
+ # Increment output by output_increment
+ ADD lr, lr, r3
+
+ # If output width is non-zero, process another pixel
+ BNE 0b
+
+ VPOP {d8-d15}
+ POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
+END_FUNCTION xnn_q8_dwconv_ukernel_up8x9__aarch32_neon
+
+#ifdef __ELF__
+.section ".note.GNU-stack","",%progbits
+#endif
diff --git a/src/q8-dwconv/up8x9-neon.c b/src/q8-dwconv/up8x9-neon.c
new file mode 100644
index 0000000..30fae41
--- /dev/null
+++ b/src/q8-dwconv/up8x9-neon.c
@@ -0,0 +1,614 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <arm_neon.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_q8_dwconv_ukernel_up8x9__neon(
+ size_t channels,
+ size_t output_width,
+ const uint8_t** input,
+ const void* weights,
+ uint8_t* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ const uint8x8_t vkernel_zero_point = vld1_dup_u8((const uint8_t*) ¶ms->neon.kernel_zero_point);
+ const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
+ const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift);
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
+ const uint8x8_t voutput_min = vld1_dup_u8(¶ms->neon.output_min);
+ const uint8x8_t voutput_max = vld1_dup_u8(¶ms->neon.output_max);
+
+#ifdef __aarch64__
+ /* Larger number of registers on AArch64 make it possible to process few pixels at a time */
+ if (input_stride == 3 * sizeof(void*)) {
+ for (; output_width >= 3; output_width -= 3) {
+ const uint8_t* i00 = input[ 0];
+ const uint8_t* i10 = input[ 1];
+ const uint8_t* i20 = input[ 2];
+ const uint8_t* i01 = input[ 3];
+ const uint8_t* i11 = input[ 4];
+ const uint8_t* i21 = input[ 5];
+ const uint8_t* i02 = input[ 6];
+ const uint8_t* i12 = input[ 7];
+ const uint8_t* i22 = input[ 8];
+ const uint8_t* i03 = input[ 9];
+ const uint8_t* i13 = input[10];
+ const uint8_t* i23 = input[11];
+ const uint8_t* i04 = input[12];
+ const uint8_t* i14 = input[13];
+ const uint8_t* i24 = input[14];
+
+ uint8_t* output0 = output;
+ uint8_t* output1 = output0 + channels + output_increment;
+ uint8_t* output2 = output1 + channels + output_increment;
+
+ input += 9;
+
+ size_t c = channels;
+ const void* w = weights;
+ for (; c >= 8; c -= 8) {
+ int32x4_t vacc0_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vacc0_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vacc1_lo = vacc0_lo;
+ int32x4_t vacc2_lo = vacc0_lo;
+ int32x4_t vacc1_hi = vacc0_hi;
+ int32x4_t vacc2_hi = vacc0_hi;
+
+ const uint8x8_t vk00 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi00 = vld1_u8(i00); i00 += 8;
+ const uint8x8_t vi01 = vld1_u8(i01); i01 += 8;
+ const uint8x8_t vi02 = vld1_u8(i02); i02 += 8;
+ const int16x8_t vxk00 = vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
+ const int16x8_t vxi00 = vreinterpretq_s16_u16(vmovl_u8(vi00));
+ const int16x8_t vxi01 = vreinterpretq_s16_u16(vmovl_u8(vi01));
+ const int16x8_t vxi02 = vreinterpretq_s16_u16(vmovl_u8(vi02));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
+
+ const uint8x8_t vk10 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi10 = vld1_u8(i10); i10 += 8;
+ const uint8x8_t vi11 = vld1_u8(i11); i11 += 8;
+ const uint8x8_t vi12 = vld1_u8(i12); i12 += 8;
+ const int16x8_t vxk10 = vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
+ const int16x8_t vxi10 = vreinterpretq_s16_u16(vmovl_u8(vi10));
+ const int16x8_t vxi11 = vreinterpretq_s16_u16(vmovl_u8(vi11));
+ const int16x8_t vxi12 = vreinterpretq_s16_u16(vmovl_u8(vi12));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
+
+ const uint8x8_t vk20 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi20 = vld1_u8(i20); i20 += 8;
+ const uint8x8_t vi21 = vld1_u8(i21); i21 += 8;
+ const uint8x8_t vi22 = vld1_u8(i22); i22 += 8;
+ const int16x8_t vxk20 = vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
+ const int16x8_t vxi20 = vreinterpretq_s16_u16(vmovl_u8(vi20));
+ const int16x8_t vxi21 = vreinterpretq_s16_u16(vmovl_u8(vi21));
+ const int16x8_t vxi22 = vreinterpretq_s16_u16(vmovl_u8(vi22));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
+
+ const uint8x8_t vk01 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi03 = vld1_u8(i03); i03 += 8;
+ const int16x8_t vxk01 = vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
+ const int16x8_t vxi03 = vreinterpretq_s16_u16(vmovl_u8(vi03));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
+
+ const uint8x8_t vk11 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi13 = vld1_u8(i13); i13 += 8;
+ const int16x8_t vxk11 = vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
+ const int16x8_t vxi13 = vreinterpretq_s16_u16(vmovl_u8(vi13));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
+
+ const uint8x8_t vk21 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi23 = vld1_u8(i23); i23 += 8;
+ const int16x8_t vxk21 = vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
+ const int16x8_t vxi23 = vreinterpretq_s16_u16(vmovl_u8(vi23));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
+
+ const uint8x8_t vk02 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi04 = vld1_u8(i04); i04 += 8;
+ const int16x8_t vxk02 = vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
+ const int16x8_t vxi04 = vreinterpretq_s16_u16(vmovl_u8(vi04));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
+
+ const uint8x8_t vk12 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi14 = vld1_u8(i14); i14 += 8;
+ const int16x8_t vxk12 = vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
+ const int16x8_t vxi14 = vreinterpretq_s16_u16(vmovl_u8(vi14));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
+
+ const uint8x8_t vk22 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi24 = vld1_u8(i24); i24 += 8;
+ const int16x8_t vxk22 = vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
+ const int16x8_t vxi24 = vreinterpretq_s16_u16(vmovl_u8(vi24));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
+
+ vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier);
+ vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier);
+ vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier);
+ vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier);
+ vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier);
+ vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier);
+
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
+ vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
+ vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
+ vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
+ vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
+ vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
+
+ vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
+ vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
+ vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
+ vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
+ vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
+ vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
+
+ const int16x8_t vacc0 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), voutput_zero_point);
+ const int16x8_t vacc1 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), voutput_zero_point);
+ const int16x8_t vacc2 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), voutput_zero_point);
+ uint8x8_t vout0 = vqmovun_s16(vacc0);
+ uint8x8_t vout1 = vqmovun_s16(vacc1);
+ uint8x8_t vout2 = vqmovun_s16(vacc2);
+ vout0 = vmax_u8(vout0, voutput_min);
+ vout1 = vmax_u8(vout1, voutput_min);
+ vout2 = vmax_u8(vout2, voutput_min);
+ vout0 = vmin_u8(vout0, voutput_max);
+ vout1 = vmin_u8(vout1, voutput_max);
+ vout2 = vmin_u8(vout2, voutput_max);
+
+ vst1_u8(output0, vout0); output0 += 8;
+ vst1_u8(output1, vout1); output1 += 8;
+ vst1_u8(output2, vout2); output2 += 8;
+ }
+ if (c != 0) {
+ int32x4_t vacc0_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vacc0_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vacc1_lo = vacc0_lo;
+ int32x4_t vacc2_lo = vacc0_lo;
+ int32x4_t vacc1_hi = vacc0_hi;
+ int32x4_t vacc2_hi = vacc0_hi;
+
+ const uint8x8_t vk00 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi00 = vld1_u8(i00);
+ const uint8x8_t vi01 = vld1_u8(i01);
+ const uint8x8_t vi02 = vld1_u8(i02);
+ const int16x8_t vxk00 = vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
+ const int16x8_t vxi00 = vreinterpretq_s16_u16(vmovl_u8(vi00));
+ const int16x8_t vxi01 = vreinterpretq_s16_u16(vmovl_u8(vi01));
+ const int16x8_t vxi02 = vreinterpretq_s16_u16(vmovl_u8(vi02));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
+
+ const uint8x8_t vk10 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi10 = vld1_u8(i10);
+ const uint8x8_t vi11 = vld1_u8(i11);
+ const uint8x8_t vi12 = vld1_u8(i12);
+ const int16x8_t vxk10 = vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
+ const int16x8_t vxi10 = vreinterpretq_s16_u16(vmovl_u8(vi10));
+ const int16x8_t vxi11 = vreinterpretq_s16_u16(vmovl_u8(vi11));
+ const int16x8_t vxi12 = vreinterpretq_s16_u16(vmovl_u8(vi12));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
+
+ const uint8x8_t vk20 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi20 = vld1_u8(i20);
+ const uint8x8_t vi21 = vld1_u8(i21);
+ const uint8x8_t vi22 = vld1_u8(i22);
+ const int16x8_t vxk20 = vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
+ const int16x8_t vxi20 = vreinterpretq_s16_u16(vmovl_u8(vi20));
+ const int16x8_t vxi21 = vreinterpretq_s16_u16(vmovl_u8(vi21));
+ const int16x8_t vxi22 = vreinterpretq_s16_u16(vmovl_u8(vi22));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
+
+ const uint8x8_t vk01 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi03 = vld1_u8(i03);
+ const int16x8_t vxk01 = vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
+ const int16x8_t vxi03 = vreinterpretq_s16_u16(vmovl_u8(vi03));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
+
+ const uint8x8_t vk11 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi13 = vld1_u8(i13);
+ const int16x8_t vxk11 = vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
+ const int16x8_t vxi13 = vreinterpretq_s16_u16(vmovl_u8(vi13));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
+
+ const uint8x8_t vk21 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi23 = vld1_u8(i23);
+ const int16x8_t vxk21 = vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
+ const int16x8_t vxi23 = vreinterpretq_s16_u16(vmovl_u8(vi23));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
+
+ const uint8x8_t vk02 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi04 = vld1_u8(i04);
+ const int16x8_t vxk02 = vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
+ const int16x8_t vxi04 = vreinterpretq_s16_u16(vmovl_u8(vi04));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
+
+ const uint8x8_t vk12 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi14 = vld1_u8(i14);
+ const int16x8_t vxk12 = vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
+ const int16x8_t vxi14 = vreinterpretq_s16_u16(vmovl_u8(vi14));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
+
+ const uint8x8_t vk22 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi24 = vld1_u8(i24);
+ const int16x8_t vxk22 = vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
+ const int16x8_t vxi24 = vreinterpretq_s16_u16(vmovl_u8(vi24));
+ vacc0_lo = vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
+ vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
+ vacc1_lo = vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
+ vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
+ vacc2_lo = vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
+ vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
+
+ vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier);
+ vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier);
+ vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier);
+ vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier);
+ vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier);
+ vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier);
+
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
+ vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
+ vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
+ vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
+ vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
+ vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
+
+ vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
+ vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
+ vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
+ vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
+ vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
+ vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
+
+ const int16x8_t vacc0 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), voutput_zero_point);
+ const int16x8_t vacc1 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), voutput_zero_point);
+ const int16x8_t vacc2 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), voutput_zero_point);
+ uint8x8_t vout0 = vqmovun_s16(vacc0);
+ uint8x8_t vout1 = vqmovun_s16(vacc1);
+ uint8x8_t vout2 = vqmovun_s16(vacc2);
+ vout0 = vmax_u8(vout0, voutput_min);
+ vout1 = vmax_u8(vout1, voutput_min);
+ vout2 = vmax_u8(vout2, voutput_min);
+ vout0 = vmin_u8(vout0, voutput_max);
+ vout1 = vmin_u8(vout1, voutput_max);
+ vout2 = vmin_u8(vout2, voutput_max);
+
+ if (c & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(output0, 1), vreinterpret_u32_u8(vout0), 0); output0 += 4;
+ vst1_lane_u32(__builtin_assume_aligned(output1, 1), vreinterpret_u32_u8(vout1), 0); output1 += 4;
+ vst1_lane_u32(__builtin_assume_aligned(output2, 1), vreinterpret_u32_u8(vout2), 0); output2 += 4;
+ vout0 = vext_u8(vout0, vout0, 4);
+ vout1 = vext_u8(vout1, vout1, 4);
+ vout2 = vext_u8(vout2, vout2, 4);
+ }
+ if (c & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(output0, 1), vreinterpret_u16_u8(vout0), 0); output0 += 2;
+ vst1_lane_u16(__builtin_assume_aligned(output1, 1), vreinterpret_u16_u8(vout1), 0); output1 += 2;
+ vst1_lane_u16(__builtin_assume_aligned(output2, 1), vreinterpret_u16_u8(vout2), 0); output2 += 2;
+ vout0 = vext_u8(vout0, vout0, 2);
+ vout1 = vext_u8(vout1, vout1, 2);
+ vout2 = vext_u8(vout2, vout2, 2);
+ }
+ if (c & 1) {
+ vst1_lane_u8(__builtin_assume_aligned(output0, 1), vout0, 0); output0++;
+ vst1_lane_u8(__builtin_assume_aligned(output1, 1), vout1, 0); output1++;
+ vst1_lane_u8(__builtin_assume_aligned(output2, 1), vout2, 0); output2++;
+ }
+ }
+
+ output = (uint8_t*) ((uintptr_t) output2 + output_increment);
+ }
+ if (output_width == 0) {
+ return;
+ }
+ }
+#endif
+
+ do {
+ const uint8_t* i0 = input[0];
+ const uint8_t* i1 = input[1];
+ const uint8_t* i2 = input[2];
+ const uint8_t* i3 = input[3];
+ const uint8_t* i4 = input[4];
+ const uint8_t* i5 = input[5];
+ const uint8_t* i6 = input[6];
+ const uint8_t* i7 = input[7];
+ const uint8_t* i8 = input[8];
+
+ input = (const uint8_t**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const void* w = weights;
+ for (; c >= 8; c -= 8) {
+ int32x4_t vaccX1_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vaccX1_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+
+ const uint8x8_t vk0 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
+ const int16x8_t vxk0 = vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
+ const int16x8_t vxi0 = vreinterpretq_s16_u16(vmovl_u8(vi0));
+ int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
+ int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
+
+ const uint8x8_t vk1 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
+ const int16x8_t vxk1 = vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
+ const int16x8_t vxi1 = vreinterpretq_s16_u16(vmovl_u8(vi1));
+ vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
+ vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
+
+ const uint8x8_t vk2 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
+ const int16x8_t vxk2 = vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
+ const int16x8_t vxi2 = vreinterpretq_s16_u16(vmovl_u8(vi2));
+ vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
+ vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
+
+ const uint8x8_t vk3 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
+ const int16x8_t vxk3 = vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
+ const int16x8_t vxi3 = vreinterpretq_s16_u16(vmovl_u8(vi3));
+ vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
+ vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
+
+ const uint8x8_t vk4 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
+ const int16x8_t vxk4 = vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
+ const int16x8_t vxi4 = vreinterpretq_s16_u16(vmovl_u8(vi4));
+ vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
+ vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
+
+ const uint8x8_t vk5 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
+ const int16x8_t vxk5 = vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
+ const int16x8_t vxi5 = vreinterpretq_s16_u16(vmovl_u8(vi5));
+ vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
+ vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
+
+ const uint8x8_t vk6 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
+ const int16x8_t vxk6 = vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
+ const int16x8_t vxi6 = vreinterpretq_s16_u16(vmovl_u8(vi6));
+ vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
+ vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
+
+ const uint8x8_t vk7 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi7 = vld1_u8(i7); i7 += 8;
+ const int16x8_t vxk7 = vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
+ const int16x8_t vxi7 = vreinterpretq_s16_u16(vmovl_u8(vi7));
+ vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
+ vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
+
+ const uint8x8_t vk8 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi8 = vld1_u8(i8); i8 += 8;
+ const int16x8_t vxk8 = vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
+ const int16x8_t vxi8 = vreinterpretq_s16_u16(vmovl_u8(vi8));
+ vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
+ vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
+
+ int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
+ int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
+
+ vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier);
+ vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier);
+
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
+ vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
+
+ vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
+ vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
+
+#ifdef __aarch64__
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ vst1_u8(output, vout); output += 8;
+ }
+ if (c != 0) {
+ int32x4_t vaccX1_lo = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vaccX1_hi = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+
+ const uint8x8_t vk0 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi0 = vld1_u8(i0);
+ const int16x8_t vxk0 = vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
+ const int16x8_t vxi0 = vreinterpretq_s16_u16(vmovl_u8(vi0));
+ int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
+ int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
+
+ const uint8x8_t vk1 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi1 = vld1_u8(i1);
+ const int16x8_t vxk1 = vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
+ const int16x8_t vxi1 = vreinterpretq_s16_u16(vmovl_u8(vi1));
+ vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
+ vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
+
+ const uint8x8_t vk2 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi2 = vld1_u8(i2);
+ const int16x8_t vxk2 = vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
+ const int16x8_t vxi2 = vreinterpretq_s16_u16(vmovl_u8(vi2));
+ vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
+ vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
+
+ const uint8x8_t vk3 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi3 = vld1_u8(i3);
+ const int16x8_t vxk3 = vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
+ const int16x8_t vxi3 = vreinterpretq_s16_u16(vmovl_u8(vi3));
+ vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
+ vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
+
+ const uint8x8_t vk4 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi4 = vld1_u8(i4);
+ const int16x8_t vxk4 = vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
+ const int16x8_t vxi4 = vreinterpretq_s16_u16(vmovl_u8(vi4));
+ vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
+ vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
+
+ const uint8x8_t vk5 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi5 = vld1_u8(i5);
+ const int16x8_t vxk5 = vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
+ const int16x8_t vxi5 = vreinterpretq_s16_u16(vmovl_u8(vi5));
+ vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
+ vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
+
+ const uint8x8_t vk6 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi6 = vld1_u8(i6);
+ const int16x8_t vxk6 = vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
+ const int16x8_t vxi6 = vreinterpretq_s16_u16(vmovl_u8(vi6));
+ vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
+ vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
+
+ const uint8x8_t vk7 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const uint8x8_t vi7 = vld1_u8(i7);
+ const int16x8_t vxk7 = vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
+ const int16x8_t vxi7 = vreinterpretq_s16_u16(vmovl_u8(vi7));
+ vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
+ vaccX1_hi = vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
+
+ const uint8x8_t vk8 = vld1_u8(w);
+ const uint8x8_t vi8 = vld1_u8(i8);
+ const int16x8_t vxk8 = vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
+ const int16x8_t vxi8 = vreinterpretq_s16_u16(vmovl_u8(vi8));
+ vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
+ vaccX0_hi = vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
+
+ int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
+ int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
+
+ vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier);
+ vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier);
+
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
+ vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
+
+ vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
+ vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
+
+#ifdef __aarch64__
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ if (c & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
+ vout = vext_u8(vout, vout, 4);
+ }
+ if (c & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
+ vout = vext_u8(vout, vout, 2);
+ }
+ if (c & 1) {
+ vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0); output++;
+ }
+ }
+
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/q8-dwconv/up8x9-sse2.c b/src/q8-dwconv/up8x9-sse2.c
new file mode 100644
index 0000000..d296127
--- /dev/null
+++ b/src/q8-dwconv/up8x9-sse2.c
@@ -0,0 +1,362 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <immintrin.h>
+
+#include <xnnpack/dwconv.h>
+
+
+void xnn_q8_dwconv_ukernel_up8x9__sse2(
+ size_t channels,
+ size_t output_width,
+ const uint8_t** input,
+ const void* weights,
+ uint8_t* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ const __m128i vkernel_zero_point = _mm_load_si128((const __m128i*) params->sse2.kernel_zero_point);
+ const __m128i vzero = _mm_setzero_si128();
+
+ do {
+ const uint8_t* i0 = input[0];
+ const uint8_t* i1 = input[1];
+ const uint8_t* i2 = input[2];
+ const uint8_t* i3 = input[3];
+ const uint8_t* i4 = input[4];
+ const uint8_t* i5 = input[5];
+ const uint8_t* i6 = input[6];
+ const uint8_t* i7 = input[7];
+ const uint8_t* i8 = input[8];
+
+ input = (const uint8_t**) ((uintptr_t) input + input_stride);
+
+ size_t c = channels;
+ const void* w = weights;
+ for (; c >= 8; c -= 8) {
+ __m128i vacc_lo = _mm_loadu_si128((const __m128i*) w);
+ __m128i vacc_hi = _mm_loadu_si128((const __m128i*) ((uintptr_t) w + 16));
+
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vk0 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 32));
+ const __m128i vxk0 = _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point);
+ const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0);
+ const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod0_odd, vprod0_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod0_odd, vprod0_even));
+
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vk1 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 40));
+ const __m128i vxk1 = _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point);
+ const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1);
+ const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod1_odd, vprod1_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod1_odd, vprod1_even));
+
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vk2 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 48));
+ const __m128i vxk2 = _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point);
+ const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2);
+ const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod2_odd, vprod2_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod2_odd, vprod2_even));
+
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vk3 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 56));
+ const __m128i vxk3 = _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point);
+ const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3);
+ const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod3_odd, vprod3_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod3_odd, vprod3_even));
+
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vk4 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 64));
+ const __m128i vxk4 = _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point);
+ const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4);
+ const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod4_odd, vprod4_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod4_odd, vprod4_even));
+
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vk5 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 72));
+ const __m128i vxk5 = _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point);
+ const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5);
+ const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod5_odd, vprod5_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod5_odd, vprod5_even));
+
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+ const __m128i vk6 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 80));
+ const __m128i vxk6 = _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point);
+ const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6);
+ const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod6_odd, vprod6_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod6_odd, vprod6_even));
+
+ const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7); i7 += 8;
+ const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
+ const __m128i vk7 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 88));
+ const __m128i vxk7 = _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point);
+ const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7);
+ const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod7_odd, vprod7_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod7_odd, vprod7_even));
+
+ const __m128i vi8 = _mm_loadl_epi64((const __m128i*) i8); i8 += 8;
+ const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
+ const __m128i vk8 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 96));
+ const __m128i vxk8 = _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point);
+ const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8);
+ const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod8_odd, vprod8_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod8_odd, vprod8_even));
+
+ w = (void*) ((uintptr_t) w + 104);
+
+ const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
+ const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
+
+ const __m128i vnmask_lo0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vnmask_hi0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabsacc_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vnmask_lo0123), vnmask_lo0123);
+ const __m128i vabsacc_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vnmask_hi0123), vnmask_hi0123);
+
+ const __m128i vabsacc_lo1032 = _mm_shuffle_epi32(vabsacc_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabsacc_hi1032 = _mm_shuffle_epi32(vabsacc_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsprod_lo02 = _mm_mul_epu32(vabsacc_lo0123, vmultiplier);
+ const __m128i vabsprod_hi02 = _mm_mul_epu32(vabsacc_hi0123, vmultiplier);
+
+ const __m128i vnmask_lo02 = _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i vnmask_hi02 = _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(2, 2, 0, 0));
+
+ const __m128i vprod_lo02 = _mm_sub_epi64(_mm_xor_si128(vabsprod_lo02, vnmask_lo02), vnmask_lo02);
+ const __m128i vprod_hi02 = _mm_sub_epi64(_mm_xor_si128(vabsprod_hi02, vnmask_hi02), vnmask_hi02);
+
+ const __m128i vq31prod_lo02 = _mm_srli_epi64(_mm_add_epi64(vprod_lo02, vrounding), 31);
+ const __m128i vq31prod_hi02 = _mm_srli_epi64(_mm_add_epi64(vprod_hi02, vrounding), 31);
+
+ const __m128i vabsprod_lo13 = _mm_mul_epu32(vabsacc_lo1032, vmultiplier);
+ const __m128i vabsprod_hi13 = _mm_mul_epu32(vabsacc_hi1032, vmultiplier);
+
+ const __m128i vnmask_lo13 = _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i vnmask_hi13 = _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(3, 3, 1, 1));
+
+ const __m128i vprod_lo13 = _mm_sub_epi64(_mm_xor_si128(vabsprod_lo13, vnmask_lo13), vnmask_lo13);
+ const __m128i vprod_hi13 = _mm_sub_epi64(_mm_xor_si128(vabsprod_hi13, vnmask_hi13), vnmask_hi13);
+
+ const __m128i vq31prod_lo13 = _mm_srli_epi64(_mm_add_epi64(vprod_lo13, vrounding), 31);
+ const __m128i vq31prod_hi13 = _mm_srli_epi64(_mm_add_epi64(vprod_hi13, vrounding), 31);
+
+ const __m128i vq31prod_lo0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod_lo02), _mm_castsi128_ps(vq31prod_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vq31prod_hi0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod_hi02), _mm_castsi128_ps(vq31prod_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vq31prod_lo0123 = _mm_shuffle_epi32(vq31prod_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vq31prod_hi0123 = _mm_shuffle_epi32(vq31prod_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vremainder_mask = _mm_load_si128((const __m128i*) params->sse2.remainder_mask);
+
+ const __m128i vrem_lo0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod_lo0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_lo0123));
+ const __m128i vrem_hi0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod_hi0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_hi0123));
+
+ const __m128i vremainder_threshold = _mm_load_si128((const __m128i*) params->sse2.remainder_threshold);
+ const __m128i vshift = _mm_load_si128((const __m128i*) params->sse2.shift);
+
+ const __m128i vout_lo = _mm_sub_epi32(_mm_sra_epi32(vq31prod_lo0123, vshift), _mm_cmpgt_epi32(vrem_lo0123, vremainder_threshold));
+ const __m128i vout_hi = _mm_sub_epi32(_mm_sra_epi32(vq31prod_hi0123, vshift), _mm_cmpgt_epi32(vrem_hi0123, vremainder_threshold));
+
+ const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
+ __m128i vout = _mm_adds_epi16(_mm_packs_epi32(vout_lo, vout_hi), voutput_zero_point);
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_min));
+
+ _mm_storel_epi64((__m128i*) output, vout); output += 8;
+ }
+ if (c != 0) {
+ __m128i vacc_lo = _mm_loadu_si128((const __m128i*) w);
+ __m128i vacc_hi = _mm_loadu_si128((const __m128i*) ((uintptr_t) w + 16));
+
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vk0 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 32));
+ const __m128i vxk0 = _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point);
+ const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0);
+ const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod0_odd, vprod0_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod0_odd, vprod0_even));
+
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vk1 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 40));
+ const __m128i vxk1 = _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point);
+ const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1);
+ const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod1_odd, vprod1_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod1_odd, vprod1_even));
+
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vk2 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 48));
+ const __m128i vxk2 = _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point);
+ const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2);
+ const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod2_odd, vprod2_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod2_odd, vprod2_even));
+
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vk3 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 56));
+ const __m128i vxk3 = _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point);
+ const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3);
+ const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod3_odd, vprod3_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod3_odd, vprod3_even));
+
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vk4 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 64));
+ const __m128i vxk4 = _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point);
+ const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4);
+ const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod4_odd, vprod4_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod4_odd, vprod4_even));
+
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vk5 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 72));
+ const __m128i vxk5 = _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point);
+ const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5);
+ const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod5_odd, vprod5_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod5_odd, vprod5_even));
+
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+ const __m128i vk6 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 80));
+ const __m128i vxk6 = _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point);
+ const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6);
+ const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod6_odd, vprod6_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod6_odd, vprod6_even));
+
+ const __m128i vi7 = _mm_loadl_epi64((const __m128i*) i7); i7 += 8;
+ const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero);
+ const __m128i vk7 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 88));
+ const __m128i vxk7 = _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point);
+ const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7);
+ const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod7_odd, vprod7_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod7_odd, vprod7_even));
+
+ const __m128i vi8 = _mm_loadl_epi64((const __m128i*) i8); i8 += 8;
+ const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero);
+ const __m128i vk8 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 96));
+ const __m128i vxk8 = _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point);
+ const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8);
+ const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8);
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod8_odd, vprod8_even));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod8_odd, vprod8_even));
+
+ const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
+ const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
+
+ const __m128i vnmask_lo0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vnmask_hi0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabsacc_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vnmask_lo0123), vnmask_lo0123);
+ const __m128i vabsacc_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vnmask_hi0123), vnmask_hi0123);
+
+ const __m128i vabsacc_lo1032 = _mm_shuffle_epi32(vabsacc_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabsacc_hi1032 = _mm_shuffle_epi32(vabsacc_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsprod_lo02 = _mm_mul_epu32(vabsacc_lo0123, vmultiplier);
+ const __m128i vabsprod_hi02 = _mm_mul_epu32(vabsacc_hi0123, vmultiplier);
+
+ const __m128i vnmask_lo02 = _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i vnmask_hi02 = _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(2, 2, 0, 0));
+
+ const __m128i vprod_lo02 = _mm_sub_epi64(_mm_xor_si128(vabsprod_lo02, vnmask_lo02), vnmask_lo02);
+ const __m128i vprod_hi02 = _mm_sub_epi64(_mm_xor_si128(vabsprod_hi02, vnmask_hi02), vnmask_hi02);
+
+ const __m128i vq31prod_lo02 = _mm_srli_epi64(_mm_add_epi64(vprod_lo02, vrounding), 31);
+ const __m128i vq31prod_hi02 = _mm_srli_epi64(_mm_add_epi64(vprod_hi02, vrounding), 31);
+
+ const __m128i vabsprod_lo13 = _mm_mul_epu32(vabsacc_lo1032, vmultiplier);
+ const __m128i vabsprod_hi13 = _mm_mul_epu32(vabsacc_hi1032, vmultiplier);
+
+ const __m128i vnmask_lo13 = _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i vnmask_hi13 = _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(3, 3, 1, 1));
+
+ const __m128i vprod_lo13 = _mm_sub_epi64(_mm_xor_si128(vabsprod_lo13, vnmask_lo13), vnmask_lo13);
+ const __m128i vprod_hi13 = _mm_sub_epi64(_mm_xor_si128(vabsprod_hi13, vnmask_hi13), vnmask_hi13);
+
+ const __m128i vq31prod_lo13 = _mm_srli_epi64(_mm_add_epi64(vprod_lo13, vrounding), 31);
+ const __m128i vq31prod_hi13 = _mm_srli_epi64(_mm_add_epi64(vprod_hi13, vrounding), 31);
+
+ const __m128i vq31prod_lo0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod_lo02), _mm_castsi128_ps(vq31prod_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vq31prod_hi0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod_hi02), _mm_castsi128_ps(vq31prod_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vq31prod_lo0123 = _mm_shuffle_epi32(vq31prod_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vq31prod_hi0123 = _mm_shuffle_epi32(vq31prod_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vremainder_mask = _mm_load_si128((const __m128i*) params->sse2.remainder_mask);
+
+ const __m128i vrem_lo0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod_lo0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_lo0123));
+ const __m128i vrem_hi0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod_hi0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_hi0123));
+
+ const __m128i vremainder_threshold = _mm_load_si128((const __m128i*) params->sse2.remainder_threshold);
+ const __m128i vshift = _mm_load_si128((const __m128i*) params->sse2.shift);
+
+ const __m128i vout_lo = _mm_sub_epi32(_mm_sra_epi32(vq31prod_lo0123, vshift), _mm_cmpgt_epi32(vrem_lo0123, vremainder_threshold));
+ const __m128i vout_hi = _mm_sub_epi32(_mm_sra_epi32(vq31prod_hi0123, vshift), _mm_cmpgt_epi32(vrem_hi0123, vremainder_threshold));
+
+ const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
+ __m128i vout = _mm_adds_epi16(_mm_packs_epi32(vout_lo, vout_hi), voutput_zero_point);
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_min));
+
+ if (c & 4) {
+ *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout);
+ output += 4;
+ vout = _mm_srli_epi64(vout, 32);
+ }
+ if (c & 2) {
+ *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout, 0);
+ output += 2;
+ vout = _mm_srli_epi32(vout, 16);
+ }
+ if (c & 1) {
+ *((uint8_t*) output) = (uint8_t) _mm_cvtsi128_si32(vout);
+ output += 1;
+ }
+ }
+
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ } while (--output_width != 0);
+}
diff --git a/src/q8-gavgpool/mp7p7q-neon.c b/src/q8-gavgpool/mp7p7q-neon.c
new file mode 100644
index 0000000..3a4aa58
--- /dev/null
+++ b/src/q8-gavgpool/mp7p7q-neon.c
@@ -0,0 +1,293 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_q8_gavgpool_ukernel_mp7p7q__neon(
+ size_t m,
+ size_t n,
+ const uint8_t* input,
+ size_t input_stride,
+ const uint8_t* zero,
+ int32_t* buffer,
+ uint8_t* output,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(m > 7);
+ assert(n != 0);
+
+ const uint8_t* i0 = input;
+ const uint8_t* i1 = (const uint8_t*) ((uintptr_t) i0 + input_stride);
+ const uint8_t* i2 = (const uint8_t*) ((uintptr_t) i1 + input_stride);
+ const uint8_t* i3 = (const uint8_t*) ((uintptr_t) i2 + input_stride);
+ const uint8_t* i4 = (const uint8_t*) ((uintptr_t) i3 + input_stride);
+ const uint8_t* i5 = (const uint8_t*) ((uintptr_t) i4 + input_stride);
+ const uint8_t* i6 = (const uint8_t*) ((uintptr_t) i5 + input_stride);
+ const size_t packed_n = round_up_po2(n, 8);
+ const size_t input_increment = 7 * input_stride - packed_n;
+ const int32x4_t vbias = vld1q_dup_s32(¶ms->neon.bias);
+
+ int32_t* acc = buffer;
+ for (size_t k = 0; k < n; k += 8) {
+ const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
+ const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
+ const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
+ const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
+ const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
+ const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
+ const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
+
+ const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+
+ const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
+ const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
+
+ const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
+
+ const int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum));
+ const int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum));
+
+ vst1q_s32(acc, vacc_lo); acc += 4;
+ vst1q_s32(acc, vacc_hi); acc += 4;
+ }
+ for (m -= 7; m > 7; m -= 7) {
+ acc = buffer;
+
+ i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
+ i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
+ i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
+ i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
+ i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
+ i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
+ i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
+
+ for (size_t k = 0; k < n; k += 8) {
+ const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
+ const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
+ const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
+ const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
+ const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
+ const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
+ const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
+ const int32x4_t vacc_lo = vld1q_s32(acc);
+ const int32x4_t vacc_hi = vld1q_s32(acc + 4);
+
+ const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+
+ const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
+ const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
+
+ const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
+
+ vst1q_s32(acc, vaddw_s16(vacc_lo, vget_low_s16(vsum))); acc += 4;
+ vst1q_s32(acc, vaddw_s16(vacc_hi, vget_high_s16(vsum))); acc += 4;
+ }
+ }
+
+#ifdef __aarch64__
+ const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
+#else
+ const int32x2_t vmultiplier = vld1_dup_s32(¶ms->neon.multiplier);
+#endif
+ const int64x2_t vleft_shift = vld1q_dup_s64(¶ms->neon.left_shift);
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
+ const uint8x8_t voutput_min = vld1_dup_u8(¶ms->neon.output_min);
+ const uint8x8_t voutput_max = vld1_dup_u8(¶ms->neon.output_max);
+
+ i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
+ i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
+ if (m < 4) {
+ i3 = zero;
+ }
+ i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
+ if (m < 6) {
+ i5 = zero;
+ }
+ i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
+ if (m <= 6) {
+ i6 = zero;
+ }
+
+ acc = buffer;
+ while (n >= 8) {
+ const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
+ const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
+ const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
+ const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
+ const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
+ const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
+ const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
+ int32x4_t vacc_lo = vld1q_s32(acc); acc += 4;
+ int32x4_t vacc_hi = vld1q_s32(acc); acc += 4;
+
+ const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+
+ const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
+ const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
+
+ const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
+ vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
+ vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
+
+ const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
+ const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
+
+#if defined(__aarch64__)
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
+#else
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
+ const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
+#endif
+
+ const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
+ const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
+ const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
+ const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
+
+#ifdef __aarch64__
+ vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
+ vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
+ vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ vst1_u8(output, vout); output += 8;
+
+ n -= 8;
+ }
+ if (n != 0) {
+ const uint8x8_t vi0 = vld1_u8(i0);
+ const uint8x8_t vi1 = vld1_u8(i1);
+ const uint8x8_t vi2 = vld1_u8(i2);
+ const uint8x8_t vi3 = vld1_u8(i3);
+ const uint8x8_t vi4 = vld1_u8(i4);
+ const uint8x8_t vi5 = vld1_u8(i5);
+ const uint8x8_t vi6 = vld1_u8(i6);
+ int32x4_t vacc_lo = vld1q_s32(acc); acc += 4;
+ int32x4_t vacc_hi = vld1q_s32(acc);
+
+ const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+
+ const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
+ const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
+
+ const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
+ vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
+ vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
+
+ const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
+ const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
+
+#if defined(__aarch64__)
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
+#else
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
+ const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
+#endif
+
+ const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
+ const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
+ const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
+ const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
+
+#ifdef __aarch64__
+ vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
+ vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
+ vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ if (n & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
+ vout = vext_u8(vout, vout, 4);
+ }
+ if (n & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
+ vout = vext_u8(vout, vout, 2);
+ }
+ if (n & 1) {
+ vst1_lane_u8(output, vout, 0);
+ }
+ }
+}
diff --git a/src/q8-gavgpool/mp7p7q-scalar.c b/src/q8-gavgpool/mp7p7q-scalar.c
new file mode 100644
index 0000000..dd6658c
--- /dev/null
+++ b/src/q8-gavgpool/mp7p7q-scalar.c
@@ -0,0 +1,163 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/gavgpool.h>
+
+
+void xnn_q8_gavgpool_ukernel_mp7p7q__scalar(
+ size_t m,
+ size_t n,
+ const uint8_t* input,
+ size_t input_stride,
+ const uint8_t* zero,
+ int32_t* buffer,
+ uint8_t* output,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(m > 7);
+ assert(n != 0);
+
+ const uint8_t* i0 = input;
+ const uint8_t* i1 = (const uint8_t*) ((uintptr_t) i0 + input_stride);
+ const uint8_t* i2 = (const uint8_t*) ((uintptr_t) i1 + input_stride);
+ const uint8_t* i3 = (const uint8_t*) ((uintptr_t) i2 + input_stride);
+ const uint8_t* i4 = (const uint8_t*) ((uintptr_t) i3 + input_stride);
+ const uint8_t* i5 = (const uint8_t*) ((uintptr_t) i4 + input_stride);
+ const uint8_t* i6 = (const uint8_t*) ((uintptr_t) i5 + input_stride);
+ const size_t input_increment = 7 * input_stride - n;
+
+ /* First pass */ {
+ const int32_t vbias = params->scalar.bias;
+
+ int32_t* b = buffer;
+ size_t k = n;
+ do {
+ const uint32_t vi0 = (uint32_t) *i0++;
+ const uint32_t vi1 = (uint32_t) *i1++;
+ const uint32_t vi2 = (uint32_t) *i2++;
+ const uint32_t vi3 = (uint32_t) *i3++;
+ const uint32_t vi4 = (uint32_t) *i4++;
+ const uint32_t vi5 = (uint32_t) *i5++;
+ const uint32_t vi6 = (uint32_t) *i6++;
+
+ const uint32_t vsum01 = vi0 + vi1;
+ const uint32_t vsum23 = vi2 + vi3;
+ const uint32_t vsum45 = vi4 + vi5;
+
+ const uint32_t vsum016 = vsum01 + vi6;
+ const uint32_t vsum2345 = vsum23 + vsum45;
+
+ const uint32_t vsum = vsum016 + vsum2345;
+ const int32_t vacc = vbias + (int32_t) vsum;
+
+ *b++ = vacc;
+ } while (--k != 0);
+ }
+ /* Intermediate passes */
+ for (m -= 7; m > 7; m -= 7) {
+ i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
+ i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
+ i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
+ i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
+ i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
+ i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
+ i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
+
+ int32_t* b = buffer;
+ size_t k = n;
+ do {
+ const uint32_t vi0 = (uint32_t) *i0++;
+ const uint32_t vi1 = (uint32_t) *i1++;
+ const uint32_t vi2 = (uint32_t) *i2++;
+ const uint32_t vi3 = (uint32_t) *i3++;
+ const uint32_t vi4 = (uint32_t) *i4++;
+ const uint32_t vi5 = (uint32_t) *i5++;
+ const uint32_t vi6 = (uint32_t) *i6++;
+
+ const uint32_t vsum01 = vi0 + vi1;
+ const uint32_t vsum23 = vi2 + vi3;
+ const uint32_t vsum45 = vi4 + vi5;
+
+ const uint32_t vsum016 = vsum01 + vi6;
+ const uint32_t vsum2345 = vsum23 + vsum45;
+
+ const uint32_t vsum = vsum016 + vsum2345;
+
+ *b++ += (int32_t) vsum;
+ } while (--k != 0);
+ }
+
+ /* Last pass */ {
+ const int32_t vmultiplier = params->scalar.multiplier;
+ const int64_t vrounding = params->scalar.rounding;
+ const uint32_t vshift = params->scalar.right_shift;
+ const int32_t voutput_min = params->scalar.output_min_less_zero_point;
+ const int32_t voutput_max = params->scalar.output_max_less_zero_point;
+ const int32_t voutput_zero_point = params->scalar.output_zero_point;
+
+ i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
+ i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
+ if (m < 4) {
+ i3 = zero;
+ }
+ i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
+ if (m < 6) {
+ i5 = zero;
+ }
+ i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
+ if (m <= 6) {
+ i6 = zero;
+ }
+
+ int32_t* b = buffer;
+ size_t k = n;
+ do {
+ int32_t vacc = *b++;
+ const uint32_t vi0 = (uint32_t) *i0++;
+ const uint32_t vi1 = (uint32_t) *i1++;
+ const uint32_t vi2 = (uint32_t) *i2++;
+ const uint32_t vi3 = (uint32_t) *i3++;
+ const uint32_t vi4 = (uint32_t) *i4++;
+ const uint32_t vi5 = (uint32_t) *i5++;
+ const uint32_t vi6 = (uint32_t) *i6++;
+
+ const uint32_t vsum01 = vi0 + vi1;
+ const uint32_t vsum23 = vi2 + vi3;
+ const uint32_t vsum45 = vi4 + vi5;
+
+ const uint32_t vsum016 = vsum01 + vi6;
+ const uint32_t vsum2345 = vsum23 + vsum45;
+
+ const uint32_t vsum = vsum016 + vsum2345;
+ vacc += (int32_t) vsum;
+
+ const int64_t vproduct = (int64_t) vacc * (int64_t) vmultiplier;
+ const int64_t vadjusted_product = vproduct - (int64_t) (vacc < 0);
+ int32_t vout = (int32_t) asr_s64(vadjusted_product + vrounding, vshift);
+ vout = vout < voutput_min ? voutput_min : vout;
+ vout = vout > voutput_max ? voutput_max : vout;
+ vout += voutput_zero_point;
+
+ *output++ = (uint8_t) vout;
+ } while (--k != 0);
+ }
+}
diff --git a/src/q8-gavgpool/mp7p7q-sse2.c b/src/q8-gavgpool/mp7p7q-sse2.c
new file mode 100644
index 0000000..77b3d02
--- /dev/null
+++ b/src/q8-gavgpool/mp7p7q-sse2.c
@@ -0,0 +1,307 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/gavgpool.h>
+#include <xnnpack/math.h>
+
+
+void xnn_q8_gavgpool_ukernel_mp7p7q__sse2(
+ size_t m,
+ size_t n,
+ const uint8_t* input,
+ size_t input_stride,
+ const uint8_t* zero,
+ int32_t* buffer,
+ uint8_t* output,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(m > 7);
+ assert(n != 0);
+
+ const uint8_t* i0 = input;
+ const uint8_t* i1 = (const uint8_t*) ((uintptr_t) i0 + input_stride);
+ const uint8_t* i2 = (const uint8_t*) ((uintptr_t) i1 + input_stride);
+ const uint8_t* i3 = (const uint8_t*) ((uintptr_t) i2 + input_stride);
+ const uint8_t* i4 = (const uint8_t*) ((uintptr_t) i3 + input_stride);
+ const uint8_t* i5 = (const uint8_t*) ((uintptr_t) i4 + input_stride);
+ const uint8_t* i6 = (const uint8_t*) ((uintptr_t) i5 + input_stride);
+ const size_t packed_n = round_up_po2(n, 8);
+ const size_t input_increment = 7 * input_stride - packed_n;
+ const __m128i vbias = _mm_load_si128((const __m128i*) ¶ms->sse2.bias);
+ const __m128i vzero = _mm_setzero_si128();
+
+ int32_t* acc = buffer;
+ for (size_t k = 0; k < n; k += 8) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+
+ const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+
+ const __m128i vsum016 = _mm_add_epi16(vsum01, vxi6);
+ const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
+ const __m128i vsum = _mm_add_epi16(vsum016, vsum2345);
+
+ const __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
+ const __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
+
+ _mm_store_si128((__m128i*) acc, vacc_lo);
+ _mm_store_si128((__m128i*) acc + 1, vacc_hi);
+ acc += 8;
+ }
+ for (m -= 7; m > 7; m -= 7) {
+ acc = buffer;
+ i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
+ i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
+ i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
+ i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
+ i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
+ i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
+ i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
+
+ for (size_t k = 0; k < n; k += 8) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+ __m128i vacc_lo = _mm_load_si128((const __m128i*) acc);
+ __m128i vacc_hi = _mm_load_si128((const __m128i*) acc + 1);
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+
+ const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+
+ const __m128i vsum016 = _mm_add_epi16(vsum01, vxi6);
+ const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
+ const __m128i vsum = _mm_add_epi16(vsum016, vsum2345);
+
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero));
+
+ _mm_store_si128((__m128i*) acc, vacc_lo);
+ _mm_store_si128((__m128i*) acc + 1, vacc_hi);
+ acc += 8;
+ }
+ }
+
+ const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
+ const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
+ const __m128i vright_shift = _mm_loadl_epi64((const __m128i*) params->sse2.right_shift);
+
+ i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
+ i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
+ if (m < 2) {
+ i1 = zero;
+ }
+ i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
+ if (m < 4) {
+ i3 = zero;
+ }
+ i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
+ if (m < 6) {
+ i5 = zero;
+ }
+ i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
+ if (m <= 6) {
+ i6 = zero;
+ }
+
+ acc = buffer;
+ while (n >= 8) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+ __m128i vacc_lo = _mm_load_si128((const __m128i*) acc);
+ __m128i vacc_hi = _mm_load_si128((const __m128i*) acc + 1);
+ acc += 8;
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+
+ const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+
+ const __m128i vsum016 = _mm_add_epi16(vsum01, vxi6);
+ const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
+ const __m128i vsum = _mm_add_epi16(vsum016, vsum2345);
+
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero));
+
+ const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
+
+ const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
+ const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
+
+ const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
+ const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
+
+ const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
+ const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
+
+ const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
+
+ __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
+ vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) params->sse2.output_zero_point));
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_min));
+
+ _mm_storel_epi64((__m128i*) output, vout); output += 8;
+
+ n -= 8;
+ }
+ if (n != 0) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0);
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1);
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2);
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3);
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4);
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5);
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6);
+ __m128i vacc_lo = _mm_load_si128((const __m128i*) acc);
+ __m128i vacc_hi = _mm_load_si128((const __m128i*) acc + 1);
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+
+ const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+
+ const __m128i vsum016 = _mm_add_epi16(vsum01, vxi6);
+ const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
+ const __m128i vsum = _mm_add_epi16(vsum016, vsum2345);
+
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero));
+
+ const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
+
+ const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
+ const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
+
+ const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
+ const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
+
+ const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
+ const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
+
+ const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
+
+ __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
+ vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) params->sse2.output_zero_point));
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_min));
+
+ if (n & 4) {
+ *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout);
+ output += 4;
+ vout = _mm_srli_epi64(vout, 32);
+ }
+ if (n & 2) {
+ *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout, 0);
+ output += 2;
+ vout = _mm_srli_epi32(vout, 16);
+ }
+ if (n & 1) {
+ *((uint8_t*) output) = (uint8_t) _mm_cvtsi128_si32(vout);
+ }
+ }
+}
diff --git a/src/q8-gavgpool/up7-neon.c b/src/q8-gavgpool/up7-neon.c
new file mode 100644
index 0000000..fb4aad8
--- /dev/null
+++ b/src/q8-gavgpool/up7-neon.c
@@ -0,0 +1,215 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gavgpool.h>
+
+
+void xnn_q8_gavgpool_ukernel_up7__neon(
+ size_t m,
+ size_t n,
+ const uint8_t* input,
+ size_t input_stride,
+ const uint8_t* zero,
+ uint8_t* output,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(m <= 7);
+ assert(n != 0);
+
+ const uint8_t* i0 = input;
+ const uint8_t* i1 = (const uint8_t*) ((uintptr_t) i0 + input_stride);
+ if (m < 2) {
+ i1 = zero;
+ }
+ const uint8_t* i2 = (const uint8_t*) ((uintptr_t) i1 + input_stride);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ const uint8_t* i3 = (const uint8_t*) ((uintptr_t) i2 + input_stride);
+ if (m < 4) {
+ i3 = zero;
+ }
+ const uint8_t* i4 = (const uint8_t*) ((uintptr_t) i3 + input_stride);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ const uint8_t* i5 = (const uint8_t*) ((uintptr_t) i4 + input_stride);
+ if (m < 6) {
+ i5 = zero;
+ }
+ const uint8_t* i6 = (const uint8_t*) ((uintptr_t) i5 + input_stride);
+ if (m <= 6) {
+ i6 = zero;
+ }
+
+ const int32x4_t vbias = vld1q_dup_s32(¶ms->neon.bias);
+#ifdef __aarch64__
+ const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
+#else
+ const int32x2_t vmultiplier = vld1_dup_s32(¶ms->neon.multiplier);
+#endif
+ const int64x2_t vleft_shift = vld1q_dup_s64(¶ms->neon.left_shift);
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
+ const uint8x8_t voutput_min = vld1_dup_u8(¶ms->neon.output_min);
+ const uint8x8_t voutput_max = vld1_dup_u8(¶ms->neon.output_max);
+ while (n >= 8) {
+ const uint8x8_t vi0 = vld1_u8(i0); i0 += 8;
+ const uint8x8_t vi1 = vld1_u8(i1); i1 += 8;
+ const uint8x8_t vi2 = vld1_u8(i2); i2 += 8;
+ const uint8x8_t vi3 = vld1_u8(i3); i3 += 8;
+ const uint8x8_t vi4 = vld1_u8(i4); i4 += 8;
+ const uint8x8_t vi5 = vld1_u8(i5); i5 += 8;
+ const uint8x8_t vi6 = vld1_u8(i6); i6 += 8;
+
+ const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+
+ const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
+ const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
+
+ const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
+ int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum));
+ int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum));
+
+ const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
+ const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
+
+#if defined(__aarch64__)
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
+#else
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
+ const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
+#endif
+
+ const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
+ const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
+ const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
+ const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
+
+#ifdef __aarch64__
+ vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
+ vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
+ vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ vst1_u8(output, vout); output += 8;
+
+ n -= 8;
+ }
+ if (n != 0) {
+ const uint8x8_t vi0 = vld1_u8(i0);
+ const uint8x8_t vi1 = vld1_u8(i1);
+ const uint8x8_t vi2 = vld1_u8(i2);
+ const uint8x8_t vi3 = vld1_u8(i3);
+ const uint8x8_t vi4 = vld1_u8(i4);
+ const uint8x8_t vi5 = vld1_u8(i5);
+ const uint8x8_t vi6 = vld1_u8(i6);
+
+ const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
+ const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
+ const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
+
+ const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6);
+ const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
+
+ const int16x8_t vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345));
+ int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum));
+ int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum));
+
+ const int32x4_t vneg_mask_lo = vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0)));
+ const int32x4_t vneg_mask_hi = vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0)));
+
+#if defined(__aarch64__)
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier));
+ const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_high_s32(vproduct23, vneg_mask_lo);
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_high_s32(vproduct67, vneg_mask_hi);
+#else
+ const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier);
+ const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier);
+ const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier);
+
+ const int64x2_t vadjusted_product01 = vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product23 = vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo));
+ const int64x2_t vadjusted_product45 = vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi));
+ const int64x2_t vadjusted_product67 = vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi));
+#endif
+
+ const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift);
+ const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift);
+ const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift);
+ const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift);
+
+#ifdef __aarch64__
+ vacc_lo = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc01), vreinterpretq_s32_s64(vscaled_acc23));
+ vacc_hi = vuzp1q_s32(vreinterpretq_s32_s64(vscaled_acc45), vreinterpretq_s32_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
+#else
+ vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23));
+ vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67));
+
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), voutput_zero_point);
+#endif
+
+ uint8x8_t vout = vqmovun_s16(vacc);
+ vout = vmax_u8(vout, voutput_min);
+ vout = vmin_u8(vout, voutput_max);
+
+ if (n & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); output += 4;
+ vout = vext_u8(vout, vout, 4);
+ }
+ if (n & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); output += 2;
+ vout = vext_u8(vout, vout, 2);
+ }
+ if (n & 1) {
+ vst1_lane_u8(output, vout, 0);
+ }
+ }
+}
diff --git a/src/q8-gavgpool/up7-scalar.c b/src/q8-gavgpool/up7-scalar.c
new file mode 100644
index 0000000..b437ef9
--- /dev/null
+++ b/src/q8-gavgpool/up7-scalar.c
@@ -0,0 +1,88 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/gavgpool.h>
+
+
+void xnn_q8_gavgpool_ukernel_up7__scalar(
+ size_t m,
+ size_t n,
+ const uint8_t* input,
+ size_t input_stride,
+ const uint8_t* zero,
+ uint8_t* output,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(m <= 7);
+ assert(n != 0);
+
+ const uint8_t* i0 = input;
+ const uint8_t* i1 = (const uint8_t*) ((uintptr_t) i0 + input_stride);
+ if (m < 2) {
+ i1 = zero;
+ }
+ const uint8_t* i2 = (const uint8_t*) ((uintptr_t) i1 + input_stride);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ const uint8_t* i3 = (const uint8_t*) ((uintptr_t) i2 + input_stride);
+ if (m < 4) {
+ i3 = zero;
+ }
+ const uint8_t* i4 = (const uint8_t*) ((uintptr_t) i3 + input_stride);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ const uint8_t* i5 = (const uint8_t*) ((uintptr_t) i4 + input_stride);
+ if (m < 6) {
+ i5 = zero;
+ }
+ const uint8_t* i6 = (const uint8_t*) ((uintptr_t) i5 + input_stride);
+ if (m <= 6) {
+ i6 = zero;
+ }
+
+ const int32_t vbias = params->scalar.bias;
+ const int32_t vmultiplier = params->scalar.multiplier;
+ const int64_t vrounding = params->scalar.rounding;
+ const uint32_t vshift = params->scalar.right_shift;
+ const int32_t voutput_min = params->scalar.output_min_less_zero_point;
+ const int32_t voutput_max = params->scalar.output_max_less_zero_point;
+ const int32_t voutput_zero_point = params->scalar.output_zero_point;
+ do {
+ const uint32_t vi0 = (uint32_t) *i0++;
+ const uint32_t vi1 = (uint32_t) *i1++;
+ const uint32_t vi2 = (uint32_t) *i2++;
+ const uint32_t vi3 = (uint32_t) *i3++;
+ const uint32_t vi4 = (uint32_t) *i4++;
+ const uint32_t vi5 = (uint32_t) *i5++;
+ const uint32_t vi6 = (uint32_t) *i6++;
+
+ const uint32_t vsum01 = vi0 + vi1;
+ const uint32_t vsum23 = vi2 + vi3;
+ const uint32_t vsum45 = vi4 + vi5;
+
+ const uint32_t vsum016 = vsum01 + vi6;
+ const uint32_t vsum2345 = vsum23 + vsum45;
+
+ const uint32_t vsum = vsum016 + vsum2345;
+ const int32_t vacc = vbias + (int32_t) vsum;
+
+ const int64_t vproduct = (int64_t) vacc * (int64_t) vmultiplier;
+ const int64_t vadjusted_product = vproduct - (int64_t) (vacc < 0);
+ int32_t vout = (int32_t) asr_s64(vadjusted_product + vrounding, vshift);
+ vout = vout < voutput_min ? voutput_min : vout;
+ vout = vout > voutput_max ? voutput_max : vout;
+ vout += voutput_zero_point;
+
+ *output++ = (uint8_t) vout;
+ } while (--n != 0);
+}
diff --git a/src/q8-gavgpool/up7-sse2.c b/src/q8-gavgpool/up7-sse2.c
new file mode 100644
index 0000000..6d8c1bc
--- /dev/null
+++ b/src/q8-gavgpool/up7-sse2.c
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/gavgpool.h>
+
+
+void xnn_q8_gavgpool_ukernel_up7__sse2(
+ size_t m,
+ size_t n,
+ const uint8_t* input,
+ size_t input_stride,
+ const uint8_t* zero,
+ uint8_t* output,
+ const union xnn_q8_avgpool_params params[restrict static 1])
+{
+ assert(m != 0);
+ assert(m <= 7);
+ assert(n != 0);
+
+ const uint8_t* i0 = input;
+ const uint8_t* i1 = (const uint8_t*) ((uintptr_t) i0 + input_stride);
+ if (m < 2) {
+ i1 = zero;
+ }
+ const uint8_t* i2 = (const uint8_t*) ((uintptr_t) i1 + input_stride);
+ if (m <= 2) {
+ i2 = zero;
+ }
+ const uint8_t* i3 = (const uint8_t*) ((uintptr_t) i2 + input_stride);
+ if (m < 4) {
+ i3 = zero;
+ }
+ const uint8_t* i4 = (const uint8_t*) ((uintptr_t) i3 + input_stride);
+ if (m <= 4) {
+ i4 = zero;
+ }
+ const uint8_t* i5 = (const uint8_t*) ((uintptr_t) i4 + input_stride);
+ if (m < 6) {
+ i5 = zero;
+ }
+ const uint8_t* i6 = (const uint8_t*) ((uintptr_t) i5 + input_stride);
+ if (m <= 6) {
+ i6 = zero;
+ }
+
+ const __m128i vbias = _mm_load_si128((const __m128i*) ¶ms->sse2.bias);
+ const __m128i vzero = _mm_setzero_si128();
+ const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
+ const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
+ const __m128i vright_shift = _mm_loadl_epi64((const __m128i*) params->sse2.right_shift);
+
+ while (n >= 8) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0); i0 += 8;
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1); i1 += 8;
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2); i2 += 8;
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3); i3 += 8;
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4); i4 += 8;
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5); i5 += 8;
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6); i6 += 8;
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+
+ const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+
+ const __m128i vsum016 = _mm_add_epi16(vsum01, vxi6);
+ const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
+ const __m128i vsum = _mm_add_epi16(vsum016, vsum2345);
+
+ __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
+ __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
+
+ const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
+
+ const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
+ const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
+
+ const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
+ const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
+
+ const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
+ const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
+
+ const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
+
+ __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
+ vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) params->sse2.output_zero_point));
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_min));
+
+ _mm_storel_epi64((__m128i*) output, vout); output += 8;
+
+ n -= 8;
+ }
+ if (n != 0) {
+ const __m128i vi0 = _mm_loadl_epi64((const __m128i*) i0);
+ const __m128i vi1 = _mm_loadl_epi64((const __m128i*) i1);
+ const __m128i vi2 = _mm_loadl_epi64((const __m128i*) i2);
+ const __m128i vi3 = _mm_loadl_epi64((const __m128i*) i3);
+ const __m128i vi4 = _mm_loadl_epi64((const __m128i*) i4);
+ const __m128i vi5 = _mm_loadl_epi64((const __m128i*) i5);
+ const __m128i vi6 = _mm_loadl_epi64((const __m128i*) i6);
+
+ const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
+ const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
+ const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
+ const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
+ const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
+ const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
+ const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
+
+ const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1);
+ const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3);
+ const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5);
+
+ const __m128i vsum016 = _mm_add_epi16(vsum01, vxi6);
+ const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45);
+ const __m128i vsum = _mm_add_epi16(vsum016, vsum2345);
+
+ __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero));
+ __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero));
+
+ const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo);
+ const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi);
+
+ const __m128i vabs_lo0123 = _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vabs_hi0123 = _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi);
+
+ const __m128i vabs_lo1032 = _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabs_hi1032 = _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier);
+ const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier);
+
+ const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier);
+ const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier);
+
+ const __m128i vabs_scaled_lo02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift);
+ const __m128i vabs_scaled_lo13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi02 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift);
+ const __m128i vabs_scaled_hi13 = _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift);
+
+ const __m128i vabs_scaled_lo0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_lo02), _mm_castsi128_ps(vabs_scaled_lo13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vabs_scaled_hi0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vabs_scaled_hi02), _mm_castsi128_ps(vabs_scaled_hi13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vabs_scaled_lo = _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vabs_scaled_hi = _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vscaled_lo = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo);
+ const __m128i vscaled_hi = _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi);
+
+ __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
+ vout = _mm_adds_epi16(vout, _mm_load_si128((const __m128i*) params->sse2.output_zero_point));
+ vout = _mm_packus_epi16(vout, vout);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_min));
+
+ if (n & 4) {
+ *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout);
+ output += 4;
+ vout = _mm_srli_epi64(vout, 32);
+ }
+ if (n & 2) {
+ *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout, 0);
+ output += 2;
+ vout = _mm_srli_epi32(vout, 16);
+ }
+ if (n & 1) {
+ *((uint8_t*) output) = (uint8_t) _mm_cvtsi128_si32(vout);
+ }
+ }
+}
diff --git a/src/q8-gemm/2x2-scalar.c b/src/q8-gemm/2x2-scalar.c
new file mode 100644
index 0000000..4c7f892
--- /dev/null
+++ b/src/q8-gemm/2x2-scalar.c
@@ -0,0 +1,132 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/gemm.h>
+
+
+void xnn_q8_gemm_ukernel_2x2__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const uint8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+
+ const uint8_t* a0 = a;
+ uint8_t* c0 = c;
+ const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if (mr != 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+
+ const int32_t vb_zero_point = params->scalar.kernel_zero_point;
+
+ do {
+ int32_t vacc0x0 = ((const int32_t*) w)[0];
+ int32_t vacc0x1 = ((const int32_t*) w)[1];
+ int32_t vacc1x0 = vacc0x0;
+ int32_t vacc1x1 = vacc0x1;
+ w = (const void*) ((uintptr_t) w + 2 * sizeof(int32_t));
+
+ size_t k = kc;
+ do {
+ const int32_t va0 = (int32_t) (uint32_t) *a0++;
+ const int32_t va1 = (int32_t) (uint32_t) *a1++;
+
+ const uint32_t vb0 = ((const uint8_t*) w)[0];
+ const uint32_t vb1 = ((const uint8_t*) w)[1];
+ w = (const void*) ((uintptr_t) w + 2 * sizeof(uint8_t));
+
+ const int32_t vxb0 = (int32_t) vb0 - vb_zero_point;
+ const int32_t vxb1 = (int32_t) vb1 - vb_zero_point;
+
+ vacc0x0 += va0 * vxb0;
+ vacc0x1 += va0 * vxb1;
+ vacc1x0 += va1 * vxb0;
+ vacc1x1 += va1 * vxb1;
+
+ k -= sizeof(uint8_t);
+ } while (k != 0);
+
+ const int32_t vmultiplier = params->scalar.multiplier;
+ const int64_t vproduct0x0 = (int64_t) vacc0x0 * (int64_t) vmultiplier;
+ const int64_t vproduct0x1 = (int64_t) vacc0x1 * (int64_t) vmultiplier;
+ const int64_t vproduct1x0 = (int64_t) vacc1x0 * (int64_t) vmultiplier;
+ const int64_t vproduct1x1 = (int64_t) vacc1x1 * (int64_t) vmultiplier;
+
+ const int64_t vq31rounding = INT64_C(0x40000000);
+ const int32_t vq31product0x0 = (int32_t) (uint32_t) ((uint64_t) (vproduct0x0 + vq31rounding) >> 31);
+ const int32_t vq31product0x1 = (int32_t) (uint32_t) ((uint64_t) (vproduct0x1 + vq31rounding) >> 31);
+ const int32_t vq31product1x0 = (int32_t) (uint32_t) ((uint64_t) (vproduct1x0 + vq31rounding) >> 31);
+ const int32_t vq31product1x1 = (int32_t) (uint32_t) ((uint64_t) (vproduct1x1 + vq31rounding) >> 31);
+
+ const int32_t vremainder_mask = params->scalar.remainder_mask;
+ const int32_t vremainder0x0 = (vq31product0x0 & vremainder_mask) - (int32_t) (vq31product0x0 < 0);
+ const int32_t vremainder0x1 = (vq31product0x1 & vremainder_mask) - (int32_t) (vq31product0x1 < 0);
+ const int32_t vremainder1x0 = (vq31product1x0 & vremainder_mask) - (int32_t) (vq31product1x0 < 0);
+ const int32_t vremainder1x1 = (vq31product1x1 & vremainder_mask) - (int32_t) (vq31product1x1 < 0);
+
+ const uint32_t vshift = params->scalar.shift;
+ const int32_t vremainder_threshold = params->scalar.remainder_threshold;
+ int32_t vout0x0 = asr_s32(vq31product0x0, vshift) + (int32_t) (vremainder0x0 > vremainder_threshold);
+ int32_t vout0x1 = asr_s32(vq31product0x1, vshift) + (int32_t) (vremainder0x1 > vremainder_threshold);
+ int32_t vout1x0 = asr_s32(vq31product1x0, vshift) + (int32_t) (vremainder1x0 > vremainder_threshold);
+ int32_t vout1x1 = asr_s32(vq31product1x1, vshift) + (int32_t) (vremainder1x1 > vremainder_threshold);
+
+ const int32_t vout_min = params->scalar.output_min_less_zero_point;
+ vout0x0 = vout0x0 < vout_min ? vout_min : vout0x0;
+ vout0x1 = vout0x1 < vout_min ? vout_min : vout0x1;
+ vout1x0 = vout1x0 < vout_min ? vout_min : vout1x0;
+ vout1x1 = vout1x1 < vout_min ? vout_min : vout1x1;
+
+ const int32_t vout_max = params->scalar.output_max_less_zero_point;
+ vout0x0 = vout0x0 > vout_max ? vout_max : vout0x0;
+ vout0x1 = vout0x1 > vout_max ? vout_max : vout0x1;
+ vout1x0 = vout1x0 > vout_max ? vout_max : vout1x0;
+ vout1x1 = vout1x1 > vout_max ? vout_max : vout1x1;
+
+ const int32_t voutput_zero_point = params->scalar.output_zero_point;
+ vout0x0 += voutput_zero_point;
+ vout0x1 += voutput_zero_point;
+ vout1x0 += voutput_zero_point;
+ vout1x1 += voutput_zero_point;
+
+ if XNN_LIKELY(nc >= 2) {
+ c0[0] = (uint8_t) vout0x0;
+ c0[1] = (uint8_t) vout0x1;
+ c1[0] = (uint8_t) vout1x0;
+ c1[1] = (uint8_t) vout1x1;
+
+ a0 = (const uint8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const uint8_t*) ((uintptr_t) a1 - kc);
+
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+
+ nc -= 2;
+ } else {
+ c0[0] = (uint8_t) vout0x0;
+ c1[0] = (uint8_t) vout1x0;
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/q8-gemm/2x4c8-sse2.c b/src/q8-gemm/2x4c8-sse2.c
new file mode 100644
index 0000000..82e2053
--- /dev/null
+++ b/src/q8-gemm/2x4c8-sse2.c
@@ -0,0 +1,206 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/math.h>
+
+
+static inline __m128i sse_reduce4_i32(__m128i x, __m128i y, __m128i z, __m128i w) {
+#if defined(__SSSE3__) && !defined(__ANDROID__)
+ /* xxyy = ( y2 + y3, y0 + y1, x2 + x3, x0 + x1 ) */
+ const __m128i xxyy = _mm_hadd_epi32(x, y);
+ /* zzww = ( w2 + w3, w0 + w1, z2 + z3, z0 + z1 ) */
+ const __m128i zzww = _mm_hadd_epi32(z, w);
+ /* xyzw = ( w0 + w1 + w2 + w3, y0 + y1 + y2 + y3, z0 + z1 + z2 + z3, x0 + x1 +
+ * x2 + x3 ) */
+ return _mm_hadd_epi32(xxyy, zzww);
+#else
+ /* xzxz = ( z1 + z3, x1 + x3, z0 + z2, x0 + x2 ) */
+ const __m128i xzxz =
+ _mm_add_epi32(_mm_unpacklo_epi32(x, z), _mm_unpackhi_epi32(x, z));
+ /* ywyw = ( w1 + w3, y1 + y3, w0 + w2, y0 + y2 ) */
+ const __m128i ywyw =
+ _mm_add_epi32(_mm_unpacklo_epi32(y, w), _mm_unpackhi_epi32(y, w));
+ /* xyzw = ( w0 + w2 + w1 + w3, y0 + y2 + y1 + y3, z0 + z2 + z1 + z3, x0 + x2 +
+ * x1 + x3 ) */
+ return _mm_add_epi32(
+ _mm_unpacklo_epi32(xzxz, ywyw), _mm_unpackhi_epi32(xzxz, ywyw));
+#endif
+}
+
+void xnn_q8_gemm_ukernel_2x4c8__sse2(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const uint8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+
+ const uint8_t* a0 = a;
+ uint8_t* c0 = c;
+ const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if (mr != 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+
+ const size_t kc_stride = round_up_po2(kc, 8);
+ const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->sse2.kernel_zero_point);
+
+ do {
+ __m128i vacc00 = _mm_cvtsi32_si128((int) ((const int32_t*) w)[0]);
+ __m128i vacc01 = _mm_cvtsi32_si128((int) ((const int32_t*) w)[1]);
+ __m128i vacc02 = _mm_cvtsi32_si128((int) ((const int32_t*) w)[2]);
+ __m128i vacc03 = _mm_cvtsi32_si128((int) ((const int32_t*) w)[3]);
+ __m128i vacc10 = vacc00;
+ __m128i vacc11 = vacc01;
+ __m128i vacc12 = vacc02;
+ __m128i vacc13 = vacc03;
+ w = (const void*) ((uintptr_t) w + 16);
+
+ const __m128i vzero = _mm_setzero_si128();
+ for (size_t k = 0; k < kc; k += 8 * sizeof(uint8_t)) {
+ const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
+ const __m128i vxa0 = _mm_unpacklo_epi8(va0, vzero);
+ a0 += 8;
+ const __m128i va1 = _mm_loadl_epi64((const __m128i*) a1);
+ const __m128i vxa1 = _mm_unpacklo_epi8(va1, vzero);
+ a1 += 8;
+
+ const __m128i vb0 = _mm_loadl_epi64((const __m128i*) w);
+ const __m128i vxb0 = _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
+ const __m128i vb1 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 8));
+ const __m128i vxb1 = _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
+ const __m128i vb2 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16));
+ const __m128i vxb2 = _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
+ const __m128i vb3 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 24));
+ const __m128i vxb3 = _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
+ w = (const void*) ((uintptr_t) w + 32);
+
+ vacc00 = _mm_add_epi32(vacc00, _mm_madd_epi16(vxa0, vxb0));
+ vacc01 = _mm_add_epi32(vacc01, _mm_madd_epi16(vxa0, vxb1));
+ vacc02 = _mm_add_epi32(vacc02, _mm_madd_epi16(vxa0, vxb2));
+ vacc03 = _mm_add_epi32(vacc03, _mm_madd_epi16(vxa0, vxb3));
+ vacc10 = _mm_add_epi32(vacc10, _mm_madd_epi16(vxa1, vxb0));
+ vacc11 = _mm_add_epi32(vacc11, _mm_madd_epi16(vxa1, vxb1));
+ vacc12 = _mm_add_epi32(vacc12, _mm_madd_epi16(vxa1, vxb2));
+ vacc13 = _mm_add_epi32(vacc13, _mm_madd_epi16(vxa1, vxb3));
+ }
+
+ __m128i vacc0x0123 = sse_reduce4_i32(vacc00, vacc01, vacc02, vacc03);
+ __m128i vacc1x0123 = sse_reduce4_i32(vacc10, vacc11, vacc12, vacc13);
+
+ const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
+ const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
+
+ const __m128i vnmask0x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc0x0123);
+ const __m128i vnmask1x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc1x0123);
+
+ const __m128i vabsacc0x0123 = _mm_sub_epi32(_mm_xor_si128(vacc0x0123, vnmask0x0123), vnmask0x0123);
+ const __m128i vabsacc1x0123 = _mm_sub_epi32(_mm_xor_si128(vacc1x0123, vnmask1x0123), vnmask1x0123);
+
+ const __m128i vabsacc0x1032 = _mm_shuffle_epi32(vabsacc0x0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabsacc1x1032 = _mm_shuffle_epi32(vabsacc1x0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsprod0x02 = _mm_mul_epu32(vabsacc0x0123, vmultiplier);
+ const __m128i vabsprod1x02 = _mm_mul_epu32(vabsacc1x0123, vmultiplier);
+
+ const __m128i vnmask0x02 = _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i vnmask1x02 = _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(2, 2, 0, 0));
+
+ const __m128i vprod0x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod0x02, vnmask0x02), vnmask0x02);
+ const __m128i vprod1x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod1x02, vnmask1x02), vnmask1x02);
+
+ const __m128i vq31prod0x02 = _mm_srli_epi64(_mm_add_epi64(vprod0x02, vrounding), 31);
+ const __m128i vq31prod1x02 = _mm_srli_epi64(_mm_add_epi64(vprod1x02, vrounding), 31);
+
+ const __m128i vabsprod0x13 = _mm_mul_epu32(vabsacc0x1032, vmultiplier);
+ const __m128i vabsprod1x13 = _mm_mul_epu32(vabsacc1x1032, vmultiplier);
+
+ const __m128i vnmask0x13 = _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i vnmask1x13 = _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(3, 3, 1, 1));
+
+ const __m128i vprod0x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod0x13, vnmask0x13), vnmask0x13);
+ const __m128i vprod1x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod1x13, vnmask1x13), vnmask1x13);
+
+ const __m128i vq31prod0x13 = _mm_srli_epi64(_mm_add_epi64(vprod0x13, vrounding), 31);
+ const __m128i vq31prod1x13 = _mm_srli_epi64(_mm_add_epi64(vprod1x13, vrounding), 31);
+
+ const __m128i vq31prod0x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod0x02), _mm_castsi128_ps(vq31prod0x13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vq31prod1x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod1x02), _mm_castsi128_ps(vq31prod1x13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vq31prod0x0123 = _mm_shuffle_epi32(vq31prod0x0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vq31prod1x0123 = _mm_shuffle_epi32(vq31prod1x0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vremainder_mask = _mm_load_si128((const __m128i*) params->sse2.remainder_mask);
+
+ const __m128i vrem0x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod0x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod0x0123));
+ const __m128i vrem1x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod1x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod1x0123));
+
+ const __m128i vremainder_threshold = _mm_load_si128((const __m128i*) params->sse2.remainder_threshold);
+ const __m128i vshift = _mm_load_si128((const __m128i*) params->sse2.shift);
+
+ vacc0x0123 =
+ _mm_sub_epi32(_mm_sra_epi32(vq31prod0x0123, vshift), _mm_cmpgt_epi32(vrem0x0123, vremainder_threshold));
+ vacc1x0123 =
+ _mm_sub_epi32(_mm_sra_epi32(vq31prod1x0123, vshift), _mm_cmpgt_epi32(vrem1x0123, vremainder_threshold));
+
+ const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
+ const __m128i vacc01x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point);
+ __m128i vout = _mm_packus_epi16(vacc01x0123, vacc01x0123);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_min));
+
+ if (nc >= 4) {
+ *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
+ *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(_mm_srli_epi64(vout, 32));
+
+ a0 = (const uint8_t*) ((uintptr_t) a0 - kc_stride);
+ a1 = (const uint8_t*) ((uintptr_t) a1 - kc_stride);
+
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ *((uint16_t*) c0) = (uint16_t) _mm_extract_epi16(vout, 0);
+ c0 += 2;
+ *((uint16_t*) c1) = (uint16_t) _mm_extract_epi16(vout, 2);
+ c1 += 2;
+ vout = _mm_srli_epi32(vout, 16);
+ }
+ if (nc & 1) {
+ *((uint8_t*) c0) = (uint8_t) _mm_cvtsi128_si32(vout);
+ *((uint8_t*) c1) = (uint8_t) _mm_extract_epi16(vout, 2);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/q8-gemm/4x4c2-sse2.c b/src/q8-gemm/4x4c2-sse2.c
new file mode 100644
index 0000000..714280b
--- /dev/null
+++ b/src/q8-gemm/4x4c2-sse2.c
@@ -0,0 +1,346 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_q8_gemm_ukernel_4x4c2__sse2(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const uint8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+
+ const uint8_t* a0 = a;
+ uint8_t* c0 = c;
+ const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride);
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride);
+ uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
+ if (mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->sse2.kernel_zero_point);
+
+ do {
+ __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*) w);
+ __m128i vacc1x0123 = vacc0x0123;
+ __m128i vacc2x0123 = vacc0x0123;
+ __m128i vacc3x0123 = vacc0x0123;
+ w = (const void*) ((uintptr_t) w + 16);
+
+ const __m128i vzero = _mm_setzero_si128();
+ size_t k = kc;
+ while (k >= 8 * sizeof(uint8_t)) {
+ const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
+ const __m128i vxa0 = _mm_unpacklo_epi8(va0, vzero);
+ a0 += 8;
+ const __m128i va1 = _mm_loadl_epi64((const __m128i*) a1);
+ const __m128i vxa1 = _mm_unpacklo_epi8(va1, vzero);
+ a1 += 8;
+ const __m128i va2 = _mm_loadl_epi64((const __m128i*) a2);
+ const __m128i vxa2 = _mm_unpacklo_epi8(va2, vzero);
+ a2 += 8;
+ const __m128i va3 = _mm_loadl_epi64((const __m128i*) a3);
+ const __m128i vxa3 = _mm_unpacklo_epi8(va3, vzero);
+ a3 += 8;
+
+ const __m128i vb0 = _mm_loadl_epi64((const __m128i*) w);
+ const __m128i vxb0 = _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+
+ const __m128i vb1 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 8));
+ const __m128i vxb1 = _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+
+ const __m128i vb2 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16));
+ const __m128i vxb2 = _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+
+ const __m128i vb3 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 24));
+ const __m128i vxb3 = _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
+ w = (const void*) ((uintptr_t) w + 32);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+
+ k -= 8 * sizeof(uint8_t);
+ }
+ if (k != 0) {
+ const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
+ a0 = (const uint8_t*) ((uintptr_t) a0 + k);
+ const __m128i vxa0 = _mm_unpacklo_epi8(va0, vzero);
+ const __m128i va1 = _mm_loadl_epi64((const __m128i*) a1);
+ a1 = (const uint8_t*) ((uintptr_t) a1 + k);
+ const __m128i vxa1 = _mm_unpacklo_epi8(va1, vzero);
+ const __m128i va2 = _mm_loadl_epi64((const __m128i*) a2);
+ a2 = (const uint8_t*) ((uintptr_t) a2 + k);
+ const __m128i vxa2 = _mm_unpacklo_epi8(va2, vzero);
+ const __m128i va3 = _mm_loadl_epi64((const __m128i*) a3);
+ a3 = (const uint8_t*) ((uintptr_t) a3 + k);
+ const __m128i vxa3 = _mm_unpacklo_epi8(va3, vzero);
+
+ const __m128i vb0 = _mm_loadl_epi64((const __m128i*) w);
+ w = (const void*) ((uintptr_t) w + 8);
+ const __m128i vxb0 = _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+
+ if (k > 2 * sizeof(uint8_t)) {
+ const __m128i vb1 = _mm_loadl_epi64((const __m128i*) w);
+ w = (const void*) ((uintptr_t) w + 8);
+ const __m128i vxb1 = _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+
+ if (k > 4 * sizeof(uint8_t)) {
+ const __m128i vb2 = _mm_loadl_epi64((const __m128i*) w);
+ w = (const void*) ((uintptr_t) w + 8);
+ const __m128i vxb2 = _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+
+ if (k > 6 * sizeof(uint8_t)) {
+ const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
+ w = (const void*) ((uintptr_t) w + 8);
+ const __m128i vxb3 = _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123,
+ _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ }
+ }
+ }
+ }
+
+ const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
+ const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
+
+ const __m128i vnmask0x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc0x0123);
+ const __m128i vnmask1x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc1x0123);
+ const __m128i vnmask2x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc2x0123);
+ const __m128i vnmask3x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc3x0123);
+
+ const __m128i vabsacc0x0123 = _mm_sub_epi32(_mm_xor_si128(vacc0x0123, vnmask0x0123), vnmask0x0123);
+ const __m128i vabsacc1x0123 = _mm_sub_epi32(_mm_xor_si128(vacc1x0123, vnmask1x0123), vnmask1x0123);
+ const __m128i vabsacc2x0123 = _mm_sub_epi32(_mm_xor_si128(vacc2x0123, vnmask2x0123), vnmask2x0123);
+ const __m128i vabsacc3x0123 = _mm_sub_epi32(_mm_xor_si128(vacc3x0123, vnmask3x0123), vnmask3x0123);
+
+ const __m128i vabsacc0x1032 = _mm_shuffle_epi32(vabsacc0x0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabsacc1x1032 = _mm_shuffle_epi32(vabsacc1x0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabsacc2x1032 = _mm_shuffle_epi32(vabsacc2x0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabsacc3x1032 = _mm_shuffle_epi32(vabsacc3x0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsprod0x02 = _mm_mul_epu32(vabsacc0x0123, vmultiplier);
+ const __m128i vabsprod1x02 = _mm_mul_epu32(vabsacc1x0123, vmultiplier);
+ const __m128i vabsprod2x02 = _mm_mul_epu32(vabsacc2x0123, vmultiplier);
+ const __m128i vabsprod3x02 = _mm_mul_epu32(vabsacc3x0123, vmultiplier);
+
+ const __m128i vnmask0x02 = _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i vnmask1x02 = _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i vnmask2x02 = _mm_shuffle_epi32(vnmask2x0123, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i vnmask3x02 = _mm_shuffle_epi32(vnmask3x0123, _MM_SHUFFLE(2, 2, 0, 0));
+
+ const __m128i vprod0x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod0x02, vnmask0x02), vnmask0x02);
+ const __m128i vprod1x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod1x02, vnmask1x02), vnmask1x02);
+ const __m128i vprod2x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod2x02, vnmask2x02), vnmask2x02);
+ const __m128i vprod3x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod3x02, vnmask3x02), vnmask3x02);
+
+ const __m128i vq31prod0x02 = _mm_srli_epi64(_mm_add_epi64(vprod0x02, vrounding), 31);
+ const __m128i vq31prod1x02 = _mm_srli_epi64(_mm_add_epi64(vprod1x02, vrounding), 31);
+ const __m128i vq31prod2x02 = _mm_srli_epi64(_mm_add_epi64(vprod2x02, vrounding), 31);
+ const __m128i vq31prod3x02 = _mm_srli_epi64(_mm_add_epi64(vprod3x02, vrounding), 31);
+
+ const __m128i vabsprod0x13 = _mm_mul_epu32(vabsacc0x1032, vmultiplier);
+ const __m128i vabsprod1x13 = _mm_mul_epu32(vabsacc1x1032, vmultiplier);
+ const __m128i vabsprod2x13 = _mm_mul_epu32(vabsacc2x1032, vmultiplier);
+ const __m128i vabsprod3x13 = _mm_mul_epu32(vabsacc3x1032, vmultiplier);
+
+ const __m128i vnmask0x13 = _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i vnmask1x13 = _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i vnmask2x13 = _mm_shuffle_epi32(vnmask2x0123, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i vnmask3x13 = _mm_shuffle_epi32(vnmask3x0123, _MM_SHUFFLE(3, 3, 1, 1));
+
+ const __m128i vprod0x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod0x13, vnmask0x13), vnmask0x13);
+ const __m128i vprod1x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod1x13, vnmask1x13), vnmask1x13);
+ const __m128i vprod2x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod2x13, vnmask2x13), vnmask2x13);
+ const __m128i vprod3x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod3x13, vnmask3x13), vnmask3x13);
+
+ const __m128i vq31prod0x13 = _mm_srli_epi64(_mm_add_epi64(vprod0x13, vrounding), 31);
+ const __m128i vq31prod1x13 = _mm_srli_epi64(_mm_add_epi64(vprod1x13, vrounding), 31);
+ const __m128i vq31prod2x13 = _mm_srli_epi64(_mm_add_epi64(vprod2x13, vrounding), 31);
+ const __m128i vq31prod3x13 = _mm_srli_epi64(_mm_add_epi64(vprod3x13, vrounding), 31);
+
+ const __m128i vq31prod0x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod0x02), _mm_castsi128_ps(vq31prod0x13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vq31prod1x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod1x02), _mm_castsi128_ps(vq31prod1x13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vq31prod2x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod2x02), _mm_castsi128_ps(vq31prod2x13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vq31prod3x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod3x02), _mm_castsi128_ps(vq31prod3x13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vq31prod0x0123 = _mm_shuffle_epi32(vq31prod0x0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vq31prod1x0123 = _mm_shuffle_epi32(vq31prod1x0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vq31prod2x0123 = _mm_shuffle_epi32(vq31prod2x0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vq31prod3x0123 = _mm_shuffle_epi32(vq31prod3x0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vremainder_mask = _mm_load_si128((const __m128i*) params->sse2.remainder_mask);
+
+ const __m128i vrem0x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod0x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod0x0123));
+ const __m128i vrem1x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod1x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod1x0123));
+ const __m128i vrem2x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod2x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod2x0123));
+ const __m128i vrem3x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod3x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod3x0123));
+
+ const __m128i vremainder_threshold = _mm_load_si128((const __m128i*) params->sse2.remainder_threshold);
+ const __m128i vshift = _mm_load_si128((const __m128i*) params->sse2.shift);
+
+ vacc0x0123 =
+ _mm_sub_epi32(_mm_sra_epi32(vq31prod0x0123, vshift), _mm_cmpgt_epi32(vrem0x0123, vremainder_threshold));
+ vacc1x0123 =
+ _mm_sub_epi32(_mm_sra_epi32(vq31prod1x0123, vshift), _mm_cmpgt_epi32(vrem1x0123, vremainder_threshold));
+ vacc2x0123 =
+ _mm_sub_epi32(_mm_sra_epi32(vq31prod2x0123, vshift), _mm_cmpgt_epi32(vrem2x0123, vremainder_threshold));
+ vacc3x0123 =
+ _mm_sub_epi32(_mm_sra_epi32(vq31prod3x0123, vshift), _mm_cmpgt_epi32(vrem3x0123, vremainder_threshold));
+
+ const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
+ const __m128i vacc01x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point);
+ const __m128i vacc23x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc2x0123, vacc3x0123), voutput_zero_point);
+ __m128i vout = _mm_packus_epi16(vacc01x0123, vacc23x0123);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_min));
+
+ if (nc >= 4) {
+ *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
+ *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(_mm_srli_epi64(vout, 32));
+ *((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(_mm_unpackhi_epi32(vout, vout));
+ *((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(_mm_srli_si128(vout, 12));
+
+ a0 = (const uint8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const uint8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const uint8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const uint8_t*) ((uintptr_t) a3 - kc);
+
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+ c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
+ c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ *((uint16_t*) c0) = (uint16_t) _mm_extract_epi16(vout, 0);
+ c0 += 2;
+ *((uint16_t*) c1) = (uint16_t) _mm_extract_epi16(vout, 2);
+ c1 += 2;
+ *((uint16_t*) c2) = (uint16_t) _mm_extract_epi16(vout, 4);
+ c2 += 2;
+ *((uint16_t*) c3) = (uint16_t) _mm_extract_epi16(vout, 6);
+ c3 += 2;
+ vout = _mm_srli_epi32(vout, 16);
+ }
+ if (nc & 1) {
+ *((uint8_t*) c0) = (uint8_t) _mm_cvtsi128_si32(vout);
+ *((uint8_t*) c1) = (uint8_t) _mm_extract_epi16(vout, 2);
+ *((uint8_t*) c2) = (uint8_t) _mm_extract_epi16(vout, 4);
+ *((uint8_t*) c3) = (uint8_t) _mm_extract_epi16(vout, 6);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/q8-gemm/4x8-neon.c b/src/q8-gemm/4x8-neon.c
new file mode 100644
index 0000000..4b025ab
--- /dev/null
+++ b/src/q8-gemm/4x8-neon.c
@@ -0,0 +1,379 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_q8_gemm_ukernel_4x8__neon(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const uint8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+
+ const uint8_t* a0 = a;
+ uint8_t* c0 = c;
+ const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride);
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride);
+ uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
+ if (mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ const uint8x8_t vb_zero_point = vld1_dup_u8((const uint8_t*) ¶ms->neon.kernel_zero_point);
+
+ do {
+ int32x4_t vacc0x0123 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16);
+ int32x4_t vacc0x4567 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16);
+ int32x4_t vacc1x0123 = vacc0x0123;
+ int32x4_t vacc1x4567 = vacc0x4567;
+ int32x4_t vacc2x0123 = vacc0x0123;
+ int32x4_t vacc2x4567 = vacc0x4567;
+ int32x4_t vacc3x0123 = vacc0x0123;
+ int32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t k = kc;
+ while (k >= 8 * sizeof(uint8_t)) {
+ const uint8x8_t va0 = vld1_u8(a0); a0 += 8;
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0));
+ const uint8x8_t va1 = vld1_u8(a1); a1 += 8;
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1));
+ const uint8x8_t va2 = vld1_u8(a2); a2 += 8;
+ const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2));
+ const uint8x8_t va3 = vld1_u8(a3); a3 += 8;
+ const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3));
+
+ const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
+
+ const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
+
+ const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
+
+ const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
+
+ const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
+
+ const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
+
+ const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
+
+ const uint8x8_t vb01234567c7 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c7 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3);
+
+ k -= 8 * sizeof(uint8_t);
+ }
+ if (k != 0) {
+ const uint8x8_t va0 = vld1_u8(a0); a0 = (const uint8_t*) ((uintptr_t) a0 + k);
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0));
+ const uint8x8_t va1 = vld1_u8(a1); a1 = (const uint8_t*) ((uintptr_t) a1 + k);
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1));
+ const uint8x8_t va2 = vld1_u8(a2); a2 = (const uint8_t*) ((uintptr_t) a2 + k);
+ const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2));
+ const uint8x8_t va3 = vld1_u8(a3); a3 = (const uint8_t*) ((uintptr_t) a3 + k);
+ const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3));
+
+ const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
+
+ if (k >= 2 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
+
+ if (k >= 3 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
+
+ if (k >= 4 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
+
+ if (k >= 5 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
+
+ if (k >= 6 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
+
+ if (k >= 7 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier);
+ vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31);
+ vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift);
+ vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
+#ifdef __aarch64__
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
+
+ uint8x16_t vout0x01234567_1x01234567 = vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567);
+ uint8x16_t vout2x01234567_3x01234567 = vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567);
+#else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point);
+
+ uint8x16_t vout0x01234567_1x01234567 = vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567));
+ uint8x16_t vout2x01234567_3x01234567 = vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567));
+#endif
+ const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min);
+ const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max);
+
+ vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min);
+ vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min);
+ vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max);
+ vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max);
+
+ if (nc >= 8) {
+ vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+ vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+ vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
+ vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
+
+ a0 = (const uint8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const uint8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const uint8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const uint8_t*) ((uintptr_t) a3 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 2); c3 += 4;
+ vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 4); c3 += 2;
+ vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
+ vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
+ vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/q8-gemm/8x8-neon.c b/src/q8-gemm/8x8-neon.c
new file mode 100644
index 0000000..675f312
--- /dev/null
+++ b/src/q8-gemm/8x8-neon.c
@@ -0,0 +1,619 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/gemm.h>
+
+
+void xnn_q8_gemm_ukernel_8x8__neon(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const uint8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 8);
+ assert(nc != 0);
+ assert(kc != 0);
+
+ const uint8_t* a0 = a;
+ uint8_t* c0 = c;
+ const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride);
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride);
+ uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+ const uint8_t* a4 = (const uint8_t*) ((uintptr_t) a3 + a_stride);
+ uint8_t* c4 = (uint8_t*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ a4 = a3;
+ c4 = c3;
+ }
+ const uint8_t* a5 = (const uint8_t*) ((uintptr_t) a4 + a_stride);
+ uint8_t* c5 = (uint8_t*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 6) {
+ a5 = a4;
+ c5 = c4;
+ }
+ const uint8_t* a6 = (const uint8_t*) ((uintptr_t) a5 + a_stride);
+ uint8_t* c6 = (uint8_t*) ((uintptr_t) c5 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 6) {
+ a6 = a5;
+ c6 = c5;
+ }
+ const uint8_t* a7 = (const uint8_t*) ((uintptr_t) a6 + a_stride);
+ uint8_t* c7 = (uint8_t*) ((uintptr_t) c6 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 8) {
+ a7 = a6;
+ c7 = c6;
+ }
+
+ const uint8x8_t vb_zero_point = vld1_dup_u8((const uint8_t*) ¶ms->neon.kernel_zero_point);
+
+ do {
+ int32x4_t vacc0x0123 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16);
+ int32x4_t vacc0x4567 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16);
+ int32x4_t vacc1x0123 = vacc0x0123;
+ int32x4_t vacc1x4567 = vacc0x4567;
+ int32x4_t vacc2x0123 = vacc0x0123;
+ int32x4_t vacc2x4567 = vacc0x4567;
+ int32x4_t vacc3x0123 = vacc0x0123;
+ int32x4_t vacc3x4567 = vacc0x4567;
+ int32x4_t vacc4x0123 = vacc0x0123;
+ int32x4_t vacc4x4567 = vacc0x4567;
+ int32x4_t vacc5x0123 = vacc0x0123;
+ int32x4_t vacc5x4567 = vacc0x4567;
+ int32x4_t vacc6x0123 = vacc0x0123;
+ int32x4_t vacc6x4567 = vacc0x4567;
+ int32x4_t vacc7x0123 = vacc0x0123;
+ int32x4_t vacc7x4567 = vacc0x4567;
+
+ size_t k = kc;
+ while (k >= 8 * sizeof(uint8_t)) {
+ const uint8x8_t va0 = vld1_u8(a0);
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0)); a0 += 8;
+ const uint8x8_t va1 = vld1_u8(a1);
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1)); a1 += 8;
+ const uint8x8_t va2 = vld1_u8(a2);
+ const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2)); a2 += 8;
+ const uint8x8_t va3 = vld1_u8(a3);
+ const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3)); a3 += 8;
+ const uint8x8_t va4 = vld1_u8(a4);
+ const int16x8_t vxa4 = vreinterpretq_s16_u16(vmovl_u8(va4)); a4 += 8;
+ const uint8x8_t va5 = vld1_u8(a5);
+ const int16x8_t vxa5 = vreinterpretq_s16_u16(vmovl_u8(va5)); a5 += 8;
+ const uint8x8_t va6 = vld1_u8(a6);
+ const int16x8_t vxa6 = vreinterpretq_s16_u16(vmovl_u8(va6)); a6 += 8;
+ const uint8x8_t va7 = vld1_u8(a7);
+ const int16x8_t vxa7 = vreinterpretq_s16_u16(vmovl_u8(va7)); a7 += 8;
+
+ const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa4), 0);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa4), 0);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa5), 0);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa5), 0);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa6), 0);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa6), 0);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa7), 0);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa7), 0);
+
+ const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa4), 1);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa4), 1);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa5), 1);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa5), 1);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa6), 1);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa6), 1);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa7), 1);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa7), 1);
+
+ const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa4), 2);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa4), 2);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa5), 2);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa5), 2);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa6), 2);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa6), 2);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa7), 2);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa7), 2);
+
+ const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa4), 3);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa4), 3);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa5), 3);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa5), 3);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa6), 3);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa6), 3);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa7), 3);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa7), 3);
+
+ const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa4), 0);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa4), 0);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa5), 0);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa5), 0);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa6), 0);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa6), 0);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa7), 0);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa7), 0);
+
+ const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa4), 1);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa4), 1);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa5), 1);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa5), 1);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa6), 1);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa6), 1);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa7), 1);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa7), 1);
+
+ const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa4), 2);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa4), 2);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa5), 2);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa5), 2);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa6), 2);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa6), 2);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa7), 2);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa7), 2);
+
+ const uint8x8_t vb01234567c7 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c7 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa4), 3);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa4), 3);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa5), 3);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa5), 3);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa6), 3);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa6), 3);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa7), 3);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa7), 3);
+
+ k -= 8 * sizeof(uint8_t);
+ }
+ if (k != 0) {
+ const uint8x8_t va0 = vld1_u8(a0); a0 = (const uint8_t*) ((uintptr_t) a0 + k);
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0));
+ const uint8x8_t va1 = vld1_u8(a1); a1 = (const uint8_t*) ((uintptr_t) a1 + k);
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1));
+ const uint8x8_t va2 = vld1_u8(a2); a2 = (const uint8_t*) ((uintptr_t) a2 + k);
+ const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2));
+ const uint8x8_t va3 = vld1_u8(a3); a3 = (const uint8_t*) ((uintptr_t) a3 + k);
+ const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3));
+ const uint8x8_t va4 = vld1_u8(a4); a4 = (const uint8_t*) ((uintptr_t) a4 + k);
+ const int16x8_t vxa4 = vreinterpretq_s16_u16(vmovl_u8(va4));
+ const uint8x8_t va5 = vld1_u8(a5); a5 = (const uint8_t*) ((uintptr_t) a5 + k);
+ const int16x8_t vxa5 = vreinterpretq_s16_u16(vmovl_u8(va5));
+ const uint8x8_t va6 = vld1_u8(a6); a6 = (const uint8_t*) ((uintptr_t) a6 + k);
+ const int16x8_t vxa6 = vreinterpretq_s16_u16(vmovl_u8(va6));
+ const uint8x8_t va7 = vld1_u8(a7); a7 = (const uint8_t*) ((uintptr_t) a7 + k);
+ const int16x8_t vxa7 = vreinterpretq_s16_u16(vmovl_u8(va7));
+
+ const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa4), 0);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa4), 0);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa5), 0);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa5), 0);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa6), 0);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa6), 0);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa7), 0);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa7), 0);
+
+ if (k >= 2 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa4), 1);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa4), 1);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa5), 1);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa5), 1);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa6), 1);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa6), 1);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa7), 1);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa7), 1);
+
+ if (k > 2 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa4), 2);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa4), 2);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa5), 2);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa5), 2);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa6), 2);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa6), 2);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa7), 2);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa7), 2);
+
+ if (k >= 4 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa4), 3);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa4), 3);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa5), 3);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa5), 3);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa6), 3);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa6), 3);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa7), 3);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa7), 3);
+
+ if (k > 4 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa4), 0);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa4), 0);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa5), 0);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa5), 0);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa6), 0);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa6), 0);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa7), 0);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa7), 0);
+
+ if (k >= 6 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa4), 1);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa4), 1);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa5), 1);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa5), 1);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa6), 1);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa6), 1);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa7), 1);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa7), 1);
+
+ if (k > 6 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8);
+ const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa4), 2);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa4), 2);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa5), 2);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa5), 2);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa6), 2);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa6), 2);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa7), 2);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa7), 2);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier);
+ vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier);
+ vacc4x0123 = vqrdmulhq_s32(vacc4x0123, vmultiplier);
+ vacc4x4567 = vqrdmulhq_s32(vacc4x4567, vmultiplier);
+ vacc5x0123 = vqrdmulhq_s32(vacc5x0123, vmultiplier);
+ vacc5x4567 = vqrdmulhq_s32(vacc5x4567, vmultiplier);
+ vacc6x0123 = vqrdmulhq_s32(vacc6x0123, vmultiplier);
+ vacc6x4567 = vqrdmulhq_s32(vacc6x4567, vmultiplier);
+ vacc7x0123 = vqrdmulhq_s32(vacc7x0123, vmultiplier);
+ vacc7x4567 = vqrdmulhq_s32(vacc7x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31);
+ vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31);
+ vacc4x0123 = vsraq_n_s32(vacc4x0123, vbicq_s32(vacc4x0123, vzero_shift_mask), 31);
+ vacc4x4567 = vsraq_n_s32(vacc4x4567, vbicq_s32(vacc4x4567, vzero_shift_mask), 31);
+ vacc5x0123 = vsraq_n_s32(vacc5x0123, vbicq_s32(vacc5x0123, vzero_shift_mask), 31);
+ vacc5x4567 = vsraq_n_s32(vacc5x4567, vbicq_s32(vacc5x4567, vzero_shift_mask), 31);
+ vacc6x0123 = vsraq_n_s32(vacc6x0123, vbicq_s32(vacc6x0123, vzero_shift_mask), 31);
+ vacc6x4567 = vsraq_n_s32(vacc6x4567, vbicq_s32(vacc6x4567, vzero_shift_mask), 31);
+ vacc7x0123 = vsraq_n_s32(vacc7x0123, vbicq_s32(vacc7x0123, vzero_shift_mask), 31);
+ vacc7x4567 = vsraq_n_s32(vacc7x4567, vbicq_s32(vacc7x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift);
+ vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift);
+ vacc4x0123 = vrshlq_s32(vacc4x0123, vright_shift);
+ vacc4x4567 = vrshlq_s32(vacc4x4567, vright_shift);
+ vacc5x0123 = vrshlq_s32(vacc5x0123, vright_shift);
+ vacc5x4567 = vrshlq_s32(vacc5x4567, vright_shift);
+ vacc6x0123 = vrshlq_s32(vacc6x0123, vright_shift);
+ vacc6x4567 = vrshlq_s32(vacc6x4567, vright_shift);
+ vacc7x0123 = vrshlq_s32(vacc7x0123, vright_shift);
+ vacc7x4567 = vrshlq_s32(vacc7x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
+#ifdef __aarch64__
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
+ const int16x8_t vacc4x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc4x0123), vacc4x4567), voutput_zero_point);
+ const int16x8_t vacc5x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc5x0123), vacc5x4567), voutput_zero_point);
+ const int16x8_t vacc6x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc6x0123), vacc6x4567), voutput_zero_point);
+ const int16x8_t vacc7x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc7x0123), vacc7x4567), voutput_zero_point);
+
+ uint8x16_t vout0x01234567_1x01234567 = vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567);
+ uint8x16_t vout2x01234567_3x01234567 = vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567);
+ uint8x16_t vout4x01234567_5x01234567 = vqmovun_high_s16(vqmovun_s16(vacc4x01234567), vacc5x01234567);
+ uint8x16_t vout6x01234567_7x01234567 = vqmovun_high_s16(vqmovun_s16(vacc6x01234567), vacc7x01234567);
+#else
+ const int16x8_t vacc0x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc3x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point);
+ const int16x8_t vacc4x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc4x0123), vqmovn_s32(vacc4x4567)), voutput_zero_point);
+ const int16x8_t vacc5x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc5x0123), vqmovn_s32(vacc5x4567)), voutput_zero_point);
+ const int16x8_t vacc6x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc6x0123), vqmovn_s32(vacc6x4567)), voutput_zero_point);
+ const int16x8_t vacc7x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc7x0123), vqmovn_s32(vacc7x4567)), voutput_zero_point);
+
+ uint8x16_t vout0x01234567_1x01234567 = vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567));
+ uint8x16_t vout2x01234567_3x01234567 = vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567));
+ uint8x16_t vout4x01234567_5x01234567 = vcombine_u8(vqmovun_s16(vacc4x01234567), vqmovun_s16(vacc5x01234567));
+ uint8x16_t vout6x01234567_7x01234567 = vcombine_u8(vqmovun_s16(vacc6x01234567), vqmovun_s16(vacc7x01234567));
+#endif
+ const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min);
+ const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max);
+
+ vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min);
+ vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min);
+ vout4x01234567_5x01234567 = vmaxq_u8(vout4x01234567_5x01234567, voutput_min);
+ vout6x01234567_7x01234567 = vmaxq_u8(vout6x01234567_7x01234567, voutput_min);
+ vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max);
+ vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max);
+ vout4x01234567_5x01234567 = vminq_u8(vout4x01234567_5x01234567, voutput_max);
+ vout6x01234567_7x01234567 = vminq_u8(vout6x01234567_7x01234567, voutput_max);
+
+ if (nc >= 8) {
+ vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+ vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+ vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
+ vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
+ vst1_u8(c4, vget_low_u8(vout4x01234567_5x01234567)); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride);
+ vst1_u8(c5, vget_high_u8(vout4x01234567_5x01234567)); c5 = (uint8_t*) ((uintptr_t) c5 + cn_stride);
+ vst1_u8(c6, vget_low_u8(vout6x01234567_7x01234567)); c6 = (uint8_t*) ((uintptr_t) c6 + cn_stride);
+ vst1_u8(c7, vget_high_u8(vout6x01234567_7x01234567)); c7 = (uint8_t*) ((uintptr_t) c7 + cn_stride);
+
+ a0 = (const uint8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const uint8_t*) ((uintptr_t) a1 - kc);
+ a2 = (const uint8_t*) ((uintptr_t) a2 - kc);
+ a3 = (const uint8_t*) ((uintptr_t) a3 - kc);
+ a4 = (const uint8_t*) ((uintptr_t) a4 - kc);
+ a5 = (const uint8_t*) ((uintptr_t) a5 - kc);
+ a6 = (const uint8_t*) ((uintptr_t) a6 - kc);
+ a7 = (const uint8_t*) ((uintptr_t) a7 - kc);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 2); c3 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpretq_u32_u8(vout4x01234567_5x01234567), 0); c4 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpretq_u32_u8(vout4x01234567_5x01234567), 2); c5 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c6, 1), vreinterpretq_u32_u8(vout6x01234567_7x01234567), 0); c6 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c7, 1), vreinterpretq_u32_u8(vout6x01234567_7x01234567), 2); c7 += 4;
+ vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
+ vout4x01234567_5x01234567 = vextq_u8(vout4x01234567_5x01234567, vout4x01234567_5x01234567, 4);
+ vout6x01234567_7x01234567 = vextq_u8(vout6x01234567_7x01234567, vout6x01234567_7x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 4); c3 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c4, 1), vreinterpretq_u16_u8(vout4x01234567_5x01234567), 0); c4 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c5, 1), vreinterpretq_u16_u8(vout4x01234567_5x01234567), 4); c5 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c6, 1), vreinterpretq_u16_u8(vout6x01234567_7x01234567), 0); c6 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c7, 1), vreinterpretq_u16_u8(vout6x01234567_7x01234567), 4); c7 += 2;
+ vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
+ vout4x01234567_5x01234567 = vextq_u8(vout4x01234567_5x01234567, vout4x01234567_5x01234567, 2);
+ vout6x01234567_7x01234567 = vextq_u8(vout6x01234567_7x01234567, vout6x01234567_7x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
+ vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
+ vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
+ vst1q_lane_u8(c4, vout4x01234567_5x01234567, 0);
+ vst1q_lane_u8(c5, vout4x01234567_5x01234567, 8);
+ vst1q_lane_u8(c6, vout6x01234567_7x01234567, 0);
+ vst1q_lane_u8(c7, vout6x01234567_7x01234567, 8);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/q8-igemm/2x2-scalar.c b/src/q8-igemm/2x2-scalar.c
new file mode 100644
index 0000000..18a398c
--- /dev/null
+++ b/src/q8-igemm/2x2-scalar.c
@@ -0,0 +1,145 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/igemm.h>
+
+
+void xnn_q8_igemm_ukernel_2x2__scalar(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const uint8_t** restrict a,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const uint8_t* zero,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (2 * sizeof(void*)) == 0);
+
+ uint8_t* c0 = c;
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if (mr != 2) {
+ c1 = c0;
+ }
+
+ const int32_t vb_zero_point = params->scalar.kernel_zero_point;
+
+ do {
+ int32_t vacc0x0 = ((const int32_t*) w)[0];
+ int32_t vacc0x1 = ((const int32_t*) w)[1];
+ int32_t vacc1x0 = vacc0x0;
+ int32_t vacc1x1 = vacc0x1;
+ w = (const void*) ((uintptr_t) w + 2 * sizeof(int32_t));
+
+ size_t p = ks;
+ do {
+ const uint8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const uint8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ a += 2;
+
+ size_t k = kc;
+ do {
+ const int32_t va0 = (int32_t) (uint32_t) *a0++;
+ const int32_t va1 = (int32_t) (uint32_t) *a1++;
+
+ const uint32_t vb0 = ((const uint8_t*) w)[0];
+ const uint32_t vb1 = ((const uint8_t*) w)[1];
+ w = (const void*) ((uintptr_t) w + 2 * sizeof(uint8_t));
+
+ const int32_t vxb0 = (int32_t) vb0 - vb_zero_point;
+ const int32_t vxb1 = (int32_t) vb1 - vb_zero_point;
+
+ vacc0x0 += va0 * vxb0;
+ vacc0x1 += va0 * vxb1;
+ vacc1x0 += va1 * vxb0;
+ vacc1x1 += va1 * vxb1;
+
+ } while (--k != 0);
+ p -= 2 * sizeof(void*);
+ } while (p != 0);
+
+ const int32_t vmultiplier = params->scalar.multiplier;
+ const int64_t vproduct0x0 = (int64_t) vacc0x0 * (int64_t) vmultiplier;
+ const int64_t vproduct0x1 = (int64_t) vacc0x1 * (int64_t) vmultiplier;
+ const int64_t vproduct1x0 = (int64_t) vacc1x0 * (int64_t) vmultiplier;
+ const int64_t vproduct1x1 = (int64_t) vacc1x1 * (int64_t) vmultiplier;
+
+ const int64_t vq31rounding = INT64_C(0x40000000);
+ const int32_t vq31product0x0 = (int32_t) (uint32_t) ((uint64_t) (vproduct0x0 + vq31rounding) >> 31);
+ const int32_t vq31product0x1 = (int32_t) (uint32_t) ((uint64_t) (vproduct0x1 + vq31rounding) >> 31);
+ const int32_t vq31product1x0 = (int32_t) (uint32_t) ((uint64_t) (vproduct1x0 + vq31rounding) >> 31);
+ const int32_t vq31product1x1 = (int32_t) (uint32_t) ((uint64_t) (vproduct1x1 + vq31rounding) >> 31);
+
+ const int32_t vremainder_mask = params->scalar.remainder_mask;
+ const int32_t vremainder0x0 = (vq31product0x0 & vremainder_mask) - (int32_t) (vq31product0x0 < 0);
+ const int32_t vremainder0x1 = (vq31product0x1 & vremainder_mask) - (int32_t) (vq31product0x1 < 0);
+ const int32_t vremainder1x0 = (vq31product1x0 & vremainder_mask) - (int32_t) (vq31product1x0 < 0);
+ const int32_t vremainder1x1 = (vq31product1x1 & vremainder_mask) - (int32_t) (vq31product1x1 < 0);
+
+ const uint32_t vshift = params->scalar.shift;
+ const int32_t vremainder_threshold = params->scalar.remainder_threshold;
+ int32_t vout0x0 = asr_s32(vq31product0x0, vshift) + (int32_t) (vremainder0x0 > vremainder_threshold);
+ int32_t vout0x1 = asr_s32(vq31product0x1, vshift) + (int32_t) (vremainder0x1 > vremainder_threshold);
+ int32_t vout1x0 = asr_s32(vq31product1x0, vshift) + (int32_t) (vremainder1x0 > vremainder_threshold);
+ int32_t vout1x1 = asr_s32(vq31product1x1, vshift) + (int32_t) (vremainder1x1 > vremainder_threshold);
+
+ const int32_t vout_min = params->scalar.output_min_less_zero_point;
+ vout0x0 = vout0x0 < vout_min ? vout_min : vout0x0;
+ vout0x1 = vout0x1 < vout_min ? vout_min : vout0x1;
+ vout1x0 = vout1x0 < vout_min ? vout_min : vout1x0;
+ vout1x1 = vout1x1 < vout_min ? vout_min : vout1x1;
+
+ const int32_t vout_max = params->scalar.output_max_less_zero_point;
+ vout0x0 = vout0x0 > vout_max ? vout_max : vout0x0;
+ vout0x1 = vout0x1 > vout_max ? vout_max : vout0x1;
+ vout1x0 = vout1x0 > vout_max ? vout_max : vout1x0;
+ vout1x1 = vout1x1 > vout_max ? vout_max : vout1x1;
+
+ const int32_t voutput_zero_point = params->scalar.output_zero_point;
+ vout0x0 += voutput_zero_point;
+ vout0x1 += voutput_zero_point;
+ vout1x0 += voutput_zero_point;
+ vout1x1 += voutput_zero_point;
+
+ if XNN_LIKELY(nc >= 2) {
+ c1[0] = (uint8_t) vout1x0;
+ c1[1] = (uint8_t) vout1x1;
+ c0[0] = (uint8_t) vout0x0;
+ c0[1] = (uint8_t) vout0x1;
+
+ c1 += cn_stride;
+ c0 += cn_stride;
+
+ a = (const uint8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 2;
+ } else {
+ c1[0] = (uint8_t) vout1x0;
+ c0[0] = (uint8_t) vout0x0;
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/q8-igemm/4x4c2-sse2.c b/src/q8-igemm/4x4c2-sse2.c
new file mode 100644
index 0000000..1c00a26
--- /dev/null
+++ b/src/q8-igemm/4x4c2-sse2.c
@@ -0,0 +1,314 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_q8_igemm_ukernel_4x4c2__sse2(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const uint8_t** restrict a,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const uint8_t* zero,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+
+ uint8_t* c0 = c;
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
+ if (mr != 4) {
+ c3 = c2;
+ }
+
+ const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->sse2.kernel_zero_point);
+
+ do {
+ __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*) w);
+ __m128i vacc1x0123 = vacc0x0123;
+ __m128i vacc2x0123 = vacc0x0123;
+ __m128i vacc3x0123 = vacc0x0123;
+ w = (const void*) ((uintptr_t) w + 16);
+
+ const __m128i vzero = _mm_setzero_si128();
+ size_t p = ks;
+ do {
+ const uint8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const uint8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ const uint8_t* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const uint8_t*) ((uintptr_t) a2 + a_offset);
+ }
+ const uint8_t* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const uint8_t*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ while (k >= 8 * sizeof(uint8_t)) {
+ const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
+ const __m128i vxa0 = _mm_unpacklo_epi8(va0, vzero);
+ a0 += 8;
+ const __m128i va1 = _mm_loadl_epi64((const __m128i*) a1);
+ const __m128i vxa1 = _mm_unpacklo_epi8(va1, vzero);
+ a1 += 8;
+ const __m128i va2 = _mm_loadl_epi64((const __m128i*) a2);
+ const __m128i vxa2 = _mm_unpacklo_epi8(va2, vzero);
+ a2 += 8;
+ const __m128i va3 = _mm_loadl_epi64((const __m128i*) a3);
+ const __m128i vxa3 = _mm_unpacklo_epi8(va3, vzero);
+ a3 += 8;
+
+ const __m128i vb0 = _mm_loadl_epi64((const __m128i*) w);
+ const __m128i vxb0 = _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
+ vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+
+ const __m128i vb1 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 8));
+ const __m128i vxb1 = _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
+ vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+
+ const __m128i vb2 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16));
+ const __m128i vxb2 = _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
+ vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+
+ const __m128i vb3 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 24));
+ const __m128i vxb3 = _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
+ vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+
+ w = (void*) ((uintptr_t) w + 32);
+
+ k -= 8 * sizeof(uint8_t);
+ }
+ if (k != 0) {
+ const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
+ const __m128i vxa0 = _mm_unpacklo_epi8(va0, vzero);
+ const __m128i va1 = _mm_loadl_epi64((const __m128i*) a1);
+ const __m128i vxa1 = _mm_unpacklo_epi8(va1, vzero);
+ const __m128i va2 = _mm_loadl_epi64((const __m128i*) a2);
+ const __m128i vxa2 = _mm_unpacklo_epi8(va2, vzero);
+ const __m128i va3 = _mm_loadl_epi64((const __m128i*) a3);
+ const __m128i vxa3 = _mm_unpacklo_epi8(va3, vzero);
+
+ const __m128i vb0 = _mm_loadl_epi64((const __m128i*) w);
+ const __m128i vxb0 = _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
+ w = (void*) ((uintptr_t) w + 8);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
+
+ if (k > 2 * sizeof(uint8_t)) {
+ const __m128i vb1 = _mm_loadl_epi64((const __m128i*) w);
+ const __m128i vxb1 = _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
+ w = (void*) ((uintptr_t) w + 8);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
+
+ if (k > 4 * sizeof(uint8_t)) {
+ const __m128i vb2 = _mm_loadl_epi64((const __m128i*) w);
+ const __m128i vxb2 = _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
+ w = (void*) ((uintptr_t) w + 8);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
+
+ if (k > 6 * sizeof(uint8_t)) {
+ const __m128i vb3 = _mm_loadl_epi64((const __m128i*) w);
+ const __m128i vxb3 = _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
+ w = (void*) ((uintptr_t) w + 8);
+
+ vacc0x0123 = _mm_add_epi32(vacc0x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc1x0123 = _mm_add_epi32(vacc1x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc2x0123 = _mm_add_epi32(vacc2x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ vacc3x0123 = _mm_add_epi32(vacc3x0123, _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
+ }
+ }
+ }
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const __m128i vmultiplier = _mm_load_si128((const __m128i*) params->sse2.multiplier);
+ const __m128i vrounding = _mm_load_si128((const __m128i*) params->sse2.rounding);
+
+ const __m128i vnmask0x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc0x0123);
+ const __m128i vnmask1x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc1x0123);
+ const __m128i vnmask2x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc2x0123);
+ const __m128i vnmask3x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc3x0123);
+
+ const __m128i vabsacc0x0123 = _mm_sub_epi32(_mm_xor_si128(vacc0x0123, vnmask0x0123), vnmask0x0123);
+ const __m128i vabsacc1x0123 = _mm_sub_epi32(_mm_xor_si128(vacc1x0123, vnmask1x0123), vnmask1x0123);
+ const __m128i vabsacc2x0123 = _mm_sub_epi32(_mm_xor_si128(vacc2x0123, vnmask2x0123), vnmask2x0123);
+ const __m128i vabsacc3x0123 = _mm_sub_epi32(_mm_xor_si128(vacc3x0123, vnmask3x0123), vnmask3x0123);
+
+ const __m128i vabsacc0x1032 = _mm_shuffle_epi32(vabsacc0x0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabsacc1x1032 = _mm_shuffle_epi32(vabsacc1x0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabsacc2x1032 = _mm_shuffle_epi32(vabsacc2x0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i vabsacc3x1032 = _mm_shuffle_epi32(vabsacc3x0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i vabsprod0x02 = _mm_mul_epu32(vabsacc0x0123, vmultiplier);
+ const __m128i vabsprod1x02 = _mm_mul_epu32(vabsacc1x0123, vmultiplier);
+ const __m128i vabsprod2x02 = _mm_mul_epu32(vabsacc2x0123, vmultiplier);
+ const __m128i vabsprod3x02 = _mm_mul_epu32(vabsacc3x0123, vmultiplier);
+
+ const __m128i vnmask0x02 = _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i vnmask1x02 = _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i vnmask2x02 = _mm_shuffle_epi32(vnmask2x0123, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i vnmask3x02 = _mm_shuffle_epi32(vnmask3x0123, _MM_SHUFFLE(2, 2, 0, 0));
+
+ const __m128i vprod0x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod0x02, vnmask0x02), vnmask0x02);
+ const __m128i vprod1x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod1x02, vnmask1x02), vnmask1x02);
+ const __m128i vprod2x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod2x02, vnmask2x02), vnmask2x02);
+ const __m128i vprod3x02 = _mm_sub_epi64(_mm_xor_si128(vabsprod3x02, vnmask3x02), vnmask3x02);
+
+ const __m128i vq31prod0x02 = _mm_srli_epi64(_mm_add_epi64(vprod0x02, vrounding), 31);
+ const __m128i vq31prod1x02 = _mm_srli_epi64(_mm_add_epi64(vprod1x02, vrounding), 31);
+ const __m128i vq31prod2x02 = _mm_srli_epi64(_mm_add_epi64(vprod2x02, vrounding), 31);
+ const __m128i vq31prod3x02 = _mm_srli_epi64(_mm_add_epi64(vprod3x02, vrounding), 31);
+
+ const __m128i vabsprod0x13 = _mm_mul_epu32(vabsacc0x1032, vmultiplier);
+ const __m128i vabsprod1x13 = _mm_mul_epu32(vabsacc1x1032, vmultiplier);
+ const __m128i vabsprod2x13 = _mm_mul_epu32(vabsacc2x1032, vmultiplier);
+ const __m128i vabsprod3x13 = _mm_mul_epu32(vabsacc3x1032, vmultiplier);
+
+ const __m128i vnmask0x13 = _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i vnmask1x13 = _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i vnmask2x13 = _mm_shuffle_epi32(vnmask2x0123, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i vnmask3x13 = _mm_shuffle_epi32(vnmask3x0123, _MM_SHUFFLE(3, 3, 1, 1));
+
+ const __m128i vprod0x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod0x13, vnmask0x13), vnmask0x13);
+ const __m128i vprod1x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod1x13, vnmask1x13), vnmask1x13);
+ const __m128i vprod2x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod2x13, vnmask2x13), vnmask2x13);
+ const __m128i vprod3x13 = _mm_sub_epi64(_mm_xor_si128(vabsprod3x13, vnmask3x13), vnmask3x13);
+
+ const __m128i vq31prod0x13 = _mm_srli_epi64(_mm_add_epi64(vprod0x13, vrounding), 31);
+ const __m128i vq31prod1x13 = _mm_srli_epi64(_mm_add_epi64(vprod1x13, vrounding), 31);
+ const __m128i vq31prod2x13 = _mm_srli_epi64(_mm_add_epi64(vprod2x13, vrounding), 31);
+ const __m128i vq31prod3x13 = _mm_srli_epi64(_mm_add_epi64(vprod3x13, vrounding), 31);
+
+ const __m128i vq31prod0x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod0x02), _mm_castsi128_ps(vq31prod0x13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vq31prod1x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod1x02), _mm_castsi128_ps(vq31prod1x13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vq31prod2x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod2x02), _mm_castsi128_ps(vq31prod2x13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i vq31prod3x0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(vq31prod3x02), _mm_castsi128_ps(vq31prod3x13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i vq31prod0x0123 = _mm_shuffle_epi32(vq31prod0x0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vq31prod1x0123 = _mm_shuffle_epi32(vq31prod1x0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vq31prod2x0123 = _mm_shuffle_epi32(vq31prod2x0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i vq31prod3x0123 = _mm_shuffle_epi32(vq31prod3x0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i vremainder_mask = _mm_load_si128((const __m128i*) params->sse2.remainder_mask);
+
+ const __m128i vrem0x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod0x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod0x0123));
+ const __m128i vrem1x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod1x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod1x0123));
+ const __m128i vrem2x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod2x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod2x0123));
+ const __m128i vrem3x0123 =
+ _mm_add_epi32(_mm_and_si128(vq31prod3x0123, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod3x0123));
+
+ const __m128i vremainder_threshold = _mm_load_si128((const __m128i*) params->sse2.remainder_threshold);
+ const __m128i vshift = _mm_load_si128((const __m128i*) params->sse2.shift);
+
+ vacc0x0123 = _mm_sub_epi32(_mm_sra_epi32(vq31prod0x0123, vshift), _mm_cmpgt_epi32(vrem0x0123, vremainder_threshold));
+ vacc1x0123 = _mm_sub_epi32(_mm_sra_epi32(vq31prod1x0123, vshift), _mm_cmpgt_epi32(vrem1x0123, vremainder_threshold));
+ vacc2x0123 = _mm_sub_epi32(_mm_sra_epi32(vq31prod2x0123, vshift), _mm_cmpgt_epi32(vrem2x0123, vremainder_threshold));
+ vacc3x0123 = _mm_sub_epi32(_mm_sra_epi32(vq31prod3x0123, vshift), _mm_cmpgt_epi32(vrem3x0123, vremainder_threshold));
+
+ const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->sse2.output_zero_point);
+ const __m128i vacc01x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point);
+ const __m128i vacc23x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc2x0123, vacc3x0123), voutput_zero_point);
+ __m128i vout = _mm_packus_epi16(vacc01x0123, vacc23x0123);
+ vout = _mm_min_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_max));
+ vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->sse2.output_min));
+
+ if XNN_LIKELY(nc >= 4) {
+ *((uint32_t*) c3) = (uint32_t) _mm_cvtsi128_si32(_mm_srli_si128(vout, 12));
+ c3 += cn_stride;
+ *((uint32_t*) c2) = (uint32_t) _mm_cvtsi128_si32(_mm_unpackhi_epi32(vout, vout));
+ c2 += cn_stride;
+ *((uint32_t*) c1) = (uint32_t) _mm_cvtsi128_si32(_mm_srli_epi64(vout, 32));
+ c1 += cn_stride;
+ *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
+ c0 += cn_stride;
+
+ a = (const uint8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 4;
+ } else {
+ if (nc & 2) {
+ *((uint16_t*) c3) = (uint16_t) _mm_extract_epi16(vout, 6); c3 += 2;
+ *((uint16_t*) c2) = (uint16_t) _mm_extract_epi16(vout, 4); c2 += 2;
+ *((uint16_t*) c1) = (uint16_t) _mm_extract_epi16(vout, 2); c1 += 2;
+ *((uint16_t*) c0) = (uint16_t) _mm_extract_epi16(vout, 0); c0 += 2;
+ vout = _mm_srli_epi32(vout, 16);
+ }
+ if (nc & 1) {
+ *((uint8_t*) c3) = (uint8_t) _mm_extract_epi16(vout, 6);
+ *((uint8_t*) c2) = (uint8_t) _mm_extract_epi16(vout, 4);
+ *((uint8_t*) c1) = (uint8_t) _mm_extract_epi16(vout, 2);
+ *((uint8_t*) c0) = (uint8_t) _mm_cvtsi128_si32(vout);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/q8-igemm/4x8-neon.c b/src/q8-igemm/4x8-neon.c
new file mode 100644
index 0000000..611d50b
--- /dev/null
+++ b/src/q8-igemm/4x8-neon.c
@@ -0,0 +1,412 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_q8_igemm_ukernel_4x8__neon(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const uint8_t** restrict a,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const uint8_t* zero,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (4 * sizeof(void*)) == 0);
+
+ uint8_t* c0 = c;
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
+ if (mr != 4) {
+ c3 = c2;
+ }
+
+ const uint8x8_t vb_zero_point = vld1_dup_u8((const uint8_t*) ¶ms->neon.kernel_zero_point);
+
+ do {
+ int32x4_t vacc0x0123 = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vacc0x4567 = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vacc1x0123 = vacc0x0123;
+ int32x4_t vacc1x4567 = vacc0x4567;
+ int32x4_t vacc2x0123 = vacc0x0123;
+ int32x4_t vacc2x4567 = vacc0x4567;
+ int32x4_t vacc3x0123 = vacc0x0123;
+ int32x4_t vacc3x4567 = vacc0x4567;
+
+ size_t p = ks;
+ do {
+ const uint8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const uint8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ const uint8_t* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const uint8_t*) ((uintptr_t) a2 + a_offset);
+ }
+ const uint8_t* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const uint8_t*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = kc;
+ while (k >= 8 * sizeof(uint8_t)) {
+ const uint8x8_t va0 = vld1_u8(a0); a0 += 8;
+ const uint8x8_t va1 = vld1_u8(a1); a1 += 8;
+ const uint8x8_t va2 = vld1_u8(a2); a2 += 8;
+ const uint8x8_t va3 = vld1_u8(a3); a3 += 8;
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0));
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1));
+ const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2));
+ const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3));
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 0);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 1);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 2);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 3);
+ }
+
+ k -= 8 * sizeof(uint8_t);
+ }
+ if (k != 0) {
+ const uint8x8_t va0 = vld1_u8(a0);
+ const uint8x8_t va1 = vld1_u8(a1);
+ const uint8x8_t va2 = vld1_u8(a2);
+ const uint8x8_t va3 = vld1_u8(a3);
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0));
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1));
+ const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2));
+ const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3));
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
+ }
+
+ if (k >= 2 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
+
+ if (k > 2 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
+
+ if (k >= 4 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
+
+ if (k > 4 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 0);
+
+ if (k >= 6 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 1);
+
+ if (k > 6 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 2);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier);
+ vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31);
+ vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift);
+ vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
+ #ifdef __aarch64__
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
+
+ uint8x16_t vout0x01234567_1x01234567 = vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567);
+ uint8x16_t vout2x01234567_3x01234567 = vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567);
+ #else
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point);
+
+ uint8x16_t vout0x01234567_1x01234567 = vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567));
+ uint8x16_t vout2x01234567_3x01234567 = vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567));
+ #endif
+ const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min);
+ const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max);
+
+ vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min);
+ vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min);
+ vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max);
+ vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); c3 += cn_stride;
+ vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); c2 += cn_stride;
+ vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); c1 += cn_stride;
+ vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); c0 += cn_stride;
+
+ a = (const uint8_t**restrict) ((uintptr_t) a - ks);
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 2); c3 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 4); c3 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
+ vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
+ vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/q8-igemm/8x8-neon.c b/src/q8-igemm/8x8-neon.c
new file mode 100644
index 0000000..1f68057
--- /dev/null
+++ b/src/q8-igemm/8x8-neon.c
@@ -0,0 +1,657 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/igemm.h>
+
+
+void xnn_q8_igemm_ukernel_8x8__neon(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const uint8_t** restrict a,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const uint8_t* zero,
+ const union xnn_q8_gemm_params params[restrict static 1])
+{
+ assert(mr != 0);
+ assert(mr <= 8);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(ks != 0);
+ assert(ks % (8 * sizeof(void*)) == 0);
+
+ uint8_t* c0 = c;
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 4) {
+ c3 = c2;
+ }
+ uint8_t* c4 = (uint8_t*) ((uintptr_t) c3 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 4) {
+ c4 = c3;
+ }
+ uint8_t* c5 = (uint8_t*) ((uintptr_t) c4 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 6) {
+ c5 = c4;
+ }
+ uint8_t* c6 = (uint8_t*) ((uintptr_t) c5 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 6) {
+ c6 = c5;
+ }
+ uint8_t* c7 = (uint8_t*) ((uintptr_t) c6 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 8) {
+ c7 = c6;
+ }
+
+ const uint8x8_t vb_zero_point = vld1_dup_u8((const uint8_t*) ¶ms->neon.kernel_zero_point);
+
+ do {
+ int32x4_t vacc0x0123 = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vacc0x4567 = vld1q_s32(w); w = (void*) ((uintptr_t) w + sizeof(int32x4_t));
+ int32x4_t vacc1x0123 = vacc0x0123;
+ int32x4_t vacc1x4567 = vacc0x4567;
+ int32x4_t vacc2x0123 = vacc0x0123;
+ int32x4_t vacc2x4567 = vacc0x4567;
+ int32x4_t vacc3x0123 = vacc0x0123;
+ int32x4_t vacc3x4567 = vacc0x4567;
+ int32x4_t vacc4x0123 = vacc0x0123;
+ int32x4_t vacc4x4567 = vacc0x4567;
+ int32x4_t vacc5x0123 = vacc0x0123;
+ int32x4_t vacc5x4567 = vacc0x4567;
+ int32x4_t vacc6x0123 = vacc0x0123;
+ int32x4_t vacc6x4567 = vacc0x4567;
+ int32x4_t vacc7x0123 = vacc0x0123;
+ int32x4_t vacc7x4567 = vacc0x4567;
+
+ size_t p = ks;
+ do {
+ const uint8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const uint8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ const uint8_t* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const uint8_t*) ((uintptr_t) a2 + a_offset);
+ }
+ const uint8_t* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const uint8_t*) ((uintptr_t) a3 + a_offset);
+ }
+ const uint8_t* restrict a4 = a[4];
+ if XNN_UNPREDICTABLE(a4 != zero) {
+ a4 = (const uint8_t*) ((uintptr_t) a4 + a_offset);
+ }
+ const uint8_t* restrict a5 = a[5];
+ if XNN_UNPREDICTABLE(a5 != zero) {
+ a5 = (const uint8_t*) ((uintptr_t) a5 + a_offset);
+ }
+ const uint8_t* restrict a6 = a[6];
+ if XNN_UNPREDICTABLE(a6 != zero) {
+ a6 = (const uint8_t*) ((uintptr_t) a6 + a_offset);
+ }
+ const uint8_t* restrict a7 = a[7];
+ if XNN_UNPREDICTABLE(a7 != zero) {
+ a7 = (const uint8_t*) ((uintptr_t) a7 + a_offset);
+ }
+ a += 8;
+
+ size_t k = kc;
+ while (k >= 8 * sizeof(uint8_t)) {
+ const uint8x8_t va0 = vld1_u8(a0); a0 += 8;
+ const uint8x8_t va1 = vld1_u8(a1); a1 += 8;
+ const uint8x8_t va2 = vld1_u8(a2); a2 += 8;
+ const uint8x8_t va3 = vld1_u8(a3); a3 += 8;
+ const uint8x8_t va4 = vld1_u8(a4); a4 += 8;
+ const uint8x8_t va5 = vld1_u8(a5); a5 += 8;
+ const uint8x8_t va6 = vld1_u8(a6); a6 += 8;
+ const uint8x8_t va7 = vld1_u8(a7); a7 += 8;
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0));
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1));
+ const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2));
+ const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3));
+ const int16x8_t vxa4 = vreinterpretq_s16_u16(vmovl_u8(va4));
+ const int16x8_t vxa5 = vreinterpretq_s16_u16(vmovl_u8(va5));
+ const int16x8_t vxa6 = vreinterpretq_s16_u16(vmovl_u8(va6));
+ const int16x8_t vxa7 = vreinterpretq_s16_u16(vmovl_u8(va7));
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 0);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 0);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 0);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 0);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 0);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 0);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 0);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 0);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 1);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 1);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 1);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 1);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 1);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 1);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 1);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 1);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 2);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 2);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 2);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 2);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 2);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 2);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 2);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 2);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 3);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 3);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 3);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 3);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 3);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 3);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 3);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 3);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 0);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 0);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 0);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 0);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 0);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 0);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 0);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 0);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 0);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 1);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 1);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 1);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 1);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 1);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 1);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 1);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 1);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 1);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 2);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 2);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 2);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 2);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 2);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 2);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 2);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 2);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 2);
+ }
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 3);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 3);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 3);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 3);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 3);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 3);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 3);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 3);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 3);
+ }
+
+ k -= 8 * sizeof(uint8_t);
+ }
+ if (k != 0) {
+ const uint8x8_t va0 = vld1_u8(a0);
+ const uint8x8_t va1 = vld1_u8(a1);
+ const uint8x8_t va2 = vld1_u8(a2);
+ const uint8x8_t va3 = vld1_u8(a3);
+ const uint8x8_t va4 = vld1_u8(a4);
+ const uint8x8_t va5 = vld1_u8(a5);
+ const uint8x8_t va6 = vld1_u8(a6);
+ const uint8x8_t va7 = vld1_u8(a7);
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0));
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1));
+ const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2));
+ const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3));
+ const int16x8_t vxa4 = vreinterpretq_s16_u16(vmovl_u8(va4));
+ const int16x8_t vxa5 = vreinterpretq_s16_u16(vmovl_u8(va5));
+ const int16x8_t vxa6 = vreinterpretq_s16_u16(vmovl_u8(va6));
+ const int16x8_t vxa7 = vreinterpretq_s16_u16(vmovl_u8(va7));
+
+ {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 0);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 0);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 0);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 0);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 0);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 0);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 0);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 0);
+ }
+
+ if (k >= 2 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 1);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 1);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 1);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 1);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 1);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 1);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 1);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 1);
+
+ if (k > 2 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 2);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 2);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 2);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 2);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 2);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 2);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 2);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 2);
+
+ if (k >= 4 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 3);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 3);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 3);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 3);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 3);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 3);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 3);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 3);
+
+ if (k > 4 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 0);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 0);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 0);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 0);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 0);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 0);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 0);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 0);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 0);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 0);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 0);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 0);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 0);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 0);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 0);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 0);
+
+ if (k >= 6 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 1);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 1);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 1);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 1);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 1);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 1);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 1);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 1);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 1);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 1);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 1);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 1);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 1);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 1);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 1);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 1);
+
+ if (k > 6 * sizeof(uint8_t)) {
+ const uint8x8_t vb01234567 = vld1_u8(w); w = (void*) ((uintptr_t) w + sizeof(uint8x8_t));
+ const int16x8_t vxb01234567 = vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
+
+ vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 2);
+ vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 2);
+ vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 2);
+ vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 2);
+ vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 2);
+ vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 2);
+ vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 2);
+ vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 2);
+ vacc4x0123 = vmlal_lane_s16(vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 2);
+ vacc4x4567 = vmlal_lane_s16(vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 2);
+ vacc5x0123 = vmlal_lane_s16(vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 2);
+ vacc5x4567 = vmlal_lane_s16(vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 2);
+ vacc6x0123 = vmlal_lane_s16(vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 2);
+ vacc6x4567 = vmlal_lane_s16(vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 2);
+ vacc7x0123 = vmlal_lane_s16(vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 2);
+ vacc7x4567 = vmlal_lane_s16(vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 2);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ p -= 8 * sizeof(void*);
+ } while (p != 0);
+
+ const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->neon.multiplier);
+ vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
+ vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
+ vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
+ vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
+ vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
+ vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
+ vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier);
+ vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier);
+ vacc4x0123 = vqrdmulhq_s32(vacc4x0123, vmultiplier);
+ vacc4x4567 = vqrdmulhq_s32(vacc4x4567, vmultiplier);
+ vacc5x0123 = vqrdmulhq_s32(vacc5x0123, vmultiplier);
+ vacc5x4567 = vqrdmulhq_s32(vacc5x4567, vmultiplier);
+ vacc6x0123 = vqrdmulhq_s32(vacc6x0123, vmultiplier);
+ vacc6x4567 = vqrdmulhq_s32(vacc6x4567, vmultiplier);
+ vacc7x0123 = vqrdmulhq_s32(vacc7x0123, vmultiplier);
+ vacc7x4567 = vqrdmulhq_s32(vacc7x4567, vmultiplier);
+
+ const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
+ vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
+ vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
+ vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
+ vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
+ vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
+ vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31);
+ vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31);
+ vacc4x0123 = vsraq_n_s32(vacc4x0123, vbicq_s32(vacc4x0123, vzero_shift_mask), 31);
+ vacc4x4567 = vsraq_n_s32(vacc4x4567, vbicq_s32(vacc4x4567, vzero_shift_mask), 31);
+ vacc5x0123 = vsraq_n_s32(vacc5x0123, vbicq_s32(vacc5x0123, vzero_shift_mask), 31);
+ vacc5x4567 = vsraq_n_s32(vacc5x4567, vbicq_s32(vacc5x4567, vzero_shift_mask), 31);
+ vacc6x0123 = vsraq_n_s32(vacc6x0123, vbicq_s32(vacc6x0123, vzero_shift_mask), 31);
+ vacc6x4567 = vsraq_n_s32(vacc6x4567, vbicq_s32(vacc6x4567, vzero_shift_mask), 31);
+ vacc7x0123 = vsraq_n_s32(vacc7x0123, vbicq_s32(vacc7x0123, vzero_shift_mask), 31);
+ vacc7x4567 = vsraq_n_s32(vacc7x4567, vbicq_s32(vacc7x4567, vzero_shift_mask), 31);
+
+ vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
+ vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
+ vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
+ vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
+ vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
+ vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
+ vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift);
+ vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift);
+ vacc4x0123 = vrshlq_s32(vacc4x0123, vright_shift);
+ vacc4x4567 = vrshlq_s32(vacc4x4567, vright_shift);
+ vacc5x0123 = vrshlq_s32(vacc5x0123, vright_shift);
+ vacc5x4567 = vrshlq_s32(vacc5x4567, vright_shift);
+ vacc6x0123 = vrshlq_s32(vacc6x0123, vright_shift);
+ vacc6x4567 = vrshlq_s32(vacc6x4567, vright_shift);
+ vacc7x0123 = vrshlq_s32(vacc7x0123, vright_shift);
+ vacc7x4567 = vrshlq_s32(vacc7x4567, vright_shift);
+
+ const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->neon.output_zero_point);
+ #ifdef __aarch64__
+ const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
+ const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
+ const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
+ const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
+ const int16x8_t vacc4x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc4x0123), vacc4x4567), voutput_zero_point);
+ const int16x8_t vacc5x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc5x0123), vacc5x4567), voutput_zero_point);
+ const int16x8_t vacc6x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc6x0123), vacc6x4567), voutput_zero_point);
+ const int16x8_t vacc7x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc7x0123), vacc7x4567), voutput_zero_point);
+
+ uint8x16_t vout0x01234567_1x01234567 = vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567);
+ uint8x16_t vout2x01234567_3x01234567 = vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567);
+ uint8x16_t vout4x01234567_5x01234567 = vqmovun_high_s16(vqmovun_s16(vacc4x01234567), vacc5x01234567);
+ uint8x16_t vout6x01234567_7x01234567 = vqmovun_high_s16(vqmovun_s16(vacc6x01234567), vacc7x01234567);
+ #else
+ const int16x8_t vacc0x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
+ const int16x8_t vacc1x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
+ const int16x8_t vacc2x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
+ const int16x8_t vacc3x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point);
+ const int16x8_t vacc4x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc4x0123), vqmovn_s32(vacc4x4567)), voutput_zero_point);
+ const int16x8_t vacc5x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc5x0123), vqmovn_s32(vacc5x4567)), voutput_zero_point);
+ const int16x8_t vacc6x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc6x0123), vqmovn_s32(vacc6x4567)), voutput_zero_point);
+ const int16x8_t vacc7x01234567 =
+ vqaddq_s16(vcombine_s16(vqmovn_s32(vacc7x0123), vqmovn_s32(vacc7x4567)), voutput_zero_point);
+
+ uint8x16_t vout0x01234567_1x01234567 = vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567));
+ uint8x16_t vout2x01234567_3x01234567 = vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567));
+ uint8x16_t vout4x01234567_5x01234567 = vcombine_u8(vqmovun_s16(vacc4x01234567), vqmovun_s16(vacc5x01234567));
+ uint8x16_t vout6x01234567_7x01234567 = vcombine_u8(vqmovun_s16(vacc6x01234567), vqmovun_s16(vacc7x01234567));
+ #endif
+ const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min);
+ const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max);
+
+ vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min);
+ vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min);
+ vout4x01234567_5x01234567 = vmaxq_u8(vout4x01234567_5x01234567, voutput_min);
+ vout6x01234567_7x01234567 = vmaxq_u8(vout6x01234567_7x01234567, voutput_min);
+ vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max);
+ vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max);
+ vout4x01234567_5x01234567 = vminq_u8(vout4x01234567_5x01234567, voutput_max);
+ vout6x01234567_7x01234567 = vminq_u8(vout6x01234567_7x01234567, voutput_max);
+
+ if XNN_LIKELY(nc >= 8) {
+ vst1_u8(c7, vget_high_u8(vout6x01234567_7x01234567)); c7 += cn_stride;
+ vst1_u8(c6, vget_low_u8(vout6x01234567_7x01234567)); c6 += cn_stride;
+ vst1_u8(c5, vget_high_u8(vout4x01234567_5x01234567)); c5 += cn_stride;
+ vst1_u8(c4, vget_low_u8(vout4x01234567_5x01234567)); c4 += cn_stride;
+ vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); c3 += cn_stride;
+ vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); c2 += cn_stride;
+ vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); c1 += cn_stride;
+ vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); c0 += cn_stride;
+
+ a = (const uint8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 8;
+ } else {
+ if (nc & 4) {
+ vst1q_lane_u32(__builtin_assume_aligned(c7, 1), vreinterpretq_u32_u8(vout6x01234567_7x01234567), 2); c7 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c6, 1), vreinterpretq_u32_u8(vout6x01234567_7x01234567), 0); c6 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c5, 1), vreinterpretq_u32_u8(vout4x01234567_5x01234567), 2); c5 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpretq_u32_u8(vout4x01234567_5x01234567), 0); c4 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 2); c3 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 0); c2 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 2); c1 += 4;
+ vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 0); c0 += 4;
+ vout6x01234567_7x01234567 = vextq_u8(vout6x01234567_7x01234567, vout6x01234567_7x01234567, 4);
+ vout4x01234567_5x01234567 = vextq_u8(vout4x01234567_5x01234567, vout4x01234567_5x01234567, 4);
+ vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
+ vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
+ }
+ if (nc & 2) {
+ vst1q_lane_u16(__builtin_assume_aligned(c7, 1), vreinterpretq_u16_u8(vout6x01234567_7x01234567), 4); c7 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c6, 1), vreinterpretq_u16_u8(vout6x01234567_7x01234567), 0); c6 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c5, 1), vreinterpretq_u16_u8(vout4x01234567_5x01234567), 4); c5 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c4, 1), vreinterpretq_u16_u8(vout4x01234567_5x01234567), 0); c4 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 4); c3 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 0); c2 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 4); c1 += 2;
+ vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 0); c0 += 2;
+ vout6x01234567_7x01234567 = vextq_u8(vout6x01234567_7x01234567, vout6x01234567_7x01234567, 2);
+ vout4x01234567_5x01234567 = vextq_u8(vout4x01234567_5x01234567, vout4x01234567_5x01234567, 2);
+ vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
+ vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
+ }
+ if (nc & 1) {
+ vst1q_lane_u8(c7, vout6x01234567_7x01234567, 8);
+ vst1q_lane_u8(c6, vout6x01234567_7x01234567, 0);
+ vst1q_lane_u8(c5, vout4x01234567_5x01234567, 8);
+ vst1q_lane_u8(c4, vout4x01234567_5x01234567, 0);
+ vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
+ vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
+ vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
+ vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
+ }
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/q8-vadd/neon.c b/src/q8-vadd/neon.c
new file mode 100644
index 0000000..5b3b421
--- /dev/null
+++ b/src/q8-vadd/neon.c
@@ -0,0 +1,248 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <arm_neon.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/vadd.h>
+
+
+void xnn_q8_vadd_ukernel__neon(
+ size_t n,
+ const uint8_t* a,
+ const uint8_t* b,
+ uint8_t* y,
+ const union xnn_q8_add_params params[restrict static 1])
+{
+ const uint8x8_t va_zero_point = vld1_dup_u8(¶ms->neon.a_zero_point);
+ const uint8x8_t vb_zero_point = vld1_dup_u8(¶ms->neon.b_zero_point);
+ const int16x8_t vy_zero_point = vld1q_dup_s16(¶ms->neon.y_zero_point);
+ const int32x4_t va_multiplier = vld1q_dup_s32(¶ms->neon.a_multiplier);
+ const int32x4_t vb_multiplier = vld1q_dup_s32(¶ms->neon.b_multiplier);
+ const int32x4_t vright_shift = vld1q_dup_s32(¶ms->neon.right_shift);
+ const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
+ const uint8x16_t vy_max = vld1q_dup_u8(¶ms->neon.y_max);
+ const uint8x16_t vy_min = vld1q_dup_u8(¶ms->neon.y_min);
+#ifdef __aarch64__
+ for (; n >= 32 * sizeof(uint8_t); n -= 32 * sizeof(uint8_t)) {
+ const uint8x16_t va01 = vld1q_u8(a); a += 16;
+ const uint8x16_t vb01 = vld1q_u8(b); b += 16;
+ const uint8x16_t va23 = vld1q_u8(a); a += 16;
+ const uint8x16_t vb23 = vld1q_u8(b); b += 16;
+
+ /* Subtract zero point */
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va01), va_zero_point));
+ const int16x8_t vxb0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb01), vb_zero_point));
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va01), va_zero_point));
+ const int16x8_t vxb1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb01), vb_zero_point));
+ const int16x8_t vxa2 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va23), va_zero_point));
+ const int16x8_t vxb2 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb23), vb_zero_point));
+ const int16x8_t vxa3 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va23), va_zero_point));
+ const int16x8_t vxb3 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb23), vb_zero_point));
+
+ /* Multiply by factors and accumulate products */
+ int32x4_t vacc0_lo = vmulq_s32(vmovl_s16(vget_low_s16(vxa0)), va_multiplier);
+ int32x4_t vacc1_lo = vmulq_s32(vmovl_s16(vget_low_s16(vxa1)), va_multiplier);
+ int32x4_t vacc2_lo = vmulq_s32(vmovl_s16(vget_low_s16(vxa2)), va_multiplier);
+ int32x4_t vacc3_lo = vmulq_s32(vmovl_s16(vget_low_s16(vxa3)), va_multiplier);
+ int32x4_t vacc0_hi = vmulq_s32(vmovl_high_s16(vxa0), va_multiplier);
+ int32x4_t vacc1_hi = vmulq_s32(vmovl_high_s16(vxa1), va_multiplier);
+ int32x4_t vacc2_hi = vmulq_s32(vmovl_high_s16(vxa2), va_multiplier);
+ int32x4_t vacc3_hi = vmulq_s32(vmovl_high_s16(vxa3), va_multiplier);
+
+ vacc0_lo = vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vxb0)), vb_multiplier);
+ vacc1_lo = vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vxb1)), vb_multiplier);
+ vacc2_lo = vmlaq_s32(vacc2_lo, vmovl_s16(vget_low_s16(vxb2)), vb_multiplier);
+ vacc3_lo = vmlaq_s32(vacc3_lo, vmovl_s16(vget_low_s16(vxb3)), vb_multiplier);
+ vacc0_hi = vmlaq_s32(vacc0_hi, vmovl_high_s16(vxb0), vb_multiplier);
+ vacc1_hi = vmlaq_s32(vacc1_hi, vmovl_high_s16(vxb1), vb_multiplier);
+ vacc2_hi = vmlaq_s32(vacc2_hi, vmovl_high_s16(vxb2), vb_multiplier);
+ vacc3_hi = vmlaq_s32(vacc3_hi, vmovl_high_s16(vxb3), vb_multiplier);
+
+ /* Shift right and round */
+ vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
+ vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
+ vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
+ vacc3_lo = vsraq_n_s32(vacc3_lo, vbicq_s32(vacc3_lo, vzero_shift_mask), 31);
+ vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
+ vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
+ vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
+ vacc3_hi = vsraq_n_s32(vacc3_hi, vbicq_s32(vacc3_hi, vzero_shift_mask), 31);
+
+ vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
+ vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
+ vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
+ vacc3_lo = vrshlq_s32(vacc3_lo, vright_shift);
+ vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
+ vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
+ vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
+ vacc3_hi = vrshlq_s32(vacc3_hi, vright_shift);
+
+ /* Pack, saturate, and add output zero point */
+ const int16x8_t vacc0 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), vy_zero_point);
+ const int16x8_t vacc1 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), vy_zero_point);
+ const int16x8_t vacc2 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), vy_zero_point);
+ const int16x8_t vacc3 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3_lo), vacc3_hi), vy_zero_point);
+
+ uint8x16_t vy01 = vqmovun_high_s16(vqmovun_s16(vacc0), vacc1);
+ uint8x16_t vy23 = vqmovun_high_s16(vqmovun_s16(vacc2), vacc3);
+
+ vy01 = vmaxq_u8(vy01, vy_min);
+ vy23 = vmaxq_u8(vy23, vy_min);
+ vy01 = vminq_u8(vy01, vy_max);
+ vy23 = vminq_u8(vy23, vy_max);
+
+ vst1q_u8(y, vy01); y += 16;
+ vst1q_u8(y, vy23); y += 16;
+ }
+#else
+ for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
+ const uint8x16_t va01 = vld1q_u8(a); a += 16;
+ const uint8x16_t vb01 = vld1q_u8(b); b += 16;
+
+ /* Subtract zero point */
+ const int16x8_t vxa0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va01), va_zero_point));
+ const int16x8_t vxb0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb01), vb_zero_point));
+ const int16x8_t vxa1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va01), va_zero_point));
+ const int16x8_t vxb1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb01), vb_zero_point));
+
+ /* Multiply by factors and accumulate products */
+ int32x4_t vacc0_lo = vmulq_s32(vmovl_s16(vget_low_s16(vxa0)), va_multiplier);
+ int32x4_t vacc1_lo = vmulq_s32(vmovl_s16(vget_low_s16(vxa1)), va_multiplier);
+ int32x4_t vacc0_hi = vmulq_s32(vmovl_s16(vget_high_s16(vxa0)), va_multiplier);
+ int32x4_t vacc1_hi = vmulq_s32(vmovl_s16(vget_high_s16(vxa1)), va_multiplier);
+
+ __builtin_prefetch(a + 640);
+ __builtin_prefetch(b + 640);
+
+ vacc0_lo = vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vxb0)), vb_multiplier);
+ vacc1_lo = vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vxb1)), vb_multiplier);
+ vacc0_hi = vmlaq_s32(vacc0_hi, vmovl_s16(vget_high_s16(vxb0)), vb_multiplier);
+ vacc1_hi = vmlaq_s32(vacc1_hi, vmovl_s16(vget_high_s16(vxb1)), vb_multiplier);
+
+ /* Shift right and round */
+ vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
+ vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
+ vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
+ vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
+
+ vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
+ vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
+ vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
+ vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
+
+ /* Pack, saturate, and add output zero point */
+ const int16x8_t vacc0 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), vy_zero_point);
+ const int16x8_t vacc1 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), vy_zero_point);
+
+ uint8x16_t vy01 = vcombine_u8(vqmovun_s16(vacc0), vqmovun_s16(vacc1));
+ vy01 = vmaxq_u8(vy01, vy_min);
+ vy01 = vminq_u8(vy01, vy_max);
+
+ vst1q_u8(y, vy01); y += 16;
+ }
+#endif
+ for (; n >= 8 * sizeof(uint8_t); n -= 8 * sizeof(uint8_t)) {
+ const uint8x8_t va = vld1_u8(a); a += 8;
+ const uint8x8_t vb = vld1_u8(b); b += 8;
+
+ /* Subtract zero point */
+ const int16x8_t vxa = vreinterpretq_s16_u16(vsubl_u8(va, va_zero_point));
+ const int16x8_t vxb = vreinterpretq_s16_u16(vsubl_u8(vb, vb_zero_point));
+
+ /* Multiply by factors and accumulate products */
+ int32x4_t vacc_lo = vmulq_s32(vmovl_s16(vget_low_s16(vxa)), va_multiplier);
+#ifdef __aarch64__
+ int32x4_t vacc_hi = vmulq_s32(vmovl_high_s16(vxa), va_multiplier);
+#else
+ int32x4_t vacc_hi = vmulq_s32(vmovl_s16(vget_high_s16(vxa)), va_multiplier);
+#endif
+
+ vacc_lo = vmlaq_s32(vacc_lo, vmovl_s16(vget_low_s16(vxb)), vb_multiplier);
+#ifdef __aarch64__
+ vacc_hi = vmlaq_s32(vacc_hi, vmovl_high_s16(vxb), vb_multiplier);
+#else
+ vacc_hi = vmlaq_s32(vacc_hi, vmovl_s16(vget_high_s16(vxb)), vb_multiplier);
+#endif
+
+ /* Shift right and round */
+ vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
+ vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
+
+ vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
+ vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
+
+ /* Pack, saturate, and add output zero point */
+#ifdef __aarch64__
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vy_zero_point);
+#else
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), vy_zero_point);
+#endif
+
+ uint8x8_t vy = vqmovun_s16(vacc);
+ vy = vmax_u8(vy, vget_low_u8(vy_min));
+ vy = vmin_u8(vy, vget_low_u8(vy_max));
+
+ vst1_u8(y, vy); y += 8;
+ }
+ if (n != 0) {
+ const uint8x8_t va = vld1_u8(a);
+ const uint8x8_t vb = vld1_u8(b);
+
+ /* Subtract zero point */
+ const int16x8_t vxa = vreinterpretq_s16_u16(vsubl_u8(va, va_zero_point));
+ const int16x8_t vxb = vreinterpretq_s16_u16(vsubl_u8(vb, vb_zero_point));
+
+ /* Multiply by factors and accumulate products */
+ int32x4_t vacc_lo = vmulq_s32(vmovl_s16(vget_low_s16(vxa)), va_multiplier);
+#ifdef __aarch64__
+ int32x4_t vacc_hi = vmulq_s32(vmovl_high_s16(vxa), va_multiplier);
+#else
+ int32x4_t vacc_hi = vmulq_s32(vmovl_s16(vget_high_s16(vxa)), va_multiplier);
+#endif
+
+ vacc_lo = vmlaq_s32(vacc_lo, vmovl_s16(vget_low_s16(vxb)), vb_multiplier);
+#ifdef __aarch64__
+ vacc_hi = vmlaq_s32(vacc_hi, vmovl_high_s16(vxb), vb_multiplier);
+#else
+ vacc_hi = vmlaq_s32(vacc_hi, vmovl_s16(vget_high_s16(vxb)), vb_multiplier);
+#endif
+
+ /* Shift right and round */
+ vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
+ vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
+
+ vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
+ vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
+
+ /* Pack, saturate, and add output zero point */
+#ifdef __aarch64__
+ const int16x8_t vacc = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vy_zero_point);
+#else
+ const int16x8_t vacc = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), vy_zero_point);
+#endif
+
+ uint8x8_t vy = vqmovun_s16(vacc);
+ vy = vmax_u8(vy, vget_low_u8(vy_min));
+ vy = vmin_u8(vy, vget_low_u8(vy_max));
+
+ if (n & (4 * sizeof(uint8_t))) {
+ vst1_lane_u32(__builtin_assume_aligned(y, 1), vreinterpret_u32_u8(vy), 0); y += 4;
+ vy = vext_u8(vy, vy, 4);
+ }
+ if (n & (2 * sizeof(uint8_t))) {
+ vst1_lane_u16(__builtin_assume_aligned(y, 1), vreinterpret_u16_u8(vy), 0); y += 2;
+ vy = vext_u8(vy, vy, 2);
+ }
+ if (n & (1 * sizeof(uint8_t))) {
+ vst1_lane_u8(y, vy, 0);
+ }
+ }
+}
diff --git a/src/q8-vadd/scalar.c b/src/q8-vadd/scalar.c
new file mode 100644
index 0000000..6c20a0e
--- /dev/null
+++ b/src/q8-vadd/scalar.c
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/vadd.h>
+
+
+void xnn_q8_vadd_ukernel__scalar(
+ size_t n,
+ const uint8_t* a,
+ const uint8_t* b,
+ uint8_t* y,
+ const union xnn_q8_add_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const int32_t vzero_point_product = params->scalar.zero_point_product;
+ const uint32_t va_multiplier = params->scalar.a_multiplier;
+ const uint32_t vb_multiplier = params->scalar.b_multiplier;
+ const uint32_t vshift = params->scalar.shift;
+ const int32_t vremainder_mask = params->scalar.remainder_mask;
+ const int32_t vremainder_threshold = params->scalar.remainder_threshold;
+ const int32_t vy_zero_point = params->scalar.y_zero_point;
+ const int32_t vy_max = params->scalar.y_max;
+ const int32_t vy_min = params->scalar.y_min;
+
+ do {
+ const int32_t va = (int32_t) (uint32_t) *a++;
+ const int32_t vb = (int32_t) (uint32_t) *b++;
+
+ /* Multiply by factors */
+ const int32_t va_product = va * va_multiplier;
+ const int32_t vb_product = vb * vb_multiplier;
+
+ /* Accumulate products */
+ const int32_t vacc = vzero_point_product + va_product + vb_product;
+
+ /* Shift right and round */
+ const int32_t vremainder = (vacc & vremainder_mask) - (int32_t) (vacc < 0);
+ int32_t vy = asr_s32(vacc, vshift) + (int32_t) (vremainder > vremainder_threshold);
+
+ /* Pack, saturate, and add output zero point */
+ vy += vy_zero_point;
+ vy = vy < vy_min ? vy_min : vy;
+ vy = vy > vy_max ? vy_max : vy;
+
+ *y++ = vy;
+
+ n -= sizeof(uint8_t);
+ } while (n != 0);
+}
diff --git a/src/q8-vadd/sse2.c b/src/q8-vadd/sse2.c
new file mode 100644
index 0000000..c962f23
--- /dev/null
+++ b/src/q8-vadd/sse2.c
@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <immintrin.h>
+
+#include <xnnpack/common.h>
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/vadd.h>
+
+
+void xnn_q8_vadd_ukernel__sse2(
+ size_t n,
+ const uint8_t* a,
+ const uint8_t* b,
+ uint8_t* y,
+ const union xnn_q8_add_params params[restrict static 1])
+{
+ const __m128i vzero_point_product = _mm_load_si128((const __m128i*) ¶ms->sse2.zero_point_product);
+ const __m128i va_multiplier_lo = _mm_load_si128((const __m128i*) ¶ms->sse2.a_multiplier_lo);
+ const __m128i va_multiplier_hi = _mm_load_si128((const __m128i*) ¶ms->sse2.a_multiplier_hi);
+ const __m128i vb_multiplier_lo = _mm_load_si128((const __m128i*) ¶ms->sse2.b_multiplier_lo);
+ const __m128i vb_multiplier_hi = _mm_load_si128((const __m128i*) ¶ms->sse2.b_multiplier_hi);
+ const __m128i vremainder_mask = _mm_load_si128((const __m128i*) params->sse2.remainder_mask);
+ const __m128i vremainder_threshold = _mm_load_si128((const __m128i*) params->sse2.remainder_threshold);
+ const __m128i vshift = _mm_cvtsi32_si128((int) params->sse2.shift);
+
+ const __m128i vzero = _mm_setzero_si128();
+ for (; n >= 8 * sizeof(uint8_t); n -= 8 * sizeof(uint8_t)) {
+ const __m128i va = _mm_loadl_epi64((const __m128i*) a);
+ a += 8;
+ const __m128i vb = _mm_loadl_epi64((const __m128i*) b);
+ b += 8;
+
+ const __m128i vxa = _mm_unpacklo_epi8(va, vzero);
+ const __m128i vxb = _mm_unpacklo_epi8(vb, vzero);
+
+ /* Multiply by factors */
+ const __m128i va_product_lo = _mm_mullo_epi16(vxa, va_multiplier_lo);
+ const __m128i va_product_hi =
+ _mm_add_epi16(_mm_mulhi_epu16(vxa, va_multiplier_lo), _mm_mullo_epi16(vxa, va_multiplier_hi));
+
+ const __m128i vb_product_lo = _mm_mullo_epi16(vxb, vb_multiplier_lo);
+ const __m128i vb_product_hi =
+ _mm_add_epi16(_mm_mulhi_epu16(vxb, vb_multiplier_lo), _mm_mullo_epi16(vxb, vb_multiplier_hi));
+
+ /* Accumulate products */
+ __m128i vacc_lo = _mm_add_epi32(vzero_point_product, _mm_unpacklo_epi16(va_product_lo, va_product_hi));
+ __m128i vacc_hi = _mm_add_epi32(vzero_point_product, _mm_unpackhi_epi16(va_product_lo, va_product_hi));
+
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vb_product_lo, vb_product_hi));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vb_product_lo, vb_product_hi));
+
+ /* Shift right and round */
+ const __m128i vrem_lo =
+ _mm_add_epi32(_mm_and_si128(vacc_lo, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo));
+ const __m128i vrem_hi =
+ _mm_add_epi32(_mm_and_si128(vacc_hi, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi));
+
+ vacc_lo = _mm_sub_epi32(_mm_sra_epi32(vacc_lo, vshift), _mm_cmpgt_epi32(vrem_lo, vremainder_threshold));
+ vacc_hi = _mm_sub_epi32(_mm_sra_epi32(vacc_hi, vshift), _mm_cmpgt_epi32(vrem_hi, vremainder_threshold));
+
+ /* Pack, saturate, and add output zero point */
+ const __m128i vy_zero_point = _mm_load_si128((const __m128i*) params->sse2.y_zero_point);
+ const __m128i vacc = _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), vy_zero_point);
+ __m128i vy = _mm_packus_epi16(vacc, vacc);
+ vy = _mm_max_epu8(vy, _mm_load_si128((const __m128i*) params->sse2.y_min));
+ vy = _mm_min_epu8(vy, _mm_load_si128((const __m128i*) params->sse2.y_max));
+
+ _mm_storel_epi64((__m128i*) y, vy);
+ y += 8;
+ }
+ if (n != 0) {
+ const __m128i va = _mm_loadl_epi64((const __m128i*) a);
+ const __m128i vb = _mm_loadl_epi64((const __m128i*) b);
+
+ const __m128i vxa = _mm_unpacklo_epi8(va, vzero);
+ const __m128i vxb = _mm_unpacklo_epi8(vb, vzero);
+
+ /* Multiply by factors */
+ const __m128i va_product_lo = _mm_mullo_epi16(vxa, va_multiplier_lo);
+ const __m128i va_product_hi =
+ _mm_add_epi16(_mm_mulhi_epu16(vxa, va_multiplier_lo), _mm_mullo_epi16(vxa, va_multiplier_hi));
+
+ const __m128i vb_product_lo = _mm_mullo_epi16(vxb, vb_multiplier_lo);
+ const __m128i vb_product_hi =
+ _mm_add_epi16(_mm_mulhi_epu16(vxb, vb_multiplier_lo), _mm_mullo_epi16(vxb, vb_multiplier_hi));
+
+ /* Accumulate products */
+ __m128i vacc_lo = _mm_add_epi32(vzero_point_product, _mm_unpacklo_epi16(va_product_lo, va_product_hi));
+ __m128i vacc_hi = _mm_add_epi32(vzero_point_product, _mm_unpackhi_epi16(va_product_lo, va_product_hi));
+
+ vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vb_product_lo, vb_product_hi));
+ vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vb_product_lo, vb_product_hi));
+
+ /* Shift right and round */
+ const __m128i vrem_lo =
+ _mm_add_epi32(_mm_and_si128(vacc_lo, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo));
+ const __m128i vrem_hi =
+ _mm_add_epi32(_mm_and_si128(vacc_hi, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi));
+
+ vacc_lo = _mm_sub_epi32(_mm_sra_epi32(vacc_lo, vshift), _mm_cmpgt_epi32(vrem_lo, vremainder_threshold));
+ vacc_hi = _mm_sub_epi32(_mm_sra_epi32(vacc_hi, vshift), _mm_cmpgt_epi32(vrem_hi, vremainder_threshold));
+
+ /* Pack, saturate, and add output zero point */
+ const __m128i vy_zero_point = _mm_load_si128((const __m128i*) params->sse2.y_zero_point);
+ const __m128i vacc = _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), vy_zero_point);
+ __m128i vy = _mm_packus_epi16(vacc, vacc);
+ vy = _mm_max_epu8(vy, _mm_load_si128((const __m128i*) params->sse2.y_min));
+ vy = _mm_min_epu8(vy, _mm_load_si128((const __m128i*) params->sse2.y_max));
+
+ if (n & (4 * sizeof(uint8_t))) {
+ *((uint32_t*) y) = (uint32_t) _mm_cvtsi128_si32(vy);
+ vy = _mm_srli_epi64(vy, 32);
+ y += 4;
+ }
+ if (n & (2 * sizeof(uint8_t))) {
+ *((uint16_t*) y) = (uint16_t) _mm_extract_epi16(vy, 0);
+ vy = _mm_srli_epi32(vy, 16);
+ y += 2;
+ }
+ if (n & (1 * sizeof(uint8_t))) {
+ *((uint8_t*) y) = (uint8_t) _mm_cvtsi128_si32(vy);
+ }
+ }
+}
diff --git a/src/requantization/fp32-neon.c b/src/requantization/fp32-neon.c
new file mode 100644
index 0000000..4d2f67c
--- /dev/null
+++ b/src/requantization/fp32-neon.c
@@ -0,0 +1,152 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_fp32__neon(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const float32x4_t vscale = vdupq_n_f32(scale);
+#ifdef __aarch64__
+ const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t) zero_point);
+ const uint8x16_t vqmin = vdupq_n_u8(qmin);
+ const uint8x16_t vqmax = vdupq_n_u8(qmax);
+#else
+ const float32x4_t vfmin = vdupq_n_f32((float) ((int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point));
+ const float32x4_t vfmax = vdupq_n_f32((float) ((int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point));
+ const float32x4_t vfmagic = vdupq_n_f32(12582912.0f);
+ const int32x4_t vimagic = vdupq_n_s32(INT32_C(0x4B400000) - (int32_t)(uint32_t) zero_point);
+#endif
+ for (; n != 0; n -= 16) {
+ const int32x4_t x = vld1q_s32(input);
+ const int32x4_t y = vld1q_s32(input + 4);
+ const int32x4_t z = vld1q_s32(input + 8);
+ const int32x4_t w = vld1q_s32(input + 12);
+ input += 16;
+
+ /*
+ * Convert int32_t input to FP32 and multiply by FP32 scale.
+ * Both operations involve statistically unbiased roundings:
+ * - Large int32_t values can't be exactly represented as FP32. The conversion instruction in ARM NEON would
+ * round it to nearest FP32 value with ties to even.
+ * - Product of two FP32 values is generally not exactly representation as an FP32 value, and will be rounded
+ * to nearest FP32 value with ties to even.
+ */
+ const float32x4_t x_scaled = vmulq_f32(vcvtq_f32_s32(x), vscale);
+ const float32x4_t y_scaled = vmulq_f32(vcvtq_f32_s32(y), vscale);
+ const float32x4_t z_scaled = vmulq_f32(vcvtq_f32_s32(z), vscale);
+ const float32x4_t w_scaled = vmulq_f32(vcvtq_f32_s32(w), vscale);
+
+#ifdef __aarch64__
+ /*
+ * Leverage "Floating-point Convert to Signed integer, rouding to nearest with ties to even" instruction.
+ * This is an ARMv8 instruction (always available in AArch64), which saturates result on overflow.
+ * We don't need to specifically consider saturated results, they will be clamped at the last stage.
+ */
+ const int32x4_t x_rounded = vcvtnq_s32_f32(x_scaled);
+ const int32x4_t y_rounded = vcvtnq_s32_f32(y_scaled);
+ const int32x4_t z_rounded = vcvtnq_s32_f32(z_scaled);
+ const int32x4_t w_rounded = vcvtnq_s32_f32(w_scaled);
+
+ /*
+ * Standard final sequence on ARM NEON:
+ * - Pack to int16_t and saturate
+ * - Add zero point
+ * - Pack to uint8_t and saturate
+ * - Clamp between qmin and qmax
+ */
+ const int16x8_t xy_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(x_rounded), y_rounded), vzero_point);
+ const int16x8_t zw_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(z_rounded), w_rounded), vzero_point);
+ const uint8x16_t xyzw_packed = vqmovun_high_s16(vqmovun_s16(xy_packed), zw_packed);
+ const uint8x16_t xyzw_clamped = vmaxq_u8(vminq_u8(xyzw_packed, vqmax), vqmin);
+
+ vst1q_u8(output, xyzw_clamped);
+ output += 16;
+#else
+ /*
+ * ARMv7 NEON offers only a floating-point to integer conversion instruction with rounding towards zero.
+ * In lieu of conversion instruction with rounding-to-nearest-even, we use a magic trick of adding a large
+ * number (1.5 * 2**23) to scaled value to cause rounding to integer, and then substracing this magic number as
+ * integer. This trick works only in a limited range (absolute value of input must be less than 2**22), so
+ * generally we have to clamp input to this range before using the magic. However, clamping to any smaller range
+ * works just as well, and thus we clamp to [qmin - zero point, qmax - zero point] range so that after we add
+ * zero point to the result, it gets into target [qmin, qmax] range.
+ */
+ const float32x4_t x_clamped = vminq_f32(vmaxq_f32(x_scaled, vfmin), vfmax);
+ const float32x4_t y_clamped = vminq_f32(vmaxq_f32(y_scaled, vfmin), vfmax);
+ const float32x4_t z_clamped = vminq_f32(vmaxq_f32(z_scaled, vfmin), vfmax);
+ const float32x4_t w_clamped = vminq_f32(vmaxq_f32(w_scaled, vfmin), vfmax);
+
+ /*
+ * Conversion to integer using the "magic trick". Rounding is performed in the output of addition operation,
+ * and result is rounded to nearest even integer with ties to even.
+ */
+ const int32x4_t x_biased = vsubq_s32(vreinterpretq_s32_f32(vaddq_f32(x_clamped, vfmagic)), vimagic);
+ const int32x4_t y_biased = vsubq_s32(vreinterpretq_s32_f32(vaddq_f32(y_clamped, vfmagic)), vimagic);
+ const int32x4_t z_biased = vsubq_s32(vreinterpretq_s32_f32(vaddq_f32(z_clamped, vfmagic)), vimagic);
+ const int32x4_t w_biased = vsubq_s32(vreinterpretq_s32_f32(vaddq_f32(w_clamped, vfmagic)), vimagic);
+
+ /*
+ * Select low 8 bits of each 32-bit integer in the vectors for the output.
+ * Since result is already clamped to [qmin, qmax] subrange of [0, 255], saturation is not needed.
+ */
+ const int16x8_t xy_packed = vcombine_s16(vmovn_s32(x_biased), vmovn_s32(y_biased));
+ const int16x8_t zw_packed = vcombine_s16(vmovn_s32(z_biased), vmovn_s32(w_biased));
+ const uint8x16_t xyzw_packed = vreinterpretq_u8_s8(vcombine_s8(vmovn_s16(xy_packed), vmovn_s16(zw_packed)));
+
+ /*
+ * AArch32 version:
+ * 4x VCVT.F32.S32 Qd, Qm
+ * 4x VMUL.F32 Qd, Qm, Qn
+ * 4x VMIN.F32 Qd, Qm, Qn
+ * 4x VMAX.F32 Qd, Qm, Qn
+ * 4x VADD.F32 Qd, Qm, Qn
+ * 4x VSUB.S32 Qd, Qm, Qn
+ * 4x VMOVN.I32 Dd, Qm
+ * 2x VMOVN.I16 Dd, Qm
+ * ---------------------
+ * 30 instructions total
+ *
+ * AArch64 version:
+ * 4x SCVTF Vd.4S, Vn.4S
+ * 4x FMUL Vd.4S, Vn.4S, Vm.4S
+ * 4x FCVTNS Vd.4S, Vn.4S
+ * 2x SQXTN Vd.4H, Vn.4S
+ * 2x SQXTN2 Vd.8H, Vn.4S
+ * 2x ADD Vd.8H, Vn.8H, Vm.8H
+ * 1x SQXTUN Vd.8B, Vn.8H
+ * 1x SQXTUN2 Vd.16B, Vn.8H
+ * 1x UMIN Vd.16B, Vn.16B, Vm.16B
+ * 1x UMAX Vd.16B, Vn.16B, Vm.16B
+ * ---------------------
+ * 22 instructions total
+ */
+
+ vst1q_u8(output, xyzw_packed);
+ output += 16;
+#endif
+ }
+}
diff --git a/src/requantization/fp32-psimd.c b/src/requantization/fp32-psimd.c
new file mode 100644
index 0000000..922a7e2
--- /dev/null
+++ b/src/requantization/fp32-psimd.c
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <psimd.h>
+
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_fp32__psimd(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const psimd_f32 vscale = psimd_splat_f32(scale);
+ const psimd_f32 vfmin = psimd_splat_f32((float) ((int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point));
+ const psimd_f32 vfmax = psimd_splat_f32((float) ((int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point));
+ const psimd_f32 vfmagic = psimd_splat_f32(12582912.0f);
+ const psimd_s32 vimagic = psimd_splat_s32(INT32_C(0x4B400000) - (int32_t)(uint32_t) zero_point);
+ for (; n != 0; n -= 16) {
+ const psimd_s32 x = psimd_load_s32(input);
+ const psimd_s32 y = psimd_load_s32(input + 4);
+ const psimd_s32 z = psimd_load_s32(input + 8);
+ const psimd_s32 w = psimd_load_s32(input + 12);
+ input += 16;
+
+ /*
+ * Convert int32_t input to FP32 and multiply by FP32 scale.
+ * Both operations involve roundings:
+ * - Large int32_t values can't be exactly represented as FP32. We expect that conversion instruction would
+ * round it to nearest FP32 value with ties to even, but Clang documentation for __builtin_convertvector does
+ * not guaratee that.
+ * - Product of two FP32 values is generally not exactly representation as an FP32 value, and will be rounded
+ * to nearest FP32 value with ties to even.
+ */
+ const psimd_f32 x_scaled = psimd_cvt_s32_f32(x) * vscale;
+ const psimd_f32 y_scaled = psimd_cvt_s32_f32(y) * vscale;
+ const psimd_f32 z_scaled = psimd_cvt_s32_f32(z) * vscale;
+ const psimd_f32 w_scaled = psimd_cvt_s32_f32(w) * vscale;
+
+ /*
+ * Clang/gcc vector extension does not provide an intrinsics for a floating-point to integer conversion
+ * operation with rounding-to-nearest-even. In lieu of such intrinsic, we use a magic trick of adding a large
+ * number (1.5 * 2**23) to scaled value to cause rounding to integer, and then substracing this magic number as
+ * integer. This trick works only in a limited range (absolute value of input must be less than 2**22), so
+ * generally we have to clamp input to this range before using the magic. However, clamping to any smaller range
+ * works just as well, and thus we clamp to [qmin - zero point, qmax - zero point] range so that after we add
+ * zero point to the result, it gets into target [qmin, qmax] range.
+ */
+ const psimd_f32 x_clamped = psimd_min_f32(psimd_max_f32(x_scaled, vfmin), vfmax);
+ const psimd_f32 y_clamped = psimd_min_f32(psimd_max_f32(y_scaled, vfmin), vfmax);
+ const psimd_f32 z_clamped = psimd_min_f32(psimd_max_f32(z_scaled, vfmin), vfmax);
+ const psimd_f32 w_clamped = psimd_min_f32(psimd_max_f32(w_scaled, vfmin), vfmax);
+
+ /*
+ * Conversion to integer using the "magic trick". Rounding is performed in the output of addition operation,
+ * and result is rounded to nearest even integer with ties to even.
+ */
+ const psimd_s32 x_biased = (psimd_s32)(x_clamped + vfmagic) - vimagic;
+ const psimd_s32 y_biased = (psimd_s32)(y_clamped + vfmagic) - vimagic;
+ const psimd_s32 z_biased = (psimd_s32)(z_clamped + vfmagic) - vimagic;
+ const psimd_s32 w_biased = (psimd_s32)(w_clamped + vfmagic) - vimagic;
+
+ /*
+ * Select low 8 bits of each 32-bit integer in the vectors for the output.
+ * Since result is already clamped to [qmin, qmax] subrange of [0, 255], saturation is not needed.
+ */
+ const psimd_u16 xy_packed = psimd_concat_even_u16((psimd_u16) x_biased, (psimd_u16) y_biased);
+ const psimd_u16 zw_packed = psimd_concat_even_u16((psimd_u16) z_biased, (psimd_u16) w_biased);
+
+ const psimd_u8 xyzw_packed = psimd_concat_even_u8((psimd_u8) xy_packed, (psimd_u8) zw_packed);
+
+ psimd_store_u8(output, xyzw_packed);
+ output += 16;
+ }
+}
diff --git a/src/requantization/fp32-scalar.c b/src/requantization/fp32-scalar.c
new file mode 100644
index 0000000..52fa1c9
--- /dev/null
+++ b/src/requantization/fp32-scalar.c
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <math.h>
+#include <stdint.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_fp32__scalar_lrintf(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 4 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const long lmin = (long) ((int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point);
+ const long lmax = (long) ((int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point);
+ for (; n != 0; n -= 4) {
+ const int32_t x = input[0];
+ const int32_t y = input[1];
+ const int32_t z = input[2];
+ const int32_t w = input[3];
+ input += 4;
+
+ const float x_scaled = (float) x * scale;
+ const float y_scaled = (float) y * scale;
+ const float z_scaled = (float) z * scale;
+ const float w_scaled = (float) w * scale;
+
+ const long x_rounded = lrintf(x_scaled);
+ const long y_rounded = lrintf(y_scaled);
+ const long z_rounded = lrintf(z_scaled);
+ const long w_rounded = lrintf(w_scaled);
+
+ const int32_t x_clamped = (int32_t)(x_rounded < lmin ? lmin : x_rounded > lmax ? lmax : x_rounded);
+ const int32_t y_clamped = (int32_t)(y_rounded < lmin ? lmin : y_rounded > lmax ? lmax : y_rounded);
+ const int32_t z_clamped = (int32_t)(z_rounded < lmin ? lmin : z_rounded > lmax ? lmax : z_rounded);
+ const int32_t w_clamped = (int32_t)(w_rounded < lmin ? lmin : w_rounded > lmax ? lmax : w_rounded);
+
+ const int32_t x_biased = x_clamped + (int32_t)(uint32_t) zero_point;
+ const int32_t y_biased = y_clamped + (int32_t)(uint32_t) zero_point;
+ const int32_t z_biased = z_clamped + (int32_t)(uint32_t) zero_point;
+ const int32_t w_biased = w_clamped + (int32_t)(uint32_t) zero_point;
+
+ output[0] = (uint8_t) x_biased;
+ output[1] = (uint8_t) y_biased;
+ output[2] = (uint8_t) z_biased;
+ output[3] = (uint8_t) w_biased;
+ output += 4;
+ }
+}
+
+void xnn_requantize_fp32__scalar_magic(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 4 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const float fmin = (float) ((int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point);
+ const float fmax = (float) ((int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point);
+ const float fmagic = 12582912.0f;
+ const int32_t imagic = INT32_C(0x4B400000) - (int32_t)(uint32_t) zero_point;
+ for (; n != 0; n -= 4) {
+ const int32_t x = input[0];
+ const int32_t y = input[1];
+ const int32_t z = input[2];
+ const int32_t w = input[3];
+ input += 4;
+
+ const float x_scaled = (float) x * scale;
+ const float y_scaled = (float) y * scale;
+ const float z_scaled = (float) z * scale;
+ const float w_scaled = (float) w * scale;
+
+ const float x_clamped = x_scaled < fmin ? fmin : x_scaled > fmax ? fmax : x_scaled;
+ const float y_clamped = y_scaled < fmin ? fmin : y_scaled > fmax ? fmax : y_scaled;
+ const float z_clamped = z_scaled < fmin ? fmin : z_scaled > fmax ? fmax : z_scaled;
+ const float w_clamped = w_scaled < fmin ? fmin : w_scaled > fmax ? fmax : w_scaled;
+
+ const int32_t x_biased = (int32_t) fp32_to_bits(x_clamped + fmagic) - imagic;
+ const int32_t y_biased = (int32_t) fp32_to_bits(y_clamped + fmagic) - imagic;
+ const int32_t z_biased = (int32_t) fp32_to_bits(z_clamped + fmagic) - imagic;
+ const int32_t w_biased = (int32_t) fp32_to_bits(w_clamped + fmagic) - imagic;
+
+ output[0] = (uint8_t) x_biased;
+ output[1] = (uint8_t) y_biased;
+ output[2] = (uint8_t) z_biased;
+ output[3] = (uint8_t) w_biased;
+ output += 4;
+ }
+}
diff --git a/src/requantization/fp32-sse2.c b/src/requantization/fp32-sse2.c
new file mode 100644
index 0000000..b574a12
--- /dev/null
+++ b/src/requantization/fp32-sse2.c
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_fp32__sse2(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const __m128 vscale = _mm_set1_ps(scale);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ /*
+ * Convert int32_t input to FP32 and multiply by FP32 scale.
+ * Both operations involve statistically unbiased roundings (with default MXCSR rounding mode):
+ * - Large int32_t values can't be exactly represented as FP32. CVTDQ2PS instruction on x86 would round it
+ * according to nearest FP32 value with ties to even (assuming default MXCSR rounding mode).
+ * - Product of two FP32 values is generally not exactly representation as an FP32 value, and will be rounded
+ * to nearest FP32 value with ties to even with default MXCSR rounding mode.
+ */
+ const __m128 x_scaled = _mm_mul_ps(_mm_cvtepi32_ps(x), vscale);
+ const __m128 y_scaled = _mm_mul_ps(_mm_cvtepi32_ps(y), vscale);
+ const __m128 z_scaled = _mm_mul_ps(_mm_cvtepi32_ps(z), vscale);
+ const __m128 w_scaled = _mm_mul_ps(_mm_cvtepi32_ps(w), vscale);
+
+ /*
+ * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction from x86 SSE2. CVTPS2DQ instruction rounds
+ * result according to nearest FP32 value with ties to even (assuming default MXCSR rounding mode).
+ * However, when conversion overflows, it produces INT32_MIN as a result. For large positive inputs the result
+ * of conversion can become negative, which affects the final requantization result. Note that on x86 SSE2 we
+ * have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This happens because float(INT32_MAX) rounds to 2**31,
+ * which overflows int32_t when it is converted back to integer.
+ *
+ * Thankfully, we can prove that overflow never happens in this requantization scheme. The largest positive
+ * input is INT32_MAX (2**31 - 1), which turns into 2**31 when converted to float. The largest scale value
+ * is 0x1.FFFFFEp-1. When multiplied together, the result is 2147483520 (compare to INT32_MAX = 2147483647),
+ * which fits into int32_t without overflow.
+ */
+ const __m128i x_rounded = _mm_cvtps_epi32(x_scaled);
+ const __m128i y_rounded = _mm_cvtps_epi32(y_scaled);
+ const __m128i z_rounded = _mm_cvtps_epi32(z_scaled);
+ const __m128i w_rounded = _mm_cvtps_epi32(w_scaled);
+
+ /*
+ * Standard final sequence on x86 SSE2:
+ * - Pack to int16_t and saturate
+ * - Add zero point
+ * - Pack to uint8_t and saturate
+ * - Clamp between qmin and qmax
+ */
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_rounded, y_rounded), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_rounded, w_rounded), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ /*
+ * 4x CVTDQ2PS
+ * 4x MULPS
+ * 4x CVTPS2DQ
+ * 2x PACKSSDW
+ * 1x PACKUSWB
+ * 2x PADDW
+ * 1x PMAXUB
+ * 1x PMINUB
+ * ---------------------
+ * 19 instructions total
+ */
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/gemmlowp-neon.c b/src/requantization/gemmlowp-neon.c
new file mode 100644
index 0000000..37b6718
--- /dev/null
+++ b/src/requantization/gemmlowp-neon.c
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <arm_neon.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+/*
+ * The requantization implementation below is adapted from Google's gemmlowp
+ * library. It is only used in XNNPACK unit tests and comparative benchmarks,
+ * but not the library itself.
+ */
+
+// Copyright 2015 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+void xnn_requantize_gemmlowp__neon(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Compute requantization parameters */
+ const uint32_t multiplier = ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7;
+ const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7;
+ const int32_t shift = -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + exponent);
+
+ const int32x4_t vmultiplier = vdupq_n_s32(multiplier);
+ const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t) zero_point);
+ const int32x4_t vshift = vdupq_n_s32(-shift);
+ const uint8x16_t vqmin = vdupq_n_u8(qmin);
+ const uint8x16_t vqmax = vdupq_n_u8(qmax);
+ for (; n != 0; n -= 16) {
+ const int32x4_t x = vld1q_s32(input);
+ const int32x4_t y = vld1q_s32(input + 4);
+ const int32x4_t z = vld1q_s32(input + 8);
+ const int32x4_t w = vld1q_s32(input + 12);
+ input += 16;
+
+ const int32x4_t x_product = vqrdmulhq_s32(x, vmultiplier);
+ const int32x4_t y_product = vqrdmulhq_s32(y, vmultiplier);
+ const int32x4_t z_product = vqrdmulhq_s32(z, vmultiplier);
+ const int32x4_t w_product = vqrdmulhq_s32(w, vmultiplier);
+
+ const int32x4_t x_product_fixup = vshrq_n_s32(vandq_s32(x, vshift), 31);
+ const int32x4_t y_product_fixup = vshrq_n_s32(vandq_s32(y, vshift), 31);
+ const int32x4_t z_product_fixup = vshrq_n_s32(vandq_s32(z, vshift), 31);
+ const int32x4_t w_product_fixup = vshrq_n_s32(vandq_s32(w, vshift), 31);
+
+ const int32x4_t x_adjusted_product = vqaddq_s32(x_product, x_product_fixup);
+ const int32x4_t y_adjusted_product = vqaddq_s32(y_product, y_product_fixup);
+ const int32x4_t z_adjusted_product = vqaddq_s32(z_product, z_product_fixup);
+ const int32x4_t w_adjusted_product = vqaddq_s32(w_product, w_product_fixup);
+
+ const int32x4_t x_scaled = vrshlq_s32(x_adjusted_product, vshift);
+ const int32x4_t y_scaled = vrshlq_s32(y_adjusted_product, vshift);
+ const int32x4_t z_scaled = vrshlq_s32(z_adjusted_product, vshift);
+ const int32x4_t w_scaled = vrshlq_s32(w_adjusted_product, vshift);
+
+#ifdef __aarch64__
+ const int16x8_t xy_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(x_scaled), y_scaled), vzero_point);
+ const int16x8_t zw_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(z_scaled), w_scaled), vzero_point);
+ const uint8x16_t xyzw_packed = vqmovun_high_s16(vqmovun_s16(xy_packed), zw_packed);
+#else
+ const int16x8_t xy_packed = vqaddq_s16(vcombine_s16(vqmovn_s32(x_scaled), vqmovn_s32(y_scaled)), vzero_point);
+ const int16x8_t zw_packed = vqaddq_s16(vcombine_s16(vqmovn_s32(z_scaled), vqmovn_s32(w_scaled)), vzero_point);
+ const uint8x16_t xyzw_packed = vcombine_u8(vqmovun_s16(xy_packed), vqmovun_s16(zw_packed));
+#endif
+
+ const uint8x16_t xyzw_clamped = vmaxq_u8(vminq_u8(xyzw_packed, vqmax), vqmin);
+
+ vst1q_u8(output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/gemmlowp-scalar.c b/src/requantization/gemmlowp-scalar.c
new file mode 100644
index 0000000..78f61b1
--- /dev/null
+++ b/src/requantization/gemmlowp-scalar.c
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+#include <xnnpack/scalar-utils.h>
+
+#include "gemmlowp-scalar.h"
+
+
+void xnn_requantize_gemmlowp__scalar(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 4 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Compute requantization parameters */
+ const uint32_t multiplier = ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7;
+ const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7;
+ const int32_t shift = -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + exponent);
+
+ const int32_t smin = (int32_t)(uint32_t) qmin;
+ const int32_t smax = (int32_t)(uint32_t) qmax;
+ for (; n != 0; n -= 4) {
+ const int32_t x = input[0];
+ const int32_t y = input[1];
+ const int32_t z = input[2];
+ const int32_t w = input[3];
+ input += 4;
+
+ const int32_t x_product = gemmlowp_scalar_vqrdmulh_s32(x, multiplier);
+ const int32_t y_product = gemmlowp_scalar_vqrdmulh_s32(y, multiplier);
+ const int32_t z_product = gemmlowp_scalar_vqrdmulh_s32(z, multiplier);
+ const int32_t w_product = gemmlowp_scalar_vqrdmulh_s32(w, multiplier);
+
+ const int32_t x_scaled = gemmlowp_scalar_rdivbypo2_s32(x_product, shift);
+ const int32_t y_scaled = gemmlowp_scalar_rdivbypo2_s32(y_product, shift);
+ const int32_t z_scaled = gemmlowp_scalar_rdivbypo2_s32(z_product, shift);
+ const int32_t w_scaled = gemmlowp_scalar_rdivbypo2_s32(w_product, shift);
+
+ /* Add zero point to scaled value */
+ const int32_t x_biased = x_scaled + zero_point;
+ const int32_t y_biased = y_scaled + zero_point;
+ const int32_t z_biased = z_scaled + zero_point;
+ const int32_t w_biased = w_scaled + zero_point;
+
+ /* Clamp scaled value with zero point between smin and smax */
+ const int32_t x_clamped = x_biased < smin ? smin : x_biased > smax ? smax : x_biased;
+ const int32_t y_clamped = y_biased < smin ? smin : y_biased > smax ? smax : y_biased;
+ const int32_t z_clamped = z_biased < smin ? smin : z_biased > smax ? smax : z_biased;
+ const int32_t w_clamped = w_biased < smin ? smin : w_biased > smax ? smax : w_biased;
+
+ output[0] = (uint8_t) x_clamped;
+ output[1] = (uint8_t) y_clamped;
+ output[2] = (uint8_t) z_clamped;
+ output[3] = (uint8_t) w_clamped;
+ output += 4;
+ }
+}
diff --git a/src/requantization/gemmlowp-scalar.h b/src/requantization/gemmlowp-scalar.h
new file mode 100644
index 0000000..91f8bf4
--- /dev/null
+++ b/src/requantization/gemmlowp-scalar.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <stdint.h>
+#include <limits.h>
+
+/*
+ * The code below is adapted from Google's gemmlowp library.
+ * It is only used in XNNPACK unit tests and comparative benchmarks,
+ * but not the library itself.
+ */
+
+// Copyright 2015 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+inline static int32_t gemmlowp_scalar_vqrdmulh_s32(int32_t a, int32_t b)
+{
+ const bool overflow = a == b && a == INT32_MIN;
+ const int64_t ab_64 = (int64_t) a * (int64_t) b;
+ const int32_t nudge = (a ^ b) >= 0 ? INT32_C(0x40000000) : -INT32_C(0x3FFFFFFF);
+ const int32_t ab_x2_high32 = (int32_t) ((ab_64 + nudge) / INT64_C(0x80000000));
+ return overflow ? INT32_MAX : ab_x2_high32;
+}
+
+inline static int32_t gemmlowp_scalar_rdivbypo2_s32(int32_t x, int exponent)
+{
+ const int32_t mask = ((1 << exponent) - 1);
+ const int32_t remainder = x & mask;
+ const int32_t threshold = (mask >> 1) + (int32_t) (x < 0);
+ return asr_s32(x, exponent) + (int32_t) (remainder > threshold);
+}
diff --git a/src/requantization/gemmlowp-sse.h b/src/requantization/gemmlowp-sse.h
new file mode 100644
index 0000000..1335fe1
--- /dev/null
+++ b/src/requantization/gemmlowp-sse.h
@@ -0,0 +1,123 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <limits.h>
+
+#include <immintrin.h>
+
+/*
+ * The code below is adapted from Google's gemmlowp library.
+ * It is only used in XNNPACK unit tests and comparative benchmarks,
+ * but not the library itself.
+ */
+
+// Copyright 2015 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+static inline __m128i gemmlowp_sse_rdivbypo2_s32(__m128i x, int exponent) {
+ const __m128i mask = _mm_set1_epi32((int32_t) ((UINT64_C(1) << exponent) - UINT64_C(1)));
+ const __m128i remainder = _mm_and_si128(x, mask);
+ const __m128i threshold = _mm_sub_epi32(
+ _mm_srli_epi32(mask, 1), _mm_cmplt_epi32(x, _mm_setzero_si128()));
+ return _mm_sub_epi32(
+ _mm_sra_epi32(x, _mm_cvtsi32_si128(exponent)),
+ _mm_cmpgt_epi32(remainder, threshold));
+}
+
+static inline __m128i gemmlowp_sse_mul_s32(__m128i a, __m128i b) {
+#ifdef __SSE4_1__
+ return _mm_mul_epi32(a, b);
+#else
+ __m128i sign, zero, mul_us, a_neg, b_neg, mul_us_neg;
+ sign = _mm_xor_si128(a, b);
+ sign = _mm_srai_epi32(sign, 31); // promote sign bit to all fields, all fff if
+ // negative and all 0 if positive
+ sign = _mm_shuffle_epi32(
+ sign,
+ _MM_SHUFFLE(2, 2, 0, 0)); // promote sign bit to 3 and 1st data lanes
+ zero = _mm_setzero_si128();
+#ifdef __SSSE3__
+ a_neg = _mm_abs_epi32(a); // negate a and b
+ b_neg = _mm_abs_epi32(b); // negate a and b
+#else /* pre-SSSE3 */
+ const __m128i a_neg_mask = _mm_cmplt_epi32(a, zero);
+ a_neg = _mm_sub_epi32(_mm_xor_si128(a, a_neg_mask), a_neg_mask);
+ const __m128i b_neg_mask = _mm_cmplt_epi32(b, zero);
+ b_neg = _mm_sub_epi32(_mm_xor_si128(b, b_neg_mask), b_neg_mask);
+#endif /* pre-SSSE3 */
+ mul_us = _mm_mul_epu32(a_neg, b_neg); // uses 0 and 2nd data lanes, (abs), the
+ // multiplication gives 64 bit result
+ mul_us_neg = _mm_sub_epi64(zero, mul_us);
+ mul_us_neg = _mm_and_si128(sign, mul_us_neg);
+ mul_us = _mm_andnot_si128(sign, mul_us);
+ return _mm_or_si128(mul_us, mul_us_neg);
+#endif
+}
+
+static inline __m128i gemmlowp_sse_vqrdmulh_s32(__m128i a, __m128i b) {
+ // saturation only happen if a == b == INT32_MIN
+ const __m128i min = _mm_set1_epi32(INT32_MIN);
+ const __m128i saturation_mask =
+ _mm_and_si128(_mm_cmpeq_epi32(a, b), _mm_cmpeq_epi32(a, min));
+
+ // a = a0 | a1 | a2 | a3
+ // b = b0 | b1 | b2 | b3
+ const __m128i a0_a2 = a;
+ const __m128i a1_a3 = _mm_srli_si128(a, 4);
+ const __m128i b0_b2 = b;
+ const __m128i b1_b3 = _mm_srli_si128(b, 4);
+
+ const __m128i a0b0_a2b2 = gemmlowp_sse_mul_s32(a0_a2, b0_b2);
+ const __m128i a1b1_a3b3 = gemmlowp_sse_mul_s32(a1_a3, b1_b3);
+
+ // do the rounding and take into account that it will be doubled
+ const __m128i nudge = _mm_set1_epi64x(1 << 30);
+ const __m128i a0b0_a2b2_rounded = _mm_add_epi64(a0b0_a2b2, nudge);
+ const __m128i a1b1_a3b3_rounded = _mm_add_epi64(a1b1_a3b3, nudge);
+
+ // do the doubling
+ const __m128i a0b0_a2b2_rounded_2x = _mm_slli_epi64(a0b0_a2b2_rounded, 1);
+ const __m128i a1b1_a3b3_rounded_2x = _mm_slli_epi64(a1b1_a3b3_rounded, 1);
+
+// get the high part of the products
+#ifdef __SSE4_1__
+ const __m128i result = _mm_blend_epi16(
+ _mm_srli_epi64(a0b0_a2b2_rounded_2x, 32), a1b1_a3b3_rounded_2x, 0xCC);
+#else
+ const __m128i result0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(a0b0_a2b2_rounded_2x),
+ _mm_castsi128_ps(a1b1_a3b3_rounded_2x),
+ _MM_SHUFFLE(3, 1, 3, 1)));
+ const __m128i result = _mm_shuffle_epi32(result0213, _MM_SHUFFLE(3, 1, 2, 0));
+#endif
+
+// saturate those which overflowed
+#ifdef __SSE4_1__
+ const __m128i saturated_result = _mm_blendv_epi8(result, min, saturation_mask);
+#else
+ const __m128i saturated_result = _mm_or_si128(
+ _mm_and_si128(saturation_mask, min),
+ _mm_andnot_si128(saturation_mask, result));
+#endif
+ return saturated_result;
+}
diff --git a/src/requantization/gemmlowp-sse2.c b/src/requantization/gemmlowp-sse2.c
new file mode 100644
index 0000000..93a869e
--- /dev/null
+++ b/src/requantization/gemmlowp-sse2.c
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <emmintrin.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+#include "gemmlowp-sse.h"
+
+
+void xnn_requantize_gemmlowp__sse2(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Compute requantization parameters */
+ const uint32_t multiplier = ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7;
+ const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7;
+ const int32_t shift = -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + exponent);
+
+ const __m128i vmultiplier = _mm_set1_epi32(multiplier);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ const __m128i x_product = gemmlowp_sse_vqrdmulh_s32(x, vmultiplier);
+ const __m128i y_product = gemmlowp_sse_vqrdmulh_s32(y, vmultiplier);
+ const __m128i z_product = gemmlowp_sse_vqrdmulh_s32(z, vmultiplier);
+ const __m128i w_product = gemmlowp_sse_vqrdmulh_s32(w, vmultiplier);
+
+ const __m128i x_scaled = gemmlowp_sse_rdivbypo2_s32(x_product, shift);
+ const __m128i y_scaled = gemmlowp_sse_rdivbypo2_s32(y_product, shift);
+ const __m128i z_scaled = gemmlowp_sse_rdivbypo2_s32(z_product, shift);
+ const __m128i w_scaled = gemmlowp_sse_rdivbypo2_s32(w_product, shift);
+
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/gemmlowp-sse4.c b/src/requantization/gemmlowp-sse4.c
new file mode 100644
index 0000000..a315746
--- /dev/null
+++ b/src/requantization/gemmlowp-sse4.c
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <smmintrin.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+#include "gemmlowp-sse.h"
+
+
+void xnn_requantize_gemmlowp__sse4(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Compute requantization parameters */
+ const uint32_t multiplier = ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7;
+ const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7;
+ const int32_t shift = -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + exponent);
+
+ const __m128i vmultiplier = _mm_set1_epi32(multiplier);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ const __m128i x_product = gemmlowp_sse_vqrdmulh_s32(x, vmultiplier);
+ const __m128i y_product = gemmlowp_sse_vqrdmulh_s32(y, vmultiplier);
+ const __m128i z_product = gemmlowp_sse_vqrdmulh_s32(z, vmultiplier);
+ const __m128i w_product = gemmlowp_sse_vqrdmulh_s32(w, vmultiplier);
+
+ const __m128i x_scaled = gemmlowp_sse_rdivbypo2_s32(x_product, shift);
+ const __m128i y_scaled = gemmlowp_sse_rdivbypo2_s32(y_product, shift);
+ const __m128i z_scaled = gemmlowp_sse_rdivbypo2_s32(z_product, shift);
+ const __m128i w_scaled = gemmlowp_sse_rdivbypo2_s32(w_product, shift);
+
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/gemmlowp-ssse3.c b/src/requantization/gemmlowp-ssse3.c
new file mode 100644
index 0000000..700e000
--- /dev/null
+++ b/src/requantization/gemmlowp-ssse3.c
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <tmmintrin.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+#include "gemmlowp-sse.h"
+
+
+void xnn_requantize_gemmlowp__ssse3(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Compute requantization parameters */
+ const uint32_t multiplier = ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7;
+ const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7;
+ const int32_t shift = -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + exponent);
+
+ const __m128i vmultiplier = _mm_set1_epi32(multiplier);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ const __m128i x_product = gemmlowp_sse_vqrdmulh_s32(x, vmultiplier);
+ const __m128i y_product = gemmlowp_sse_vqrdmulh_s32(y, vmultiplier);
+ const __m128i z_product = gemmlowp_sse_vqrdmulh_s32(z, vmultiplier);
+ const __m128i w_product = gemmlowp_sse_vqrdmulh_s32(w, vmultiplier);
+
+ const __m128i x_scaled = gemmlowp_sse_rdivbypo2_s32(x_product, shift);
+ const __m128i y_scaled = gemmlowp_sse_rdivbypo2_s32(y_product, shift);
+ const __m128i z_scaled = gemmlowp_sse_rdivbypo2_s32(z_product, shift);
+ const __m128i w_scaled = gemmlowp_sse_rdivbypo2_s32(w_product, shift);
+
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/precise-neon.c b/src/requantization/precise-neon.c
new file mode 100644
index 0000000..2e796e3
--- /dev/null
+++ b/src/requantization/precise-neon.c
@@ -0,0 +1,168 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <arm_neon.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_precise__neon(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+ const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
+ const int32_t shift = 127 + 23 - (scale_bits >> 23);
+ assert(shift >= 24);
+ assert(shift < 56);
+
+#if defined(__aarch64__)
+ const int32x4_t vmultiplier = vdupq_n_s32(multiplier);
+#else
+ const int32x2_t vmultiplier = vdup_n_s32(multiplier);
+#endif
+ const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t) zero_point);
+ const int64x2_t vshift = vdupq_n_s64(-shift);
+ const uint8x16_t vqmin = vdupq_n_u8(qmin);
+ const uint8x16_t vqmax = vdupq_n_u8(qmax);
+ for (; n != 0; n -= 16) {
+ const int32x4_t x = vld1q_s32(input);
+ const int32x4_t y = vld1q_s32(input + 4);
+ const int32x4_t z = vld1q_s32(input + 8);
+ const int32x4_t w = vld1q_s32(input + 12);
+ input += 16;
+
+ const uint32x4_t x_neg_mask = vcltq_s32(x, vmovq_n_s32(0));
+ const uint32x4_t y_neg_mask = vcltq_s32(y, vmovq_n_s32(0));
+ const uint32x4_t z_neg_mask = vcltq_s32(z, vmovq_n_s32(0));
+ const uint32x4_t w_neg_mask = vcltq_s32(w, vmovq_n_s32(0));
+
+#if defined(__aarch64__)
+ const int64x2_t x01_product = vmull_s32(vget_low_s32(x), vget_low_s32(vmultiplier));
+ const int64x2_t x23_product = vmull_high_s32(x, vmultiplier);
+ const int64x2_t y01_product = vmull_s32(vget_low_s32(y), vget_low_s32(vmultiplier));
+ const int64x2_t y23_product = vmull_high_s32(y, vmultiplier);
+ const int64x2_t z01_product = vmull_s32(vget_low_s32(z), vget_low_s32(vmultiplier));
+ const int64x2_t z23_product = vmull_high_s32(z, vmultiplier);
+ const int64x2_t w01_product = vmull_s32(vget_low_s32(w), vget_low_s32(vmultiplier));
+ const int64x2_t w23_product = vmull_high_s32(w, vmultiplier);
+#else
+ const int64x2_t x01_product = vmull_s32(vget_low_s32(x), vmultiplier);
+ const int64x2_t x23_product = vmull_s32(vget_high_s32(x), vmultiplier);
+ const int64x2_t y01_product = vmull_s32(vget_low_s32(y), vmultiplier);
+ const int64x2_t y23_product = vmull_s32(vget_high_s32(y), vmultiplier);
+ const int64x2_t z01_product = vmull_s32(vget_low_s32(z), vmultiplier);
+ const int64x2_t z23_product = vmull_s32(vget_high_s32(z), vmultiplier);
+ const int64x2_t w01_product = vmull_s32(vget_low_s32(w), vmultiplier);
+ const int64x2_t w23_product = vmull_s32(vget_high_s32(w), vmultiplier);
+#endif
+
+#if defined(__aarch64__)
+ const int64x2_t x01_adjusted_product = vaddw_s32(x01_product, vreinterpret_s32_u32(vget_low_u32(x_neg_mask)));
+ const int64x2_t x23_adjusted_product = vaddw_high_s32(x23_product, vreinterpretq_s32_u32(x_neg_mask));
+ const int64x2_t y01_adjusted_product = vaddw_s32(y01_product, vreinterpret_s32_u32(vget_low_u32(y_neg_mask)));
+ const int64x2_t y23_adjusted_product = vaddw_high_s32(y23_product, vreinterpretq_s32_u32(y_neg_mask));
+ const int64x2_t z01_adjusted_product = vaddw_s32(z01_product, vreinterpret_s32_u32(vget_low_u32(z_neg_mask)));
+ const int64x2_t z23_adjusted_product = vaddw_high_s32(z23_product, vreinterpretq_s32_u32(z_neg_mask));
+ const int64x2_t w01_adjusted_product = vaddw_s32(w01_product, vreinterpret_s32_u32(vget_low_u32(w_neg_mask)));
+ const int64x2_t w23_adjusted_product = vaddw_high_s32(w23_product, vreinterpretq_s32_u32(w_neg_mask));
+#else
+ const int64x2_t x01_adjusted_product = vaddw_s32(x01_product, vreinterpret_s32_u32(vget_low_u32(x_neg_mask)));
+ const int64x2_t x23_adjusted_product = vaddw_s32(x23_product, vreinterpret_s32_u32(vget_high_u32(x_neg_mask)));
+ const int64x2_t y01_adjusted_product = vaddw_s32(y01_product, vreinterpret_s32_u32(vget_low_u32(y_neg_mask)));
+ const int64x2_t y23_adjusted_product = vaddw_s32(y23_product, vreinterpret_s32_u32(vget_high_u32(y_neg_mask)));
+ const int64x2_t z01_adjusted_product = vaddw_s32(z01_product, vreinterpret_s32_u32(vget_low_u32(z_neg_mask)));
+ const int64x2_t z23_adjusted_product = vaddw_s32(z23_product, vreinterpret_s32_u32(vget_high_u32(z_neg_mask)));
+ const int64x2_t w01_adjusted_product = vaddw_s32(w01_product, vreinterpret_s32_u32(vget_low_u32(w_neg_mask)));
+ const int64x2_t w23_adjusted_product = vaddw_s32(w23_product, vreinterpret_s32_u32(vget_high_u32(w_neg_mask)));
+#endif
+
+ const int64x2_t x01_scaled = vrshlq_s64(x01_adjusted_product, vshift);
+ const int64x2_t x23_scaled = vrshlq_s64(x23_adjusted_product, vshift);
+ const int64x2_t y01_scaled = vrshlq_s64(y01_adjusted_product, vshift);
+ const int64x2_t y23_scaled = vrshlq_s64(y23_adjusted_product, vshift);
+ const int64x2_t z01_scaled = vrshlq_s64(z01_adjusted_product, vshift);
+ const int64x2_t z23_scaled = vrshlq_s64(z23_adjusted_product, vshift);
+ const int64x2_t w01_scaled = vrshlq_s64(w01_adjusted_product, vshift);
+ const int64x2_t w23_scaled = vrshlq_s64(w23_adjusted_product, vshift);
+
+#ifdef __aarch64__
+ const int32x4_t x_scaled = vuzp1q_s32(vreinterpretq_s32_s64(x01_scaled), vreinterpretq_s32_s64(x23_scaled));
+ const int32x4_t y_scaled = vuzp1q_s32(vreinterpretq_s32_s64(y01_scaled), vreinterpretq_s32_s64(y23_scaled));
+ const int32x4_t z_scaled = vuzp1q_s32(vreinterpretq_s32_s64(z01_scaled), vreinterpretq_s32_s64(z23_scaled));
+ const int32x4_t w_scaled = vuzp1q_s32(vreinterpretq_s32_s64(w01_scaled), vreinterpretq_s32_s64(w23_scaled));
+
+ const int16x8_t xy_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(x_scaled), y_scaled), vzero_point);
+ const int16x8_t zw_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(z_scaled), w_scaled), vzero_point);
+ const uint8x16_t xyzw_packed = vqmovun_high_s16(vqmovun_s16(xy_packed), zw_packed);
+#else
+ const int32x4_t x_scaled = vcombine_s32(vmovn_s64(x01_scaled), vmovn_s64(x23_scaled));
+ const int32x4_t y_scaled = vcombine_s32(vmovn_s64(y01_scaled), vmovn_s64(y23_scaled));
+ const int32x4_t z_scaled = vcombine_s32(vmovn_s64(z01_scaled), vmovn_s64(z23_scaled));
+ const int32x4_t w_scaled = vcombine_s32(vmovn_s64(w01_scaled), vmovn_s64(w23_scaled));
+
+ const int16x8_t xy_packed = vqaddq_s16(vcombine_s16(vqmovn_s32(x_scaled), vqmovn_s32(y_scaled)), vzero_point);
+ const int16x8_t zw_packed = vqaddq_s16(vcombine_s16(vqmovn_s32(z_scaled), vqmovn_s32(w_scaled)), vzero_point);
+ const uint8x16_t xyzw_packed = vcombine_u8(vqmovun_s16(xy_packed), vqmovun_s16(zw_packed));
+#endif
+
+ const uint8x16_t xyzw_clamped = vmaxq_u8(vminq_u8(xyzw_packed, vqmax), vqmin);
+
+ /*
+ * AArch32 version:
+ * 4x VCLT.S32 Qd, Qm, #0
+ * 8x VMULL.S32 Qd, Dm, Dn
+ * 8x VADDW.S32 Qd, Qm, Dn
+ * 8x VRSHL.S32 Qd, Qm, Qn
+ * 8x VMOVN.S64 Dd, Qm
+ * 4x VQMOVN.S32 Dd, Qm
+ * 2x VADD.S16 Qd, Qm, Qn
+ * 2x VQMOVUN.S16 Dd, Qm
+ * 1x VMAX.U8 Qd, Qm, Qn
+ * 1x VMIN.U8 Qd, Qm, Qn
+ * ---------------------
+ * 46 instructions total
+ *
+ * AArch64 version:
+ * 4x CMLT Vd.4S, Vn.4S, #0
+ * 4x SMULL Vd.2D, Vn.2S, Vm.2S
+ * 4x SMULL2 Vd.2D, Vn.4S, Vm.4S
+ * 4x SADDW Vd.2D, Vn.2D, Vm.2S
+ * 4x SADDW2 Vd.2D, Vn.2D, Vm.4S
+ * 8x SRSHL Vd.2D, Vn.2D, Vm.2D
+ * 4x UZP1 Vd.4S, Vn.4S, Vm.4S
+ * 2x SQXTN Vd.4H, Vn.4S
+ * 2x SQXTN2 Vd.8H, Vn.4S
+ * 2x ADD Vd.8H, Vn.8H, Vm.8H
+ * 1x SQXTUN Vd.8B, Vn.8H
+ * 1x SQXTUN2 Vd.16B, Vn.8H
+ * 1x UMIN Vd.16B, Vn.16B, Vm.16B
+ * 1x UMAX Vd.16B, Vn.16B, Vm.16B
+ * ---------------------
+ * 42 instructions total
+ */
+
+ vst1q_u8(output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/precise-psimd.c b/src/requantization/precise-psimd.c
new file mode 100644
index 0000000..5228155
--- /dev/null
+++ b/src/requantization/precise-psimd.c
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <psimd.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_precise__psimd(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+ const uint32_t multiplier = (scale_bits << 8) | UINT32_C(0x80000000);
+ const uint32_t shift = 127 + 31 - (scale_bits >> 23);
+ assert(shift >= 32);
+ assert(shift < 64);
+ const uint64_t rounding = UINT64_C(1) << (shift - 1);
+
+ const psimd_u32 vmultiplier_lo = psimd_splat_u32(multiplier & UINT32_C(0x0000FFFF));
+ const psimd_u32 vmultiplier_hi = psimd_splat_u32(multiplier >> 16);
+ const psimd_s32 vzero_point = psimd_splat_s32((int32_t)(uint32_t) zero_point);
+ const psimd_s32 vsmin = psimd_splat_s32((int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point);
+ const psimd_s32 vsmax = psimd_splat_s32((int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point);
+ const psimd_u32 vrounding_lo = psimd_splat_u32((uint32_t) rounding);
+ const psimd_u32 vrounding_hi = psimd_splat_u32((uint32_t)(rounding >> 32));
+ const psimd_u32 vshift = psimd_splat_u32(shift - 32);
+ for (; n != 0; n -= 16) {
+ const psimd_s32 x = psimd_load_s32(input);
+ const psimd_s32 y = psimd_load_s32(input + 4);
+ const psimd_s32 z = psimd_load_s32(input + 8);
+ const psimd_s32 w = psimd_load_s32(input + 12);
+ input += 16;
+
+ const psimd_s32 x_neg_mask = x >> psimd_splat_s32(31);
+ const psimd_s32 y_neg_mask = y >> psimd_splat_s32(31);
+ const psimd_s32 z_neg_mask = z >> psimd_splat_s32(31);
+ const psimd_s32 w_neg_mask = w >> psimd_splat_s32(31);
+
+ const psimd_u32 x_abs = (psimd_u32)((x ^ x_neg_mask) - x_neg_mask);
+ const psimd_u32 y_abs = (psimd_u32)((y ^ y_neg_mask) - y_neg_mask);
+ const psimd_u32 z_abs = (psimd_u32)((z ^ z_neg_mask) - z_neg_mask);
+ const psimd_u32 w_abs = (psimd_u32)((w ^ w_neg_mask) - w_neg_mask);
+
+ const psimd_u32 x_abs_lo = x_abs & psimd_splat_u32(UINT32_C(0x0000FFFF));
+ const psimd_u32 x_abs_hi = x_abs >> psimd_splat_u32(16);
+ const psimd_u32 y_abs_lo = y_abs & psimd_splat_u32(UINT32_C(0x0000FFFF));
+ const psimd_u32 y_abs_hi = y_abs >> psimd_splat_u32(16);
+ const psimd_u32 z_abs_lo = z_abs & psimd_splat_u32(UINT32_C(0x0000FFFF));
+ const psimd_u32 z_abs_hi = z_abs >> psimd_splat_u32(16);
+ const psimd_u32 w_abs_lo = w_abs & psimd_splat_u32(UINT32_C(0x0000FFFF));
+ const psimd_u32 w_abs_hi = w_abs >> psimd_splat_u32(16);
+
+ const psimd_u32 x_product_ll = x_abs_lo * vmultiplier_lo;
+ const psimd_u32 y_product_ll = y_abs_lo * vmultiplier_lo;
+ const psimd_u32 z_product_ll = z_abs_lo * vmultiplier_lo;
+ const psimd_u32 w_product_ll = w_abs_lo * vmultiplier_lo;
+
+ const psimd_u32 x_product_lh = x_abs_lo * vmultiplier_hi + (x_product_ll >> psimd_splat_u32(16));
+ const psimd_u32 y_product_lh = y_abs_lo * vmultiplier_hi + (y_product_ll >> psimd_splat_u32(16));
+ const psimd_u32 z_product_lh = z_abs_lo * vmultiplier_hi + (z_product_ll >> psimd_splat_u32(16));
+ const psimd_u32 w_product_lh = w_abs_lo * vmultiplier_hi + (w_product_ll >> psimd_splat_u32(16));
+
+ const psimd_u32 x_product_hl = x_abs_hi * vmultiplier_lo + (x_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF)));
+ const psimd_u32 y_product_hl = y_abs_hi * vmultiplier_lo + (y_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF)));
+ const psimd_u32 z_product_hl = z_abs_hi * vmultiplier_lo + (z_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF)));
+ const psimd_u32 w_product_hl = w_abs_hi * vmultiplier_lo + (w_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF)));
+
+ const psimd_u32 x_product_lo =
+ (x_product_hl << psimd_splat_u32(16)) + (x_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF)));
+ const psimd_u32 y_product_lo =
+ (y_product_hl << psimd_splat_u32(16)) + (y_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF)));
+ const psimd_u32 z_product_lo =
+ (z_product_hl << psimd_splat_u32(16)) + (z_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF)));
+ const psimd_u32 w_product_lo =
+ (w_product_hl << psimd_splat_u32(16)) + (w_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF)));
+
+ const psimd_u32 x_product_hi =
+ x_abs_hi * vmultiplier_hi + (x_product_lh >> psimd_splat_u32(16)) + (x_product_hl >> psimd_splat_u32(16));
+ const psimd_u32 y_product_hi =
+ y_abs_hi * vmultiplier_hi + (y_product_lh >> psimd_splat_u32(16)) + (y_product_hl >> psimd_splat_u32(16));
+ const psimd_u32 z_product_hi =
+ z_abs_hi * vmultiplier_hi + (z_product_lh >> psimd_splat_u32(16)) + (z_product_hl >> psimd_splat_u32(16));
+ const psimd_u32 w_product_hi =
+ w_abs_hi * vmultiplier_hi + (w_product_lh >> psimd_splat_u32(16)) + (w_product_hl >> psimd_splat_u32(16));
+
+ const psimd_u32 x_adjusted_product =
+ (x_product_hi + vrounding_hi) - ((psimd_s32)(x_product_lo & vrounding_lo) >> psimd_splat_s32(31));
+ const psimd_u32 y_adjusted_product =
+ (y_product_hi + vrounding_hi) - ((psimd_s32)(y_product_lo & vrounding_lo) >> psimd_splat_s32(31));
+ const psimd_u32 z_adjusted_product =
+ (z_product_hi + vrounding_hi) - ((psimd_s32)(z_product_lo & vrounding_lo) >> psimd_splat_s32(31));
+ const psimd_u32 w_adjusted_product =
+ (w_product_hi + vrounding_hi) - ((psimd_s32)(w_product_lo & vrounding_lo) >> psimd_splat_s32(31));
+
+ const psimd_u32 x_abs_scaled = x_adjusted_product >> vshift;
+ const psimd_u32 y_abs_scaled = y_adjusted_product >> vshift;
+ const psimd_u32 z_abs_scaled = z_adjusted_product >> vshift;
+ const psimd_u32 w_abs_scaled = w_adjusted_product >> vshift;
+
+ const psimd_s32 x_scaled = (psimd_s32)(x_abs_scaled ^ x_neg_mask) - x_neg_mask;
+ const psimd_s32 y_scaled = (psimd_s32)(y_abs_scaled ^ y_neg_mask) - y_neg_mask;
+ const psimd_s32 z_scaled = (psimd_s32)(z_abs_scaled ^ z_neg_mask) - z_neg_mask;
+ const psimd_s32 w_scaled = (psimd_s32)(w_abs_scaled ^ w_neg_mask) - w_neg_mask;
+
+ const psimd_u32 x_clamped = (psimd_u32) psimd_max_s32(psimd_min_s32(x_scaled, vsmax), vsmin) + vzero_point;
+ const psimd_u32 y_clamped = (psimd_u32) psimd_max_s32(psimd_min_s32(y_scaled, vsmax), vsmin) + vzero_point;
+ const psimd_u32 z_clamped = (psimd_u32) psimd_max_s32(psimd_min_s32(z_scaled, vsmax), vsmin) + vzero_point;
+ const psimd_u32 w_clamped = (psimd_u32) psimd_max_s32(psimd_min_s32(w_scaled, vsmax), vsmin) + vzero_point;
+
+ const psimd_u16 xy_clamped = psimd_concat_even_u16((psimd_u16) x_clamped, (psimd_u16) y_clamped);
+ const psimd_u16 zw_clamped = psimd_concat_even_u16((psimd_u16) z_clamped, (psimd_u16) w_clamped);
+
+ const psimd_u8 xyzw_clamped = psimd_concat_even_u8((psimd_u8) xy_clamped, (psimd_u8) zw_clamped);
+
+ psimd_store_u8(output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/precise-scalar.c b/src/requantization/precise-scalar.c
new file mode 100644
index 0000000..e93ae0e
--- /dev/null
+++ b/src/requantization/precise-scalar.c
@@ -0,0 +1,321 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_precise__scalar_unsigned32(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 4 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+ const uint32_t multiplier = (scale_bits << 8) | UINT32_C(0x80000000);
+ const uint32_t shift = 127 + 31 - (scale_bits >> 23);
+ assert(shift >= 32);
+ assert(shift < 64);
+
+ const uint64_t rounding = UINT64_C(1) << (shift - 1);
+ const uint32_t rounding_hi = (uint32_t)(rounding >> 32);
+ const uint32_t rounding_lo = (uint32_t) rounding;
+ const uint32_t shift_minus_32 = shift - 32;
+ const int32_t smin = (int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point;
+ const int32_t smax = (int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point;
+ for (; n != 0; n -= 4) {
+ const int32_t x = input[0];
+ const int32_t y = input[1];
+ const int32_t z = input[2];
+ const int32_t w = input[3];
+ input += 4;
+
+ /*
+ * Compute absolute value of input as unsigned 32-bit int.
+ * All further computations will work with unsigned values to avoid undefined behaviour on signed operations.
+ */
+ const uint32_t x_abs = (x >= 0) ? (uint32_t) x : -(uint32_t) x;
+ const uint32_t y_abs = (y >= 0) ? (uint32_t) y : -(uint32_t) y;
+ const uint32_t z_abs = (z >= 0) ? (uint32_t) z : -(uint32_t) z;
+ const uint32_t w_abs = (w >= 0) ? (uint32_t) w : -(uint32_t) w;
+
+ /* Compute full 64-bit product of 32-bit factors */
+ const uint64_t x_product = (uint64_t) x_abs * (uint64_t) multiplier;
+ const uint64_t y_product = (uint64_t) y_abs * (uint64_t) multiplier;
+ const uint64_t z_product = (uint64_t) z_abs * (uint64_t) multiplier;
+ const uint64_t w_product = (uint64_t) w_abs * (uint64_t) multiplier;
+
+ /*
+ * Shift the full 64-bit product right with rounding.
+ * Rounding is performed towards closest integer, with midpoints rounded up (same as away from zero).
+ *
+ * Generally, this operation requires both 64-bit addition and 64-bit shift, but we use two tricks to replace
+ * 64-bit operations with 32-bit operations.
+ *
+ * To avoid full 64-bit addition we make use of three facts:
+ * - 64-bit rounding value added before the shift is a power of 2, and thus has only one bit set.
+ * - When 0x1.0p-32f <= scale < 0x1.0p-31f, then the non-zero bit in rounding is in the low 32 bits, and
+ * rounding is exactly 0x80000000 (2**31), because rounding is 2**(scale-1) and scale >= 32. In this case,
+ * addition of rounding can affect high 32 bits of the product only through overflow, which happens if
+ * low 32-bit part of the product equals or exceeds 0x80000000. We can reformulate the latter condition
+ * as low 32-bit part of the product has the bit 31 set, and then overflow happens if both the low 32-bit part
+ * of the product and the low 32-bit part of the rounding value have bit 31 set. Since 32-bit numbers with the
+ * bit 31 set are negative when interpreted as signed integers, we can check the overflow condition as
+ * (int32_t) (LOW(product) & LOW(rounding)) < 0
+ * - When 0x1.0p-31f <= scale < 1.0f, then the non-zero bit is in the high 32 bits of rounding. We just need
+ * to do 32-bit addition of high 32 bits of rounding and high 32 bits of product. This addition never
+ * overflows because product <= 0x80000000 * 0xFFFFFF00 < 2**63 and rounding = 2**(scale-1) <= 2**62.
+ *
+ * To avoid full 64-bit shift, we leverage the fact that shift >= 32, and do it in two steps:
+ * - Shift by 32, which can be implemented by extacting the high 32-bit word on 32-bit systems.
+ * - Shift by (shift - 32), which can be implemented as a 32-bit shift of high word of addition result.
+ */
+ const uint32_t x_carry_lo = (uint32_t)((int32_t)((uint32_t) x_product & rounding_lo) < 0);
+ const uint32_t y_carry_lo = (uint32_t)((int32_t)((uint32_t) y_product & rounding_lo) < 0);
+ const uint32_t z_carry_lo = (uint32_t)((int32_t)((uint32_t) z_product & rounding_lo) < 0);
+ const uint32_t w_carry_lo = (uint32_t)((int32_t)((uint32_t) w_product & rounding_lo) < 0);
+
+ const uint32_t x_product_hi = (uint32_t)(x_product >> 32);
+ const uint32_t y_product_hi = (uint32_t)(y_product >> 32);
+ const uint32_t z_product_hi = (uint32_t)(z_product >> 32);
+ const uint32_t w_product_hi = (uint32_t)(w_product >> 32);
+
+ const uint32_t x_abs_scaled = (uint32_t)(x_product_hi + rounding_hi + x_carry_lo) >> shift_minus_32;
+ const uint32_t y_abs_scaled = (uint32_t)(y_product_hi + rounding_hi + y_carry_lo) >> shift_minus_32;
+ const uint32_t z_abs_scaled = (uint32_t)(z_product_hi + rounding_hi + z_carry_lo) >> shift_minus_32;
+ const uint32_t w_abs_scaled = (uint32_t)(w_product_hi + rounding_hi + w_carry_lo) >> shift_minus_32;
+
+ /* Copy the sign of input to scaled absolute input value */
+ const int32_t x_scaled = (int32_t)(x >= 0 ? x_abs_scaled : -x_abs_scaled);
+ const int32_t y_scaled = (int32_t)(y >= 0 ? y_abs_scaled : -y_abs_scaled);
+ const int32_t z_scaled = (int32_t)(z >= 0 ? z_abs_scaled : -z_abs_scaled);
+ const int32_t w_scaled = (int32_t)(w >= 0 ? w_abs_scaled : -w_abs_scaled);
+
+ /*
+ * Clamp scaled value with zero point between (qmin - zero point) and (qmax - zero point).
+ */
+ const int32_t x_clamped = x_scaled < smin ? smin : x_scaled > smax ? smax : x_scaled;
+ const int32_t y_clamped = y_scaled < smin ? smin : y_scaled > smax ? smax : y_scaled;
+ const int32_t z_clamped = z_scaled < smin ? smin : z_scaled > smax ? smax : z_scaled;
+ const int32_t w_clamped = w_scaled < smin ? smin : w_scaled > smax ? smax : w_scaled;
+
+ /*
+ * Add zero point to clamped value.
+ * The result is guaranteed to be in [qmin, qmax] range.
+ *
+ * This addition can not be safely done before clamping, because scaled values are in [-2147483520, 2147483519]
+ * range, so addition of zero point (which can be up to 255) can overflow signed 32-bit integer.
+ */
+ const int32_t x_biased = x_clamped + zero_point;
+ const int32_t y_biased = y_clamped + zero_point;
+ const int32_t z_biased = z_clamped + zero_point;
+ const int32_t w_biased = w_clamped + zero_point;
+
+ output[0] = (uint8_t) x_biased;
+ output[1] = (uint8_t) y_biased;
+ output[2] = (uint8_t) z_biased;
+ output[3] = (uint8_t) w_biased;
+ output += 4;
+ }
+}
+
+void xnn_requantize_precise__scalar_unsigned64(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 4 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+ const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
+ const uint32_t shift = 127 + 23 - (scale_bits >> 23);
+ assert(shift >= 24);
+ assert(shift < 56);
+
+ const uint64_t rounding = UINT64_C(1) << (shift - 1);
+ const int32_t smin = (int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point;
+ const int32_t smax = (int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point;
+ for (; n != 0; n -= 4) {
+ const int32_t x = input[0];
+ const int32_t y = input[1];
+ const int32_t z = input[2];
+ const int32_t w = input[3];
+ input += 4;
+
+ /*
+ * Compute absolute value of input as unsigned 32-bit int.
+ * All further computations will work with unsigned values to avoid undefined behaviour on signed operations.
+ */
+ const uint32_t x_abs = (x >= 0) ? (uint32_t) x : -(uint32_t) x;
+ const uint32_t y_abs = (y >= 0) ? (uint32_t) y : -(uint32_t) y;
+ const uint32_t z_abs = (z >= 0) ? (uint32_t) z : -(uint32_t) z;
+ const uint32_t w_abs = (w >= 0) ? (uint32_t) w : -(uint32_t) w;
+
+ /* Compute full 64-bit product of 32-bit factors */
+ const uint64_t x_product = (uint64_t) x_abs * (uint64_t) multiplier;
+ const uint64_t y_product = (uint64_t) y_abs * (uint64_t) multiplier;
+ const uint64_t z_product = (uint64_t) z_abs * (uint64_t) multiplier;
+ const uint64_t w_product = (uint64_t) w_abs * (uint64_t) multiplier;
+
+ /*
+ * Shift the full 64-bit product right with rounding.
+ * Rounding is performed towards closest integer, with midpoints rounded up (same as away from zero).
+ *
+ * Note that although rounding is precomputed, it is dependent on shift value, and on processors with 64-bit
+ * "right shift with rounding" instruction each line below can be represented by just one such instruction
+ * (e.g. VRSHL.U64 on ARM NEON, URSHL in ARM64 Advanced SIMD).
+ */
+ const uint32_t x_abs_scaled = (uint32_t)((x_product + rounding) >> shift);
+ const uint32_t y_abs_scaled = (uint32_t)((y_product + rounding) >> shift);
+ const uint32_t z_abs_scaled = (uint32_t)((z_product + rounding) >> shift);
+ const uint32_t w_abs_scaled = (uint32_t)((w_product + rounding) >> shift);
+
+ /*
+ * Copy the sign of input to scaled absolute input value.
+ *
+ * On x86 processors with SSSE3 instruction set, this operation nicely maps to PSIGND instruction.
+ */
+ const int32_t x_scaled = (int32_t)(x >= 0 ? x_abs_scaled : -x_abs_scaled);
+ const int32_t y_scaled = (int32_t)(y >= 0 ? y_abs_scaled : -y_abs_scaled);
+ const int32_t z_scaled = (int32_t)(z >= 0 ? z_abs_scaled : -z_abs_scaled);
+ const int32_t w_scaled = (int32_t)(w >= 0 ? w_abs_scaled : -w_abs_scaled);
+
+ /*
+ * Clamp scaled value with zero point between (qmin - zero point) and (qmax - zero point).
+ */
+ const int32_t x_clamped = x_scaled < smin ? smin : x_scaled > smax ? smax : x_scaled;
+ const int32_t y_clamped = y_scaled < smin ? smin : y_scaled > smax ? smax : y_scaled;
+ const int32_t z_clamped = z_scaled < smin ? smin : z_scaled > smax ? smax : z_scaled;
+ const int32_t w_clamped = w_scaled < smin ? smin : w_scaled > smax ? smax : w_scaled;
+
+ /*
+ * Add zero point to clamped value.
+ * The result is guaranteed to be in [qmin, qmax] range.
+ *
+ * This addition can not be safely done before clamping, because scaled values are in [-2147483520, 2147483519]
+ * range, so addition of zero point (which can be up to 255) can overflow signed 32-bit integer.
+ */
+ const int32_t x_biased = x_clamped + zero_point;
+ const int32_t y_biased = y_clamped + zero_point;
+ const int32_t z_biased = z_clamped + zero_point;
+ const int32_t w_biased = w_clamped + zero_point;
+
+ output[0] = (uint8_t) x_biased;
+ output[1] = (uint8_t) y_biased;
+ output[2] = (uint8_t) z_biased;
+ output[3] = (uint8_t) w_biased;
+ output += 4;
+ }
+}
+
+void xnn_requantize_precise__scalar_signed64(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 4 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+ const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
+ const uint32_t shift = 127 + 23 - (scale_bits >> 23);
+ assert(shift >= 24);
+ assert(shift < 56);
+
+ const int64_t rounding = INT64_C(1) << (shift - 1);
+ const int32_t smin = (int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point;
+ const int32_t smax = (int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point;
+ for (; n != 0; n -= 4) {
+ const int32_t x = input[0];
+ const int32_t y = input[1];
+ const int32_t z = input[2];
+ const int32_t w = input[3];
+ input += 4;
+
+ /*
+ * Compute full 64-bit product of signed 32-bit factors.
+ *
+ * Note: multiplier can be treated as either signed or unsigned.
+ */
+ const int64_t x_product = (int64_t) x * (int64_t) multiplier;
+ const int64_t y_product = (int64_t) y * (int64_t) multiplier;
+ const int64_t z_product = (int64_t) z * (int64_t) multiplier;
+ const int64_t w_product = (int64_t) w * (int64_t) multiplier;
+
+ /*
+ * Adjust product before subsequent shift with rounding up to simulate shift with rounding away from zero.
+ */
+ const int64_t x_adjusted_product = x_product - (int64_t)(x < 0);
+ const int64_t y_adjusted_product = y_product - (int64_t)(y < 0);
+ const int64_t z_adjusted_product = z_product - (int64_t)(z < 0);
+ const int64_t w_adjusted_product = w_product - (int64_t)(w < 0);
+
+ /*
+ * Arithmetically shift the full 64-bit product right with rounding.
+ * Rounding is performed towards closest integer, with midpoints rounded up.
+ *
+ * Note that although rounding is precomputed, it is dependent on shift value, and on processors with 64-bit
+ * "right shift with rounding" instruction each line below can be represented by just one such instruction
+ * (e.g. VRSHL.S64 on ARM NEON, SRSHL in ARM64 Advanced SIMD).
+ */
+ const int32_t x_scaled = (int32_t) asr_s64(x_adjusted_product + rounding, shift);
+ const int32_t y_scaled = (int32_t) asr_s64(y_adjusted_product + rounding, shift);
+ const int32_t z_scaled = (int32_t) asr_s64(z_adjusted_product + rounding, shift);
+ const int32_t w_scaled = (int32_t) asr_s64(w_adjusted_product + rounding, shift);
+
+ /*
+ * Clamp scaled value with zero point between (qmin - zero point) and (qmax - zero point).
+ */
+ const int32_t x_clamped = x_scaled < smin ? smin : x_scaled > smax ? smax : x_scaled;
+ const int32_t y_clamped = y_scaled < smin ? smin : y_scaled > smax ? smax : y_scaled;
+ const int32_t z_clamped = z_scaled < smin ? smin : z_scaled > smax ? smax : z_scaled;
+ const int32_t w_clamped = w_scaled < smin ? smin : w_scaled > smax ? smax : w_scaled;
+
+ /*
+ * Add zero point to clamped value.
+ * The result is guaranteed to be in [qmin, qmax] range.
+ *
+ * This addition can not be safely done before clamping, because scaled values are in [-2147483520, 2147483519]
+ * range, so addition of zero point (which can be up to 255) can overflow signed 32-bit integer.
+ */
+ const int32_t x_biased = x_clamped + zero_point;
+ const int32_t y_biased = y_clamped + zero_point;
+ const int32_t z_biased = z_clamped + zero_point;
+ const int32_t w_biased = w_clamped + zero_point;
+
+ output[0] = (uint8_t) x_biased;
+ output[1] = (uint8_t) y_biased;
+ output[2] = (uint8_t) z_biased;
+ output[3] = (uint8_t) w_biased;
+ output += 4;
+ }
+}
diff --git a/src/requantization/precise-sse2.c b/src/requantization/precise-sse2.c
new file mode 100644
index 0000000..c82361c
--- /dev/null
+++ b/src/requantization/precise-sse2.c
@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <emmintrin.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_precise__sse2(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+ const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
+ const uint32_t shift = 127 + 23 - (scale_bits >> 23);
+ assert(shift >= 24);
+ assert(shift < 56);
+ const uint64_t rounding = UINT64_C(1) << (shift - 1);
+
+ const __m128i vmultiplier = _mm_set1_epi32(multiplier);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ const __m128i vshift = _mm_cvtsi32_si128((int) shift);
+ const __m128i vrounding = _mm_set1_epi64x(rounding);
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ const __m128i x_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), x);
+ const __m128i y_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), y);
+ const __m128i z_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), z);
+ const __m128i w_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), w);
+
+ const __m128i x_abs0123 = _mm_sub_epi32(_mm_xor_si128(x, x_neg_mask), x_neg_mask);
+ const __m128i y_abs0123 = _mm_sub_epi32(_mm_xor_si128(y, y_neg_mask), y_neg_mask);
+ const __m128i z_abs0123 = _mm_sub_epi32(_mm_xor_si128(z, z_neg_mask), z_neg_mask);
+ const __m128i w_abs0123 = _mm_sub_epi32(_mm_xor_si128(w, w_neg_mask), w_neg_mask);
+
+ const __m128i x_abs1032 = _mm_shuffle_epi32(x_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i y_abs1032 = _mm_shuffle_epi32(y_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i z_abs1032 = _mm_shuffle_epi32(z_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i w_abs1032 = _mm_shuffle_epi32(w_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i x_absmul02 = _mm_mul_epu32(x_abs0123, vmultiplier);
+ const __m128i y_absmul02 = _mm_mul_epu32(y_abs0123, vmultiplier);
+ const __m128i z_absmul02 = _mm_mul_epu32(z_abs0123, vmultiplier);
+ const __m128i w_absmul02 = _mm_mul_epu32(w_abs0123, vmultiplier);
+
+ const __m128i x_absmul13 = _mm_mul_epu32(x_abs1032, vmultiplier);
+ const __m128i y_absmul13 = _mm_mul_epu32(y_abs1032, vmultiplier);
+ const __m128i z_absmul13 = _mm_mul_epu32(z_abs1032, vmultiplier);
+ const __m128i w_absmul13 = _mm_mul_epu32(w_abs1032, vmultiplier);
+
+ const __m128i x_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(x_absmul02, vrounding), vshift);
+ const __m128i x_abs_scaled13 = _mm_srl_epi64(_mm_add_epi64(x_absmul13, vrounding), vshift);
+ const __m128i y_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(y_absmul02, vrounding), vshift);
+ const __m128i y_abs_scaled13 = _mm_srl_epi64(_mm_add_epi64(y_absmul13, vrounding), vshift);
+ const __m128i z_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(z_absmul02, vrounding), vshift);
+ const __m128i z_abs_scaled13 = _mm_srl_epi64(_mm_add_epi64(z_absmul13, vrounding), vshift);
+ const __m128i w_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(w_absmul02, vrounding), vshift);
+ const __m128i w_abs_scaled13 = _mm_srl_epi64(_mm_add_epi64(w_absmul13, vrounding), vshift);
+
+ const __m128i x_abs_scaled0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(x_abs_scaled02), _mm_castsi128_ps(x_abs_scaled13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i y_abs_scaled0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(y_abs_scaled02), _mm_castsi128_ps(y_abs_scaled13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i z_abs_scaled0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(z_abs_scaled02), _mm_castsi128_ps(z_abs_scaled13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i w_abs_scaled0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(w_abs_scaled02), _mm_castsi128_ps(w_abs_scaled13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i x_abs_scaled = _mm_shuffle_epi32(x_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i y_abs_scaled = _mm_shuffle_epi32(y_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i z_abs_scaled = _mm_shuffle_epi32(z_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i w_abs_scaled = _mm_shuffle_epi32(w_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i x_scaled = _mm_sub_epi32(_mm_xor_si128(x_abs_scaled, x_neg_mask), x_neg_mask);
+ const __m128i y_scaled = _mm_sub_epi32(_mm_xor_si128(y_abs_scaled, y_neg_mask), y_neg_mask);
+ const __m128i z_scaled = _mm_sub_epi32(_mm_xor_si128(z_abs_scaled, z_neg_mask), z_neg_mask);
+ const __m128i w_scaled = _mm_sub_epi32(_mm_xor_si128(w_abs_scaled, w_neg_mask), w_neg_mask);
+
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ /*
+ * 4x PXOR (setzero)
+ * 8x PSUBD
+ * 8x PXOR
+ * 8x PSHUFD
+ * 8x PMULUDQ
+ * 8x PSRLQ
+ * 8x PADDQ
+ * 4x SHUFPS
+ * 2x PACKSSDW
+ * 1x PACKUSWB
+ * 2x PADDW
+ * 1x PMAXUB
+ * 1x PMINUB
+ * ---------------------
+ * 63 instructions total
+ */
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/precise-sse4.c b/src/requantization/precise-sse4.c
new file mode 100644
index 0000000..974a3b1
--- /dev/null
+++ b/src/requantization/precise-sse4.c
@@ -0,0 +1,119 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <smmintrin.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_precise__sse4(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+ const uint32_t multiplier = (scale_bits << 8) | UINT32_C(0x80000000);
+ const uint32_t shift = 127 + 31 - (scale_bits >> 23);
+ assert(shift >= 32);
+ assert(shift < 64);
+ const uint64_t rounding = UINT64_C(1) << (shift - 1);
+
+ const __m128i vmultiplier = _mm_set1_epi32(multiplier);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ const __m128i vshiftlo = _mm_cvtsi32_si128((int) shift);
+ const __m128i vshifthi = _mm_cvtsi32_si128((int) shift - 32);
+ const __m128i vrounding = _mm_set1_epi64x(rounding);
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ const __m128i x_abs0123 = _mm_abs_epi32(x);
+ const __m128i y_abs0123 = _mm_abs_epi32(y);
+ const __m128i z_abs0123 = _mm_abs_epi32(z);
+ const __m128i w_abs0123 = _mm_abs_epi32(w);
+
+ const __m128i x_abs1032 = _mm_shuffle_epi32(x_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i y_abs1032 = _mm_shuffle_epi32(y_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i z_abs1032 = _mm_shuffle_epi32(z_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i w_abs1032 = _mm_shuffle_epi32(w_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i x_absmul02 = _mm_mul_epu32(x_abs0123, vmultiplier);
+ const __m128i y_absmul02 = _mm_mul_epu32(y_abs0123, vmultiplier);
+ const __m128i z_absmul02 = _mm_mul_epu32(z_abs0123, vmultiplier);
+ const __m128i w_absmul02 = _mm_mul_epu32(w_abs0123, vmultiplier);
+
+ const __m128i x_absmul13 = _mm_mul_epu32(x_abs1032, vmultiplier);
+ const __m128i y_absmul13 = _mm_mul_epu32(y_abs1032, vmultiplier);
+ const __m128i z_absmul13 = _mm_mul_epu32(z_abs1032, vmultiplier);
+ const __m128i w_absmul13 = _mm_mul_epu32(w_abs1032, vmultiplier);
+
+ const __m128i x_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(x_absmul02, vrounding), vshiftlo);
+ const __m128i x_abs_scaled13 = _mm_srl_epi32(_mm_add_epi64(x_absmul13, vrounding), vshifthi);
+ const __m128i y_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(y_absmul02, vrounding), vshiftlo);
+ const __m128i y_abs_scaled13 = _mm_srl_epi32(_mm_add_epi64(y_absmul13, vrounding), vshifthi);
+ const __m128i z_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(z_absmul02, vrounding), vshiftlo);
+ const __m128i z_abs_scaled13 = _mm_srl_epi32(_mm_add_epi64(z_absmul13, vrounding), vshifthi);
+ const __m128i w_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(w_absmul02, vrounding), vshiftlo);
+ const __m128i w_abs_scaled13 = _mm_srl_epi32(_mm_add_epi64(w_absmul13, vrounding), vshifthi);
+
+ const __m128i x_abs_scaled = _mm_blend_epi16(x_abs_scaled02, x_abs_scaled13, 0xCC);
+ const __m128i y_abs_scaled = _mm_blend_epi16(y_abs_scaled02, y_abs_scaled13, 0xCC);
+ const __m128i z_abs_scaled = _mm_blend_epi16(z_abs_scaled02, z_abs_scaled13, 0xCC);
+ const __m128i w_abs_scaled = _mm_blend_epi16(w_abs_scaled02, w_abs_scaled13, 0xCC);
+
+ const __m128i x_scaled = _mm_sign_epi32(x_abs_scaled, x);
+ const __m128i y_scaled = _mm_sign_epi32(y_abs_scaled, y);
+ const __m128i z_scaled = _mm_sign_epi32(z_abs_scaled, z);
+ const __m128i w_scaled = _mm_sign_epi32(w_abs_scaled, w);
+
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ /*
+ * 4x PABSD
+ * 4x PSHUFD
+ * 8x PMULUDQ
+ * 4x PSRLQ
+ * 4x PSRLD
+ * 8x PADDQ
+ * 4x PBLENDW
+ * 4x PSIGND
+ * 2x PACKSSDW
+ * 1x PACKUSWB
+ * 2x PADDW
+ * 1x PMAXUB
+ * 1x PMINUB
+ * ---------------------
+ * 47 instructions total
+ */
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/precise-ssse3.c b/src/requantization/precise-ssse3.c
new file mode 100644
index 0000000..626c0eb
--- /dev/null
+++ b/src/requantization/precise-ssse3.c
@@ -0,0 +1,126 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <tmmintrin.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_precise__ssse3(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+ const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
+ const uint32_t shift = 127 + 23 - (scale_bits >> 23);
+ assert(shift >= 24);
+ assert(shift < 56);
+ const uint64_t rounding = UINT64_C(1) << (shift - 1);
+
+ const __m128i vmultiplier = _mm_set1_epi32(multiplier);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ const __m128i vshift = _mm_cvtsi32_si128((int) shift);
+ const __m128i vrounding = _mm_set1_epi64x(rounding);
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ const __m128i x_abs0123 = _mm_abs_epi32(x);
+ const __m128i y_abs0123 = _mm_abs_epi32(y);
+ const __m128i z_abs0123 = _mm_abs_epi32(z);
+ const __m128i w_abs0123 = _mm_abs_epi32(w);
+
+ const __m128i x_abs1032 = _mm_shuffle_epi32(x_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i y_abs1032 = _mm_shuffle_epi32(y_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i z_abs1032 = _mm_shuffle_epi32(z_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i w_abs1032 = _mm_shuffle_epi32(w_abs0123, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i x_absmul02 = _mm_mul_epu32(x_abs0123, vmultiplier);
+ const __m128i y_absmul02 = _mm_mul_epu32(y_abs0123, vmultiplier);
+ const __m128i z_absmul02 = _mm_mul_epu32(z_abs0123, vmultiplier);
+ const __m128i w_absmul02 = _mm_mul_epu32(w_abs0123, vmultiplier);
+
+ const __m128i x_absmul13 = _mm_mul_epu32(x_abs1032, vmultiplier);
+ const __m128i y_absmul13 = _mm_mul_epu32(y_abs1032, vmultiplier);
+ const __m128i z_absmul13 = _mm_mul_epu32(z_abs1032, vmultiplier);
+ const __m128i w_absmul13 = _mm_mul_epu32(w_abs1032, vmultiplier);
+
+ const __m128i x_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(x_absmul02, vrounding), vshift);
+ const __m128i x_abs_scaled13 = _mm_srl_epi64(_mm_add_epi64(x_absmul13, vrounding), vshift);
+ const __m128i y_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(y_absmul02, vrounding), vshift);
+ const __m128i y_abs_scaled13 = _mm_srl_epi64(_mm_add_epi64(y_absmul13, vrounding), vshift);
+ const __m128i z_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(z_absmul02, vrounding), vshift);
+ const __m128i z_abs_scaled13 = _mm_srl_epi64(_mm_add_epi64(z_absmul13, vrounding), vshift);
+ const __m128i w_abs_scaled02 = _mm_srl_epi64(_mm_add_epi64(w_absmul02, vrounding), vshift);
+ const __m128i w_abs_scaled13 = _mm_srl_epi64(_mm_add_epi64(w_absmul13, vrounding), vshift);
+
+ const __m128i x_abs_scaled0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(x_abs_scaled02), _mm_castsi128_ps(x_abs_scaled13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i y_abs_scaled0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(y_abs_scaled02), _mm_castsi128_ps(y_abs_scaled13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i z_abs_scaled0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(z_abs_scaled02), _mm_castsi128_ps(z_abs_scaled13), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i w_abs_scaled0213 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(w_abs_scaled02), _mm_castsi128_ps(w_abs_scaled13), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i x_abs_scaled = _mm_shuffle_epi32(x_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i y_abs_scaled = _mm_shuffle_epi32(y_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i z_abs_scaled = _mm_shuffle_epi32(z_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i w_abs_scaled = _mm_shuffle_epi32(w_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i x_scaled = _mm_sign_epi32(x_abs_scaled, x);
+ const __m128i y_scaled = _mm_sign_epi32(y_abs_scaled, y);
+ const __m128i z_scaled = _mm_sign_epi32(z_abs_scaled, z);
+ const __m128i w_scaled = _mm_sign_epi32(w_abs_scaled, w);
+
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ /*
+ * 4x PABSD
+ * 8x PSHUFD
+ * 8x PMULUDQ
+ * 8x PSRLQ
+ * 8x PADDQ
+ * 4x SHUFPS
+ * 4x PSIGND
+ * 2x PACKSSDW
+ * 1x PACKUSWB
+ * 2x PADDW
+ * 1x PMAXUB
+ * 1x PMINUB
+ * ---------------------
+ * 51 instructions total
+ */
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/q31-neon.c b/src/requantization/q31-neon.c
new file mode 100644
index 0000000..37986bc
--- /dev/null
+++ b/src/requantization/q31-neon.c
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <arm_neon.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_q31__neon(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ /* Compute requantization parameters */
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
+ const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
+ assert(multiplier >= INT32_C(0x40000000));
+ assert(multiplier <= INT32_C(0x7FFFFF80));
+
+ /* Shift is in [0, 31] range */
+ const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
+ assert(shift >= 0);
+ assert(shift < 32);
+
+ const int32x4_t vmultiplier = vdupq_n_s32(multiplier);
+ const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t) zero_point);
+ const int32x4_t vshift = vdupq_n_s32(-shift);
+ const int32x4_t vshift_eq_0_mask = vreinterpretq_s32_u32(vceqq_s32(vshift, vmovq_n_s32(0)));
+ const uint8x16_t vqmin = vdupq_n_u8(qmin);
+ const uint8x16_t vqmax = vdupq_n_u8(qmax);
+ for (; n != 0; n -= 16) {
+ const int32x4_t x = vld1q_s32(input);
+ const int32x4_t y = vld1q_s32(input + 4);
+ const int32x4_t z = vld1q_s32(input + 8);
+ const int32x4_t w = vld1q_s32(input + 12);
+ input += 16;
+
+ /*
+ * Directly use VQRDMULH/SQRDMULH instruction for Q31 multiplication with rounding.
+ * Although these instruction saturate out-of-range outputs, we never hit this case in requantization.
+ */
+ const int32x4_t x_product = vqrdmulhq_s32(x, vmultiplier);
+ const int32x4_t y_product = vqrdmulhq_s32(y, vmultiplier);
+ const int32x4_t z_product = vqrdmulhq_s32(z, vmultiplier);
+ const int32x4_t w_product = vqrdmulhq_s32(w, vmultiplier);
+
+ /*
+ * Shift the 32-bit product right with rounding.
+ * Rounding is performed towards closest integer, with midpoints rounded up (same as away from zero).
+ *
+ * We leverage the "right shift with rounding" instruction (VRSHL.S32 on ARM NEON, SRSHL in ARM64 Advanced SIMD) to
+ * do the shift. However, as this instruction rounds midpoints up, rather than away from zero, we adjust the input
+ * by subtracting 1 from negative values, but only if shift is non-zero.
+ */
+ const int32x4_t x_adjusted_product = vsraq_n_s32(x_product, vbicq_s32(x, vshift_eq_0_mask), 31);
+ const int32x4_t y_adjusted_product = vsraq_n_s32(y_product, vbicq_s32(y, vshift_eq_0_mask), 31);
+ const int32x4_t z_adjusted_product = vsraq_n_s32(z_product, vbicq_s32(z, vshift_eq_0_mask), 31);
+ const int32x4_t w_adjusted_product = vsraq_n_s32(w_product, vbicq_s32(w, vshift_eq_0_mask), 31);
+
+ const int32x4_t x_scaled = vrshlq_s32(x_adjusted_product, vshift);
+ const int32x4_t y_scaled = vrshlq_s32(y_adjusted_product, vshift);
+ const int32x4_t z_scaled = vrshlq_s32(z_adjusted_product, vshift);
+ const int32x4_t w_scaled = vrshlq_s32(w_adjusted_product, vshift);
+
+#ifdef __aarch64__
+ const int16x8_t xy_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(x_scaled), y_scaled), vzero_point);
+ const int16x8_t zw_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(z_scaled), w_scaled), vzero_point);
+ const uint8x16_t xyzw_packed = vqmovun_high_s16(vqmovun_s16(xy_packed), zw_packed);
+#else
+ const int16x8_t xy_packed = vqaddq_s16(vcombine_s16(vqmovn_s32(x_scaled), vqmovn_s32(y_scaled)), vzero_point);
+ const int16x8_t zw_packed = vqaddq_s16(vcombine_s16(vqmovn_s32(z_scaled), vqmovn_s32(w_scaled)), vzero_point);
+ const uint8x16_t xyzw_packed = vcombine_u8(vqmovun_s16(xy_packed), vqmovun_s16(zw_packed));
+#endif
+
+ const uint8x16_t xyzw_clamped = vmaxq_u8(vminq_u8(xyzw_packed, vqmax), vqmin);
+
+ /*
+ * AArch32 version:
+ * 4x VQRDMULH.S32 Qd, Qm, Qn
+ * 4x VAND Qd, Qm, Dn
+ * 4x VSRA.S32 Qd, Qm, #31
+ * 4x VRSHL.S32 Qd, Qm, Qn
+ * 4x VQMOVN.S32 Dd, Qm
+ * 2x VADD.S16 Qd, Qm, Qn
+ * 2x VQMOVUN.S16 Dd, Qm
+ * 1x VMAX.U8 Qd, Qm, Qn
+ * 1x VMIN.U8 Qd, Qm, Qn
+ * ---------------------
+ * 26 instructions total
+ *
+ * AArch64 version:
+ * 4x SQRDMULH Vd.4S, Vn.4S, Vm.4S
+ * 4x AND Vd.16B, Vn.16B, Vm.16B
+ * 4x SSRA Vd.4S, Vn.4S, #31
+ * 4x SRSHL Vd.4S, Vn.4S, Vm.4S
+ * 2x SQXTN Vd.4H, Vn.4S
+ * 2x SQXTN2 Vd.8H, Vn.4S
+ * 2x ADD Vd.8H, Vn.8H, Vm.8H
+ * 1x SQXTUN Vd.8B, Vn.8H
+ * 1x SQXTUN2 Vd.16B, Vn.8H
+ * 1x UMIN Vd.16B, Vn.16B, Vm.16B
+ * 1x UMAX Vd.16B, Vn.16B, Vm.16B
+ * ---------------------
+ * 26 instructions total
+ */
+
+ vst1q_u8(output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/q31-scalar.c b/src/requantization/q31-scalar.c
new file mode 100644
index 0000000..1677d0b
--- /dev/null
+++ b/src/requantization/q31-scalar.c
@@ -0,0 +1,140 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/scalar-utils.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_q31__scalar(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 4 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ /* Compute requantization parameters */
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
+ const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
+ assert(multiplier >= INT32_C(0x40000000));
+ assert(multiplier <= INT32_C(0x7FFFFF80));
+
+ /* Shift is in [0, 31] range */
+ const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
+ assert(shift >= 0);
+ assert(shift < 32);
+
+ const int64_t q31rounding = INT64_C(0x40000000);
+ const int32_t remainder_mask = (int32_t)((UINT32_C(1) << shift) - UINT32_C(1));
+ const int32_t threshold = (int32_t)((uint32_t) remainder_mask >> 1);
+ const int32_t smin = (int32_t)(uint32_t) qmin - (int32_t)(uint32_t) zero_point;
+ const int32_t smax = (int32_t)(uint32_t) qmax - (int32_t)(uint32_t) zero_point;
+ for (; n != 0; n -= 4) {
+ const int32_t x = input[0];
+ const int32_t y = input[1];
+ const int32_t z = input[2];
+ const int32_t w = input[3];
+ input += 4;
+
+ /*
+ * Compute full 64-bit product of signed 32-bit factors.
+ *
+ * Note: multiplier can be treated as either signed or unsigned.
+ */
+ const int64_t x_product = (int64_t) x * (int64_t) multiplier;
+ const int64_t y_product = (int64_t) y * (int64_t) multiplier;
+ const int64_t z_product = (int64_t) z * (int64_t) multiplier;
+ const int64_t w_product = (int64_t) w * (int64_t) multiplier;
+
+ /*
+ * Get the Q31 multiplication result by extracting bits 31-62 of the product, with rounding up.
+ * Add rounding value (0x40000000) and then shift right by 31 bits and extract the low 32-bit word.
+ * Note: casts to unsigned types are needed to avoid undefined behavior.
+ * Given the multiplier range, the result of Q31 multiplication is in [-2147483520, 2147483519] range.
+ */
+ const int32_t x_q31product = (int32_t)(uint32_t)((uint64_t)(x_product + q31rounding) >> 31);
+ const int32_t y_q31product = (int32_t)(uint32_t)((uint64_t)(y_product + q31rounding) >> 31);
+ const int32_t z_q31product = (int32_t)(uint32_t)((uint64_t)(z_product + q31rounding) >> 31);
+ const int32_t w_q31product = (int32_t)(uint32_t)((uint64_t)(w_product + q31rounding) >> 31);
+
+ /*
+ * Arithmetically shift the adjusted product right with rounding.
+ * Rounding is performed towards closest integer, with midpoints rounded away from zero.
+ *
+ * Shift with correct rounding could be efficiently implemented by pre-adding rounding constant, but with input in
+ * [-2147483520, 2147483519] range and rounding constant up to 2**30 we can't rule out overflow. This limitation
+ * leaves us with 3 options:
+ * 1. Extend input to 64-bit signed integer, perform addition and shift on 64-bit integers, then truncate result
+ * to 32 bits.
+ * 2. Detect overflow and handle this situation separately. Note that overflow is possible only when input is
+ * positive, and even when addition of a rounding constant overflows 32-bit signed integer, it still doesn't
+ * overflow 32-bit unsigned integer. Thus, in case of signed overflow, we can compute the result using unsigned
+ * arithmetics, specifically using logical shift right instead of arithmetic shift right.
+ * 3. Performs arithmetic shift as is, which will produce division result rounded down. Then compute remainder of
+ * this division by a power of 2, and adjust the result. Result needs adjustment (increment by 1) when
+ * - input is positive, shift is non-zero, and remainder >= 2**(shift - 1), e.g. 10 >> 2 needs adjustment
+ * - input is negative, shift is non-zero, and remainder > 2**(shift - 1), e.g. -10 >> 2 doesn't need adjustment
+ * These conditions can be generalized as
+ * remainder + (input <= 0) > 2**(shift - 1)
+ * or equivalently
+ * remainder - (input < 0) > ((2**shift - 1) >> 1)
+ * When shift is 0, remainder is 0 as well, the last condition is always false, and no adjustment is done.
+ *
+ * Among these options, option 3 is the most performant across the board, although option 1 is promising for 64-bit
+ * instruction sets.
+ */
+ const int32_t x_remainder = (x_q31product & remainder_mask) - (int32_t)(x_q31product < 0);
+ const int32_t y_remainder = (y_q31product & remainder_mask) - (int32_t)(y_q31product < 0);
+ const int32_t z_remainder = (z_q31product & remainder_mask) - (int32_t)(z_q31product < 0);
+ const int32_t w_remainder = (w_q31product & remainder_mask) - (int32_t)(w_q31product < 0);
+
+ const int32_t x_scaled = asr_s32(x_q31product, shift) + (int32_t)(x_remainder > threshold);
+ const int32_t y_scaled = asr_s32(y_q31product, shift) + (int32_t)(y_remainder > threshold);
+ const int32_t z_scaled = asr_s32(z_q31product, shift) + (int32_t)(z_remainder > threshold);
+ const int32_t w_scaled = asr_s32(w_q31product, shift) + (int32_t)(w_remainder > threshold);
+
+ /*
+ * Clamp scaled value with zero point between (qmin - zero point) and (qmax - zero point).
+ */
+ const int32_t x_clamped = x_scaled < smin ? smin : x_scaled > smax ? smax : x_scaled;
+ const int32_t y_clamped = y_scaled < smin ? smin : y_scaled > smax ? smax : y_scaled;
+ const int32_t z_clamped = z_scaled < smin ? smin : z_scaled > smax ? smax : z_scaled;
+ const int32_t w_clamped = w_scaled < smin ? smin : w_scaled > smax ? smax : w_scaled;
+
+ /*
+ * Add zero point to clamped value.
+ * The result is guaranteed to be in [qmin, qmax] range.
+ *
+ * This addition can not be safely done before clamping, because scaled values are in [-2147483520, 2147483519]
+ * range, so addition of zero point (which can be up to 255) can overflow signed 32-bit integer.
+ */
+ const int32_t x_biased = x_clamped + zero_point;
+ const int32_t y_biased = y_clamped + zero_point;
+ const int32_t z_biased = z_clamped + zero_point;
+ const int32_t w_biased = w_clamped + zero_point;
+
+ output[0] = (uint8_t) x_biased;
+ output[1] = (uint8_t) y_biased;
+ output[2] = (uint8_t) z_biased;
+ output[3] = (uint8_t) w_biased;
+ output += 4;
+ }
+}
diff --git a/src/requantization/q31-sse2.c b/src/requantization/q31-sse2.c
new file mode 100644
index 0000000..4223e8a
--- /dev/null
+++ b/src/requantization/q31-sse2.c
@@ -0,0 +1,189 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <emmintrin.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_q31__sse2(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ /* Compute requantization parameters */
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
+ const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
+ assert(multiplier >= INT32_C(0x40000000));
+ assert(multiplier <= INT32_C(0x7FFFFF80));
+
+ /* Shift is in [0, 31] range */
+ const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
+ assert(shift >= 0);
+ assert(shift < 32);
+
+ const __m128i vmultiplier = _mm_set1_epi32(multiplier);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ const __m128i vshift = _mm_cvtsi32_si128((int) shift);
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const __m128i vremainder_mask = _mm_set1_epi32((int) remainder_mask);
+ const __m128i vthreshold = _mm_set1_epi32((int) (remainder_mask >> 1));
+ const __m128i vq31rounding = _mm_set1_epi64x(UINT64_C(0x40000000));
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ const __m128i x_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), x);
+ const __m128i y_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), y);
+ const __m128i z_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), z);
+ const __m128i w_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), w);
+
+ const __m128i x_abs = _mm_sub_epi32(_mm_xor_si128(x, x_neg_mask), x_neg_mask);
+ const __m128i y_abs = _mm_sub_epi32(_mm_xor_si128(y, y_neg_mask), y_neg_mask);
+ const __m128i z_abs = _mm_sub_epi32(_mm_xor_si128(z, z_neg_mask), z_neg_mask);
+ const __m128i w_abs = _mm_sub_epi32(_mm_xor_si128(w, w_neg_mask), w_neg_mask);
+
+ const __m128i x_abs_rev = _mm_shuffle_epi32(x_abs, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i y_abs_rev = _mm_shuffle_epi32(y_abs, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i z_abs_rev = _mm_shuffle_epi32(z_abs, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i w_abs_rev = _mm_shuffle_epi32(w_abs, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i x_abs_product_even = _mm_mul_epu32(x_abs, vmultiplier);
+ const __m128i y_abs_product_even = _mm_mul_epu32(y_abs, vmultiplier);
+ const __m128i z_abs_product_even = _mm_mul_epu32(z_abs, vmultiplier);
+ const __m128i w_abs_product_even = _mm_mul_epu32(w_abs, vmultiplier);
+
+ const __m128i x_neg_mask_even = _mm_shuffle_epi32(x_neg_mask, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i y_neg_mask_even = _mm_shuffle_epi32(y_neg_mask, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i z_neg_mask_even = _mm_shuffle_epi32(z_neg_mask, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i w_neg_mask_even = _mm_shuffle_epi32(w_neg_mask, _MM_SHUFFLE(2, 2, 0, 0));
+
+ const __m128i x_product_even = _mm_sub_epi64(_mm_xor_si128(x_abs_product_even, x_neg_mask_even), x_neg_mask_even);
+ const __m128i y_product_even = _mm_sub_epi64(_mm_xor_si128(y_abs_product_even, y_neg_mask_even), y_neg_mask_even);
+ const __m128i z_product_even = _mm_sub_epi64(_mm_xor_si128(z_abs_product_even, z_neg_mask_even), z_neg_mask_even);
+ const __m128i w_product_even = _mm_sub_epi64(_mm_xor_si128(w_abs_product_even, w_neg_mask_even), w_neg_mask_even);
+
+ const __m128i x_rounded_product_even = _mm_add_epi64(x_product_even, vq31rounding);
+ const __m128i y_rounded_product_even = _mm_add_epi64(y_product_even, vq31rounding);
+ const __m128i z_rounded_product_even = _mm_add_epi64(z_product_even, vq31rounding);
+ const __m128i w_rounded_product_even = _mm_add_epi64(w_product_even, vq31rounding);
+
+ const __m128i x_abs_product_odd = _mm_mul_epu32(x_abs_rev, vmultiplier);
+ const __m128i y_abs_product_odd = _mm_mul_epu32(y_abs_rev, vmultiplier);
+ const __m128i z_abs_product_odd = _mm_mul_epu32(z_abs_rev, vmultiplier);
+ const __m128i w_abs_product_odd = _mm_mul_epu32(w_abs_rev, vmultiplier);
+
+ const __m128i x_neg_mask_odd = _mm_shuffle_epi32(x_neg_mask, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i y_neg_mask_odd = _mm_shuffle_epi32(y_neg_mask, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i z_neg_mask_odd = _mm_shuffle_epi32(z_neg_mask, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i w_neg_mask_odd = _mm_shuffle_epi32(w_neg_mask, _MM_SHUFFLE(3, 3, 1, 1));
+
+ const __m128i x_product_odd = _mm_sub_epi64(_mm_xor_si128(x_abs_product_odd, x_neg_mask_odd), x_neg_mask_odd);
+ const __m128i y_product_odd = _mm_sub_epi64(_mm_xor_si128(y_abs_product_odd, y_neg_mask_odd), y_neg_mask_odd);
+ const __m128i z_product_odd = _mm_sub_epi64(_mm_xor_si128(z_abs_product_odd, z_neg_mask_odd), z_neg_mask_odd);
+ const __m128i w_product_odd = _mm_sub_epi64(_mm_xor_si128(w_abs_product_odd, w_neg_mask_odd), w_neg_mask_odd);
+
+ const __m128i x_rounded_product_odd = _mm_add_epi64(x_product_odd, vq31rounding);
+ const __m128i y_rounded_product_odd = _mm_add_epi64(y_product_odd, vq31rounding);
+ const __m128i z_rounded_product_odd = _mm_add_epi64(z_product_odd, vq31rounding);
+ const __m128i w_rounded_product_odd = _mm_add_epi64(w_product_odd, vq31rounding);
+
+ const __m128i x_q31product_even = _mm_srli_epi64(x_rounded_product_even, 31);
+ const __m128i x_q31product_odd = _mm_srli_epi64(x_rounded_product_odd, 31);
+ const __m128i y_q31product_even = _mm_srli_epi64(y_rounded_product_even, 31);
+ const __m128i y_q31product_odd = _mm_srli_epi64(y_rounded_product_odd, 31);
+ const __m128i z_q31product_even = _mm_srli_epi64(z_rounded_product_even, 31);
+ const __m128i z_q31product_odd = _mm_srli_epi64(z_rounded_product_odd, 31);
+ const __m128i w_q31product_even = _mm_srli_epi64(w_rounded_product_even, 31);
+ const __m128i w_q31product_odd = _mm_srli_epi64(w_rounded_product_odd, 31);
+
+ const __m128i x_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(x_q31product_even), _mm_castsi128_ps(x_q31product_odd), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i y_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(y_q31product_even), _mm_castsi128_ps(y_q31product_odd), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i z_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(z_q31product_even), _mm_castsi128_ps(z_q31product_odd), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i w_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(w_q31product_even), _mm_castsi128_ps(w_q31product_odd), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i x_q31product = _mm_shuffle_epi32(x_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i y_q31product = _mm_shuffle_epi32(y_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i z_q31product = _mm_shuffle_epi32(z_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i w_q31product = _mm_shuffle_epi32(w_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i x_remainder =
+ _mm_add_epi32(_mm_and_si128(x_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), x_q31product));
+ const __m128i y_remainder =
+ _mm_add_epi32(_mm_and_si128(y_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), y_q31product));
+ const __m128i z_remainder =
+ _mm_add_epi32(_mm_and_si128(z_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), z_q31product));
+ const __m128i w_remainder =
+ _mm_add_epi32(_mm_and_si128(w_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), w_q31product));
+
+ const __m128i x_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(x_q31product, vshift), _mm_cmpgt_epi32(x_remainder, vthreshold));
+ const __m128i y_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(y_q31product, vshift), _mm_cmpgt_epi32(y_remainder, vthreshold));
+ const __m128i z_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(z_q31product, vshift), _mm_cmpgt_epi32(z_remainder, vthreshold));
+ const __m128i w_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(w_q31product, vshift), _mm_cmpgt_epi32(w_remainder, vthreshold));
+
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ /*
+ * 16x PSHUFD
+ * 4x SHUFPS
+ * 8x PMULUDQ
+ * 8x PXOR (setzero)
+ * 12x PXOR
+ * 4x PAND
+ * 8x PADDQ
+ * 4x PADDD
+ * 2x PADDW
+ * 8x PSUBQ
+ * 8x PSUBD
+ * 8x PSRLQ (immediate)
+ * 4x PSRAD (register)
+ * 12x PCMPGTD
+ * 2x PACKSSDW
+ * 1x PACKUSWB
+ * 1x PMAXUB
+ * 1x PMINUB
+ * ---------------------
+ * 111 instructions total
+ */
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/q31-sse4.c b/src/requantization/q31-sse4.c
new file mode 100644
index 0000000..c598d6b
--- /dev/null
+++ b/src/requantization/q31-sse4.c
@@ -0,0 +1,138 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <smmintrin.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_q31__sse4(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ /* Compute requantization parameters */
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
+ const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
+ assert(multiplier >= INT32_C(0x40000000));
+ assert(multiplier <= INT32_C(0x7FFFFF80));
+
+ /* Shift is in [0, 31] range */
+ const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
+ assert(shift >= 0);
+ assert(shift < 32);
+
+ const __m128i vmultiplier = _mm_set1_epi32(multiplier);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ const __m128i vshift = _mm_cvtsi32_si128((int) shift);
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const __m128i vremainder_mask = _mm_set1_epi32((int) remainder_mask);
+ const __m128i vthreshold = _mm_set1_epi32((int) (remainder_mask >> 1));
+ const __m128i vq31rounding = _mm_set1_epi64x(UINT64_C(0x40000000));
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ const __m128i x_rev = _mm_shuffle_epi32(x, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i y_rev = _mm_shuffle_epi32(y, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i z_rev = _mm_shuffle_epi32(z, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i w_rev = _mm_shuffle_epi32(w, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i x_product_even = _mm_add_epi64(_mm_mul_epi32(x, vmultiplier), vq31rounding);
+ const __m128i y_product_even = _mm_add_epi64(_mm_mul_epi32(y, vmultiplier), vq31rounding);
+ const __m128i z_product_even = _mm_add_epi64(_mm_mul_epi32(z, vmultiplier), vq31rounding);
+ const __m128i w_product_even = _mm_add_epi64(_mm_mul_epi32(w, vmultiplier), vq31rounding);
+
+ const __m128i x_product_odd = _mm_add_epi64(_mm_mul_epi32(x_rev, vmultiplier), vq31rounding);
+ const __m128i y_product_odd = _mm_add_epi64(_mm_mul_epi32(y_rev, vmultiplier), vq31rounding);
+ const __m128i z_product_odd = _mm_add_epi64(_mm_mul_epi32(z_rev, vmultiplier), vq31rounding);
+ const __m128i w_product_odd = _mm_add_epi64(_mm_mul_epi32(w_rev, vmultiplier), vq31rounding);
+
+ const __m128i x_q31product_even = _mm_srli_epi64(x_product_even, 31);
+ const __m128i x_q31product_odd = _mm_add_epi64(x_product_odd, x_product_odd);
+ const __m128i y_q31product_even = _mm_srli_epi64(y_product_even, 31);
+ const __m128i y_q31product_odd = _mm_add_epi64(y_product_odd, y_product_odd);
+ const __m128i z_q31product_even = _mm_srli_epi64(z_product_even, 31);
+ const __m128i z_q31product_odd = _mm_add_epi64(z_product_odd, z_product_odd);
+ const __m128i w_q31product_even = _mm_srli_epi64(w_product_even, 31);
+ const __m128i w_q31product_odd = _mm_add_epi64(w_product_odd, w_product_odd);
+
+ const __m128i x_q31product = _mm_blend_epi16(x_q31product_even, x_q31product_odd, 0xCC);
+ const __m128i y_q31product = _mm_blend_epi16(y_q31product_even, y_q31product_odd, 0xCC);
+ const __m128i z_q31product = _mm_blend_epi16(z_q31product_even, z_q31product_odd, 0xCC);
+ const __m128i w_q31product = _mm_blend_epi16(w_q31product_even, w_q31product_odd, 0xCC);
+
+ const __m128i x_remainder =
+ _mm_add_epi32(_mm_and_si128(x_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), x_q31product));
+ const __m128i y_remainder =
+ _mm_add_epi32(_mm_and_si128(y_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), y_q31product));
+ const __m128i z_remainder =
+ _mm_add_epi32(_mm_and_si128(z_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), z_q31product));
+ const __m128i w_remainder =
+ _mm_add_epi32(_mm_and_si128(w_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), w_q31product));
+
+ const __m128i x_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(x_q31product, vshift), _mm_cmpgt_epi32(x_remainder, vthreshold));
+ const __m128i y_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(y_q31product, vshift), _mm_cmpgt_epi32(y_remainder, vthreshold));
+ const __m128i z_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(z_q31product, vshift), _mm_cmpgt_epi32(z_remainder, vthreshold));
+ const __m128i w_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(w_q31product, vshift), _mm_cmpgt_epi32(w_remainder, vthreshold));
+
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ /*
+ * 4x PSHUFD
+ * 8x PMULDQ
+ * 12x PADDQ
+ * 4x PADDD
+ * 2x PADDW
+ * 4x PSUBD
+ * 4x PSLRQ (immediate)
+ * 4x PSRAD (register)
+ * 4x PBLENDW
+ * 4x PAND
+ * 4x PXOR (setzero)
+ * 8x PCMPGTD
+ * 2x PACKSSDW
+ * 1x PACKUSWB
+ * 1x PMAXUB
+ * 1x PMINUB
+ * ---------------------
+ * 67 instructions total
+ */
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/requantization/q31-ssse3.c b/src/requantization/q31-ssse3.c
new file mode 100644
index 0000000..368ae75
--- /dev/null
+++ b/src/requantization/q31-ssse3.c
@@ -0,0 +1,190 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <tmmintrin.h>
+
+#include <fp16/bitcasts.h>
+#include <xnnpack/requantization-stubs.h>
+
+
+void xnn_requantize_q31__ssse3(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output)
+{
+ assert(n % 16 == 0);
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ /* Compute requantization parameters */
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
+ const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
+ assert(multiplier >= INT32_C(0x40000000));
+ assert(multiplier <= INT32_C(0x7FFFFF80));
+
+ /* Shift is in [0, 31] range */
+ const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
+ assert(shift >= 0);
+ assert(shift < 32);
+
+ const __m128i vmultiplier = _mm_set1_epi32(multiplier);
+ const __m128i vzero_point = _mm_set1_epi16((short) (uint16_t) zero_point);
+ const __m128i vqmin = _mm_set1_epi8((char) qmin);
+ const __m128i vqmax = _mm_set1_epi8((char) qmax);
+ const __m128i vshift = _mm_cvtsi32_si128((int) shift);
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const __m128i vremainder_mask = _mm_set1_epi32((int) remainder_mask);
+ const __m128i vthreshold = _mm_set1_epi32((int) (remainder_mask >> 1));
+ const __m128i vq31rounding = _mm_set1_epi64x(UINT64_C(0x40000000));
+ for (; n != 0; n -= 16) {
+ const __m128i x = _mm_loadu_si128((const __m128i*) input);
+ const __m128i y = _mm_loadu_si128((const __m128i*) (input + 4));
+ const __m128i z = _mm_loadu_si128((const __m128i*) (input + 8));
+ const __m128i w = _mm_loadu_si128((const __m128i*) (input + 12));
+ input += 16;
+
+ const __m128i x_abs = _mm_abs_epi32(x);
+ const __m128i y_abs = _mm_abs_epi32(y);
+ const __m128i z_abs = _mm_abs_epi32(z);
+ const __m128i w_abs = _mm_abs_epi32(w);
+
+ const __m128i x_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), x);
+ const __m128i y_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), y);
+ const __m128i z_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), z);
+ const __m128i w_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), w);
+
+ const __m128i x_abs_rev = _mm_shuffle_epi32(x_abs, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i y_abs_rev = _mm_shuffle_epi32(y_abs, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i z_abs_rev = _mm_shuffle_epi32(z_abs, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i w_abs_rev = _mm_shuffle_epi32(w_abs, _MM_SHUFFLE(2, 3, 0, 1));
+
+ const __m128i x_abs_product_even = _mm_mul_epu32(x_abs, vmultiplier);
+ const __m128i y_abs_product_even = _mm_mul_epu32(y_abs, vmultiplier);
+ const __m128i z_abs_product_even = _mm_mul_epu32(z_abs, vmultiplier);
+ const __m128i w_abs_product_even = _mm_mul_epu32(w_abs, vmultiplier);
+
+ const __m128i x_neg_mask_even = _mm_shuffle_epi32(x_neg_mask, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i y_neg_mask_even = _mm_shuffle_epi32(y_neg_mask, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i z_neg_mask_even = _mm_shuffle_epi32(z_neg_mask, _MM_SHUFFLE(2, 2, 0, 0));
+ const __m128i w_neg_mask_even = _mm_shuffle_epi32(w_neg_mask, _MM_SHUFFLE(2, 2, 0, 0));
+
+ const __m128i x_product_even = _mm_sub_epi64(_mm_xor_si128(x_abs_product_even, x_neg_mask_even), x_neg_mask_even);
+ const __m128i y_product_even = _mm_sub_epi64(_mm_xor_si128(y_abs_product_even, y_neg_mask_even), y_neg_mask_even);
+ const __m128i z_product_even = _mm_sub_epi64(_mm_xor_si128(z_abs_product_even, z_neg_mask_even), z_neg_mask_even);
+ const __m128i w_product_even = _mm_sub_epi64(_mm_xor_si128(w_abs_product_even, w_neg_mask_even), w_neg_mask_even);
+
+ const __m128i x_rounded_product_even = _mm_add_epi64(x_product_even, vq31rounding);
+ const __m128i y_rounded_product_even = _mm_add_epi64(y_product_even, vq31rounding);
+ const __m128i z_rounded_product_even = _mm_add_epi64(z_product_even, vq31rounding);
+ const __m128i w_rounded_product_even = _mm_add_epi64(w_product_even, vq31rounding);
+
+ const __m128i x_abs_product_odd = _mm_mul_epu32(x_abs_rev, vmultiplier);
+ const __m128i y_abs_product_odd = _mm_mul_epu32(y_abs_rev, vmultiplier);
+ const __m128i z_abs_product_odd = _mm_mul_epu32(z_abs_rev, vmultiplier);
+ const __m128i w_abs_product_odd = _mm_mul_epu32(w_abs_rev, vmultiplier);
+
+ const __m128i x_neg_mask_odd = _mm_shuffle_epi32(x_neg_mask, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i y_neg_mask_odd = _mm_shuffle_epi32(y_neg_mask, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i z_neg_mask_odd = _mm_shuffle_epi32(z_neg_mask, _MM_SHUFFLE(3, 3, 1, 1));
+ const __m128i w_neg_mask_odd = _mm_shuffle_epi32(w_neg_mask, _MM_SHUFFLE(3, 3, 1, 1));
+
+ const __m128i x_product_odd = _mm_sub_epi64(_mm_xor_si128(x_abs_product_odd, x_neg_mask_odd), x_neg_mask_odd);
+ const __m128i y_product_odd = _mm_sub_epi64(_mm_xor_si128(y_abs_product_odd, y_neg_mask_odd), y_neg_mask_odd);
+ const __m128i z_product_odd = _mm_sub_epi64(_mm_xor_si128(z_abs_product_odd, z_neg_mask_odd), z_neg_mask_odd);
+ const __m128i w_product_odd = _mm_sub_epi64(_mm_xor_si128(w_abs_product_odd, w_neg_mask_odd), w_neg_mask_odd);
+
+ const __m128i x_rounded_product_odd = _mm_add_epi64(x_product_odd, vq31rounding);
+ const __m128i y_rounded_product_odd = _mm_add_epi64(y_product_odd, vq31rounding);
+ const __m128i z_rounded_product_odd = _mm_add_epi64(z_product_odd, vq31rounding);
+ const __m128i w_rounded_product_odd = _mm_add_epi64(w_product_odd, vq31rounding);
+
+ const __m128i x_q31product_even = _mm_srli_epi64(x_rounded_product_even, 31);
+ const __m128i x_q31product_odd = _mm_srli_epi64(x_rounded_product_odd, 31);
+ const __m128i y_q31product_even = _mm_srli_epi64(y_rounded_product_even, 31);
+ const __m128i y_q31product_odd = _mm_srli_epi64(y_rounded_product_odd, 31);
+ const __m128i z_q31product_even = _mm_srli_epi64(z_rounded_product_even, 31);
+ const __m128i z_q31product_odd = _mm_srli_epi64(z_rounded_product_odd, 31);
+ const __m128i w_q31product_even = _mm_srli_epi64(w_rounded_product_even, 31);
+ const __m128i w_q31product_odd = _mm_srli_epi64(w_rounded_product_odd, 31);
+
+ const __m128i x_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(x_q31product_even), _mm_castsi128_ps(x_q31product_odd), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i y_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(y_q31product_even), _mm_castsi128_ps(y_q31product_odd), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i z_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(z_q31product_even), _mm_castsi128_ps(z_q31product_odd), _MM_SHUFFLE(2, 0, 2, 0)));
+ const __m128i w_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps(
+ _mm_castsi128_ps(w_q31product_even), _mm_castsi128_ps(w_q31product_odd), _MM_SHUFFLE(2, 0, 2, 0)));
+
+ const __m128i x_q31product = _mm_shuffle_epi32(x_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i y_q31product = _mm_shuffle_epi32(y_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i z_q31product = _mm_shuffle_epi32(z_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0));
+ const __m128i w_q31product = _mm_shuffle_epi32(w_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0));
+
+ const __m128i x_remainder =
+ _mm_add_epi32(_mm_and_si128(x_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), x_q31product));
+ const __m128i y_remainder =
+ _mm_add_epi32(_mm_and_si128(y_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), y_q31product));
+ const __m128i z_remainder =
+ _mm_add_epi32(_mm_and_si128(z_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), z_q31product));
+ const __m128i w_remainder =
+ _mm_add_epi32(_mm_and_si128(w_q31product, vremainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), w_q31product));
+
+ const __m128i x_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(x_q31product, vshift), _mm_cmpgt_epi32(x_remainder, vthreshold));
+ const __m128i y_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(y_q31product, vshift), _mm_cmpgt_epi32(y_remainder, vthreshold));
+ const __m128i z_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(z_q31product, vshift), _mm_cmpgt_epi32(z_remainder, vthreshold));
+ const __m128i w_scaled =
+ _mm_sub_epi32(_mm_sra_epi32(w_q31product, vshift), _mm_cmpgt_epi32(w_remainder, vthreshold));
+
+ const __m128i xy_packed = _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point);
+ const __m128i zw_packed = _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point);
+ const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed);
+ const __m128i xyzw_clamped = _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin);
+
+ /*
+ * 16x PSHUFD
+ * 4x SHUFPS
+ * 8x PMULUDQ
+ * 8x PXOR (setzero)
+ * 8x PXOR
+ * 4x PAND
+ * 8x PADDQ
+ * 4x PADDD
+ * 2x PADDW
+ * 8x PSUBQ
+ * 4x PSUBD
+ * 8x PSRLQ (immediate)
+ * 4x PSRAD (register)
+ * 12x PCMPGTD
+ * 4x PABSD
+ * 2x PACKSSDW
+ * 1x PACKUSWB
+ * 1x PMAXUB
+ * 1x PMINUB
+ * ---------------------
+ * 107 instructions total
+ */
+
+ _mm_storeu_si128((__m128i*) output, xyzw_clamped);
+ output += 16;
+ }
+}
diff --git a/src/sigmoid.c b/src/sigmoid.c
new file mode 100644
index 0000000..f11f87a
--- /dev/null
+++ b/src/sigmoid.c
@@ -0,0 +1,212 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+
+
+enum xnn_status xnn_create_sigmoid_nc_q8(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ uint8_t input_zero_point,
+ float input_scale,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint8_t output_min,
+ uint8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* sigmoid_op_out)
+{
+ xnn_operator_t sigmoid_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Sigmoid operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with %zu channels: number of channels must be non-zero", channels);
+ goto error;
+ }
+
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_stride, channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_stride, channels);
+ goto error;
+ }
+
+ if (input_scale <= 0.0f || !isnormal(input_scale)) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with %.7g input scale: scale must be finite, normalized, and positive",
+ input_scale);
+ goto error;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with %.7g output scale: scale must be finite, normalized, and positive",
+ output_scale);
+ goto error;
+ }
+
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
+ output_min, output_max);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ if (output_scale != 0x1.0p-8f) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with %.7g output scale: only output scale of 1/256 is supported",
+ output_scale);
+ goto error;
+ }
+
+ if (output_zero_point != 0) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with %" PRIu8 " output zero point: only output zero point of 0 is supported",
+ output_zero_point);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ sigmoid_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (sigmoid_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Sigmoid operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ sigmoid_op->lookup_table = xnn_allocate_memory(256 * sizeof(uint8_t));
+ if (sigmoid_op->lookup_table == NULL) {
+ xnn_log_error("failed to allocate 256 bytes for Sigmoid lookup table");
+ goto error;
+ }
+
+ uint8_t* lookup_table = sigmoid_op->lookup_table;
+ const float scaled_min = (float) (int32_t) output_min;
+ const float scaled_max = (float) (int32_t) output_max;
+ for (int32_t i = 0; i < 256; i++) {
+ const float x = input_scale * (float) (i - (int32_t) (uint32_t) input_zero_point);
+ // Scale sigmoid(x) by 1 / output scale = 256.0
+ float scaled_sigmoid_x = 256.0f / (1.0f + expf(-x));
+ if (scaled_sigmoid_x < scaled_min) {
+ scaled_sigmoid_x = scaled_min;
+ }
+ if (scaled_sigmoid_x > scaled_max) {
+ scaled_sigmoid_x = scaled_max;
+ }
+ lookup_table[(uint32_t) i] = (uint8_t) lrintf(scaled_sigmoid_x);
+ }
+
+ sigmoid_op->channels = channels;
+ sigmoid_op->input_pixel_stride = input_stride;
+ sigmoid_op->output_pixel_stride = output_stride;
+
+ sigmoid_op->type = xnn_operator_type_sigmoid_q8;
+ sigmoid_op->ukernel.type = xnn_ukernel_type_lut;
+
+ sigmoid_op->state = xnn_run_state_invalid;
+
+ *sigmoid_op_out = sigmoid_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(sigmoid_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_sigmoid_nc_q8(
+ xnn_operator_t sigmoid_op,
+ size_t batch_size,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (sigmoid_op->type != xnn_operator_type_sigmoid_q8) {
+ xnn_log_error("failed to setup Sigmoid (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ sigmoid_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Sigmoid operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ sigmoid_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ sigmoid_op->batch_size = batch_size;
+ sigmoid_op->input = input;
+ sigmoid_op->output = output;
+
+ const size_t channels = sigmoid_op->channels;
+ const size_t input_stride = sigmoid_op->input_pixel_stride;
+ const size_t output_stride = sigmoid_op->output_pixel_stride;
+ if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
+ const size_t block_size = 1024;
+ sigmoid_op->context.lut_contiguous = (struct lut_contiguous_context) {
+ .x = input,
+ .x_stride = input_stride * sizeof(uint8_t),
+ .t = sigmoid_op->lookup_table,
+ .y = output,
+ .y_stride = output_stride * sizeof(uint8_t),
+ .ukernel = xnn_params.x8.lut,
+ };
+ sigmoid_op->compute.type = xnn_parallelization_type_1d_tile_1d;
+ sigmoid_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_lut_contiguous;
+ sigmoid_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
+ sigmoid_op->compute.tile[0] = block_size;
+ } else {
+ sigmoid_op->context.lut_strided = (struct lut_strided_context) {
+ .n = channels,
+ .x = input,
+ .x_stride = input_stride * sizeof(uint8_t),
+ .t = sigmoid_op->lookup_table,
+ .y = output,
+ .y_stride = output_stride * sizeof(uint8_t),
+ .ukernel = xnn_params.x8.lut,
+ };
+ sigmoid_op->compute.type = xnn_parallelization_type_1d;
+ sigmoid_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_lut_strided;
+ sigmoid_op->compute.range[0] = batch_size;
+ sigmoid_op->compute.tile[0] = 0;
+ }
+ sigmoid_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/softargmax.c b/src/softargmax.c
new file mode 100644
index 0000000..5228b0d
--- /dev/null
+++ b/src/softargmax.c
@@ -0,0 +1,175 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+
+
+enum xnn_status xnn_create_softargmax_nc_q8(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ float input_scale,
+ uint8_t output_zero_point,
+ float output_scale,
+ uint32_t flags,
+ xnn_operator_t* softargmax_op_out)
+{
+ xnn_operator_t softargmax_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create SoftArgMax operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create SoftArgMax operator with %zu channels: number of channels must be non-zero", channels);
+ goto error;
+ }
+
+ if (input_stride < channels) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with input element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_stride, channels);
+ goto error;
+ }
+
+ if (output_stride < channels) {
+ xnn_log_error(
+ "failed to create Sigmoid operator with output element stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_stride, channels);
+ goto error;
+ }
+
+ if (input_scale <= 0.0f || !isnormal(input_scale)) {
+ xnn_log_error(
+ "failed to create SoftArgMax operator with %.7g input scale: scale must be finite, normalized, and positive",
+ input_scale);
+ goto error;
+ }
+
+ if (output_scale <= 0.0f || !isnormal(output_scale)) {
+ xnn_log_error(
+ "failed to create SoftArgMax operator with %.7g output scale: scale must be finite, normalized, and positive",
+ output_scale);
+ goto error;
+ }
+
+ status = xnn_status_unsupported_parameter;
+
+ if (output_scale != 0x1.0p-8f) {
+ xnn_log_error(
+ "failed to create SoftArgMax operator with %.7g output scale: only output scale of 1/256 is supported",
+ output_scale);
+ goto error;
+ }
+
+ if (output_zero_point != 0) {
+ xnn_log_error(
+ "failed to create SoftArgMax operator with %" PRIu8 " output zero point: "
+ "only output zero point of 0 is supported",
+ output_zero_point);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ softargmax_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (softargmax_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for SoftArgMax operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ softargmax_op->lookup_table = xnn_allocate_memory(256 * sizeof(uint32_t));
+ if (softargmax_op->lookup_table == NULL) {
+ xnn_log_error("failed to allocate 256 bytes for SoftArgMax lookup table");
+ goto error;
+ }
+
+ uint32_t* lookup_table = softargmax_op->lookup_table;
+ const double qscale = fmin(((double) UINT32_MAX) / (double) channels, 8388607.0);
+ for (int32_t i = 0; i < 256; i++) {
+ const double scaled_exp_xi = qscale * exp((double) (i - 255) * (double) input_scale);
+ lookup_table[(uint32_t) i] = (uint32_t) lrint(scaled_exp_xi);
+ }
+
+ softargmax_op->channels = channels;
+ softargmax_op->input_pixel_stride = input_stride;
+ softargmax_op->output_pixel_stride = output_stride;
+
+ softargmax_op->type = xnn_operator_type_softargmax_q8;
+ softargmax_op->ukernel.type = xnn_ukernel_type_softargmax;
+
+ softargmax_op->state = xnn_run_state_invalid;
+
+ *softargmax_op_out = softargmax_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(softargmax_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_softargmax_nc_q8(
+ xnn_operator_t softargmax_op,
+ size_t batch_size,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (softargmax_op->type != xnn_operator_type_softargmax_q8) {
+ xnn_log_error("failed to setup SoftArgMax (Q8) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ softargmax_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup SoftArgMax operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (batch_size == 0) {
+ softargmax_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ softargmax_op->batch_size = batch_size;
+ softargmax_op->input = input;
+ softargmax_op->output = output;
+
+ softargmax_op->context.u8_softargmax = (struct u8_softargmax_context) {
+ .n = softargmax_op->channels,
+ .x = input,
+ .x_stride = softargmax_op->input_pixel_stride * sizeof(uint8_t),
+ .t = softargmax_op->lookup_table,
+ .y = output,
+ .y_stride = softargmax_op->output_pixel_stride * sizeof(uint8_t),
+ .rmax_ukernel = xnn_params.u8.rmax,
+ .lut_norm_ukernel = xnn_params.u8.lut32norm,
+ };
+ softargmax_op->compute.type = xnn_parallelization_type_1d;
+ softargmax_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_u8_softargmax;
+ softargmax_op->compute.range[0] = batch_size;
+ softargmax_op->state = xnn_run_state_ready;
+
+ return xnn_status_success;
+}
diff --git a/src/u8-clamp/neon.c b/src/u8-clamp/neon.c
new file mode 100644
index 0000000..04725f9
--- /dev/null
+++ b/src/u8-clamp/neon.c
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/clamp.h>
+
+
+void xnn_u8_clamp_ukernel__neon(
+ size_t n,
+ const uint8_t* x,
+ uint8_t* y,
+ const union xnn_u8_output_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.max);
+ const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.min);
+
+ for (; n >= 64; n -= 64) {
+ const uint8x16_t vx0 = vld1q_u8(x); x += 16;
+ const uint8x16_t vx1 = vld1q_u8(x); x += 16;
+ const uint8x16_t vx2 = vld1q_u8(x); x += 16;
+ const uint8x16_t vx3 = vld1q_u8(x); x += 16;
+
+ const uint8x16_t vy0 = vminq_u8(vmaxq_u8(vx0, voutput_min), voutput_max);
+ const uint8x16_t vy1 = vminq_u8(vmaxq_u8(vx1, voutput_min), voutput_max);
+ const uint8x16_t vy2 = vminq_u8(vmaxq_u8(vx2, voutput_min), voutput_max);
+ const uint8x16_t vy3 = vminq_u8(vmaxq_u8(vx3, voutput_min), voutput_max);
+
+ __builtin_prefetch(x + 640);
+
+ vst1q_u8(y, vy0); y += 16;
+ vst1q_u8(y, vy1); y += 16;
+ vst1q_u8(y, vy2); y += 16;
+ vst1q_u8(y, vy3); y += 16;
+ }
+ for (; n >= 8; n -= 8) {
+ uint8x8_t vout = vld1_u8(x); x += 8;
+ vout = vmin_u8(vout, vget_low_u8(voutput_max));
+ vout = vmax_u8(vout, vget_low_u8(voutput_min));
+ vst1_u8(y, vout); y += 8;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ uint8x8_t vout = vld1_u8(x);
+ vout = vmin_u8(vout, vget_low_u8(voutput_max));
+ vout = vmax_u8(vout, vget_low_u8(voutput_min));
+
+ if (n & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(y, 1), vreinterpret_u32_u8(vout), 0); y += 4;
+ vout = vext_u8(vout, vout, 4);
+ }
+ if (n & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(y, 1), vreinterpret_u16_u8(vout), 0); y += 2;
+ vout = vext_u8(vout, vout, 2);
+ }
+ if (n & 1) {
+ vst1_lane_u8(y, vout, 0);
+ }
+ }
+}
diff --git a/src/u8-clamp/scalar.c b/src/u8-clamp/scalar.c
new file mode 100644
index 0000000..6b513b0
--- /dev/null
+++ b/src/u8-clamp/scalar.c
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/clamp.h>
+
+
+void xnn_u8_clamp_ukernel__scalar(
+ size_t n,
+ const uint8_t* x,
+ uint8_t* y,
+ const union xnn_u8_output_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const uint8_t voutput_max = params->scalar.max;
+ const uint8_t voutput_min = params->scalar.min;
+
+ for (; n >= 4 * sizeof(uint8_t); n -= 4 * sizeof(uint8_t)) {
+ uint8_t vt0 = x[0];
+ uint8_t vt1 = x[1];
+ uint8_t vt2 = x[2];
+ uint8_t vt3 = x[3];
+ x += 4;
+
+ vt0 = vt0 < voutput_min ? voutput_min : vt0;
+ vt1 = vt1 < voutput_min ? voutput_min : vt1;
+ vt2 = vt2 < voutput_min ? voutput_min : vt2;
+ vt3 = vt3 < voutput_min ? voutput_min : vt3;
+
+ vt0 = vt0 > voutput_max ? voutput_max : vt0;
+ vt1 = vt1 > voutput_max ? voutput_max : vt1;
+ vt2 = vt2 > voutput_max ? voutput_max : vt2;
+ vt3 = vt3 > voutput_max ? voutput_max : vt3;
+
+ y[0] = vt0;
+ y[1] = vt1;
+ y[2] = vt2;
+ y[3] = vt3;
+ y += 4;
+ }
+
+ if XNN_UNLIKELY(n != 0) {
+ do {
+ uint8_t vt = *x++;
+ vt = vt < voutput_min ? voutput_min : vt;
+ vt = vt > voutput_max ? voutput_max : vt;
+ *y++ = vt;
+
+ n -= sizeof(uint8_t);
+ } while (n != 0);
+ }
+}
diff --git a/src/u8-clamp/sse2.c b/src/u8-clamp/sse2.c
new file mode 100644
index 0000000..04179a2
--- /dev/null
+++ b/src/u8-clamp/sse2.c
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/clamp.h>
+
+
+void xnn_u8_clamp_ukernel__sse2(
+ size_t n,
+ const uint8_t* x,
+ uint8_t* y,
+ const union xnn_u8_output_params params[restrict static 1])
+{
+ assert(n != 0);
+
+ const __m128i voutput_max = _mm_load_si128((const __m128i*) ¶ms->sse2.max);
+ const __m128i voutput_min = _mm_load_si128((const __m128i*) ¶ms->sse2.min);
+ for (; n >= 64; n -= 64) {
+ const __m128i vx0 = _mm_loadu_si128((const __m128i*) x);
+ const __m128i vx1 = _mm_loadu_si128((const __m128i*) x + 1);
+ const __m128i vx2 = _mm_loadu_si128((const __m128i*) x + 2);
+ const __m128i vx3 = _mm_loadu_si128((const __m128i*) x + 3);
+ x += 64;
+
+ const __m128i vy0 = _mm_min_epu8(_mm_max_epu8(vx0, voutput_min), voutput_max);
+ const __m128i vy1 = _mm_min_epu8(_mm_max_epu8(vx1, voutput_min), voutput_max);
+ const __m128i vy2 = _mm_min_epu8(_mm_max_epu8(vx2, voutput_min), voutput_max);
+ const __m128i vy3 = _mm_min_epu8(_mm_max_epu8(vx3, voutput_min), voutput_max);
+
+ __builtin_prefetch(x + 640);
+
+ _mm_storeu_si128((__m128i*) y, vy0);
+ _mm_storeu_si128((__m128i*) y + 1, vy1);
+ _mm_storeu_si128((__m128i*) y + 2, vy2);
+ _mm_storeu_si128((__m128i*) y + 3, vy3);
+ y += 64;
+ }
+ for (; n >= 8; n -= 8) {
+ __m128i vout = _mm_loadl_epi64((const __m128i*) x);
+ x += 8;
+ vout = _mm_min_epu8(vout, voutput_max);
+ vout = _mm_max_epu8(vout, voutput_min);
+ _mm_storel_epi64((__m128i*) y, vout);
+ y += 8;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ __m128i vout = _mm_loadl_epi64((const __m128i*) x);
+ vout = _mm_min_epu8(vout, voutput_max);
+ vout = _mm_max_epu8(vout, voutput_min);
+ if (n & 4) {
+ *((uint32_t*) y) = (uint32_t) _mm_cvtsi128_si32(vout);
+ y += 4;
+ vout = _mm_srli_epi64(vout, 32);
+ }
+ if (n & 2) {
+ *((uint16_t*) y) = (uint16_t) _mm_extract_epi16(vout, 0);
+ y += 2;
+ vout = _mm_srli_epi32(vout, 16);
+ }
+ if (n & 1) {
+ *((uint8_t*) y) = (uint8_t) _mm_cvtsi128_si32(vout);
+ }
+ }
+}
diff --git a/src/u8-lut32norm/scalar.c b/src/u8-lut32norm/scalar.c
new file mode 100644
index 0000000..3e54121
--- /dev/null
+++ b/src/u8-lut32norm/scalar.c
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <fxdiv.h>
+
+#include <xnnpack/lut.h>
+
+
+static inline uint32_t compute_sum(
+ size_t n,
+ const uint8_t* x,
+ const uint32_t* t)
+{
+ assert(n != 0);
+
+ uint32_t vsum = 0;
+ do {
+ const size_t vx = *x++;
+ vsum += t[vx];
+ } while (--n != 0);
+ return vsum;
+}
+
+void xnn_u8_lut32norm_ukernel__scalar(
+ size_t n,
+ const uint8_t* x,
+ const uint32_t* t,
+ uint8_t* y)
+{
+ assert(n != 0);
+
+ const uint32_t vsum = compute_sum(n, x, t);
+ assert(vsum != 0);
+
+ struct fxdiv_divisor_uint32_t vsum_divisor = fxdiv_init_uint32_t(vsum);
+ const uint32_t vrounding = (vsum >> 1);
+ do {
+ const size_t vx = *x++;
+ const uint32_t vt = t[vx];
+ const uint32_t vq = fxdiv_quotient_uint32_t((vt << 8) + vrounding, vsum_divisor);
+ const uint8_t vy = vq > 255 ? UINT8_C(255) : (uint8_t) vq;
+ *y++ = vy;
+ } while (--n != 0);
+}
diff --git a/src/u8-maxpool/9p8q-neon.c b/src/u8-maxpool/9p8q-neon.c
new file mode 100644
index 0000000..49b747e
--- /dev/null
+++ b/src/u8-maxpool/9p8q-neon.c
@@ -0,0 +1,233 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/maxpool.h>
+
+
+void xnn_u8_maxpool_ukernel_9p8q__neon(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** input,
+ uint8_t* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_u8_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(kc != 0);
+
+ const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.max);
+ const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.min);
+ do {
+ uint8_t* o = output;
+ {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+ const uint8_t* i8 = *input++;
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks < 4) {
+ i3 = i0;
+ }
+ if (ks <= 4) {
+ i4 = i0;
+ }
+ if (ks < 6) {
+ i5 = i0;
+ }
+ if (ks <= 6) {
+ i6 = i0;
+ }
+ if (ks < 8) {
+ i7 = i0;
+ }
+ if (ks <= 8) {
+ i8 = i0;
+ }
+
+ size_t k = kc;
+ for (; k >= 16; k -= 16) {
+ const uint8x16_t vi0 = vld1q_u8(i0); i0 += 16;
+ const uint8x16_t vi1 = vld1q_u8(i1); i1 += 16;
+ const uint8x16_t vi2 = vld1q_u8(i2); i2 += 16;
+ const uint8x16_t vi3 = vld1q_u8(i3); i3 += 16;
+ const uint8x16_t vi4 = vld1q_u8(i4); i4 += 16;
+ const uint8x16_t vi5 = vld1q_u8(i5); i5 += 16;
+ const uint8x16_t vi6 = vld1q_u8(i6); i6 += 16;
+ const uint8x16_t vi7 = vld1q_u8(i7); i7 += 16;
+ const uint8x16_t vi8 = vld1q_u8(i8); i8 += 16;
+
+ const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
+ const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
+ const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
+ const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
+
+ const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
+ const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
+ const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
+ const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
+
+ vst1q_u8(o, vout); o += 16;
+ }
+ if (k != 0) {
+ const uint8x16_t vi0 = vld1q_u8(i0);
+ const uint8x16_t vi1 = vld1q_u8(i1);
+ const uint8x16_t vi2 = vld1q_u8(i2);
+ const uint8x16_t vi3 = vld1q_u8(i3);
+ const uint8x16_t vi4 = vld1q_u8(i4);
+ const uint8x16_t vi5 = vld1q_u8(i5);
+ const uint8x16_t vi6 = vld1q_u8(i6);
+ const uint8x16_t vi7 = vld1q_u8(i7);
+ const uint8x16_t vi8 = vld1q_u8(i8);
+
+ const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
+ const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
+ const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
+ const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
+
+ const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
+ const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
+ const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
+ const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
+
+ uint8x8_t vout_lo = vget_low_u8(vout);
+ if (k & 8) {
+ vst1_u8(o, vout_lo); o += 8;
+ vout_lo = vget_high_u8(vout);
+ }
+ if (k & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(o, 1), vreinterpret_u32_u8(vout_lo), 0); o += 4;
+ vout_lo = vext_u8(vout_lo, vout_lo, 4);
+ }
+ if (k & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(o, 1), vreinterpret_u16_u8(vout_lo), 0); o += 2;
+ vout_lo = vext_u8(vout_lo, vout_lo, 2);
+ }
+ if (k & 1) {
+ vst1_lane_u8(o, vout_lo, 0); o += 1;
+ }
+ }
+ }
+
+ for (ptrdiff_t m = (ptrdiff_t) ks - 9; m > 0; m -= 8) {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+ if (m < 2) {
+ i1 = i0;
+ }
+ if (m <= 2) {
+ i2 = i0;
+ }
+ if (m < 4) {
+ i3 = i0;
+ }
+ if (m <= 4) {
+ i4 = i0;
+ }
+ if (m < 6) {
+ i5 = i0;
+ }
+ if (m <= 6) {
+ i6 = i0;
+ }
+ if (m < 8) {
+ i7 = i0;
+ }
+
+ o = output;
+ size_t k = kc;
+ for (; k >= 16; k -= 16) {
+ const uint8x16_t vi0 = vld1q_u8(i0); i0 += 16;
+ const uint8x16_t vi1 = vld1q_u8(i1); i1 += 16;
+ const uint8x16_t vi2 = vld1q_u8(i2); i2 += 16;
+ const uint8x16_t vi3 = vld1q_u8(i3); i3 += 16;
+ const uint8x16_t vi4 = vld1q_u8(i4); i4 += 16;
+ const uint8x16_t vi5 = vld1q_u8(i5); i5 += 16;
+ const uint8x16_t vi6 = vld1q_u8(i6); i6 += 16;
+ const uint8x16_t vi7 = vld1q_u8(i7); i7 += 16;
+ const uint8x16_t vo = vld1q_u8(o);
+
+ const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
+ const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
+ const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
+ const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
+
+ const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
+ const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
+ const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
+ const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
+
+ vst1q_u8(o, vout); o += 16;
+ }
+ if (k != 0) {
+ const uint8x16_t vi0 = vld1q_u8(i0);
+ const uint8x16_t vi1 = vld1q_u8(i1);
+ const uint8x16_t vi2 = vld1q_u8(i2);
+ const uint8x16_t vi3 = vld1q_u8(i3);
+ const uint8x16_t vi4 = vld1q_u8(i4);
+ const uint8x16_t vi5 = vld1q_u8(i5);
+ const uint8x16_t vi6 = vld1q_u8(i6);
+ const uint8x16_t vi7 = vld1q_u8(i7);
+ const uint8x16_t vo = vld1q_u8(o);
+
+ const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
+ const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
+ const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
+ const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
+
+ const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
+ const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
+ const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
+ const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
+
+ uint8x8_t vout_lo = vget_low_u8(vout);
+ if (k & 8) {
+ vst1_u8(o, vout_lo); o += 8;
+ vout_lo = vget_high_u8(vout);
+ }
+ if (k & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(o, 1), vreinterpret_u32_u8(vout_lo), 0); o += 4;
+ vout_lo = vext_u8(vout_lo, vout_lo, 4);
+ }
+ if (k & 2) {
+ vst1_lane_u16(__builtin_assume_aligned(o, 1), vreinterpret_u16_u8(vout_lo), 0); o += 2;
+ vout_lo = vext_u8(vout_lo, vout_lo, 2);
+ }
+ if (k & 1) {
+ vst1_lane_u8(o, vout_lo, 0); o += 1;
+ }
+ }
+ }
+ input = (const uint8_t**) ((uintptr_t) input + input_increment);
+ output = (uint8_t*) ((uintptr_t) o + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/u8-maxpool/9p8q-scalar.c b/src/u8-maxpool/9p8q-scalar.c
new file mode 100644
index 0000000..7aa20a5
--- /dev/null
+++ b/src/u8-maxpool/9p8q-scalar.c
@@ -0,0 +1,158 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/maxpool.h>
+
+
+void xnn_u8_maxpool_ukernel_9p8q__scalar(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** input,
+ uint8_t* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_u8_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(kc != 0);
+
+ const uint8_t voutput_max = params->scalar.max;
+ const uint8_t voutput_min = params->scalar.min;
+ do {
+ uint8_t* o = output;
+ {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+ const uint8_t* i8 = *input++;
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks < 4) {
+ i3 = i0;
+ }
+ if (ks <= 4) {
+ i4 = i0;
+ }
+ if (ks < 6) {
+ i5 = i0;
+ }
+ if (ks <= 6) {
+ i6 = i0;
+ }
+ if (ks < 8) {
+ i7 = i0;
+ }
+ if (ks <= 8) {
+ i8 = i0;
+ }
+
+ size_t k = kc;
+ do {
+ const uint8_t vi0 = *i0++;
+ const uint8_t vi1 = *i1++;
+ const uint8_t vi2 = *i2++;
+ const uint8_t vi3 = *i3++;
+ const uint8_t vi4 = *i4++;
+ const uint8_t vi5 = *i5++;
+ const uint8_t vi6 = *i6++;
+ const uint8_t vi7 = *i7++;
+ const uint8_t vi8 = *i8++;
+
+ const uint8_t vmax01 = vi0 > vi1 ? vi0 : vi1;
+ const uint8_t vmax23 = vi2 > vi3 ? vi2 : vi3;
+ const uint8_t vmax45 = vi4 > vi5 ? vi4 : vi5;
+ const uint8_t vmax67 = vi6 > vi7 ? vi6 : vi7;
+ const uint8_t vmax018 = vmax01 > vi8 ? vmax01 : vi8;
+
+ const uint8_t vmax2345 = vmax23 > vmax45 ? vmax23 : vmax45;
+ const uint8_t vmax01678 = vmax018 > vmax67 ? vmax018 : vmax67;
+
+ uint8_t vout = vmax2345 > vmax01678 ? vmax2345 : vmax01678;
+ vout = vout > voutput_max ? voutput_max : vout;
+ vout = vout < voutput_min ? voutput_min : vout;
+
+ *o++ = vout;
+ } while (--k != 0);
+ }
+
+ for (ptrdiff_t m = (ptrdiff_t) ks - 9; m > 0; m -= 8) {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+ if (m < 2) {
+ i1 = i0;
+ }
+ if (m <= 2) {
+ i2 = i0;
+ }
+ if (m < 4) {
+ i3 = i0;
+ }
+ if (m <= 4) {
+ i4 = i0;
+ }
+ if (m < 6) {
+ i5 = i0;
+ }
+ if (m <= 6) {
+ i6 = i0;
+ }
+ if (m < 8) {
+ i7 = i0;
+ }
+
+ o = output;
+ size_t k = kc;
+ do {
+ const uint8_t vi0 = *i0++;
+ const uint8_t vi1 = *i1++;
+ const uint8_t vi2 = *i2++;
+ const uint8_t vi3 = *i3++;
+ const uint8_t vi4 = *i4++;
+ const uint8_t vi5 = *i5++;
+ const uint8_t vi6 = *i6++;
+ const uint8_t vi7 = *i7++;
+ const uint8_t vi8 = *o;
+
+ const uint8_t vmax01 = vi0 > vi1 ? vi0 : vi1;
+ const uint8_t vmax23 = vi2 > vi3 ? vi2 : vi3;
+ const uint8_t vmax45 = vi4 > vi5 ? vi4 : vi5;
+ const uint8_t vmax67 = vi6 > vi7 ? vi6 : vi7;
+ const uint8_t vmax018 = vmax01 > vi8 ? vmax01 : vi8;
+
+ const uint8_t vmax2345 = vmax23 > vmax45 ? vmax23 : vmax45;
+ const uint8_t vmax01678 = vmax018 > vmax67 ? vmax018 : vmax67;
+
+ uint8_t vout = vmax2345 > vmax01678 ? vmax2345 : vmax01678;
+ vout = vout > voutput_max ? voutput_max : vout;
+ vout = vout < voutput_min ? voutput_min : vout;
+
+ *o++ = vout;
+ } while (--k != 0);
+ }
+ input = (const uint8_t**) ((uintptr_t) input + input_increment);
+ output = (uint8_t*) ((uintptr_t) o + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/u8-maxpool/9p8q-sse2.c b/src/u8-maxpool/9p8q-sse2.c
new file mode 100644
index 0000000..02e84eb
--- /dev/null
+++ b/src/u8-maxpool/9p8q-sse2.c
@@ -0,0 +1,241 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/maxpool.h>
+
+
+void xnn_u8_maxpool_ukernel_9p8q__sse2(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** input,
+ uint8_t* output,
+ size_t input_increment,
+ size_t output_increment,
+ const union xnn_u8_output_params params[restrict static 1])
+{
+ assert(n != 0);
+ assert(ks != 0);
+ assert(kc != 0);
+
+ const __m128i voutput_max = _mm_load_si128((const __m128i*) params->sse2.max);
+ const __m128i voutput_min = _mm_load_si128((const __m128i*) params->sse2.min);
+
+ do {
+ uint8_t* o = output;
+ {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+ const uint8_t* i8 = *input++;
+ if (ks < 2) {
+ i1 = i0;
+ }
+ if (ks <= 2) {
+ i2 = i0;
+ }
+ if (ks < 4) {
+ i3 = i0;
+ }
+ if (ks <= 4) {
+ i4 = i0;
+ }
+ if (ks < 6) {
+ i5 = i0;
+ }
+ if (ks <= 6) {
+ i6 = i0;
+ }
+ if (ks < 8) {
+ i7 = i0;
+ }
+ if (ks <= 8) {
+ i8 = i0;
+ }
+
+ size_t k = kc;
+ for (; k >= 16; k -= 16) {
+ const __m128i vi0 = _mm_loadu_si128((const __m128i*) i0); i0 += 16;
+ const __m128i vi1 = _mm_loadu_si128((const __m128i*) i1); i1 += 16;
+ const __m128i vi2 = _mm_loadu_si128((const __m128i*) i2); i2 += 16;
+ const __m128i vi3 = _mm_loadu_si128((const __m128i*) i3); i3 += 16;
+ const __m128i vi4 = _mm_loadu_si128((const __m128i*) i4); i4 += 16;
+ const __m128i vi5 = _mm_loadu_si128((const __m128i*) i5); i5 += 16;
+ const __m128i vi6 = _mm_loadu_si128((const __m128i*) i6); i6 += 16;
+ const __m128i vi7 = _mm_loadu_si128((const __m128i*) i7); i7 += 16;
+ const __m128i vi8 = _mm_loadu_si128((const __m128i*) i8); i8 += 16;
+
+ const __m128i vmax018 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vi8);
+ const __m128i vmax23 = _mm_max_epu8(vi2, vi3);
+ const __m128i vmax45 = _mm_max_epu8(vi4, vi5);
+ const __m128i vmax67 = _mm_max_epu8(vi6, vi7);
+
+ const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45);
+ const __m128i vmax01678 = _mm_max_epu8(vmax018, vmax67);
+ const __m128i vmax = _mm_max_epu8(vmax2345, vmax01678);
+ const __m128i vout = _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min);
+
+ _mm_storeu_si128((__m128i*) o, vout); o += 16;
+ }
+ if (k != 0) {
+ const __m128i vi0 = _mm_loadu_si128((const __m128i*) i0);
+ const __m128i vi1 = _mm_loadu_si128((const __m128i*) i1);
+ const __m128i vi2 = _mm_loadu_si128((const __m128i*) i2);
+ const __m128i vi3 = _mm_loadu_si128((const __m128i*) i3);
+ const __m128i vi4 = _mm_loadu_si128((const __m128i*) i4);
+ const __m128i vi5 = _mm_loadu_si128((const __m128i*) i5);
+ const __m128i vi6 = _mm_loadu_si128((const __m128i*) i6);
+ const __m128i vi7 = _mm_loadu_si128((const __m128i*) i7);
+ const __m128i vi8 = _mm_loadu_si128((const __m128i*) i8);
+
+ const __m128i vmax018 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vi8);
+ const __m128i vmax23 = _mm_max_epu8(vi2, vi3);
+ const __m128i vmax45 = _mm_max_epu8(vi4, vi5);
+ const __m128i vmax67 = _mm_max_epu8(vi6, vi7);
+
+ const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45);
+ const __m128i vmax01678 = _mm_max_epu8(vmax018, vmax67);
+ const __m128i vmax = _mm_max_epu8(vmax2345, vmax01678);
+ __m128i vout = _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min);
+
+ if (k & 8) {
+ _mm_storel_epi64((__m128i*) o, vout);
+ vout = _mm_unpackhi_epi64(vout, vout);
+ o += 8;
+ }
+ if (k & 4) {
+ *((uint32_t*) o) = (uint32_t) _mm_cvtsi128_si32(vout);
+ vout = _mm_srli_epi64(vout, 32);
+ o += 4;
+ }
+ if (k & 2) {
+ *((uint16_t*) o) = (uint16_t) _mm_extract_epi16(vout, 0);
+ vout = _mm_srli_epi32(vout, 16);
+ o += 2;
+ }
+ if (k & 1) {
+ *((uint8_t*) o) = (uint8_t) _mm_cvtsi128_si32(vout);
+ o += 1;
+ }
+ }
+ }
+
+ for (ptrdiff_t m = (ptrdiff_t) ks - 9; m > 0; m -= 8) {
+ const uint8_t* i0 = *input++;
+ const uint8_t* i1 = *input++;
+ const uint8_t* i2 = *input++;
+ const uint8_t* i3 = *input++;
+ const uint8_t* i4 = *input++;
+ const uint8_t* i5 = *input++;
+ const uint8_t* i6 = *input++;
+ const uint8_t* i7 = *input++;
+ if (m < 2) {
+ i1 = i0;
+ }
+ if (m <= 2) {
+ i2 = i0;
+ }
+ if (m < 4) {
+ i3 = i0;
+ }
+ if (m <= 4) {
+ i4 = i0;
+ }
+ if (m < 6) {
+ i5 = i0;
+ }
+ if (m <= 6) {
+ i6 = i0;
+ }
+ if (m < 8) {
+ i7 = i0;
+ }
+
+ o = output;
+ size_t k = kc;
+ for (; k >= 16; k -= 16) {
+ const __m128i vi0 = _mm_loadu_si128((const __m128i*) i0); i0 += 16;
+ const __m128i vi1 = _mm_loadu_si128((const __m128i*) i1); i1 += 16;
+ const __m128i vi2 = _mm_loadu_si128((const __m128i*) i2); i2 += 16;
+ const __m128i vi3 = _mm_loadu_si128((const __m128i*) i3); i3 += 16;
+ const __m128i vi4 = _mm_loadu_si128((const __m128i*) i4); i4 += 16;
+ const __m128i vi5 = _mm_loadu_si128((const __m128i*) i5); i5 += 16;
+ const __m128i vi6 = _mm_loadu_si128((const __m128i*) i6); i6 += 16;
+ const __m128i vi7 = _mm_loadu_si128((const __m128i*) i7); i7 += 16;
+ const __m128i vo = _mm_loadu_si128((const __m128i*) o);
+
+ const __m128i vmax01 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vo);
+ const __m128i vmax23 = _mm_max_epu8(vi2, vi3);
+ const __m128i vmax45 = _mm_max_epu8(vi4, vi5);
+ const __m128i vmax67 = _mm_max_epu8(vi6, vi7);
+
+ const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45);
+ const __m128i vmax0167 = _mm_max_epu8(vmax01, vmax67);
+ const __m128i vmax = _mm_max_epu8(vmax2345, vmax0167);
+ const __m128i vout = _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min);
+
+ _mm_storeu_si128((__m128i*) o, vout);
+ o += 16;
+ }
+ if (k != 0) {
+ const __m128i vi0 = _mm_loadu_si128((const __m128i*) i0);
+ const __m128i vi1 = _mm_loadu_si128((const __m128i*) i1);
+ const __m128i vi2 = _mm_loadu_si128((const __m128i*) i2);
+ const __m128i vi3 = _mm_loadu_si128((const __m128i*) i3);
+ const __m128i vi4 = _mm_loadu_si128((const __m128i*) i4);
+ const __m128i vi5 = _mm_loadu_si128((const __m128i*) i5);
+ const __m128i vi6 = _mm_loadu_si128((const __m128i*) i6);
+ const __m128i vi7 = _mm_loadu_si128((const __m128i*) i7);
+ const __m128i vo = _mm_loadu_si128((const __m128i*) o);
+
+ const __m128i vmax01 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vo);
+ const __m128i vmax23 = _mm_max_epu8(vi2, vi3);
+ const __m128i vmax45 = _mm_max_epu8(vi4, vi5);
+ const __m128i vmax67 = _mm_max_epu8(vi6, vi7);
+
+ const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45);
+ const __m128i vmax0167 = _mm_max_epu8(vmax01, vmax67);
+ const __m128i vmax = _mm_max_epu8(vmax2345, vmax0167);
+ __m128i vout = _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min);
+
+ if (k & 8) {
+ _mm_storel_epi64((__m128i*) o, vout);
+ vout = _mm_unpackhi_epi64(vout, vout);
+ o += 8;
+ }
+ if (k & 4) {
+ *((uint32_t*) o) = (uint32_t) _mm_cvtsi128_si32(vout);
+ vout = _mm_srli_epi64(vout, 32);
+ o += 4;
+ }
+ if (k & 2) {
+ *((uint16_t*) o) = (uint16_t) _mm_extract_epi16(vout, 0);
+ vout = _mm_srli_epi32(vout, 16);
+ o += 2;
+ }
+ if (k & 1) {
+ *((uint8_t*) o) = (uint8_t) _mm_cvtsi128_si32(vout);
+ o += 1;
+ }
+ }
+ }
+ input = (const uint8_t**) ((uintptr_t) input + input_increment);
+ output = (uint8_t*) ((uintptr_t) o + output_increment);
+ } while (--n != 0);
+}
diff --git a/src/u8-rmax/neon.c b/src/u8-rmax/neon.c
new file mode 100644
index 0000000..12e5bbf
--- /dev/null
+++ b/src/u8-rmax/neon.c
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/rmax.h>
+
+
+void xnn_u8_rmax_ukernel__neon(
+ size_t n,
+ const uint8_t* x,
+ uint8_t* y)
+{
+ assert(n != 0);
+
+ if XNN_LIKELY(n >= 16) {
+ uint8x16_t vmax = vmovq_n_u8(0);
+ do {
+ const uint8x16_t vx = vld1q_u8(x); x += 16;
+ vmax = vmaxq_u8(vmax, vx);
+ n -= 16;
+ } while (n >= 16);
+ if (n != 0) {
+ const size_t x_increment = n - 16;
+ x = (const uint8_t*) ((uintptr_t) x + x_increment);
+ const uint8x16_t vx = vld1q_u8(x);
+ vmax = vmaxq_u8(vmax, vx);
+ }
+ uint8x8_t vmax8 = vmax_u8(vget_low_u8(vmax), vget_high_u8(vmax));
+ const uint8x8_t vmax4 = vpmax_u8(vmax8, vmax8);
+ const uint8x8_t vmax2 = vpmax_u8(vmax4, vmax4);
+ const uint8x8_t vmax1 = vpmax_u8(vmax2, vmax2);
+ vst1_lane_u8(y, vmax1, 0);
+ } else {
+ uint8x8_t vmax = vmov_n_u8(0);
+ do {
+ const uint8x8_t vx = vld1_dup_u8(x); x += 1;
+ vmax = vmax_u8(vmax, vx);
+ } while (--n != 0);
+ vst1_lane_u8(y, vmax, 0);
+ }
+}
diff --git a/src/u8-rmax/scalar.c b/src/u8-rmax/scalar.c
new file mode 100644
index 0000000..a9a3298
--- /dev/null
+++ b/src/u8-rmax/scalar.c
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/rmax.h>
+
+
+void xnn_u8_rmax_ukernel__scalar(
+ size_t n,
+ const uint8_t* x,
+ uint8_t* y)
+{
+ assert(n != 0);
+
+ uint8_t vmax0 = 0;
+ uint8_t vmax1 = 0;
+ for (; n >= 2 * sizeof(uint8_t); n -= 2 * sizeof(uint8_t)) {
+ const uint8_t vt0 = x[0];
+ const uint8_t vt1 = x[1];
+ x += 2;
+
+ vmax0 = vt0 > vmax0 ? vt0 : vmax0;
+ vmax1 = vt1 > vmax1 ? vt1 : vmax1;
+ }
+ uint8_t vmax = vmax0 > vmax1 ? vmax0 : vmax1;
+ if (n != 0) {
+ const uint8_t vt = *x++;
+ vmax = vt > vmax ? vt : vmax;
+ }
+ *y = vmax;
+}
diff --git a/src/u8-rmax/sse2.c b/src/u8-rmax/sse2.c
new file mode 100644
index 0000000..034962d
--- /dev/null
+++ b/src/u8-rmax/sse2.c
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/rmax.h>
+
+
+void xnn_u8_rmax_ukernel__sse2(
+ size_t n,
+ const uint8_t* x,
+ uint8_t* y)
+{
+ assert(n != 0);
+
+ if XNN_LIKELY(n >= 16) {
+ __m128i vmax = _mm_setzero_si128();
+ do {
+ const __m128i vx = _mm_loadu_si128((const __m128i*) x);
+ x += 16;
+ vmax = _mm_max_epu8(vmax, vx);
+ n -= 16;
+ } while (n >= 16);
+ if (n != 0) {
+ const size_t x_increment = n - 16;
+ x = (const uint8_t*) ((uintptr_t) x + x_increment);
+ const __m128i vx = _mm_loadu_si128((const __m128i*) x);
+ vmax = _mm_max_epu8(vmax, vx);
+ }
+ vmax = _mm_max_epu8(vmax, _mm_unpackhi_epi64(vmax, vmax));
+ vmax = _mm_max_epu8(vmax, _mm_srli_epi64(vmax, 32));
+ vmax = _mm_max_epu8(vmax, _mm_srli_epi32(vmax, 16));
+ vmax = _mm_max_epu8(vmax, _mm_srli_epi16(vmax, 8));
+ *y = (uint8_t) _mm_cvtsi128_si32(vmax);
+ } else {
+ uint8_t vmax = 0;
+ do {
+ const uint8_t vx = *x++;
+ vmax = vx > vmax ? vx : vmax;
+ } while (--n != 0);
+ *y = vmax;
+ }
+}
diff --git a/src/unpooling.c b/src/unpooling.c
new file mode 100644
index 0000000..6264f17
--- /dev/null
+++ b/src/unpooling.c
@@ -0,0 +1,228 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/params.h>
+#include <xnnpack/indirection.h>
+
+
+static inline size_t compute_output_dimension(
+ size_t input_dimension,
+ size_t input_padding_dimension,
+ size_t kernel_dimension)
+{
+ return doz(kernel_dimension * input_dimension, input_padding_dimension);
+}
+
+enum xnn_status xnn_create_unpooling2d_nhwc_x32(
+ uint32_t input_padding_top,
+ uint32_t input_padding_right,
+ uint32_t input_padding_bottom,
+ uint32_t input_padding_left,
+ uint32_t pooling_height,
+ uint32_t pooling_width,
+ size_t channels,
+ size_t input_pixel_stride,
+ size_t output_pixel_stride,
+ uint32_t flags,
+ xnn_operator_t* unpooling_op_out)
+{
+ xnn_operator_t unpooling_op = NULL;
+ enum xnn_status status = xnn_status_uninitialized;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to create Unpooling operator: XNNPACK is not initialized");
+ goto error;
+ }
+
+ status = xnn_status_invalid_parameter;
+
+ const uint32_t pooling_size = pooling_height * pooling_width;
+ if (pooling_size == 0) {
+ xnn_log_error(
+ "failed to create Unpooling operator with %" PRIu32 "x%" PRIu32 " pooling size: "
+ "pooling size dimensions must be non-zero",
+ pooling_width, pooling_height);
+ goto error;
+ }
+
+ if (pooling_size == 1) {
+ xnn_log_error(
+ "failed to create Unpooling operator with 1 pooling element: 1x1 unpooling is meaningless");
+ goto error;
+ }
+
+ if (channels == 0) {
+ xnn_log_error(
+ "failed to create Unpooling operator with %zu channels: number of channels must be non-zero",
+ channels);
+ goto error;
+ }
+
+ if (input_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Unpooling operator with input pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ input_pixel_stride, channels);
+ goto error;
+ }
+
+ if (output_pixel_stride < channels) {
+ xnn_log_error(
+ "failed to create Unpooling operator with output pixel stride of %zu: "
+ "stride must be at least as large as the number of channels (%zu)",
+ output_pixel_stride, channels);
+ goto error;
+ }
+
+ status = xnn_status_out_of_memory;
+
+ unpooling_op = xnn_allocate_zero_memory(sizeof(struct xnn_operator));
+ if (unpooling_op == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for Unpooling operator descriptor", sizeof(struct xnn_operator));
+ goto error;
+ }
+
+ unpooling_op->padding_top = input_padding_top;
+ unpooling_op->padding_right = input_padding_right;
+ unpooling_op->padding_bottom = input_padding_bottom;
+ unpooling_op->padding_left = input_padding_left;
+
+ unpooling_op->kernel_height = pooling_height;
+ unpooling_op->kernel_width = pooling_width;
+ unpooling_op->channels = channels;
+ unpooling_op->input_pixel_stride = input_pixel_stride;
+ unpooling_op->output_pixel_stride = output_pixel_stride;
+
+ unpooling_op->type = xnn_operator_type_unpooling_x32;
+ unpooling_op->ukernel.type = xnn_ukernel_type_unpooling;
+
+ unpooling_op->state = xnn_run_state_invalid;
+
+ *unpooling_op_out = unpooling_op;
+ return xnn_status_success;
+
+error:
+ xnn_delete_operator(unpooling_op);
+ return status;
+}
+
+enum xnn_status xnn_setup_unpooling2d_nhwc_x32(
+ xnn_operator_t unpooling_op,
+ size_t batch_size,
+ size_t input_height,
+ size_t input_width,
+ const void* input,
+ const uint32_t* index,
+ void* output,
+ pthreadpool_t threadpool)
+{
+ if (unpooling_op->type != xnn_operator_type_unpooling_x32) {
+ xnn_log_error("failed to setup Unpooling (X32) operator: operator type mismatch");
+ return xnn_status_invalid_parameter;
+ }
+ unpooling_op->state = xnn_run_state_invalid;
+
+ if (!xnn_params.initialized) {
+ xnn_log_error("failed to setup Unpooling operator: XNNPACK is not initialized");
+ return xnn_status_uninitialized;
+ }
+
+ if (input_width == 0 || input_height == 0) {
+ xnn_log_error(
+ "failed to setup Unpooling operator with %zux%zu input: input dimensions must be non-zero",
+ input_width, input_height);
+ return xnn_status_invalid_parameter;
+ }
+
+ if (batch_size == 0) {
+ unpooling_op->state = xnn_run_state_skip;
+ return xnn_status_success;
+ }
+
+ unpooling_op->batch_size = batch_size;
+ unpooling_op->input_height = input_height;
+ unpooling_op->input_width = input_width;
+ unpooling_op->input = input;
+
+ unpooling_op->output_height = compute_output_dimension(
+ input_height, unpooling_op->padding_top + unpooling_op->padding_bottom,
+ unpooling_op->kernel_height);
+ unpooling_op->output_width = compute_output_dimension(
+ input_width, unpooling_op->padding_left + unpooling_op->padding_right,
+ unpooling_op->kernel_width);
+ unpooling_op->output = output;
+
+ size_t valid_batch_size = 0;
+ if (output == unpooling_op->last_output &&
+ input_height == unpooling_op->last_input_height &&
+ input_width == unpooling_op->last_input_width)
+ {
+ valid_batch_size = unpooling_op->valid_batch_size;
+ if (batch_size <= valid_batch_size) {
+ unpooling_op->compute.range[0] = batch_size * input_height;
+ unpooling_op->state = xnn_run_state_ready;
+ return xnn_status_success;
+ }
+ }
+
+ const size_t pooling_height = unpooling_op->kernel_height;
+ const size_t pooling_width = unpooling_op->kernel_width;
+ const size_t pooling_size = pooling_height * pooling_width;
+
+ const size_t indirection_buffer_size = sizeof(void*) * (batch_size * input_height * input_width * pooling_size);
+
+ void** indirection_buffer = (void**) realloc(unpooling_op->indirection_buffer, indirection_buffer_size);
+ if (indirection_buffer == NULL) {
+ xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
+ return xnn_status_out_of_memory;
+ }
+ unpooling_op->indirection_buffer = (const void**) indirection_buffer;
+
+ xnn_indirection_init_unpool2d(unpooling_op, valid_batch_size, 2 /* log2(sizeof(type32)) */);
+
+ const size_t channels = unpooling_op->channels;
+ const size_t input_pixel_stride_in_bytes = unpooling_op->input_pixel_stride * sizeof(float);
+ unpooling_op->context.unpooling = (struct unpooling_context) {
+ .input = input,
+ .input_height_stride = input_width * input_pixel_stride_in_bytes,
+ .input_width_stride = input_pixel_stride_in_bytes,
+ .index = index,
+ .index_height_stride = input_width * channels * sizeof(uint32_t),
+ .index_width_stride = channels * sizeof(uint32_t),
+ .indirect_output = indirection_buffer,
+ .indirect_output_height_stride = input_width * pooling_size * sizeof(void*),
+ .indirect_output_width_stride = pooling_size * sizeof(void*),
+ .pooling_size = pooling_size,
+ .channels = channels,
+ .fill_value = 0,
+ .ukernel = xnn_params.x32.unpool,
+ };
+ unpooling_op->compute.type = xnn_parallelization_type_2d;
+ unpooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_unpooling;
+ unpooling_op->compute.range[0] = batch_size * input_height;
+ unpooling_op->compute.range[1] = input_width;
+ unpooling_op->state = xnn_run_state_ready;
+
+ unpooling_op->last_output = output;
+ unpooling_op->last_input_height = input_height;
+ unpooling_op->last_input_width = input_width;
+ unpooling_op->valid_batch_size = max(valid_batch_size, batch_size);
+
+ return xnn_status_success;
+}
diff --git a/src/wasm-stubs.c b/src/wasm-stubs.c
new file mode 100644
index 0000000..78790b5
--- /dev/null
+++ b/src/wasm-stubs.c
@@ -0,0 +1,19 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <stdint.h>
+
+#include <cpuinfo.h>
+#include <fp16.h>
+
+uint32_t xnn_stub_wasm_f32_sub(uint32_t a, uint32_t b) {
+ return fp32_to_bits(fp32_from_bits(a) - fp32_from_bits(b));
+}
+
+#if CPUINFO_ARCH_WASM || CPUINFO_ARCH_WASMSIMD
+uint32_t xnn_stub_wasm_f32_min(uint32_t a, uint32_t b) {
+ return fp32_to_bits(__builtin_wasm_min_f32(fp32_from_bits(a), fp32_from_bits(b)));
+}
+#endif /* CPUINFO_ARCH_WASM || CPUINFO_ARCH_WASMSIMD */
diff --git a/src/x32-packx/x2-scalar.c b/src/x32-packx/x2-scalar.c
new file mode 100644
index 0000000..91c9846
--- /dev/null
+++ b/src/x32-packx/x2-scalar.c
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/packx.h>
+
+
+void xnn_x32_packx_ukernel_2x__scalar(
+ size_t m,
+ size_t k,
+ const uint32_t* restrict x,
+ size_t x_stride,
+ uint32_t* restrict y)
+{
+ assert(m != 0);
+ assert(k != 0);
+
+ const float* x0 = (const float*) x;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ if (m != 2) {
+ x1 = x0;
+ }
+
+ float*restrict y_f32 = (float*) y;
+
+ do {
+ const float vx0 = *x0++;
+ const float vx1 = *x1++;
+
+ y_f32[0] = vx0;
+ y_f32[1] = vx1;
+ y_f32 += 2;
+ } while (--k != 0);
+}
diff --git a/src/x32-packx/x3-scalar.c b/src/x32-packx/x3-scalar.c
new file mode 100644
index 0000000..dc6d8ed
--- /dev/null
+++ b/src/x32-packx/x3-scalar.c
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/packx.h>
+
+
+void xnn_x32_packx_ukernel_3x__scalar(
+ size_t m,
+ size_t k,
+ const uint32_t* restrict x,
+ size_t x_stride,
+ uint32_t* restrict y)
+{
+ assert(m != 0);
+ assert(k != 0);
+
+ const float* x0 = (const float*) x;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ if (m < 2) {
+ x1 = x0;
+ }
+ const float* x2 = (const float*) ((uintptr_t) x1 + x_stride);
+ if (m <= 2) {
+ x2 = x1;
+ }
+
+ float*restrict y_f32 = (float*) y;
+
+ do {
+ const float vx0 = *x0++;
+ const float vx1 = *x1++;
+ const float vx2 = *x2++;
+
+ y_f32[0] = vx0;
+ y_f32[1] = vx1;
+ y_f32[2] = vx2;
+ y_f32 += 3;
+ } while (--k != 0);
+}
diff --git a/src/x32-packx/x4-neon-st4.c b/src/x32-packx/x4-neon-st4.c
new file mode 100644
index 0000000..97e5501
--- /dev/null
+++ b/src/x32-packx/x4-neon-st4.c
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/packx.h>
+
+
+void xnn_x32_packx_ukernel_4x__neon_st4(
+ size_t m,
+ size_t k,
+ const uint32_t* restrict x,
+ size_t x_stride,
+ uint32_t* restrict y)
+{
+ assert(m != 0);
+ assert(k != 0);
+
+ const uint32_t* x0 = x;
+ const uint32_t* x1 = (const uint32_t*) ((uintptr_t) x0 + x_stride);
+ if (m < 2) {
+ x1 = x0;
+ }
+ const uint32_t* x2 = (const uint32_t*) ((uintptr_t) x1 + x_stride);
+ if (m <= 2) {
+ x2 = x1;
+ }
+ const uint32_t* x3 = (const uint32_t*) ((uintptr_t) x2 + x_stride);
+ if (m != 4) {
+ x3 = x2;
+ }
+
+ for (; k >= 4; k -= 4) {
+ const uint32x4_t vx0 = vld1q_u32(x0); x0 += 4;
+ const uint32x4_t vx1 = vld1q_u32(x1); x1 += 4;
+ const uint32x4_t vx2 = vld1q_u32(x2); x2 += 4;
+ const uint32x4_t vx3 = vld1q_u32(x3); x3 += 4;
+
+ const uint32x4x4_t vy = { vx0, vx1, vx2, vx3 };
+ vst4q_u32(y, vy); y += 16;
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const uint32x2_t vx00 = vld1_dup_u32(x0); x0 += 1;
+ const uint32x2_t vx22 = vld1_dup_u32(x2); x2 += 1;
+ const uint32x2_t vx01 = vld1_lane_u32(x1, vx00, 1); x1 += 1;
+ const uint32x2_t vx23 = vld1_lane_u32(x3, vx22, 1); x3 += 1;
+ const uint32x4_t vy = vcombine_u32(vx01, vx23);
+ vst1q_u32(y, vy); y += 4;
+ } while (--k != 0);
+ }
+}
diff --git a/src/x32-packx/x4-psimd.c b/src/x32-packx/x4-psimd.c
new file mode 100644
index 0000000..4d15a7e
--- /dev/null
+++ b/src/x32-packx/x4-psimd.c
@@ -0,0 +1,85 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/packx.h>
+
+
+void xnn_x32_packx_ukernel_4x__psimd(
+ size_t m,
+ size_t k,
+ const uint32_t* restrict x,
+ size_t x_stride,
+ uint32_t* restrict y)
+{
+ assert(m != 0);
+ assert(k != 0);
+
+ const uint32_t* x0 = x;
+ const uint32_t* x1 = (const uint32_t*) ((uintptr_t) x0 + x_stride);
+ if (m < 2) {
+ x1 = x0;
+ }
+ const uint32_t* x2 = (const uint32_t*) ((uintptr_t) x1 + x_stride);
+ if (m <= 2) {
+ x2 = x1;
+ }
+ const uint32_t* x3 = (const uint32_t*) ((uintptr_t) x2 + x_stride);
+ if (m != 4) {
+ x3 = x2;
+ }
+
+ for (; k >= 4; k -= 4) {
+ const psimd_u32 vx0 = psimd_load_u32(x0);
+ x0 += 4;
+ const psimd_u32 vx1 = psimd_load_u32(x1);
+ x1 += 4;
+ const psimd_u32 vx2 = psimd_load_u32(x2);
+ x2 += 4;
+ const psimd_u32 vx3 = psimd_load_u32(x3);
+ x3 += 4;
+
+ const psimd_u32 vt0 = psimd_interleave_lo_u32(vx0, vx1);
+ const psimd_u32 vt1 = psimd_interleave_hi_u32(vx0, vx1);
+ const psimd_u32 vt2 = psimd_interleave_lo_u32(vx2, vx3);
+ const psimd_u32 vt3 = psimd_interleave_hi_u32(vx2, vx3);
+
+ const psimd_u32 vy0 = psimd_concat_lo_u32(vt0, vt2);
+ psimd_store_u32(y, vy0);
+
+ const psimd_u32 vy1 = psimd_concat_hi_u32(vt0, vt2);
+ psimd_store_u32(y + 4, vy1);
+
+ const psimd_u32 vy2 = psimd_concat_lo_u32(vt1, vt3);
+ psimd_store_u32(y + 8, vy2);
+
+ const psimd_u32 vy3 = psimd_concat_hi_u32(vt1, vt3);
+ psimd_store_u32(y + 12, vy3);
+
+ y += 16;
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const psimd_u32 vx0 = psimd_load1_u32(x0);
+ x0 += 1;
+ const psimd_u32 vx1 = psimd_load1_u32(x1);
+ x1 += 1;
+ const psimd_u32 vx2 = psimd_load1_u32(x2);
+ x2 += 1;
+ const psimd_u32 vx3 = psimd_load1_u32(x3);
+ x3 += 1;
+ const psimd_u32 vx01 = psimd_interleave_lo_u32(vx0, vx1);
+ const psimd_u32 vx23 = psimd_interleave_lo_u32(vx2, vx3);
+ const psimd_u32 vy = psimd_concat_lo_u32(vx01, vx23);
+ psimd_store_u32(y, vy);
+ y += 4;
+ } while (--k != 0);
+ }
+}
diff --git a/src/x32-packx/x4-scalar.c b/src/x32-packx/x4-scalar.c
new file mode 100644
index 0000000..17f7bb7
--- /dev/null
+++ b/src/x32-packx/x4-scalar.c
@@ -0,0 +1,51 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/packx.h>
+
+
+void xnn_x32_packx_ukernel_4x__scalar(
+ size_t m,
+ size_t k,
+ const uint32_t* restrict x,
+ size_t x_stride,
+ uint32_t* restrict y)
+{
+ assert(m != 0);
+ assert(k != 0);
+
+ const float* x0 = (const float*) x;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ if (m < 2) {
+ x1 = x0;
+ }
+ const float* x2 = (const float*) ((uintptr_t) x1 + x_stride);
+ if (m <= 2) {
+ x2 = x1;
+ }
+ const float* x3 = (const float*) ((uintptr_t) x2 + x_stride);
+ if (m != 4) {
+ x3 = x2;
+ }
+
+ float*restrict y_f32 = (float*) y;
+
+ do {
+ const float vx0 = *x0++;
+ const float vx1 = *x1++;
+ const float vx2 = *x2++;
+ const float vx3 = *x3++;
+
+ y_f32[0] = vx0;
+ y_f32[1] = vx1;
+ y_f32[2] = vx2;
+ y_f32[3] = vx3;
+ y_f32 += 4;
+ } while (--k != 0);
+}
diff --git a/src/x32-packx/x4-sse.c b/src/x32-packx/x4-sse.c
new file mode 100644
index 0000000..1ff64d7
--- /dev/null
+++ b/src/x32-packx/x4-sse.c
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xmmintrin.h>
+
+#include <xnnpack/packx.h>
+
+
+void xnn_x32_packx_ukernel_4x__sse(
+ size_t m,
+ size_t k,
+ const uint32_t* restrict x,
+ size_t x_stride,
+ uint32_t* restrict y)
+{
+ assert(m != 0);
+ assert(k != 0);
+
+ const float* x0 = (const float*) x;
+ const float* x1 = (const float*) ((uintptr_t) x0 + x_stride);
+ if (m < 2) {
+ x1 = x0;
+ }
+ const float* x2 = (const float*) ((uintptr_t) x1 + x_stride);
+ if (m <= 2) {
+ x2 = x1;
+ }
+ const float* x3 = (const float*) ((uintptr_t) x2 + x_stride);
+ if (m != 4) {
+ x3 = x2;
+ }
+
+ float*restrict y_f32 = (float*) y;
+
+ for (; k >= 4; k -= 4) {
+ const __m128 vx0 = _mm_loadu_ps(x0);
+ x0 += 4;
+ const __m128 vx1 = _mm_loadu_ps(x1);
+ x1 += 4;
+ const __m128 vx2 = _mm_loadu_ps(x2);
+ x2 += 4;
+ const __m128 vx3 = _mm_loadu_ps(x3);
+ x3 += 4;
+
+ const __m128 vt0 = _mm_unpacklo_ps(vx0, vx1);
+ const __m128 vt1 = _mm_unpackhi_ps(vx0, vx1);
+ const __m128 vt2 = _mm_unpacklo_ps(vx2, vx3);
+ const __m128 vt3 = _mm_unpackhi_ps(vx2, vx3);
+
+ const __m128 vy0 = _mm_movelh_ps(vt0, vt2);
+ _mm_store_ps(y_f32, vy0);
+
+ const __m128 vy1 = _mm_movehl_ps(vt2, vt0);
+ _mm_store_ps(y_f32 + 4, vy1);
+
+ const __m128 vy2 = _mm_movelh_ps(vt1, vt3);
+ _mm_store_ps(y_f32 + 8, vy2);
+
+ const __m128 vy3 = _mm_movehl_ps(vt3, vt1);
+ _mm_store_ps(y_f32 + 12, vy3);
+
+ y_f32 += 16;
+ }
+ if XNN_UNLIKELY(k != 0) {
+ do {
+ const __m128 vx0 = _mm_load_ss(x0);
+ x0 += 1;
+ const __m128 vx1 = _mm_load_ss(x1);
+ x1 += 1;
+ const __m128 vx2 = _mm_load_ss(x2);
+ x2 += 1;
+ const __m128 vx3 = _mm_load_ss(x3);
+ x3 += 1;
+
+ const __m128 vx01 = _mm_unpacklo_ps(vx0, vx1);
+ const __m128 vx23 = _mm_unpacklo_ps(vx2, vx3);
+ const __m128 vy = _mm_movelh_ps(vx01, vx23);
+
+ _mm_store_ps(y_f32, vy);
+ y_f32 += 4;
+ } while (--k != 0);
+ }
+}
diff --git a/src/x32-pad/x2-neon.c b/src/x32-pad/x2-neon.c
new file mode 100644
index 0000000..f1d9da0
--- /dev/null
+++ b/src/x32-pad/x2-neon.c
@@ -0,0 +1,93 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/pad.h>
+
+
+void xnn_x32_pad_x2__neon(
+ size_t m,
+ size_t n,
+ size_t l,
+ size_t r,
+ uint32_t c,
+ const void* x,
+ size_t x_stride,
+ void* y,
+ size_t y_stride)
+{
+ assert(m <= 2);
+ assert(l % 4 == 0);
+ assert(n % 4 == 0);
+ assert(r % 4 == 0);
+
+ const uint32_t* x0 = x;
+ uint32_t* y0 = y;
+
+ const uint32_t* x1 = (const uint32_t*) ((uintptr_t) x0 + x_stride);
+ uint32_t* y1 = (uint32_t*) ((uintptr_t) y0 + y_stride);
+ if (m != 2) {
+ x1 = x0;
+ y1 = y0;
+ }
+ const uint32x4_t vc = vmovq_n_u32(c);
+
+ /* Pre-pad input channels */
+ for (; l >= 16; l -= 16) {
+ vst1q_u32(y0, vc); y0 += 4;
+ vst1q_u32(y1, vc); y1 += 4;
+ }
+ if (l & 8) {
+ vst1_u32(y0, vget_low_u32(vc)); y0 += 2;
+ vst1_u32(y1, vget_low_u32(vc)); y1 += 2;
+ }
+ if (l & 4) {
+ vst1q_lane_u32(y0, vc, 0); y0 += 1;
+ vst1q_lane_u32(y1, vc, 0); y1 += 1;
+ }
+
+ /* Copy input channels */
+ for (; n >= 16; n -= 16) {
+ const uint32x4_t vt0 = vld1q_u32(x0); x0 += 4;
+ const uint32x4_t vt1 = vld1q_u32(x1); x1 += 4;
+ vst1q_u32(y0, vt0); y0 += 4;
+ vst1q_u32(y1, vt1); y1 += 4;
+ }
+ if (n != 0) {
+ const uint32x4_t vt0 = vld1q_u32(x0); x0 += 4;
+ const uint32x4_t vt1 = vld1q_u32(x1); x1 += 4;
+ uint32x2_t vt0lo = vget_low_u32(vt0);
+ uint32x2_t vt1lo = vget_low_u32(vt1);
+ if (n & 8) {
+ vst1_u32(y0, vt0lo); y0 += 2;
+ vst1_u32(y1, vt1lo); y1 += 2;
+ vt0lo = vget_high_u32(vt0);
+ vt1lo = vget_high_u32(vt1);
+ }
+ if (n & 4) {
+ vst1_lane_u32(y0, vt0lo, 0); y0 += 1;
+ vst1_lane_u32(y1, vt1lo, 0); y1 += 1;
+ }
+ }
+
+ /* Post-pad input channels */
+ for (; r >= 16; r -= 16) {
+ vst1q_u32(y0, vc); y0 += 4;
+ vst1q_u32(y1, vc); y1 += 4;
+ }
+ if (r & 8) {
+ vst1_u32(y0, vget_low_u32(vc)); y0 += 2;
+ vst1_u32(y1, vget_low_u32(vc)); y1 += 2;
+ }
+ if (r & 4) {
+ vst1q_lane_u32(y0, vc, 0);
+ vst1q_lane_u32(y1, vc, 0);
+ }
+}
diff --git a/src/x32-pad/x2-psimd.c b/src/x32-pad/x2-psimd.c
new file mode 100644
index 0000000..78b471a
--- /dev/null
+++ b/src/x32-pad/x2-psimd.c
@@ -0,0 +1,91 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/pad.h>
+
+
+void xnn_x32_pad_x2__psimd(
+ size_t m,
+ size_t n,
+ size_t l,
+ size_t r,
+ uint32_t c,
+ const void* x,
+ size_t x_stride,
+ void* y,
+ size_t y_stride)
+{
+ assert(m <= 2);
+ assert(l % 4 == 0);
+ assert(n % 4 == 0);
+ assert(r % 4 == 0);
+
+ const uint32_t* x0 = x;
+ uint32_t* y0 = y;
+
+ const uint32_t* x1 = (const uint32_t*) ((uintptr_t) x0 + x_stride);
+ uint32_t* y1 = (uint32_t*) ((uintptr_t) y0 + y_stride);
+ if (m != 2) {
+ x1 = x0;
+ y1 = y0;
+ }
+ const psimd_u32 vc = psimd_splat_u32(c);
+
+ /* Pre-pad input channels */
+ for (; l >= 16; l -= 16) {
+ psimd_store_u32(y0, vc); y0 += 4;
+ psimd_store_u32(y1, vc); y1 += 4;
+ }
+ if (l & 8) {
+ psimd_store2_u32(y0, vc); y0 += 2;
+ psimd_store2_u32(y1, vc); y1 += 2;
+ }
+ if (l & 4) {
+ psimd_store1_u32(y0, vc); y0 += 1;
+ psimd_store1_u32(y1, vc); y1 += 1;
+ }
+
+ /* Copy input channels */
+ for (; n >= 16; n -= 16) {
+ const psimd_u32 vt0 = psimd_load_u32(x0); x0 += 4;
+ const psimd_u32 vt1 = psimd_load_u32(x1); x1 += 4;
+ psimd_store_u32(y0, vt0); y0 += 4;
+ psimd_store_u32(y1, vt1); y1 += 4;
+ }
+ if (n != 0) {
+ psimd_u32 vt0 = psimd_load_u32(x0);
+ psimd_u32 vt1 = psimd_load_u32(x1);
+ if (n & 8) {
+ psimd_store2_u32(y0, vt0); y0 += 2;
+ psimd_store2_u32(y1, vt1); y1 += 2;
+ vt0 = psimd_concat_hi_u32(vt0, vt0);
+ vt1 = psimd_concat_hi_u32(vt1, vt1);
+ }
+ if (n & 4) {
+ psimd_store1_u32(y0, vt0); y0 += 1;
+ psimd_store1_u32(y1, vt1); y1 += 1;
+ }
+ }
+
+ /* Post-pad input channels */
+ for (; r >= 16; r -= 16) {
+ psimd_store_u32(y0, vc); y0 += 4;
+ psimd_store_u32(y1, vc); y1 += 4;
+ }
+ if (r & 8) {
+ psimd_store2_u32(y0, vc); y0 += 2;
+ psimd_store2_u32(y1, vc); y1 += 2;
+ }
+ if (r & 4) {
+ psimd_store1_u32(y0, vc);
+ psimd_store1_u32(y1, vc);
+ }
+}
diff --git a/src/x32-pad/x2-scalar.c b/src/x32-pad/x2-scalar.c
new file mode 100644
index 0000000..bb17cc7
--- /dev/null
+++ b/src/x32-pad/x2-scalar.c
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/pad.h>
+
+
+void xnn_x32_pad_x2__scalar(
+ size_t m,
+ size_t n,
+ size_t l,
+ size_t r,
+ uint32_t c,
+ const void* x,
+ size_t x_stride,
+ void* y,
+ size_t y_stride)
+{
+ assert(m <= 2);
+ assert(l % 4 == 0);
+ assert(n % 4 == 0);
+ assert(r % 4 == 0);
+
+ const uint32_t* x0 = x;
+ uint32_t* y0 = y;
+
+ const uint32_t* x1 = (const uint32_t*) ((uintptr_t) x0 + x_stride);
+ uint32_t* y1 = (uint32_t*) ((uintptr_t) y0 + y_stride);
+ if (m != 2) {
+ x1 = x0;
+ y1 = y0;
+ }
+
+ /* Pre-pad input channels */
+ for (; l != 0; l -= 4) {
+ *y0++ = c;
+ *y1++ = c;
+ }
+
+ /* Copy input channels */
+ for (; n != 0; n -= 4) {
+ *y0++ = *x0++;
+ *y1++ = *x1++;
+ }
+
+ /* Post-pad input channels */
+ for (; r != 0; r -= 4) {
+ *y0++ = c;
+ *y1++ = c;
+ }
+}
diff --git a/src/x32-pad/x2-sse2.c b/src/x32-pad/x2-sse2.c
new file mode 100644
index 0000000..49f7b3f
--- /dev/null
+++ b/src/x32-pad/x2-sse2.c
@@ -0,0 +1,91 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/pad.h>
+
+
+void xnn_x32_pad_x2__sse2(
+ size_t m,
+ size_t n,
+ size_t l,
+ size_t r,
+ uint32_t c,
+ const void* x,
+ size_t x_stride,
+ void* y,
+ size_t y_stride)
+{
+ assert(m <= 2);
+ assert(l % 4 == 0);
+ assert(n % 4 == 0);
+ assert(r % 4 == 0);
+
+ const uint32_t* x0 = x;
+ uint32_t* y0 = y;
+
+ const uint32_t* x1 = (const uint32_t*) ((uintptr_t) x0 + x_stride);
+ uint32_t* y1 = (uint32_t*) ((uintptr_t) y0 + y_stride);
+ if (m != 2) {
+ x1 = x0;
+ y1 = y0;
+ }
+ const __m128i vc = _mm_set1_epi32((int) c);
+
+ /* Pre-pad input channels */
+ for (; l >= 16; l -= 16) {
+ _mm_storeu_si128((__m128i*) y0, vc); y0 += 4;
+ _mm_storeu_si128((__m128i*) y1, vc); y1 += 4;
+ }
+ if (l & 8) {
+ _mm_storel_epi64((__m128i*) y0, vc); y0 += 2;
+ _mm_storel_epi64((__m128i*) y1, vc); y1 += 2;
+ }
+ if (l & 4) {
+ *((uint32_t*) y0) = (uint32_t) _mm_cvtsi128_si32(vc); y0 += 1;
+ *((uint32_t*) y1) = (uint32_t) _mm_cvtsi128_si32(vc); y1 += 1;
+ }
+
+ /* Copy input channels */
+ for (; n >= 16; n -= 16) {
+ const __m128i vt0 = _mm_loadu_si128((const __m128i*) x0); x0 += 4;
+ const __m128i vt1 = _mm_loadu_si128((const __m128i*) x1); x1 += 4;
+ _mm_storeu_si128((__m128i*) y0, vt0); y0 += 4;
+ _mm_storeu_si128((__m128i*) y1, vt1); y1 += 4;
+ }
+ if (n != 0) {
+ __m128i vt0 = _mm_loadu_si128((const __m128i*) x0);
+ __m128i vt1 = _mm_loadu_si128((const __m128i*) x1);
+ if (n & 8) {
+ _mm_storel_epi64((__m128i*) y0, vt0); y0 += 2;
+ _mm_storel_epi64((__m128i*) y1, vt1); y1 += 2;
+ vt0 = _mm_unpackhi_epi64(vt0, vt0);
+ vt1 = _mm_unpackhi_epi64(vt1, vt1);
+ }
+ if (n & 4) {
+ *((uint32_t*) y0) = (uint32_t) _mm_cvtsi128_si32(vt0); y0 += 1;
+ *((uint32_t*) y1) = (uint32_t) _mm_cvtsi128_si32(vt1); y1 += 1;
+ }
+ }
+
+ /* Post-pad input channels */
+ for (; r >= 16; r -= 16) {
+ _mm_storeu_si128((__m128i*) y0, vc); y0 += 4;
+ _mm_storeu_si128((__m128i*) y1, vc); y1 += 4;
+ }
+ if (r & 8) {
+ _mm_storel_epi64((__m128i*) y0, vc); y0 += 2;
+ _mm_storel_epi64((__m128i*) y1, vc); y1 += 2;
+ }
+ if (r & 4) {
+ *((uint32_t*) y0) = (uint32_t) _mm_cvtsi128_si32(vc);
+ *((uint32_t*) y1) = (uint32_t) _mm_cvtsi128_si32(vc);
+ }
+}
diff --git a/src/x32-unpool/psimd.c b/src/x32-unpool/psimd.c
new file mode 100644
index 0000000..3501997
--- /dev/null
+++ b/src/x32-unpool/psimd.c
@@ -0,0 +1,51 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/pad.h>
+
+
+void xnn_x32_unpool_ukernel__psimd(
+ size_t p,
+ size_t c,
+ uint32_t f,
+ const uint32_t* input,
+ const uint32_t* index,
+ uint32_t** output)
+{
+ /* Pre-initialize outputs with constant */
+ const psimd_u32 vf = psimd_splat_u32(f);
+ uint32_t** os = output;
+ do {
+ uint32_t* o = *os++;
+ size_t k = c;
+ for (; k >= 4; k -= 4) {
+ psimd_store_u32(o, vf);
+ o += 4;
+ }
+ if (k != 0) {
+ if (k & 2) {
+ psimd_store2_u32(o, vf);
+ o += 2;
+ }
+ if (k & 1) {
+ psimd_store1_u32(o, vf);
+ }
+ }
+ } while (--p != 0);
+
+ /* Copy indexed elements to output */
+ size_t offset = 0;
+ do {
+ const uint32_t i = *index++;
+ *((uint32_t*) ((uintptr_t) output[i] + offset)) = *input++;
+ offset += sizeof(uint32_t);
+ } while (--c != 0);
+}
diff --git a/src/x32-unpool/scalar.c b/src/x32-unpool/scalar.c
new file mode 100644
index 0000000..dd6abab
--- /dev/null
+++ b/src/x32-unpool/scalar.c
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/pad.h>
+
+
+void xnn_x32_unpool_ukernel__scalar(
+ size_t p,
+ size_t c,
+ uint32_t f,
+ const uint32_t* input,
+ const uint32_t* index,
+ uint32_t** output)
+{
+ /* Pre-initialize outputs with constant */
+ uint32_t** os = output;
+ do {
+ uint32_t* o = *os++;
+ size_t k = c;
+ do {
+ *o++ = f;
+ } while (--k != 0);
+ } while (--p != 0);
+
+ /* Copy indexed elements to output */
+ size_t offset = 0;
+ do {
+ const uint32_t i = *index++;
+ *((uint32_t*) ((uintptr_t) output[i] + offset)) = *input++;
+ offset += sizeof(uint32_t);
+ } while (--c != 0);
+}
diff --git a/src/x32-zip/x2-neon.c b/src/x32-zip/x2-neon.c
new file mode 100644
index 0000000..c88695b
--- /dev/null
+++ b/src/x32-zip/x2-neon.c
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x2_ukernel__neon(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ uint32_t* o = output;
+
+ while (n >= 16) {
+ uint32x4x2_t vxy;
+ vxy.val[0] = vld1q_u32(x); x += 4;
+ vxy.val[1] = vld1q_u32(y); y += 4;
+ vst2q_u32(o, vxy); o += 8;
+ n -= 16;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ if (n & 8) {
+ uint32x2x2_t vxy;
+ vxy.val[0] = vld1_u32(x); x += 2;
+ vxy.val[1] = vld1_u32(y); y += 2;
+ vst2_u32(o, vxy); o += 4;
+ }
+ if (n & 4) {
+ uint32x2_t vxy = vld1_dup_u32(x);
+ vxy = vld1_lane_u32(y, vxy, 1);
+ vst1_u32(o, vxy);
+ }
+ }
+}
diff --git a/src/x32-zip/x2-psimd.c b/src/x32-zip/x2-psimd.c
new file mode 100644
index 0000000..4723f2b
--- /dev/null
+++ b/src/x32-zip/x2-psimd.c
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x2_ukernel__psimd(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ uint32_t* o = output;
+
+ while (n >= 16) {
+ const psimd_u32 vx = psimd_load_u32(x);
+ x += 4;
+ const psimd_u32 vy = psimd_load_u32(y);
+ y += 4;
+ const psimd_u32 vxy_lo = psimd_interleave_lo_u32(vx, vy);
+ const psimd_u32 vxy_hi = psimd_interleave_hi_u32(vx, vy);
+ psimd_store_u32(o, vxy_lo);
+ psimd_store_u32(o + 4, vxy_hi);
+ o += 8;
+ n -= 16;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ if (n & 8) {
+ const psimd_u32 vx = psimd_load2_u32(x);
+ x += 2;
+ const psimd_u32 vy = psimd_load2_u32(y);
+ y += 2;
+ const psimd_u32 vxy = psimd_interleave_lo_u32(vx, vy);
+ psimd_store_u32((psimd_u32*) o, vxy);
+ o += 4;
+ }
+ if (n & 4) {
+ const uint32_t vx = *x;
+ const uint32_t vy = *y;
+ o[0] = vx;
+ o[1] = vy;
+ }
+ }
+}
diff --git a/src/x32-zip/x2-scalar.c b/src/x32-zip/x2-scalar.c
new file mode 100644
index 0000000..06e58a3
--- /dev/null
+++ b/src/x32-zip/x2-scalar.c
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x2_ukernel__scalar(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+
+ do {
+ const uint32_t vx = *x++;
+ const uint32_t vy = *y++;
+ output[0] = vx;
+ output[1] = vy;
+ output += 2;
+
+ n -= 4;
+ } while (n != 0);
+}
diff --git a/src/x32-zip/x2-sse2.c b/src/x32-zip/x2-sse2.c
new file mode 100644
index 0000000..c61575d
--- /dev/null
+++ b/src/x32-zip/x2-sse2.c
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x2_ukernel__sse2(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ uint32_t* o = output;
+
+ while (n >= 16) {
+ const __m128i vx = _mm_loadu_si128((const __m128i*) x);
+ x += 4;
+ const __m128i vy = _mm_loadu_si128((const __m128i*) y);
+ y += 4;
+ const __m128i vxy_lo = _mm_unpacklo_epi32(vx, vy);
+ const __m128i vxy_hi = _mm_unpackhi_epi32(vx, vy);
+ _mm_storeu_si128((__m128i*) o, vxy_lo);
+ _mm_storeu_si128((__m128i*) (o + 4), vxy_hi);
+ o += 8;
+ n -= 16;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ if (n & 8) {
+ const __m128i vx = _mm_loadl_epi64((const __m128i*) x);
+ x += 2;
+ const __m128i vy = _mm_loadl_epi64((const __m128i*) y);
+ y += 2;
+ const __m128i vxy = _mm_unpacklo_epi32(vx, vy);
+ _mm_storeu_si128((__m128i*) o, vxy);
+ o += 4;
+ }
+ if (n & 4) {
+ const uint32_t vx = *x;
+ const uint32_t vy = *y;
+ o[0] = vx;
+ o[1] = vy;
+ }
+ }
+}
diff --git a/src/x32-zip/x3-neon.c b/src/x32-zip/x3-neon.c
new file mode 100644
index 0000000..1522a2d
--- /dev/null
+++ b/src/x32-zip/x3-neon.c
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x3_ukernel__neon(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n);
+ uint32_t* o = output;
+
+ while (n >= 16) {
+ uint32x4x3_t vxyz;
+ vxyz.val[0] = vld1q_u32(x); x += 4;
+ vxyz.val[1] = vld1q_u32(y); y += 4;
+ vxyz.val[2] = vld1q_u32(z); z += 4;
+ vst3q_u32(o, vxyz); o += 12;
+ n -= 16;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ if (n & 8) {
+ uint32x2x3_t vxyz;
+ vxyz.val[0] = vld1_u32(x); x += 2;
+ vxyz.val[1] = vld1_u32(y); y += 2;
+ vxyz.val[2] = vld1_u32(z); z += 2;
+ vst3_u32(o, vxyz); o += 6;
+ }
+ if (n & 4) {
+ uint32x2_t vxy = vld1_dup_u32(x);
+ const uint32x2_t vz = vld1_dup_u32(z);
+ vxy = vld1_lane_u32(y, vxy, 1);
+ vst1_u32(o, vxy); o += 2;
+ vst1_lane_u32(o, vz, 0);
+ }
+ }
+}
diff --git a/src/x32-zip/x3-psimd.c b/src/x32-zip/x3-psimd.c
new file mode 100644
index 0000000..dea008c
--- /dev/null
+++ b/src/x32-zip/x3-psimd.c
@@ -0,0 +1,71 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x3_ukernel__psimd(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = (const uint32_t*) input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n);
+ uint32_t* o = (uint32_t*) output;
+
+ while (n >= 16) {
+ /* vx = ( x3, x2, x1, x0 ) */
+ const psimd_u32 vx = psimd_load_u32(x);
+ x += 4;
+ /* vy = ( y3, y2, y1, y0 ) */
+ const psimd_u32 vy = psimd_load_u32(y);
+ y += 4;
+ /* vz = ( z3, z2, z1, z0 ) */
+ const psimd_u32 vz = psimd_load_u32(z);
+ z += 4;
+
+ /* vxy = ( y2, y0, x2, x0 ) */
+ const psimd_u32 vxy = psimd_concat_even_u32(vx, vy);
+ /* vyz = ( z3, z1, y3, y1 ) */
+ const psimd_u32 vyz = psimd_concat_odd_u32(vy, vz);
+ /* vzx = ( x3, x1, z2, z0 ) */
+ const psimd_u32 vzx = __builtin_shufflevector(vz, vx, 0, 2, 4+1, 4+3);
+
+ /* vxyz0 = ( x1, z0, y0, x0 ) */
+ const psimd_u32 vxyz0 = psimd_concat_even_u32(vxy, vzx);
+ /* vxyz1 = ( y2, x2, z1, y1 ) */
+ const psimd_u32 vxyz1 = __builtin_shufflevector(vyz, vxy, 0, 2, 4+1, 4+3);
+ /* vxyz2 = ( z3, y3, x3, z2 ) */
+ const psimd_u32 vxyz2 = psimd_concat_odd_u32(vzx, vyz);
+
+ psimd_store_u32(o, vxyz0);
+ psimd_store_u32(o + 4, vxyz1);
+ psimd_store_u32(o + 8, vxyz2);
+ o += 12;
+ n -= 16;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ do {
+ const uint32_t vx = *x++;
+ const uint32_t vy = *y++;
+ const uint32_t vz = *z++;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o += 3;
+ n -= 4;
+ } while (n != 0);
+ }
+}
diff --git a/src/x32-zip/x3-scalar.c b/src/x32-zip/x3-scalar.c
new file mode 100644
index 0000000..bb25c3b
--- /dev/null
+++ b/src/x32-zip/x3-scalar.c
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x3_ukernel__scalar(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n);
+ uint32_t* o = output;
+
+ do {
+ const uint32_t vx = *x++;
+ const uint32_t vy = *y++;
+ const uint32_t vz = *z++;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o += 3;
+
+ n -= 4;
+ } while (n != 0);
+}
diff --git a/src/x32-zip/x3-sse2.c b/src/x32-zip/x3-sse2.c
new file mode 100644
index 0000000..49925ae
--- /dev/null
+++ b/src/x32-zip/x3-sse2.c
@@ -0,0 +1,91 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x3_ukernel__sse2(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const float* x = (const float*) input;
+ const float* y = (const float*) ((uintptr_t) x + n);
+ const float* z = (const float*) ((uintptr_t) y + n);
+ float* o = (float*) output;
+
+ while (n >= 16) {
+ /* vx = ( x3, x2, x1, x0 ) */
+ const __m128 vx = _mm_loadu_ps(x);
+ x += 4;
+ /* vy = ( y3, y2, y1, y0 ) */
+ const __m128 vy = _mm_loadu_ps(y);
+ y += 4;
+ /* vz = ( z3, z2, z1, z0 ) */
+ const __m128 vz = _mm_loadu_ps(z);
+ z += 4;
+
+ /* vxy = ( y2, y0, x2, x0 ) */
+ const __m128 vxy = _mm_shuffle_ps(vx, vy, _MM_SHUFFLE(2, 0, 2, 0));
+ /* vyz = ( z3, z1, y3, y1 ) */
+ const __m128 vyz = _mm_shuffle_ps(vy, vz, _MM_SHUFFLE(3, 1, 3, 1));
+ /* vzx = ( x3, x1, z2, z0 ) */
+ const __m128 vzx = _mm_shuffle_ps(vz, vx, _MM_SHUFFLE(3, 1, 2, 0));
+
+ /* vxyz0 = ( x1, z0, y0, x0 ) */
+ const __m128 vxyz0 = _mm_shuffle_ps(vxy, vzx, _MM_SHUFFLE(2, 0, 2, 0));
+ /* vxyz1 = ( y2, x2, z1, y1 ) */
+ const __m128 vxyz1 = _mm_shuffle_ps(vyz, vxy, _MM_SHUFFLE(3, 1, 2, 0));
+ /* vxyz2 = ( z3, y3, x3, z2 ) */
+ const __m128 vxyz2 = _mm_shuffle_ps(vzx, vyz, _MM_SHUFFLE(3, 1, 3, 1));
+
+ _mm_storeu_ps(o, vxyz0);
+ _mm_storeu_ps(o + 4, vxyz1);
+ _mm_storeu_ps(o + 8, vxyz2);
+ o += 12;
+ n -= 16;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ if (n & 8) {
+ /* vx = ( -, -, x1, x0 ) */
+ const __m128 vx = _mm_castpd_ps(_mm_load_sd((const double*) x));
+ x += 2;
+ /* vy = ( -, -, y1, y0 ) */
+ const __m128 vy = _mm_castpd_ps(_mm_load_sd((const double*) y));
+ y += 2;
+ /* vz = ( -, -, z1, z0 ) */
+ const __m128 vz = _mm_castpd_ps(_mm_load_sd((const double*) z));
+ z += 2;
+
+ /* vxy = ( y1, x1, y0, x0 ) */
+ const __m128 vxy = _mm_unpacklo_ps(vx, vy);
+ /* vzx = ( x1, z1, x0, z0 ) */
+ const __m128 vzx = _mm_unpacklo_ps(vz, vx);
+ /* vyz = ( z1, y1, z0, y0 ) */
+ const __m128 vyz = _mm_unpacklo_ps(vy, vz);
+
+ _mm_storeu_ps(o, _mm_shuffle_ps(vxy, vzx, _MM_SHUFFLE(3, 0, 1, 0)));
+ _mm_storeh_pi((__m64*) (o + 4), vyz);
+ o += 6;
+ }
+ if (n & 4) {
+ const __m128 vx = _mm_load_ss(x);
+ const __m128 vy = _mm_load_ss(y);
+ const __m128 vz = _mm_load_ss(z);
+ _mm_store_ss(o, vx);
+ _mm_store_ss(o + 1, vy);
+ _mm_store_ss(o + 2, vz);
+ }
+ }
+}
diff --git a/src/x32-zip/x4-neon.c b/src/x32-zip/x4-neon.c
new file mode 100644
index 0000000..45c3dc0
--- /dev/null
+++ b/src/x32-zip/x4-neon.c
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x4_ukernel__neon(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n);
+ const uint32_t* w = (const uint32_t*) ((uintptr_t) z + n);
+ uint32_t* o = output;
+
+ while (n >= 16) {
+ uint32x4x4_t vxyzw;
+ vxyzw.val[0] = vld1q_u32(x); x += 4;
+ vxyzw.val[1] = vld1q_u32(y); y += 4;
+ vxyzw.val[2] = vld1q_u32(z); z += 4;
+ vxyzw.val[3] = vld1q_u32(w); w += 4;
+ vst4q_u32(o, vxyzw); o += 16;
+ n -= 16;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ if (n & 8) {
+ uint32x2x4_t vxyzw;
+ vxyzw.val[0] = vld1_u32(x); x += 2;
+ vxyzw.val[1] = vld1_u32(y); y += 2;
+ vxyzw.val[2] = vld1_u32(z); z += 2;
+ vxyzw.val[3] = vld1_u32(w); w += 2;
+ vst4_u32(o, vxyzw); o += 8;
+ }
+ if (n & 4) {
+ uint32x4_t vxyzw = vld1q_dup_u32(x);
+ vxyzw = vld1q_lane_u32(y, vxyzw, 1);
+ vxyzw = vld1q_lane_u32(z, vxyzw, 2);
+ vxyzw = vld1q_lane_u32(w, vxyzw, 3);
+ vst1q_u32(o, vxyzw);
+ }
+ }
+}
diff --git a/src/x32-zip/x4-psimd.c b/src/x32-zip/x4-psimd.c
new file mode 100644
index 0000000..e144f34
--- /dev/null
+++ b/src/x32-zip/x4-psimd.c
@@ -0,0 +1,88 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x4_ukernel__psimd(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n);
+ const uint32_t* w = (const uint32_t*) ((uintptr_t) z + n);
+ uint32_t* o = output;
+
+ while (n >= 16) {
+ const psimd_u32 vx = psimd_load_u32(x);
+ x += 4;
+ const psimd_u32 vy = psimd_load_u32(y);
+ y += 4;
+ const psimd_u32 vz = psimd_load_u32(z);
+ z += 4;
+ const psimd_u32 vw = psimd_load_u32(w);
+ w += 4;
+
+ const psimd_u32 vxy_lo = psimd_interleave_lo_u32(vx, vy);
+ const psimd_u32 vxy_hi = psimd_interleave_hi_u32(vx, vy);
+ const psimd_u32 vzw_lo = psimd_interleave_lo_u32(vz, vw);
+ const psimd_u32 vzw_hi = psimd_interleave_hi_u32(vz, vw);
+
+ const psimd_u32 vxyzw0 = psimd_concat_lo_u32(vxy_lo, vzw_lo);
+ const psimd_u32 vxyzw1 = psimd_concat_hi_u32(vxy_lo, vzw_lo);
+ const psimd_u32 vxyzw2 = psimd_concat_lo_u32(vxy_hi, vzw_hi);
+ const psimd_u32 vxyzw3 = psimd_concat_hi_u32(vxy_hi, vzw_hi);
+
+ psimd_store_u32(o, vxyzw0);
+ psimd_store_u32(o + 4, vxyzw1);
+ psimd_store_u32(o + 8, vxyzw2);
+ psimd_store_u32(o + 12, vxyzw3);
+ o += 16;
+ n -= 16;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ if (n & 8) {
+ const psimd_u32 vx = psimd_load2_u32(x);
+ x += 2;
+ const psimd_u32 vy = psimd_load2_u32(y);
+ y += 2;
+ const psimd_u32 vz = psimd_load2_u32(z);
+ z += 2;
+ const psimd_u32 vw = psimd_load2_u32(w);
+ w += 2;
+
+ const psimd_u32 vxy = psimd_interleave_lo_u32(vx, vy);
+ const psimd_u32 vzw = psimd_interleave_lo_u32(vz, vw);
+
+ const psimd_u32 vxyzw_lo = psimd_concat_lo_u32(vxy, vzw);
+ const psimd_u32 vxyzw_hi = psimd_concat_hi_u32(vxy, vzw);
+
+ psimd_store_u32(o, vxyzw_lo);
+ psimd_store_u32(o + 4, vxyzw_hi);
+ o += 8;
+ }
+ if (n & 4) {
+ const uint32_t vx = *x;
+ const uint32_t vy = *y;
+ const uint32_t vz = *z;
+ const uint32_t vw = *w;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o[3] = vw;
+ }
+ }
+}
diff --git a/src/x32-zip/x4-scalar.c b/src/x32-zip/x4-scalar.c
new file mode 100644
index 0000000..a1ca351
--- /dev/null
+++ b/src/x32-zip/x4-scalar.c
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x4_ukernel__scalar(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n);
+ const uint32_t* w = (const uint32_t*) ((uintptr_t) z + n);
+ uint32_t* o = output;
+
+ do {
+ const uint32_t vx = *x++;
+ const uint32_t vy = *y++;
+ const uint32_t vz = *z++;
+ const uint32_t vw = *w++;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o[3] = vw;
+ o += 4;
+
+ n -= 4;
+ } while (n != 0);
+}
diff --git a/src/x32-zip/x4-sse2.c b/src/x32-zip/x4-sse2.c
new file mode 100644
index 0000000..0ffc400
--- /dev/null
+++ b/src/x32-zip/x4-sse2.c
@@ -0,0 +1,88 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_x4_ukernel__sse2(
+ size_t n,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+
+ const uint32_t* x = input;
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n);
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n);
+ const uint32_t* w = (const uint32_t*) ((uintptr_t) z + n);
+ uint32_t* o = output;
+
+ while (n >= 16) {
+ const __m128i vx = _mm_loadu_si128((const __m128i*) x);
+ x += 4;
+ const __m128i vy = _mm_loadu_si128((const __m128i*) y);
+ y += 4;
+ const __m128i vz = _mm_loadu_si128((const __m128i*) z);
+ z += 4;
+ const __m128i vw = _mm_loadu_si128((const __m128i*) w);
+ w += 4;
+
+ const __m128i vxy_lo = _mm_unpacklo_epi32(vx, vy);
+ const __m128i vxy_hi = _mm_unpackhi_epi32(vx, vy);
+ const __m128i vzw_lo = _mm_unpacklo_epi32(vz, vw);
+ const __m128i vzw_hi = _mm_unpackhi_epi32(vz, vw);
+
+ const __m128i vxyzw0 = _mm_unpacklo_epi64(vxy_lo, vzw_lo);
+ const __m128i vxyzw1 = _mm_unpackhi_epi64(vxy_lo, vzw_lo);
+ const __m128i vxyzw2 = _mm_unpacklo_epi64(vxy_hi, vzw_hi);
+ const __m128i vxyzw3 = _mm_unpackhi_epi64(vxy_hi, vzw_hi);
+
+ _mm_storeu_si128((__m128i*) o, vxyzw0);
+ _mm_storeu_si128((__m128i*) (o + 4), vxyzw1);
+ _mm_storeu_si128((__m128i*) (o + 8), vxyzw2);
+ _mm_storeu_si128((__m128i*) (o + 12), vxyzw3);
+ o += 16;
+ n -= 16;
+ }
+ if XNN_UNLIKELY(n != 0) {
+ if (n & 8) {
+ const __m128i vx = _mm_loadl_epi64((const __m128i*) x);
+ x += 2;
+ const __m128i vy = _mm_loadl_epi64((const __m128i*) y);
+ y += 2;
+ const __m128i vz = _mm_loadl_epi64((const __m128i*) z);
+ z += 2;
+ const __m128i vw = _mm_loadl_epi64((const __m128i*) w);
+ w += 2;
+
+ const __m128i vxy = _mm_unpacklo_epi32(vx, vy);
+ const __m128i vzw = _mm_unpacklo_epi32(vz, vw);
+
+ const __m128i vxyzw_lo = _mm_unpacklo_epi64(vxy, vzw);
+ const __m128i vxyzw_hi = _mm_unpackhi_epi64(vxy, vzw);
+
+ _mm_storeu_si128((__m128i*) o, vxyzw_lo);
+ _mm_storeu_si128((__m128i*) (o + 4), vxyzw_hi);
+ o += 8;
+ }
+ if (n & 4) {
+ const uint32_t vx = *x;
+ const uint32_t vy = *y;
+ const uint32_t vz = *z;
+ const uint32_t vw = *w;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o[3] = vw;
+ }
+ }
+}
diff --git a/src/x32-zip/xm-neon.c b/src/x32-zip/xm-neon.c
new file mode 100644
index 0000000..2ab68a1
--- /dev/null
+++ b/src/x32-zip/xm-neon.c
@@ -0,0 +1,103 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <arm_neon.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_xm_ukernel__neon(
+ size_t n,
+ size_t m,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+ assert(m >= 4);
+
+ const uint32_t* w = input;
+ const size_t group_increment = m * 4;
+ const size_t input_increment = n * 3;
+ const size_t output_increment = 16 - m * n;
+ const uint32_t* last_input = (const uint32_t*) ((uintptr_t) input + n * (m - 1));
+ uint32_t* last_output = (uint32_t*) ((uintptr_t) output + (m * 4 - 16));
+
+ for (size_t i = 0; i < m; i += 4) {
+ w = (const uint32_t*) ((uintptr_t) w + input_increment);
+ if (w >= last_input) {
+ w = last_input;
+ }
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) w - n);
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) z - n);
+ const uint32_t* x = (const uint32_t*) ((uintptr_t) y - n);
+
+ size_t k = n;
+ while (k >= 16) {
+ const uint32x4_t vx = vld1q_u32(x); x += 4;
+ const uint32x4_t vy = vld1q_u32(y); y += 4;
+ const uint32x4_t vz = vld1q_u32(z); z += 4;
+ const uint32x4_t vw = vld1q_u32(w); w += 4;
+
+ const uint32x4x2_t vxy = vzipq_u32(vx, vy);
+ const uint32x4x2_t vzw = vzipq_u32(vz, vw);
+
+ vst1_u32(output, vget_low_u32(vxy.val[0]));
+ vst1_u32(output + 2, vget_low_u32(vzw.val[0]));
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ vst1_u32(output, vget_high_u32(vxy.val[0]));
+ vst1_u32(output + 2, vget_high_u32(vzw.val[0]));
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ vst1_u32(output, vget_low_u32(vxy.val[1]));
+ vst1_u32(output + 2, vget_low_u32(vzw.val[1]));
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ vst1_u32(output, vget_high_u32(vxy.val[1]));
+ vst1_u32(output + 2, vget_high_u32(vzw.val[1]));
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ k -= 16;
+ }
+ if XNN_UNLIKELY(k != 0) {
+ if (k & 8) {
+ const uint32x2_t vx = vld1_u32(x); x += 2;
+ const uint32x2_t vy = vld1_u32(y); y += 2;
+ const uint32x2_t vz = vld1_u32(z); z += 2;
+ const uint32x2_t vw = vld1_u32(w); w += 2;
+
+ const uint32x2x2_t vxy = vzip_u32(vx, vy);
+ const uint32x2x2_t vzw = vzip_u32(vz, vw);
+
+ vst1_u32(output, vxy.val[0]);
+ vst1_u32(output + 2, vzw.val[0]);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ vst1_u32(output, vxy.val[1]);
+ vst1_u32(output + 2, vzw.val[1]);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+ }
+ if (k & 4) {
+ const uint32x2_t vx = vld1_dup_u32(x);
+ const uint32x2_t vz = vld1_dup_u32(z);
+ const uint32x2_t vxy = vld1_lane_u32(y, vx, 1);
+ const uint32x2_t vzw = vld1_lane_u32(w, vz, 1); w += 1;
+
+ vst1_u32(output, vxy);
+ vst1_u32(output + 2, vzw);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+ }
+ }
+ output = (uint32_t*) ((uintptr_t) output + output_increment);
+ if (output > last_output) {
+ output = last_output;
+ }
+ }
+}
diff --git a/src/x32-zip/xm-psimd.c b/src/x32-zip/xm-psimd.c
new file mode 100644
index 0000000..d5b0c36
--- /dev/null
+++ b/src/x32-zip/xm-psimd.c
@@ -0,0 +1,117 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <psimd.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_xm_ukernel__psimd(
+ size_t n,
+ size_t m,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+ assert(m >= 4);
+
+ const uint32_t* w = input;
+ const size_t group_increment = m * 4;
+ const size_t input_increment = n * 3;
+ const size_t output_increment = 16 - m * n;
+ const uint32_t* last_input = (const uint32_t*) ((uintptr_t) input + n * (m - 1));
+ uint32_t* last_output = (uint32_t*) ((uintptr_t) output + (m * 4 - 16));
+
+ for (size_t i = 0; i < m; i += 4) {
+ w = (const uint32_t*) ((uintptr_t) w + input_increment);
+ if (w >= last_input) {
+ w = last_input;
+ }
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) w - n);
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) z - n);
+ const uint32_t* x = (const uint32_t*) ((uintptr_t) y - n);
+
+ size_t k = n;
+ while (k >= 16) {
+ const psimd_u32 vx = psimd_load_u32((const psimd_u32*) x);
+ x += 4;
+ const psimd_u32 vy = psimd_load_u32((const psimd_u32*) y);
+ y += 4;
+ const psimd_u32 vz = psimd_load_u32((const psimd_u32*) z);
+ z += 4;
+ const psimd_u32 vw = psimd_load_u32((const psimd_u32*) w);
+ w += 4;
+
+ const psimd_u32 vxy_lo = psimd_interleave_lo_u32(vx, vy);
+ const psimd_u32 vxy_hi = psimd_interleave_hi_u32(vx, vy);
+ const psimd_u32 vzw_lo = psimd_interleave_lo_u32(vz, vw);
+ const psimd_u32 vzw_hi = psimd_interleave_hi_u32(vz, vw);
+
+ const psimd_u32 vxyzw0 = psimd_concat_lo_u32(vxy_lo, vzw_lo);
+ const psimd_u32 vxyzw1 = psimd_concat_hi_u32(vxy_lo, vzw_lo);
+ const psimd_u32 vxyzw2 = psimd_concat_lo_u32(vxy_hi, vzw_hi);
+ const psimd_u32 vxyzw3 = psimd_concat_hi_u32(vxy_hi, vzw_hi);
+
+ psimd_store_u32(output, vxyzw0);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ psimd_store_u32(output, vxyzw1);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ psimd_store_u32(output, vxyzw2);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ psimd_store_u32(output, vxyzw3);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ k -= 16;
+ }
+ if XNN_UNLIKELY(k != 0) {
+ if (k & 8) {
+ const psimd_u32 vx = psimd_load2_u32(x);
+ x += 2;
+ const psimd_u32 vy = psimd_load2_u32(y);
+ y += 2;
+ const psimd_u32 vz = psimd_load2_u32(z);
+ z += 2;
+ const psimd_u32 vw = psimd_load2_u32(w);
+ w += 2;
+
+ const psimd_u32 vxy = psimd_interleave_lo_u32(vx, vy);
+ const psimd_u32 vzw = psimd_interleave_lo_u32(vz, vw);
+
+ const psimd_u32 vxyzw_lo = psimd_concat_lo_u32(vxy, vzw);
+ const psimd_u32 vxyzw_hi = psimd_concat_hi_u32(vxy, vzw);
+
+ psimd_store_u32(output, vxyzw_lo);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ psimd_store_u32(output, vxyzw_hi);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+ }
+ if (k & 4) {
+ const uint32_t vx = *x;
+ const uint32_t vy = *y;
+ const uint32_t vz = *z;
+ const uint32_t vw = *w++;
+
+ output[0] = vx;
+ output[1] = vy;
+ output[2] = vz;
+ output[3] = vw;
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+ }
+ }
+ output = (uint32_t*) ((uintptr_t) output + output_increment);
+ if (output > last_output) {
+ output = last_output;
+ }
+ }
+}
diff --git a/src/x32-zip/xm-scalar.c b/src/x32-zip/xm-scalar.c
new file mode 100644
index 0000000..fa2ee80
--- /dev/null
+++ b/src/x32-zip/xm-scalar.c
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_xm_ukernel__scalar(
+ size_t n,
+ size_t m,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+ assert(m >= 4);
+
+ size_t k = n;
+ do {
+ size_t l = m;
+ const uint32_t* input_column = input++;
+ do {
+ *output++ = *input_column;
+ input_column = (uint32_t*) ((uintptr_t) input_column + n);
+ } while (--l != 0);
+ k -= 4;
+ } while (k != 0);
+}
diff --git a/src/x32-zip/xm-sse2.c b/src/x32-zip/xm-sse2.c
new file mode 100644
index 0000000..da06541
--- /dev/null
+++ b/src/x32-zip/xm-sse2.c
@@ -0,0 +1,117 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <emmintrin.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x32_zip_xm_ukernel__sse2(
+ size_t n,
+ size_t m,
+ const uint32_t* input,
+ uint32_t* output)
+{
+ assert(n != 0);
+ assert(n % 4 == 0);
+ assert(m >= 4);
+
+ const uint32_t* w = input;
+ const size_t group_increment = m * 4;
+ const size_t input_increment = n * 3;
+ const size_t output_increment = 16 - m * n;
+ const uint32_t* last_input = (const uint32_t*) ((uintptr_t) input + n * (m - 1));
+ uint32_t* last_output = (uint32_t*) ((uintptr_t) output + (m * 4 - 16));
+
+ for (size_t i = 0; i < m; i += 4) {
+ w = (const uint32_t*) ((uintptr_t) w + input_increment);
+ if (w >= last_input) {
+ w = last_input;
+ }
+ const uint32_t* z = (const uint32_t*) ((uintptr_t) w - n);
+ const uint32_t* y = (const uint32_t*) ((uintptr_t) z - n);
+ const uint32_t* x = (const uint32_t*) ((uintptr_t) y - n);
+
+ size_t k = n;
+ while (k >= 16) {
+ const __m128i vx = _mm_loadu_si128((const __m128i*) x);
+ x += 4;
+ const __m128i vy = _mm_loadu_si128((const __m128i*) y);
+ y += 4;
+ const __m128i vz = _mm_loadu_si128((const __m128i*) z);
+ z += 4;
+ const __m128i vw = _mm_loadu_si128((const __m128i*) w);
+ w += 4;
+
+ const __m128i vxy_lo = _mm_unpacklo_epi32(vx, vy);
+ const __m128i vxy_hi = _mm_unpackhi_epi32(vx, vy);
+ const __m128i vzw_lo = _mm_unpacklo_epi32(vz, vw);
+ const __m128i vzw_hi = _mm_unpackhi_epi32(vz, vw);
+
+ const __m128i vxyzw0 = _mm_unpacklo_epi64(vxy_lo, vzw_lo);
+ const __m128i vxyzw1 = _mm_unpackhi_epi64(vxy_lo, vzw_lo);
+ const __m128i vxyzw2 = _mm_unpacklo_epi64(vxy_hi, vzw_hi);
+ const __m128i vxyzw3 = _mm_unpackhi_epi64(vxy_hi, vzw_hi);
+
+ _mm_storeu_si128((__m128i*) output, vxyzw0);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ _mm_storeu_si128((__m128i*) output, vxyzw1);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ _mm_storeu_si128((__m128i*) output, vxyzw2);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ _mm_storeu_si128((__m128i*) output, vxyzw3);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ k -= 16;
+ }
+ if XNN_UNLIKELY(k != 0) {
+ if (k & 8) {
+ const __m128i vx = _mm_loadl_epi64((const __m128i*) x);
+ x += 2;
+ const __m128i vy = _mm_loadl_epi64((const __m128i*) y);
+ y += 2;
+ const __m128i vz = _mm_loadl_epi64((const __m128i*) z);
+ z += 2;
+ const __m128i vw = _mm_loadl_epi64((const __m128i*) w);
+ w += 2;
+
+ const __m128i vxy = _mm_unpacklo_epi32(vx, vy);
+ const __m128i vzw = _mm_unpacklo_epi32(vz, vw);
+
+ const __m128i vxyzw_lo = _mm_unpacklo_epi64(vxy, vzw);
+ const __m128i vxyzw_hi = _mm_unpackhi_epi64(vxy, vzw);
+
+ _mm_storeu_si128((__m128i*) output, vxyzw_lo);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+
+ _mm_storeu_si128((__m128i*) output, vxyzw_hi);
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+ }
+ if (k & 4) {
+ const uint32_t vx = *x;
+ const uint32_t vy = *y;
+ const uint32_t vz = *z;
+ const uint32_t vw = *w++;
+
+ output[0] = vx;
+ output[1] = vy;
+ output[2] = vz;
+ output[3] = vw;
+ output = (uint32_t*) ((uintptr_t) output + group_increment);
+ }
+ }
+ output = (uint32_t*) ((uintptr_t) output + output_increment);
+ if (output > last_output) {
+ output = last_output;
+ }
+ }
+}
diff --git a/src/x8-lut/scalar.c b/src/x8-lut/scalar.c
new file mode 100644
index 0000000..6b6e8a8
--- /dev/null
+++ b/src/x8-lut/scalar.c
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/lut.h>
+
+
+void xnn_x8_lut_ukernel__scalar(
+ size_t n,
+ const uint8_t* x,
+ const uint8_t t[restrict static 256],
+ uint8_t* y)
+{
+ assert(n != 0);
+
+ while (n >= 4) {
+ const size_t vx0 = x[0];
+ const size_t vx1 = x[1];
+ const size_t vx2 = x[2];
+ const size_t vx3 = x[3];
+ x += 4;
+
+ const uint8_t vt0 = t[vx0];
+ const uint8_t vt1 = t[vx1];
+ const uint8_t vt2 = t[vx2];
+ const uint8_t vt3 = t[vx3];
+
+ y[0] = vt0;
+ y[1] = vt1;
+ y[2] = vt2;
+ y[3] = vt3;
+ y += 4;
+
+ n -= 4;
+ }
+ while (n != 0) {
+ const size_t vx = *x++;
+ const uint8_t vt = t[vx];
+ *y++ = vt;
+
+ n--;
+ };
+}
diff --git a/src/x8-zip/x2-neon.c b/src/x8-zip/x2-neon.c
new file mode 100644
index 0000000..3732e90
--- /dev/null
+++ b/src/x8-zip/x2-neon.c
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <arm_neon.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_x2_ukernel__neon(
+ size_t n,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ const uint8_t* x = input;
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n);
+ uint8_t* o = output;
+
+ if (n >= 8) {
+ do {
+ uint8x8x2_t vxy;
+ vxy.val[0] = vld1_u8(x); x += 8;
+ vxy.val[1] = vld1_u8(y); y += 8;
+ vst2_u8(o, vxy); o += 16;;
+ n -= 8;
+ } while (n >= 8);
+ if (n != 0) {
+ const size_t address_increment = n - 8;
+ uint8x8x2_t vxy;
+ vxy.val[0] = vld1_u8((const uint8_t*) ((uintptr_t) x + address_increment));
+ vxy.val[1] = vld1_u8((const uint8_t*) ((uintptr_t) y + address_increment));
+ vst2_u8((uint8_t*) ((uintptr_t) o + address_increment * 2), vxy);
+ }
+ } else {
+ do {
+ const uint8_t vx = *x++;
+ const uint8_t vy = *y++;
+ o[0] = vx;
+ o[1] = vy;
+ o += 2;
+ } while (--n != 0);
+ }
+}
diff --git a/src/x8-zip/x2-scalar.c b/src/x8-zip/x2-scalar.c
new file mode 100644
index 0000000..e7906ca
--- /dev/null
+++ b/src/x8-zip/x2-scalar.c
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_x2_ukernel__scalar(
+ size_t n,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ assert(n != 0);
+
+ const uint8_t* x = input;
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n);
+ uint8_t* o = output;
+
+ do {
+ const uint8_t vx = *x++;
+ const uint8_t vy = *y++;
+ o[0] = vx;
+ o[1] = vy;
+ o += 2;
+
+ n -= sizeof(uint8_t);
+ } while (n != 0);
+}
diff --git a/src/x8-zip/x2-sse2.c b/src/x8-zip/x2-sse2.c
new file mode 100644
index 0000000..6ba0963
--- /dev/null
+++ b/src/x8-zip/x2-sse2.c
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <emmintrin.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_x2_ukernel__sse2(
+ size_t n,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ const uint8_t* x = input;
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n);
+ uint8_t* o = output;
+
+ if (n >= 16) {
+ do {
+ const __m128i vx = _mm_loadu_si128((const __m128i*) x);
+ x += 16;
+ const __m128i vy = _mm_loadu_si128((const __m128i*) y);
+ y += 16;
+ const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy);
+ const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy);
+ _mm_storeu_si128((__m128i*) o, vxy_lo);
+ _mm_storeu_si128((__m128i*) (o + 16), vxy_hi);
+ o = (void*) ((uintptr_t) o + 32);
+ n -= 16;
+ } while (n >= 16);
+ if (n != 0) {
+ const size_t address_increment = n - 16;
+ const __m128i vx = _mm_loadu_si128((const __m128i*) ((uintptr_t) x + address_increment));
+ const __m128i vy = _mm_loadu_si128((const __m128i*) ((uintptr_t) y + address_increment));
+ const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy);
+ const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy);
+ o = (void*) ((uintptr_t) o + address_increment * 2);
+ _mm_storeu_si128((__m128i*) o, vxy_lo);
+ _mm_storeu_si128((__m128i*) o + 1, vxy_hi);
+ }
+ } else {
+ do {
+ const uint8_t vx = *x++;
+ const uint8_t vy = *y++;
+ o[0] = vx;
+ o[1] = vy;
+ o += 2;
+ } while (--n != 0);
+ }
+}
diff --git a/src/x8-zip/x3-neon.c b/src/x8-zip/x3-neon.c
new file mode 100644
index 0000000..9348ecb
--- /dev/null
+++ b/src/x8-zip/x3-neon.c
@@ -0,0 +1,54 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <arm_neon.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_x3_ukernel__neon(
+ size_t n,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ const uint8_t* x = input;
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n);
+ const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n);
+ uint8_t* o = output;
+
+ if (n >= 8) {
+ do {
+ uint8x8x3_t vxyz;
+ vxyz.val[0] = vld1_u8(x); x += 8;
+ vxyz.val[1] = vld1_u8(y); y += 8;
+ vxyz.val[2] = vld1_u8(z); z += 8;
+ vst3_u8(o, vxyz); o += 24;
+ n -= 8;
+ } while (n >= 8);
+ if (n != 0) {
+ const size_t address_increment = n - 8;
+ uint8x8x3_t vxyz;
+ vxyz.val[0] = vld1_u8(x + address_increment);
+ vxyz.val[1] = vld1_u8(y + address_increment);
+ vxyz.val[2] = vld1_u8(z + address_increment);
+ vst3_u8((uint8_t*) ((uintptr_t) o + address_increment * 3), vxyz);
+ }
+ } else {
+ do {
+ const uint8_t vx = *x++;
+ const uint8_t vy = *y++;
+ const uint8_t vz = *z++;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o += 3;
+ } while (--n != 0);
+ }
+}
diff --git a/src/x8-zip/x3-scalar.c b/src/x8-zip/x3-scalar.c
new file mode 100644
index 0000000..b4319ef
--- /dev/null
+++ b/src/x8-zip/x3-scalar.c
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_x3_ukernel__scalar(
+ size_t n,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ const uint8_t* x = input;
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n);
+ const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n);
+ uint8_t* o = output;
+
+ do {
+ const uint8_t vx = *x++;
+ const uint8_t vy = *y++;
+ const uint8_t vz = *z++;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o += 3;
+
+ n -= sizeof(uint8_t);
+ } while (n != 0);
+}
diff --git a/src/x8-zip/x3-sse2.c b/src/x8-zip/x3-sse2.c
new file mode 100644
index 0000000..045fc20
--- /dev/null
+++ b/src/x8-zip/x3-sse2.c
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <emmintrin.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_x3_ukernel__sse2(
+ size_t n,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ const uint8_t* x = input;
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n);
+ const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n);
+ uint8_t* o = output;
+
+ if (n >= 16) {
+ const __m128i vmask0x00FF00FF = _mm_set1_epi16(0x00FF);
+ const __m128i vmask0x0000FFFF = _mm_set1_epi32(0x0000FFFF);
+ do {
+ /* vx = ( x15, x14, x13, x12, x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1, x0 ) */
+ const __m128i vx = _mm_loadu_si128((const __m128i*) x);
+ x += 16;
+ /* vy = ( y15, y14, y13, y12, y11, y10, y9, y8, y7, y6, y5, y4, y3, y2, y1, y0 ) */
+ const __m128i vy = _mm_loadu_si128((const __m128i*) y);
+ y += 16;
+ /* vz = ( z15, z14, z13, z12, z11, z10, z9, z8, z7, z6, z5, z4, z3, z2, z1, z0 ) */
+ const __m128i vz = _mm_loadu_si128((const __m128i*) z);
+ z += 16;
+
+ /* vxeye = ( y14, x14, y12, x12, y10, x10, y8, x8, y6, x6, y4, x4, y2, x2, y0, x0 ) */
+ const __m128i vxeye = _mm_or_si128(_mm_and_si128(vx, vmask0x00FF00FF), _mm_slli_epi16(vy, 8));
+ /* vyozo = ( z15, y15, z13, y13, z11, y11, z9, y9, z7, y7, z5, y5, z3, y3, z1, y1 ) */
+ const __m128i vyozo = _mm_or_si128(_mm_andnot_si128(vmask0x00FF00FF, vz), _mm_srli_epi16(vy, 8));
+ /* vzoxo = ( x15, z14, x13, z12, x11, z10, x9, z8, x7, z6, x5, z4, x3, z2, x1, z0 ) */
+ const __m128i vzexo = _mm_or_si128(_mm_and_si128(vz, vmask0x00FF00FF), _mm_andnot_si128(vmask0x00FF00FF, vx));
+
+ /* vxeyezexo = ( x13, z12, y12, x12, x9, z8, y8, x8, x5, z4, y4, x4, x1, z0, y0, x0 ) */
+ const __m128i vxeyezexo = _mm_or_si128(_mm_and_si128(vxeye, vmask0x0000FFFF), _mm_slli_epi32(vzexo, 16));
+ /* vyozoxeye = ( y14, x14, z13, y13, y10, x10, z9, y9, y6, x6, z5, y5, y2, x2, z1, y1 ) */
+ const __m128i vyozoxeye = _mm_or_si128(_mm_and_si128(vyozo, vmask0x0000FFFF), _mm_andnot_si128(vmask0x0000FFFF, vxeye));
+ /* vzexoyozo = ( z15, y15, x15, z14, z11, y11, x11, z10, z7, y7, x7, z6, z3, y3, x3, z2 ) */
+ const __m128i vzexoyozo = _mm_or_si128(_mm_andnot_si128(vmask0x0000FFFF, vyozo), _mm_srli_epi32(vzexo, 16));
+
+ /* vtemp0 = ( x13, z12, y12, x12, x5, z4, y4, x4, z11, y11, x11, z10, z3, y3, x3, z2 ) */
+ const __m128i vtemp0 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vzexoyozo), _mm_castsi128_ps(vxeyezexo), _MM_SHUFFLE(3, 1, 2, 0)));
+ /* vtemp1 = ( y10, x10, z9, y9, y2, x2, z1, y1, x9, z8, y8, x8, x1, z0, y0, x0 ) */
+ const __m128i vtemp1 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vxeyezexo), _mm_castsi128_ps(vyozoxeye), _MM_SHUFFLE(2, 0, 2, 0)));
+ /* vtemp2 = ( z15, y15, x15, z14, z7, y7, x7, z6, y14, x14, z13, y13, y6, x6, z5, y5 ) */
+ const __m128i vtemp2 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vyozoxeye), _mm_castsi128_ps(vzexoyozo), _MM_SHUFFLE(3, 1, 3, 1)));
+
+ /* vxyz0 = ( x5, z4, y4, x4, z3, y3, x3, z2, y2, x2, z1, y1, x1, z0, y0, x0 ) */
+ const __m128i vxyz0 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vtemp1), _mm_castsi128_ps(vtemp0), _MM_SHUFFLE(2, 0, 2, 0)));
+ /* vxyz1 = ( y10, x10, z9, y9, x9, z8, y8, x8, z7, y7, x7, z6, y6, x6, z5, y5 ) */
+ const __m128i vxyz1 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vtemp2), _mm_castsi128_ps(vtemp1), _MM_SHUFFLE(3, 1, 2, 0)));
+ /* vxyz2 = ( z15, y15, x15, z14, y14, x14, z13, y13, x13, z12, y12, x12, z11, y11, x11, z10 ) */
+ const __m128i vxyz2 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vtemp0), _mm_castsi128_ps(vtemp2), _MM_SHUFFLE(3, 1, 3, 1)));
+
+ _mm_storeu_si128((__m128i*) o, vxyz0);
+ _mm_storeu_si128((__m128i*) o + 1, vxyz1);
+ _mm_storeu_si128((__m128i*) o + 2, vxyz2);
+ o += 48;
+ n -= 16;
+ } while (n >= 16);
+ if (n != 0) {
+ const size_t address_increment = n - 16;
+ /* vx = ( x15, x14, x13, x12, x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1, x0 ) */
+ const __m128i vx = _mm_loadu_si128((const __m128i*) ((uintptr_t) x + address_increment));
+ /* vy = ( y15, y14, y13, y12, y11, y10, y9, y8, y7, y6, y5, y4, y3, y2, y1, y0 ) */
+ const __m128i vy = _mm_loadu_si128((const __m128i*) ((uintptr_t) y + address_increment));
+ /* vz = ( z15, z14, z13, z12, z11, z10, z9, z8, z7, z6, z5, z4, z3, z2, z1, z0 ) */
+ const __m128i vz = _mm_loadu_si128((const __m128i*) ((uintptr_t) z + address_increment));
+
+ /* vxeye = ( y14, x14, y12, x12, y10, x10, y8, x8, y6, x6, y4, x4, y2, x2, y0, x0 ) */
+ const __m128i vxeye = _mm_or_si128(_mm_and_si128(vx, vmask0x00FF00FF), _mm_slli_epi16(vy, 8));
+ /* vyozo = ( z15, y15, z13, y13, z11, y11, z9, y9, z7, y7, z5, y5, z3, y3, z1, y1 ) */
+ const __m128i vyozo = _mm_or_si128(_mm_andnot_si128(vmask0x00FF00FF, vz), _mm_srli_epi16(vy, 8));
+ /* vzoxo = ( x15, z14, x13, z12, x11, z10, x9, z8, x7, z6, x5, z4, x3, z2, x1, z0 ) */
+ const __m128i vzexo = _mm_or_si128(_mm_and_si128(vz, vmask0x00FF00FF), _mm_andnot_si128(vmask0x00FF00FF, vx));
+
+ /* vxeyezexo = ( x13, z12, y12, x12, x9, z8, y8, x8, x5, z4, y4, x4, x1, z0, y0, x0 ) */
+ const __m128i vxeyezexo = _mm_or_si128(_mm_and_si128(vxeye, vmask0x0000FFFF), _mm_slli_epi32(vzexo, 16));
+ /* vyozoxeye = ( y14, x14, z13, y13, y10, x10, z9, y9, y6, x6, z5, y5, y2, x2, z1, y1 ) */
+ const __m128i vyozoxeye = _mm_or_si128(_mm_and_si128(vyozo, vmask0x0000FFFF), _mm_andnot_si128(vmask0x0000FFFF, vxeye));
+ /* vzexoyozo = ( z15, y15, x15, z14, z11, y11, x11, z10, z7, y7, x7, z6, z3, y3, x3, z2 ) */
+ const __m128i vzexoyozo = _mm_or_si128(_mm_andnot_si128(vmask0x0000FFFF, vyozo), _mm_srli_epi32(vzexo, 16));
+
+ /* vtemp0 = ( x13, z12, y12, x12, x5, z4, y4, x4, z11, y11, x11, z10, z3, y3, x3, z2 ) */
+ const __m128i vtemp0 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vzexoyozo), _mm_castsi128_ps(vxeyezexo), _MM_SHUFFLE(3, 1, 2, 0)));
+ /* vtemp1 = ( y10, x10, z9, y9, y2, x2, z1, y1, x9, z8, y8, x8, x1, z0, y0, x0 ) */
+ const __m128i vtemp1 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vxeyezexo), _mm_castsi128_ps(vyozoxeye), _MM_SHUFFLE(2, 0, 2, 0)));
+ /* vtemp2 = ( z15, y15, x15, z14, z7, y7, x7, z6, y14, x14, z13, y13, y6, x6, z5, y5 ) */
+ const __m128i vtemp2 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vyozoxeye), _mm_castsi128_ps(vzexoyozo), _MM_SHUFFLE(3, 1, 3, 1)));
+
+ /* vxyz0 = ( x5, z4, y4, x4, z3, y3, x3, z2, y2, x2, z1, y1, x1, z0, y0, x0 ) */
+ const __m128i vxyz0 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vtemp1), _mm_castsi128_ps(vtemp0), _MM_SHUFFLE(2, 0, 2, 0)));
+ /* vxyz1 = ( y10, x10, z9, y9, x9, z8, y8, x8, z7, y7, x7, z6, y6, x6, z5, y5 ) */
+ const __m128i vxyz1 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vtemp2), _mm_castsi128_ps(vtemp1), _MM_SHUFFLE(3, 1, 2, 0)));
+ /* vxyz2 = ( z15, y15, x15, z14, y14, x14, z13, y13, x13, z12, y12, x12, z11, y11, x11, z10 ) */
+ const __m128i vxyz2 = _mm_castps_si128(
+ _mm_shuffle_ps(_mm_castsi128_ps(vtemp0), _mm_castsi128_ps(vtemp2), _MM_SHUFFLE(3, 1, 3, 1)));
+
+ o = (uint8_t*) ((uintptr_t) o + address_increment * 3);
+ _mm_storeu_si128((__m128i*) o, vxyz0);
+ _mm_storeu_si128((__m128i*) o + 1, vxyz1);
+ _mm_storeu_si128((__m128i*) o + 2, vxyz2);
+ }
+ } else {
+ do {
+ const uint8_t vx = *x++;
+ const uint8_t vy = *y++;
+ const uint8_t vz = *z++;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o += 3;
+ } while (--n != 0);
+ }
+}
diff --git a/src/x8-zip/x4-neon.c b/src/x8-zip/x4-neon.c
new file mode 100644
index 0000000..38ac597
--- /dev/null
+++ b/src/x8-zip/x4-neon.c
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <arm_neon.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_x4_ukernel__neon(
+ size_t n,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ const uint8_t* x = input;
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n);
+ const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n);
+ const uint8_t* w = (const uint8_t*) ((uintptr_t) z + n);
+ uint8_t* o = output;
+
+ if (n >= 8) {
+ do {
+ uint8x8x4_t vxyzw;
+ vxyzw.val[0] = vld1_u8(x); x += 8;
+ vxyzw.val[1] = vld1_u8(y); y += 8;
+ vxyzw.val[2] = vld1_u8(z); z += 8;
+ vxyzw.val[3] = vld1_u8(w); w += 8;
+ vst4_u8(o, vxyzw); o += 32;
+ n -= 8;
+ } while (n >= 8);
+ if (n != 0) {
+ const size_t address_increment = n - 8;
+ uint8x8x4_t vxyzw;
+ vxyzw.val[0] = vld1_u8(x + address_increment);
+ vxyzw.val[1] = vld1_u8(y + address_increment);
+ vxyzw.val[2] = vld1_u8(z + address_increment);
+ vxyzw.val[3] = vld1_u8(w + address_increment);
+ vst4_u8((uint8_t*) ((uintptr_t) o + address_increment * 4), vxyzw);
+ }
+ } else {
+ do {
+ const uint8_t vx = *x++;
+ const uint8_t vy = *y++;
+ const uint8_t vz = *z++;
+ const uint8_t vw = *w++;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o[3] = vw;
+ o += 4;
+ } while (--n != 0);
+ }
+}
diff --git a/src/x8-zip/x4-scalar.c b/src/x8-zip/x4-scalar.c
new file mode 100644
index 0000000..b56c969
--- /dev/null
+++ b/src/x8-zip/x4-scalar.c
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_x4_ukernel__scalar(
+ size_t n,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ assert(n != 0);
+
+ const uint8_t* x = input;
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n);
+ const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n);
+ const uint8_t* w = (const uint8_t*) ((uintptr_t) z + n);
+ uint8_t* o = output;
+
+ do {
+ const uint8_t vx = *x++;
+ const uint8_t vy = *y++;
+ const uint8_t vz = *z++;
+ const uint8_t vw = *w++;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o[3] = vw;
+ o += 4;
+
+ n -= sizeof(uint8_t);
+ } while (n != 0);
+}
diff --git a/src/x8-zip/x4-sse2.c b/src/x8-zip/x4-sse2.c
new file mode 100644
index 0000000..292a981
--- /dev/null
+++ b/src/x8-zip/x4-sse2.c
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <emmintrin.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_x4_ukernel__sse2(
+ size_t n,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ const uint8_t* x = input;
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n);
+ const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n);
+ const uint8_t* w = (const uint8_t*) ((uintptr_t) z + n);
+ uint8_t* o = output;
+
+ if (n >= 16) {
+ do {
+ const __m128i vx = _mm_loadu_si128((const __m128i*) x);
+ x += 16;
+ const __m128i vy = _mm_loadu_si128((const __m128i*) y);
+ y += 16;
+ const __m128i vz = _mm_loadu_si128((const __m128i*) z);
+ z += 16;
+ const __m128i vw = _mm_loadu_si128((const __m128i*) w);
+ w += 16;
+ const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy);
+ const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy);
+ const __m128i vzw_lo = _mm_unpacklo_epi8(vz, vw);
+ const __m128i vzw_hi = _mm_unpackhi_epi8(vz, vw);
+ const __m128i vxyzw0 = _mm_unpacklo_epi16(vxy_lo, vzw_lo);
+ const __m128i vxyzw1 = _mm_unpackhi_epi16(vxy_lo, vzw_lo);
+ const __m128i vxyzw2 = _mm_unpacklo_epi16(vxy_hi, vzw_hi);
+ const __m128i vxyzw3 = _mm_unpackhi_epi16(vxy_hi, vzw_hi);
+ _mm_storeu_si128((__m128i*) o, vxyzw0);
+ _mm_storeu_si128((__m128i*) o + 1, vxyzw1);
+ _mm_storeu_si128((__m128i*) o + 2, vxyzw2);
+ _mm_storeu_si128((__m128i*) o + 3, vxyzw3);
+ o = (void*) ((uintptr_t) o + 64);
+ n -= 16;
+ } while (n >= 16);
+ if (n != 0) {
+ const size_t address_increment = n - 16;
+ const __m128i vx = _mm_loadu_si128((const __m128i*) ((uintptr_t) x + address_increment));
+ const __m128i vy = _mm_loadu_si128((const __m128i*) ((uintptr_t) y + address_increment));
+ const __m128i vz = _mm_loadu_si128((const __m128i*) ((uintptr_t) z + address_increment));
+ const __m128i vw = _mm_loadu_si128((const __m128i*) ((uintptr_t) w + address_increment));
+ const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy);
+ const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy);
+ const __m128i vzw_lo = _mm_unpacklo_epi8(vz, vw);
+ const __m128i vzw_hi = _mm_unpackhi_epi8(vz, vw);
+ const __m128i vxyzw0 = _mm_unpacklo_epi16(vxy_lo, vzw_lo);
+ const __m128i vxyzw1 = _mm_unpackhi_epi16(vxy_lo, vzw_lo);
+ const __m128i vxyzw2 = _mm_unpacklo_epi16(vxy_hi, vzw_hi);
+ const __m128i vxyzw3 = _mm_unpackhi_epi16(vxy_hi, vzw_hi);
+ o = (void*) ((uintptr_t) o + address_increment * 4);
+ _mm_storeu_si128((__m128i*) o, vxyzw0);
+ _mm_storeu_si128((__m128i*) o + 1, vxyzw1);
+ _mm_storeu_si128((__m128i*) o + 2, vxyzw2);
+ _mm_storeu_si128((__m128i*) o + 3, vxyzw3);
+ }
+ } else {
+ do {
+ const uint8_t vx = *x++;
+ const uint8_t vy = *y++;
+ const uint8_t vz = *z++;
+ const uint8_t vw = *w++;
+ o[0] = vx;
+ o[1] = vy;
+ o[2] = vz;
+ o[3] = vw;
+ o += 4;
+ } while (--n != 0);
+ }
+}
diff --git a/src/x8-zip/xm-neon.c b/src/x8-zip/xm-neon.c
new file mode 100644
index 0000000..2ac19eb
--- /dev/null
+++ b/src/x8-zip/xm-neon.c
@@ -0,0 +1,146 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <arm_neon.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_xm_ukernel__neon(
+ size_t n,
+ size_t m,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ const uint8_t* w = input;
+ const size_t input_increment = n * 3;
+ const size_t output_increment = 4 - m * n;
+ const uint8_t* last_input = w + n * (m - 1);
+ uint8_t* last_output = (uint8_t*) ((uintptr_t) output + (m - 4));
+
+ if (n >= 8) {
+ for (size_t i = 0; i < m; i += 4) {
+ size_t k = n;
+ w = (const uint8_t*) ((uintptr_t) w + input_increment);
+ if (w >= last_input) {
+ w = last_input;
+ }
+ const uint8_t* z = (const uint8_t*) ((uintptr_t) w - n);
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) z - n);
+ const uint8_t* x = (const uint8_t*) ((uintptr_t) y - n);
+ while (k >= 8) {
+ const uint8x8_t vx = vld1_u8(x); x += 8;
+ const uint8x8_t vy = vld1_u8(y); y += 8;
+ const uint8x8_t vz = vld1_u8(z); z += 8;
+ const uint8x8_t vw = vld1_u8(w); w += 8;
+
+ const uint8x8x2_t vxy = vzip_u8(vx, vy);
+ const uint8x8x2_t vzw = vzip_u8(vz, vw);
+ const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0]));
+ const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1]));
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u16(vxyzw_lo.val[0]), 0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u16(vxyzw_lo.val[0]), 1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u16(vxyzw_lo.val[1]), 0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u16(vxyzw_lo.val[1]), 1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u16(vxyzw_hi.val[0]), 0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u16(vxyzw_hi.val[0]), 1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u16(vxyzw_hi.val[1]), 0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vreinterpret_u32_u16(vxyzw_hi.val[1]), 1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ k -= 8;
+ }
+ if (k != 0) {
+ const size_t address_increment = k - 8;
+ x = (const uint8_t*) ((uintptr_t) x + address_increment);
+ y = (const uint8_t*) ((uintptr_t) y + address_increment);
+ z = (const uint8_t*) ((uintptr_t) z + address_increment);
+ w = (const uint8_t*) ((uintptr_t) w + address_increment);
+ const int64x1_t vshift = vmov_n_s64(8 * address_increment);
+
+ const uint64x1_t vx = vshl_u64(vreinterpret_u64_u8(vld1_u8(x)), vshift);
+ const uint64x1_t vy = vshl_u64(vreinterpret_u64_u8(vld1_u8(y)), vshift);
+ const uint64x1_t vz = vshl_u64(vreinterpret_u64_u8(vld1_u8(z)), vshift);
+ const uint64x1_t vw = vshl_u64(vreinterpret_u64_u8(vld1_u8(w)), vshift); w += 8;
+ const uint8x8x2_t vxy = vzip_u8(vreinterpret_u8_u64(vx), vreinterpret_u8_u64(vy));
+ const uint8x8x2_t vzw = vzip_u8(vreinterpret_u8_u64(vz), vreinterpret_u8_u64(vw));
+ const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0]));
+ const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1]));
+
+ uint32x2_t vxyzw0 = vreinterpret_u32_u16(vxyzw_lo.val[0]);
+ uint32x2_t vxyzw1 = vreinterpret_u32_u16(vxyzw_lo.val[1]);
+ uint32x2_t vxyzw2 = vreinterpret_u32_u16(vxyzw_hi.val[0]);
+ uint32x2_t vxyzw3 = vreinterpret_u32_u16(vxyzw_hi.val[1]);
+
+ if (k & 4) {
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw1, 0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw1, 1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vxyzw0 = vxyzw2;
+ vxyzw1 = vxyzw3;
+ }
+
+ if (k & 2) {
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ vxyzw0 = vxyzw1;
+ }
+ if (k & 1) {
+ vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ }
+ }
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ if (output > last_output) {
+ output = last_output;
+ }
+ }
+ } else {
+ const uint8_t* i = input;
+ uint8_t* o = output;
+ size_t k = n;
+ do {
+ size_t l = m;
+ const uint8_t* ii = i++;
+ do {
+ *o++ = *ii;
+ ii += n;
+ } while (--l != 0);
+ } while (--k != 0);
+ }
+}
diff --git a/src/x8-zip/xm-scalar.c b/src/x8-zip/xm-scalar.c
new file mode 100644
index 0000000..32a8ae2
--- /dev/null
+++ b/src/x8-zip/xm-scalar.c
@@ -0,0 +1,32 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <assert.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_xm_ukernel__scalar(
+ size_t n,
+ size_t m,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ assert(n != 0);
+ assert(m >= 4);
+
+ size_t k = n;
+ do {
+ size_t l = m;
+ const uint8_t* input_column = input++;
+ do {
+ *output++ = *input_column;
+ input_column = (uint8_t*) ((uintptr_t) input_column + n);
+ } while (--l != 0);
+ k -= sizeof(uint8_t);
+ } while (k != 0);
+}
diff --git a/src/x8-zip/xm-sse2.c b/src/x8-zip/xm-sse2.c
new file mode 100644
index 0000000..30971f9
--- /dev/null
+++ b/src/x8-zip/xm-sse2.c
@@ -0,0 +1,208 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <emmintrin.h>
+
+#include <xnnpack/zip.h>
+
+
+void xnn_x8_zip_xm_ukernel__sse2(
+ size_t n,
+ size_t m,
+ const uint8_t* input,
+ uint8_t* output)
+{
+ const uint8_t* w = input;
+ const size_t input_increment = n * 3;
+ const size_t output_increment = 4 - m * n;
+ const uint8_t* last_input = w + n * (m - 1);
+ uint8_t* last_output = (uint8_t*) ((uintptr_t) output + (m - 4));
+
+ if (n >= 8) {
+ for (size_t i = 0; i < m; i += 4) {
+ size_t k = n;
+ w = (const uint8_t*) ((uintptr_t) w + input_increment);
+ if (w >= last_input) {
+ w = last_input;
+ }
+ const uint8_t* z = (const uint8_t*) ((uintptr_t) w - n);
+ const uint8_t* y = (const uint8_t*) ((uintptr_t) z - n);
+ const uint8_t* x = (const uint8_t*) ((uintptr_t) y - n);
+ while (k >= 16) {
+ const __m128i vx = _mm_loadu_si128((const __m128i*) x);
+ x += 16;
+ const __m128i vy = _mm_loadu_si128((const __m128i*) y);
+ y += 16;
+ const __m128i vz = _mm_loadu_si128((const __m128i*) z);
+ z += 16;
+ const __m128i vw = _mm_loadu_si128((const __m128i*) w);
+ w += 16;
+ const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy);
+ const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy);
+ const __m128i vzw_lo = _mm_unpacklo_epi8(vz, vw);
+ const __m128i vzw_hi = _mm_unpackhi_epi8(vz, vw);
+ __m128i vxyzw0 = _mm_unpacklo_epi16(vxy_lo, vzw_lo);
+ __m128i vxyzw1 = _mm_unpackhi_epi16(vxy_lo, vzw_lo);
+ __m128i vxyzw2 = _mm_unpacklo_epi16(vxy_hi, vzw_hi);
+ __m128i vxyzw3 = _mm_unpackhi_epi16(vxy_hi, vzw_hi);
+
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0);
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw1 = _mm_unpackhi_epi64(vxyzw1, vxyzw1);
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw2);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw2 = _mm_shufflelo_epi16(vxyzw2, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw2);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw2 = _mm_unpackhi_epi64(vxyzw2, vxyzw2);
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw2);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw2 = _mm_shufflelo_epi16(vxyzw2, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw2);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw3);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw3 = _mm_shufflelo_epi16(vxyzw3, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw3);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw3 = _mm_unpackhi_epi64(vxyzw3, vxyzw3);
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw3);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw3 = _mm_shufflelo_epi16(vxyzw3, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw3);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ k -= 16;
+ };
+ if (k >= 8) {
+ const __m128i vx = _mm_loadl_epi64((const __m128i*) x);
+ x += 8;
+ const __m128i vy = _mm_loadl_epi64((const __m128i*) y);
+ y += 8;
+ const __m128i vz = _mm_loadl_epi64((const __m128i*) z);
+ z += 8;
+ const __m128i vw = _mm_loadl_epi64((const __m128i*) w);
+ w += 8;
+ const __m128i vxy = _mm_unpacklo_epi8(vx, vy);
+ const __m128i vzw = _mm_unpacklo_epi8(vz, vw);
+ __m128i vxyzw0 = _mm_unpacklo_epi16(vxy, vzw);
+ __m128i vxyzw1 = _mm_unpackhi_epi16(vxy, vzw);
+
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0);
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw1 = _mm_unpackhi_epi64(vxyzw1, vxyzw1);
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw1);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ k -= 8;
+ }
+ if (k != 0) {
+ const size_t address_decrement = 8 - k;
+ x -= address_decrement;
+ y -= address_decrement;
+ z -= address_decrement;
+ w -= address_decrement;
+ const __m128i vshift = _mm_cvtsi32_si128(8 * address_decrement);
+
+ const __m128i vx = _mm_srl_epi64(_mm_loadl_epi64((const __m128i*) x), vshift);
+ const __m128i vy = _mm_srl_epi64(_mm_loadl_epi64((const __m128i*) y), vshift);
+ const __m128i vz = _mm_srl_epi64(_mm_loadl_epi64((const __m128i*) z), vshift);
+ const __m128i vw = _mm_srl_epi64(_mm_loadl_epi64((const __m128i*) w), vshift);
+ w += 8;
+ const __m128i vxy = _mm_unpacklo_epi8(vx, vy);
+ const __m128i vzw = _mm_unpacklo_epi8(vz, vw);
+ __m128i vxyzw0 = _mm_unpacklo_epi16(vxy, vzw);
+ __m128i vxyzw1 = _mm_unpackhi_epi16(vxy, vzw);
+
+ if (k & 4) {
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0);
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = vxyzw1;
+ }
+
+ if (k & 2) {
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2));
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0);
+ }
+ if (k & 1) {
+ *((uint32_t*) output) = _mm_cvtsi128_si32(vxyzw0);
+ output = (uint8_t*) ((uintptr_t) output + m);
+ }
+ }
+ output = (uint8_t*) ((uintptr_t) output + output_increment);
+ if (output > last_output) {
+ output = last_output;
+ }
+ }
+ } else {
+ const uint8_t* i = input;
+ uint8_t* o = output;
+ size_t k = n;
+ do {
+ size_t l = m;
+ const uint8_t* ii = i++;
+ do {
+ *o++ = *ii;
+ ii += n;
+ } while (--l != 0);
+ } while (--k != 0);
+ }
+}
diff --git a/src/xnnpack/AlignedAllocator.h b/src/xnnpack/AlignedAllocator.h
new file mode 100644
index 0000000..ee12481
--- /dev/null
+++ b/src/xnnpack/AlignedAllocator.h
@@ -0,0 +1,104 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <cstddef>
+#include <limits>
+#include <type_traits>
+#include <utility>
+
+#include <stdlib.h>
+
+template <typename T, size_t Alignment>
+class AlignedAllocator;
+
+template <size_t Alignment>
+class AlignedAllocator<void, Alignment> {
+ public:
+ typedef void* pointer;
+ typedef const void* const_pointer;
+ typedef void value_type;
+
+ template <class U>
+ struct rebind {
+ typedef AlignedAllocator<U, Alignment> other;
+ };
+};
+
+template <typename T, size_t Alignment>
+class AlignedAllocator {
+ public:
+ typedef T value_type;
+ typedef T* pointer;
+ typedef const T* const_pointer;
+ typedef T& reference;
+ typedef const T& const_reference;
+ typedef size_t size_type;
+ typedef ptrdiff_t difference_type;
+
+#if __cplusplus >= 201402L
+ typedef std::true_type propagate_on_container_move_assignment;
+#endif
+
+ template <class U>
+ struct rebind {
+ typedef AlignedAllocator<U, Alignment> other;
+ };
+
+ public:
+ inline AlignedAllocator() noexcept {}
+
+ template <class U>
+ inline AlignedAllocator(
+ const AlignedAllocator<U, Alignment>& other) noexcept {}
+
+ inline size_type max_size() const noexcept {
+ return (std::numeric_limits<size_type>::max() - size_type(Alignment)) /
+ sizeof(T);
+ }
+
+ inline pointer address(reference x) const noexcept {
+ return std::addressof(x);
+ }
+
+ inline const_pointer address(const_reference x) const noexcept {
+ return std::addressof(x);
+ }
+
+ inline pointer allocate(
+ size_type n,
+ typename AlignedAllocator<void, Alignment>::const_pointer hint = 0) {
+#if defined(__ANDROID__)
+ void* memory = memalign(Alignment, n * sizeof(T));
+ if (memory == 0) {
+#if !defined(__GNUC__) || defined(__EXCEPTIONS)
+ throw std::bad_alloc();
+#endif
+ }
+#else
+ void* memory = nullptr;
+ if (posix_memalign(&memory, Alignment, n * sizeof(T)) != 0) {
+#if !defined(__GNUC__) || defined(__EXCEPTIONS)
+ throw std::bad_alloc();
+#endif
+ }
+#endif
+ return static_cast<pointer>(memory);
+ }
+
+ inline void deallocate(pointer p, size_type n) noexcept {
+ free(static_cast<void*>(p));
+ }
+
+ template <class U, class... Args>
+ inline void construct(U* p, Args&&... args) {
+ ::new (static_cast<void*>(p)) U(std::forward<Args>(args)...);
+ }
+
+ template <class U>
+ inline void destroy(U* p) {
+ p->~U();
+ }
+};
diff --git a/src/xnnpack/allocator.h b/src/xnnpack/allocator.h
new file mode 100644
index 0000000..303aa37
--- /dev/null
+++ b/src/xnnpack/allocator.h
@@ -0,0 +1,47 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdlib.h>
+#include <string.h>
+#ifdef __ANDROID__
+ #include <malloc.h>
+#endif
+
+#include <cpuinfo.h>
+
+extern int posix_memalign(void **memptr, size_t alignment, size_t size);
+
+
+#define XNN_ALLOCATION_ALIGNMENT 16
+
+
+inline static void* xnn_allocate_memory(size_t memory_size) {
+ void* memory_ptr = NULL;
+#if CPUINFO_ARCH_ASMJS || CPUINFO_ARCH_WASM
+ memory_ptr = malloc(memory_size);
+#elif defined(__ANDROID__)
+ memory_ptr = memalign(XNN_ALLOCATION_ALIGNMENT, memory_size);
+#else
+ if (posix_memalign(&memory_ptr, XNN_ALLOCATION_ALIGNMENT, memory_size) != 0) {
+ return NULL;
+ }
+#endif
+ return memory_ptr;
+}
+
+inline static void* xnn_allocate_zero_memory(size_t memory_size) {
+ void* memory_ptr = xnn_allocate_memory(memory_size);
+ if (memory_ptr != NULL) {
+ memset(memory_ptr, 0, memory_size);
+ }
+ return memory_ptr;
+}
+
+inline static void xnn_release_memory(void* memory_ptr) {
+ free(memory_ptr);
+}
diff --git a/src/xnnpack/argmaxpool.h b/src/xnnpack/argmaxpool.h
new file mode 100644
index 0000000..5b9776d
--- /dev/null
+++ b/src/xnnpack/argmaxpool.h
@@ -0,0 +1,60 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_ARGMAXPOOL_UNIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const float** x, \
+ float* y, \
+ uint32_t* i, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_ARGMAXPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_argmaxpool_ukernel_up4__psimd)
+DECLARE_F32_ARGMAXPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_argmaxpool_ukernel_up4__scalar)
+DECLARE_F32_ARGMAXPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_argmaxpool_ukernel_up4__sse2)
+DECLARE_F32_ARGMAXPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_argmaxpool_ukernel_up9__psimd)
+DECLARE_F32_ARGMAXPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_argmaxpool_ukernel_up9__scalar)
+DECLARE_F32_ARGMAXPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_argmaxpool_ukernel_up9__sse2)
+
+
+#define DECLARE_F32_ARGMAXPOOL_MULTIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const float** x, \
+ float* ab, \
+ uint32_t* ib, \
+ float* y, \
+ uint32_t* i, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_ARGMAXPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_argmaxpool_ukernel_mp9p8q__psimd)
+DECLARE_F32_ARGMAXPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_argmaxpool_ukernel_mp9p8q__scalar)
+DECLARE_F32_ARGMAXPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_argmaxpool_ukernel_mp9p8q__sse2)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/assembly.h b/src/xnnpack/assembly.h
new file mode 100644
index 0000000..4ed7270
--- /dev/null
+++ b/src/xnnpack/assembly.h
@@ -0,0 +1,32 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#ifdef __ELF__
+ .macro BEGIN_FUNCTION name
+ .text
+ .p2align 4
+ .global \name
+ .type \name, %function
+ \name:
+ .endm
+
+ .macro END_FUNCTION name
+ .size \name, .-\name
+ .endm
+#elif defined(__MACH__)
+ .macro BEGIN_FUNCTION name
+ .text
+ .p2align 4
+ .global _\name
+ .private_extern _\name
+ _\name:
+ .endm
+
+ .macro END_FUNCTION name
+ .endm
+#endif
diff --git a/src/xnnpack/avgpool.h b/src/xnnpack/avgpool.h
new file mode 100644
index 0000000..5fd51b9
--- /dev/null
+++ b/src/xnnpack/avgpool.h
@@ -0,0 +1,96 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_AVGPOOL_MULTIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const float** x, \
+ const float* zero, \
+ float* buffer, \
+ float* y, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_f32_avgpool_params* params);
+
+DECLARE_F32_AVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_avgpool_ukernel_mp9p8q__neon)
+DECLARE_F32_AVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_avgpool_ukernel_mp9p8q__psimd)
+DECLARE_F32_AVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_avgpool_ukernel_mp9p8q__scalar)
+DECLARE_F32_AVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_avgpool_ukernel_mp9p8q__sse)
+
+
+#define DECLARE_F32_AVGPOOL_UNIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const float** x, \
+ const float* zero, \
+ float* y, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_f32_avgpool_params* params);
+
+DECLARE_F32_AVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_avgpool_ukernel_up9__neon)
+DECLARE_F32_AVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_avgpool_ukernel_up9__psimd)
+DECLARE_F32_AVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_avgpool_ukernel_up9__scalar)
+DECLARE_F32_AVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_avgpool_ukernel_up9__sse)
+
+
+#define DECLARE_Q8_AVGPOOL_MULTIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const uint8_t** x, \
+ const uint8_t* zero, \
+ int32_t* buffer, \
+ uint8_t* y, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_q8_avgpool_params* params);
+
+DECLARE_Q8_AVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_q8_avgpool_ukernel_mp9p8q__neon)
+DECLARE_Q8_AVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_q8_avgpool_ukernel_mp9p8q__scalar)
+DECLARE_Q8_AVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_q8_avgpool_ukernel_mp9p8q__sse2)
+
+
+#define DECLARE_Q8_AVGPOOL_UNIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const uint8_t** x, \
+ const uint8_t* zero, \
+ uint8_t* y, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_q8_avgpool_params* params);
+
+DECLARE_Q8_AVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_q8_avgpool_ukernel_up9__neon)
+DECLARE_Q8_AVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_q8_avgpool_ukernel_up9__scalar)
+DECLARE_Q8_AVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_q8_avgpool_ukernel_up9__sse2)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/clamp.h b/src/xnnpack/clamp.h
new file mode 100644
index 0000000..db19d28
--- /dev/null
+++ b/src/xnnpack/clamp.h
@@ -0,0 +1,49 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_CLAMP_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const float* x, \
+ float* y, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__psimd)
+DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__neon)
+DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__sse)
+DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__scalar)
+
+
+#define DECLARE_U8_CLAMP_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const uint8_t* x, \
+ uint8_t* y, \
+ const union xnn_u8_output_params* params);
+
+DECLARE_U8_CLAMP_UKERNEL_FUNCTION(xnn_u8_clamp_ukernel__neon)
+DECLARE_U8_CLAMP_UKERNEL_FUNCTION(xnn_u8_clamp_ukernel__sse2)
+DECLARE_U8_CLAMP_UKERNEL_FUNCTION(xnn_u8_clamp_ukernel__scalar)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/common.h b/src/xnnpack/common.h
new file mode 100644
index 0000000..0fc7011
--- /dev/null
+++ b/src/xnnpack/common.h
@@ -0,0 +1,67 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+
+#if defined(__GNUC__)
+ #if defined(__clang__) || (__GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ >= 5)
+ #define XNN_UNREACHABLE do { __builtin_unreachable(); } while (0)
+ #else
+ #define XNN_UNREACHABLE do { __builtin_trap(); } while (0)
+ #endif
+#elif defined(_MSC_VER)
+ #define XNN_UNREACHABLE __assume(0)
+#else
+ #define XNN_UNREACHABLE do { } while (0)
+#endif
+
+#define XNN_ALIGN(alignment) __attribute__((__aligned__(alignment)))
+
+#define XNN_COUNT_OF(array) (sizeof(array) / sizeof(0[array]))
+
+#if defined(__GNUC__)
+ #define XNN_LIKELY(condition) (__builtin_expect(!!(condition), 1))
+ #define XNN_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
+#else
+ #define XNN_LIKELY(condition) (!!(condition))
+ #define XNN_UNLIKELY(condition) (!!(condition))
+#endif
+
+// TODO - __builtin_expect_with_probability for GCC 9+
+#if defined(__clang__) && (__has_builtin(__builtin_unpredictable))
+ #define XNN_UNPREDICTABLE(condition) (__builtin_unpredictable(!!(condition)))
+#else
+ #define XNN_UNPREDICTABLE(condition) (!!(condition))
+#endif
+
+#if defined(__GNUC__)
+ #define XNN_INLINE inline __attribute__((__always_inline__))
+#else
+ #define XNN_INLINE inline
+#endif
+
+#ifndef XNN_INTERNAL
+ #if defined(__ELF__)
+ #define XNN_INTERNAL __attribute__((__visibility__("internal")))
+ #elif defined(__MACH__)
+ #define XNN_INTERNAL __attribute__((__visibility__("hidden")))
+ #else
+ #define XNN_INTERNAL
+ #endif
+#endif
+
+#ifndef XNN_PRIVATE
+ #if defined(__ELF__)
+ #define XNN_PRIVATE __attribute__((__visibility__("hidden")))
+ #elif defined(__MACH__)
+ #define XNN_PRIVATE __attribute__((__visibility__("hidden")))
+ #else
+ #define XNN_PRIVATE
+ #endif
+#endif
diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h
new file mode 100644
index 0000000..fc8693a
--- /dev/null
+++ b/src/xnnpack/compute.h
@@ -0,0 +1,709 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack.h>
+#include <xnnpack/common.h>
+#include <xnnpack/math.h>
+#include <xnnpack/params.h>
+
+
+enum xnn_parallelization_type {
+ xnn_parallelization_type_invalid = 0,
+ xnn_parallelization_type_1d,
+ xnn_parallelization_type_1d_tile_1d,
+ xnn_parallelization_type_2d,
+ xnn_parallelization_type_2d_tile_1d,
+ xnn_parallelization_type_2d_tile_2d,
+ xnn_parallelization_type_3d_tile_2d,
+ xnn_parallelization_type_4d_tile_2d,
+ xnn_parallelization_type_5d_tile_2d,
+ xnn_parallelization_type_6d_tile_2d,
+};
+
+struct compute_parameters {
+ enum xnn_parallelization_type type;
+ union {
+ pthreadpool_task_1d_t task_1d;
+ pthreadpool_task_1d_tile_1d_t task_1d_tile_1d;
+ pthreadpool_task_2d_t task_2d;
+ pthreadpool_task_2d_tile_1d_t task_2d_tile_1d;
+ pthreadpool_task_2d_tile_2d_t task_2d_tile_2d;
+ pthreadpool_task_3d_tile_2d_t task_3d_tile_2d;
+ pthreadpool_task_4d_tile_2d_t task_4d_tile_2d;
+ pthreadpool_task_5d_tile_2d_t task_5d_tile_2d;
+ pthreadpool_task_6d_tile_2d_t task_6d_tile_2d;
+ };
+ size_t range[6];
+ size_t tile[2];
+};
+
+struct gemm_context {
+ size_t k_scaled;
+ const void* a;
+ size_t a_stride;
+ const void* packed_w;
+ size_t w_stride;
+ size_t wg_stride;
+ void* c;
+ size_t cm_stride;
+ size_t cn_stride;
+ size_t cg_stride;
+ uint32_t log2_csize;
+ xnn_gemm_ukernel_function ukernel;
+ union {
+ union xnn_q8_gemm_params q8;
+ union xnn_f32_output_params f32;
+ } params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_ggemm(
+ const struct gemm_context context[restrict static 1],
+ size_t group_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size);
+
+ XNN_PRIVATE void xnn_compute_gemm(
+ const struct gemm_context context[restrict static 1],
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size);
+#endif
+
+// Context for Sparse Matrix-Dense Matrix Multiplication.
+// C [MxN] := A [MxK] * B [KxN] + bias [N]
+// A and C are dense matrices with row-major storage, B is a sparse matrix.
+struct spmm_context {
+ // N dimension of the B and C matrices.
+ // Corresponds to number of output channels in 1x1 convolution.
+ size_t n;
+ // Input matrix A.
+ const void* a;
+ // Packed bias elements and non-zero filter elements.
+ const void* packed_weights;
+ // Input pointer increments, in bytes, after each processed non-zero weight.
+ const int32_t* input_increments;
+ // Number of non-zero filter elements per each N (output channel) dimension.
+ const uint32_t* output_channel_nonzeros;
+ // Output matrix C.
+ void* c;
+ // Stride, in bytes, between matrices A corresponding to different images in batched 1x1 Convolution
+ size_t batched_a_stride;
+ // Stride, in bytes, between matrices C corresponding to different images in batched 1x1 Convolution
+ size_t batched_c_stride;
+ // Micro-kernel function pointer.
+ xnn_spmm_ukernel_function ukernel;
+ // Output activation parameters.
+ union {
+ union xnn_f32_output_params f32;
+ } params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_spmm(
+ const struct spmm_context context[restrict static 1],
+ size_t batch_index,
+ size_t mr_block_start,
+ size_t mr_block_size);
+#endif
+
+struct igemm_context {
+ size_t ks;
+ size_t ks_scaled;
+ size_t kc;
+ size_t w_stride;
+ const void** indirect_a;
+ size_t a_offset;
+ void* zero;
+ const void* packed_w;
+ void* c;
+ size_t cm_stride;
+ size_t cn_stride;
+ size_t ga_stride;
+ size_t gw_stride;
+ size_t gc_stride;
+ size_t ba_stride;
+ size_t bc_stride;
+ uint32_t log2_csize;
+ xnn_igemm_ukernel_function ukernel;
+ union {
+ union xnn_q8_gemm_params q8;
+ union xnn_f32_output_params f32;
+ } params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_gigemm(
+ const struct igemm_context context[restrict static 1],
+ size_t batch_index,
+ size_t group_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size);
+
+ XNN_PRIVATE void xnn_compute_igemm(
+ const struct igemm_context context[restrict static 1],
+ size_t batch_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size);
+#endif
+
+struct subconv_context {
+ const struct subconvolution_params* subconvolution_params;
+ size_t kc;
+ size_t a_offset;
+ void* zero;
+ size_t cx_stride;
+ size_t cy_stride;
+ size_t cn_stride;
+ size_t ga_stride;
+ size_t gw_stride;
+ size_t gc_stride;
+ size_t ba_stride;
+ size_t bc_stride;
+ uint32_t log2_csize;
+ xnn_igemm_ukernel_function ukernel;
+ union {
+ union xnn_q8_gemm_params q8;
+ union xnn_f32_output_params f32;
+ } params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_gsubconv2d(
+ const struct subconv_context context[restrict static 1],
+ size_t batch_index,
+ size_t group_index,
+ size_t subkernel_index,
+ size_t slice_y,
+ size_t slice_x_start,
+ size_t nr_block_start,
+ size_t slice_x_max,
+ size_t nr_block_size);
+
+ XNN_PRIVATE void xnn_compute_subconv2d(
+ const struct subconv_context context[restrict static 1],
+ size_t batch_index,
+ size_t subkernel_index,
+ size_t slice_y,
+ size_t slice_x_start,
+ size_t nr_block_start,
+ size_t slice_x_max,
+ size_t nr_block_size);
+#endif
+
+struct dconv2d_context {
+ size_t input_height;
+ size_t input_width;
+ const void* input;
+ size_t input_batch_stride;
+ const void* zero;
+ const void* packed_weights;
+ void* output;
+ size_t output_batch_stride;
+ size_t input_padding_top;
+ size_t output_channels;
+ size_t output_height_stride;
+ size_t output_channel_stride;
+ union {
+ xnn_conv_hwc2spchw_ukernel_function hwc2spchw_ukernel;
+ };
+ union {
+ union xnn_f32_output_params f32;
+ } params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_dconv2d_hwc2spchw(
+ const struct dconv2d_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y_start,
+ size_t output_y_slice);
+#endif
+
+struct dwconv_context {
+ size_t groups;
+ const void** indirection_buffer;
+ size_t indirection_buffer_row_stride;
+ size_t indirection_buffer_col_stride;
+ const void* packed_weights;
+ void* output;
+ size_t output_width;
+ size_t output_row_stride;
+ size_t output_col_increment;
+ union {
+ union xnn_q8_gemm_params q8;
+ union xnn_f32_output_params f32;
+ } params;
+ union {
+ xnn_dwconv_up_ukernel_function unipass_ukernel;
+ };
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_dwconv_unipass(
+ const struct dwconv_context context[restrict static 1],
+ size_t output_y);
+#endif
+
+struct dwconv2d_context {
+ size_t output_height;
+ size_t input_width;
+ const void* input;
+ size_t input_channel_stride;
+ size_t input_batch_stride;
+ const void* packed_weights;
+ size_t weights_channel_stride;
+ void* output;
+ size_t output_channel_stride;
+ size_t output_batch_stride;
+ size_t input_tuple_stride;
+ size_t output_tuple_stride;
+ size_t input_pixel_stride;
+ size_t output_pixel_stride;
+ union {
+ union xnn_f32_spchw_params f32;
+ } params;
+ union {
+ xnn_dwconv_spchw_ukernel_function spchw_ukernel;
+ };
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_dwconv2d_spchw(
+ const struct dwconv2d_context context[restrict static 1],
+ size_t batch_index,
+ size_t channel);
+#endif
+
+struct max_pooling_context {
+ const void** indirect_input;
+ size_t indirect_input_batch_stride;
+ size_t indirect_input_height_stride;
+ void* output;
+ size_t output_batch_stride;
+ size_t output_height_stride;
+ size_t output_width;
+ size_t pooling_size;
+ size_t channels;
+ size_t input_increment;
+ size_t output_increment;
+ union {
+ union xnn_u8_output_params u8;
+ union xnn_f32_output_params f32;
+ } params;
+ xnn_maxpool_ukernel_function ukernel;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_max_pooling(
+ const struct max_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y);
+#endif
+
+struct unpooling_context {
+ const void* input;
+ size_t input_height_stride;
+ size_t input_width_stride;
+ const uint32_t* index;
+ size_t index_height_stride;
+ size_t index_width_stride;
+ void** indirect_output;
+ size_t indirect_output_height_stride;
+ size_t indirect_output_width_stride;
+ size_t pooling_size;
+ size_t channels;
+ uint32_t fill_value;
+ xnn_unpool_ukernel_function ukernel;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_unpooling(
+ const struct unpooling_context context[restrict static 1],
+ size_t input_y,
+ size_t input_x);
+#endif
+
+struct argmax_pooling_context {
+ const void** indirect_input;
+ size_t indirect_input_batch_stride;
+ size_t indirect_input_height_stride;
+ void* output;
+ size_t output_batch_stride;
+ size_t output_height_stride;
+ size_t output_width;
+ uint32_t* index;
+ size_t index_batch_stride;
+ size_t index_height_stride;
+ size_t pooling_size;
+ size_t channels;
+ size_t input_increment;
+ size_t output_increment;
+ union {
+ union xnn_f32_output_params f32;
+ } params;
+ union {
+ xnn_argmaxpool_up_ukernel_function unipass_ukernel;
+ xnn_argmaxpool_mp_ukernel_function multipass_ukernel;
+ };
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_argmax_pooling_unipass(
+ const struct argmax_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y);
+
+ XNN_PRIVATE void xnn_compute_argmax_pooling_multipass(
+ const struct argmax_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y);
+#endif
+
+struct average_pooling_context {
+ const void** indirect_input;
+ size_t indirect_input_batch_stride;
+ size_t indirect_input_height_stride;
+ void* output;
+ size_t output_batch_stride;
+ size_t output_height_stride;
+ size_t output_width;
+ size_t pooling_size;
+ size_t channels;
+ const void* zero;
+ size_t input_increment;
+ size_t output_increment;
+ union {
+ union xnn_q8_avgpool_params q8;
+ union xnn_f32_avgpool_params f32;
+ } params;
+ union {
+ xnn_avgpool_up_ukernel_function unipass_ukernel;
+ xnn_avgpool_mp_ukernel_function multipass_ukernel;
+ };
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_average_pooling_unipass(
+ const struct average_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y);
+
+ XNN_PRIVATE void xnn_compute_average_pooling_multipass(
+ const struct average_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y);
+#endif
+
+struct pixelwise_average_pooling_context {
+ const void** indirect_input;
+ size_t indirect_input_batch_stride;
+ size_t indirect_input_height_stride;
+ const void* pixelwise_buffer;
+ size_t pixelwise_buffer_height_stride;
+ void* output;
+ size_t output_batch_stride;
+ size_t output_height_stride;
+ size_t output_width;
+ size_t pooling_size;
+ size_t channels;
+ const void* zero;
+ size_t input_increment;
+ size_t output_increment;
+ union {
+ union xnn_u8_output_params u8;
+ union xnn_f32_output_params f32;
+ } params;
+ union {
+ xnn_pavgpool_up_ukernel_function unipass_ukernel;
+ xnn_pavgpool_mp_ukernel_function multipass_ukernel;
+ };
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass(
+ const struct pixelwise_average_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y);
+
+ XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass(
+ const struct pixelwise_average_pooling_context context[restrict static 1],
+ size_t batch_index,
+ size_t output_y);
+#endif
+
+struct global_average_pooling_context {
+ const void* input;
+ const void* zero;
+ size_t input_pixel_stride;
+ size_t input_batch_stride;
+ size_t input_elements;
+ size_t channels;
+ void* output;
+ size_t output_batch_stride;
+ union {
+ union xnn_q8_avgpool_params q8;
+ union xnn_f32_avgpool_params f32;
+ } params;
+ union {
+ xnn_gavgpool_up_ukernel_function unipass_ukernel;
+ xnn_gavgpool_mp_ukernel_function multipass_ukernel;
+ };
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_global_average_pooling_unipass(
+ const struct global_average_pooling_context context[restrict static 1],
+ size_t batch_index);
+
+ XNN_PRIVATE void xnn_compute_global_average_pooling_multipass(
+ const struct global_average_pooling_context context[restrict static 1],
+ size_t batch_index);
+#endif
+
+struct global_average_pooling_spnchw_context {
+ size_t input_elements;
+ const void* input;
+ size_t input_channel_stride;
+ size_t input_batch_stride;
+ void* output;
+ size_t output_channel_stride;
+ size_t output_batch_stride;
+ xnn_gavgpool_spchw_ukernel_function ukernel;
+ union {
+ union xnn_f32_gavgpool_params f32;
+ } params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_global_average_pooling_spnchw(
+ const struct global_average_pooling_spnchw_context context[restrict static 1],
+ size_t batch_index,
+ size_t channels_start,
+ size_t channels_slice);
+#endif
+
+struct add_strided_context {
+ size_t n;
+ const void* a;
+ size_t a_stride;
+ const void* b;
+ size_t b_stride;
+ const void* y;
+ size_t y_stride;
+ union {
+ union xnn_q8_add_params q8;
+ union xnn_f32_output_params f32;
+ } params;
+ xnn_vadd_ukernel_function ukernel;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_add_strided(
+ const struct add_strided_context context[restrict static 1],
+ size_t batch_index,
+ size_t batch_range);
+#endif
+
+struct add_contiguous_context {
+ const void* a;
+ const void* b;
+ void* y;
+ union {
+ union xnn_q8_add_params q8;
+ union xnn_f32_output_params f32;
+ } params;
+ xnn_vadd_ukernel_function ukernel;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_add_contiguous(
+ const struct add_contiguous_context context[restrict static 1],
+ size_t offset,
+ size_t size);
+#endif
+
+struct channel_shuffle_context {
+ const void* x;
+ size_t x_stride;
+ void* y;
+ size_t y_stride;
+ size_t n;
+ size_t m;
+ union {
+ xnn_zipc_ukernel_function fixed_ukernel;
+ xnn_zipv_ukernel_function variable_ukernel;
+ };
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_channel_shuffle_fixed(
+ const struct channel_shuffle_context context[restrict static 1],
+ size_t index);
+
+ XNN_PRIVATE void xnn_compute_channel_shuffle_variable(
+ const struct channel_shuffle_context context[restrict static 1],
+ size_t index);
+#endif
+
+struct lut_strided_context {
+ size_t n;
+ const void* x;
+ size_t x_stride;
+ const void* t;
+ void* y;
+ size_t y_stride;
+ xnn_x8_lut_ukernel_function ukernel;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_lut_strided(
+ const struct lut_strided_context context[restrict static 1],
+ size_t batch_index);
+#endif
+
+struct lut_contiguous_context {
+ const void* x;
+ size_t x_stride;
+ const void* t;
+ void* y;
+ size_t y_stride;
+ xnn_x8_lut_ukernel_function ukernel;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_lut_contiguous(
+ const struct lut_contiguous_context context[restrict static 1],
+ size_t offset,
+ size_t size);
+#endif
+
+struct univector_strided_context {
+ size_t n;
+ const void* x;
+ size_t x_stride;
+ void* y;
+ size_t y_stride;
+ xnn_univector_ukernel_function ukernel;
+ union {
+ union xnn_u8_output_params u8_output;
+ union xnn_f32_output_params f32_output;
+ union xnn_f32_hswish_params f32_hswish;
+ } params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_univector_strided(
+ const struct univector_strided_context context[restrict static 1],
+ size_t batch_index,
+ size_t batch_range);
+#endif
+
+struct univector_contiguous_context {
+ const void* x;
+ size_t x_stride;
+ void* y;
+ size_t y_stride;
+ xnn_univector_ukernel_function ukernel;
+ union {
+ union xnn_u8_output_params u8_output;
+ union xnn_f32_output_params f32_output;
+ union xnn_f32_hswish_params f32_hswish;
+ } params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_univector_contiguous(
+ const struct univector_contiguous_context context[restrict static 1],
+ size_t offset,
+ size_t size);
+#endif
+
+struct prelu_context {
+ size_t n;
+ const void* x;
+ size_t x_stride;
+ const void* w;
+ void* y;
+ size_t y_stride;
+ xnn_prelu_ukernel_function ukernel;
+ union xnn_f32_output_params params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_prelu(
+ const struct prelu_context context[restrict static 1],
+ size_t batch_start,
+ size_t batch_range);
+#endif
+
+struct vmulcaddc_context {
+ size_t n;
+ const void* x;
+ size_t x_stride;
+ const void* w;
+ void* y;
+ size_t y_stride;
+ xnn_vmulcaddc_ukernel_function ukernel;
+ union {
+ union xnn_f32_output_params f32;
+ } params;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_vmulcaddc(
+ const struct vmulcaddc_context context[restrict static 1],
+ size_t batch_start,
+ size_t batch_size);
+#endif
+
+struct channel_pad_context {
+ size_t n;
+ size_t l;
+ size_t r;
+ uint32_t c;
+ const void* x;
+ size_t x_stride;
+ void* y;
+ size_t y_stride;
+ xnn_pad_ukernel_function ukernel;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_channel_pad(
+ const struct channel_pad_context context[restrict static 1],
+ size_t batch_start,
+ size_t batch_range);
+#endif
+
+struct u8_softargmax_context {
+ size_t n;
+ const uint8_t* x;
+ size_t x_stride;
+ const uint32_t* t;
+ uint8_t* y;
+ size_t y_stride;
+ xnn_u8_rmax_ukernel_function rmax_ukernel;
+ xnn_u8_lut32norm_ukernel_function lut_norm_ukernel;
+};
+
+#ifndef __cplusplus
+ XNN_PRIVATE void xnn_compute_u8_softargmax(
+ const struct u8_softargmax_context context[restrict static 1],
+ size_t batch_index);
+#endif
diff --git a/src/xnnpack/conv.h b/src/xnnpack/conv.h
new file mode 100644
index 0000000..c1bdec3
--- /dev/null
+++ b/src/xnnpack/conv.h
@@ -0,0 +1,63 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_CONV_HWC_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t input_height, \
+ size_t input_width, \
+ size_t output_y_start, \
+ size_t output_y_end, \
+ const float* input, \
+ const float* zero, \
+ const float* weights, \
+ float* output, \
+ size_t input_padding_top, \
+ size_t output_channels, \
+ size_t output_height_stride, \
+ size_t output_width_stride, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_CONV_HWC_UKERNEL_FUNCTION(xnn_f32_conv_hwc_ukernel_3x3s2p1c3x8__neonfma_2x2)
+DECLARE_F32_CONV_HWC_UKERNEL_FUNCTION(xnn_f32_conv_hwc_ukernel_3x3s2p1c3x4__neonfma_2x2)
+
+
+#define DECLARE_F32_CONV_HWC2SPCHW_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t input_height, \
+ size_t input_width, \
+ size_t output_y_start, \
+ size_t output_y_end, \
+ const float* input, \
+ const float* zero, \
+ const float* weights, \
+ float* output, \
+ size_t input_padding_top, \
+ size_t output_channels, \
+ size_t output_height_stride, \
+ size_t output_channel_stride, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_CONV_HWC2SPCHW_UKERNEL_FUNCTION(xnn_f32_conv_hwc2spchw_ukernel_3x3s2p1c3x4__neonfma_2x2)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/dwconv.h b/src/xnnpack/dwconv.h
new file mode 100644
index 0000000..dc52a61
--- /dev/null
+++ b/src/xnnpack/dwconv.h
@@ -0,0 +1,88 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t channels, \
+ size_t output_width, \
+ const float** input, \
+ const float* weights, \
+ float* output, \
+ size_t input_stride, \
+ size_t output_increment, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up1x25__scalar)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up1x4__scalar)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up1x9__scalar)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x25__psimd)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x25__sse)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x4__psimd)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x4__sse)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x9__aarch64_neonfma)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x9__aarch64_neonfma_cortex_a55)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x9__neon)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x9__neonfma)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x9__psimd)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up4x9__sse)
+DECLARE_F32_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_f32_dwconv_ukernel_up8x9__neonfma)
+
+
+#define DECLARE_Q8_DWCONV_UNIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t channels, \
+ size_t output_width, \
+ const uint8_t** input, \
+ const void* weights, \
+ uint8_t* output, \
+ size_t input_stride, \
+ size_t output_increment, \
+ const union xnn_q8_gemm_params* params);
+
+DECLARE_Q8_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_q8_dwconv_ukernel_up1x9__scalar)
+DECLARE_Q8_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_q8_dwconv_ukernel_up8x9__aarch32_neon)
+DECLARE_Q8_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_q8_dwconv_ukernel_up8x9__neon)
+DECLARE_Q8_DWCONV_UNIPASS_UKERNEL_FUNCTION(xnn_q8_dwconv_ukernel_up8x9__sse2)
+
+
+#define DECLARE_F32_DWCONV_SPCHW_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t m, \
+ size_t n, \
+ const float* input, \
+ const float* weights, \
+ float* output, \
+ size_t input_tuple_stride, \
+ size_t output_tuple_stride, \
+ size_t input_height_stride, \
+ size_t output_height_stride, \
+ const union xnn_f32_spchw_params* params);
+
+DECLARE_F32_DWCONV_SPCHW_UKERNEL_FUNCTION(xnn_f32_dwconv_spchw_ukernel_3x3p1__neonfma)
+DECLARE_F32_DWCONV_SPCHW_UKERNEL_FUNCTION(xnn_f32_dwconv_spchw_ukernel_5x5p2__neonfma)
+DECLARE_F32_DWCONV_SPCHW_UKERNEL_FUNCTION(xnn_f32_dwconv_spchw_ukernel_3x3p1__sse)
+DECLARE_F32_DWCONV_SPCHW_UKERNEL_FUNCTION(xnn_f32_dwconv_spchw_ukernel_3x3s2p1__neonfma)
+DECLARE_F32_DWCONV_SPCHW_UKERNEL_FUNCTION(xnn_f32_dwconv_spchw_ukernel_5x5s2p2__neonfma)
+DECLARE_F32_DWCONV_SPCHW_UKERNEL_FUNCTION(xnn_f32_dwconv_spchw_ukernel_3x3s2p1__sse)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/gavgpool.h b/src/xnnpack/gavgpool.h
new file mode 100644
index 0000000..b567196
--- /dev/null
+++ b/src/xnnpack/gavgpool.h
@@ -0,0 +1,99 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_GAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t m, \
+ size_t n, \
+ const float* x, \
+ size_t x_stride, \
+ const float* zero, \
+ float* buffer, \
+ float* y, \
+ const union xnn_f32_avgpool_params* params);
+
+DECLARE_F32_GAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_ukernel_mp7p7q__neon)
+DECLARE_F32_GAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_ukernel_mp7p7q__psimd)
+DECLARE_F32_GAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_ukernel_mp7p7q__scalar)
+DECLARE_F32_GAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_ukernel_mp7p7q__sse)
+
+
+#define DECLARE_F32_GAVGPOOL_UNIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t m, \
+ size_t n, \
+ const float* x, \
+ size_t x_stride, \
+ const float* zero, \
+ float* y, \
+ const union xnn_f32_avgpool_params* params);
+
+DECLARE_F32_GAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_ukernel_up7__neon)
+DECLARE_F32_GAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_ukernel_up7__psimd)
+DECLARE_F32_GAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_ukernel_up7__scalar)
+DECLARE_F32_GAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_gavgpool_ukernel_up7__sse)
+
+#define DECLARE_Q8_GAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t m, \
+ size_t n, \
+ const uint8_t* x, \
+ size_t x_stride, \
+ const uint8_t* zero, \
+ int32_t* buffer, \
+ uint8_t* y, \
+ const union xnn_q8_avgpool_params* params);
+
+DECLARE_Q8_GAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_q8_gavgpool_ukernel_mp7p7q__neon)
+DECLARE_Q8_GAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_q8_gavgpool_ukernel_mp7p7q__scalar)
+DECLARE_Q8_GAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_q8_gavgpool_ukernel_mp7p7q__sse2)
+
+
+#define DECLARE_Q8_GAVGPOOL_UNIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t m, \
+ size_t n, \
+ const uint8_t* x, \
+ size_t x_stride, \
+ const uint8_t* zero, \
+ uint8_t* y, \
+ const union xnn_q8_avgpool_params* params);
+
+DECLARE_Q8_GAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_q8_gavgpool_ukernel_up7__neon)
+DECLARE_Q8_GAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_q8_gavgpool_ukernel_up7__scalar)
+DECLARE_Q8_GAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_q8_gavgpool_ukernel_up7__sse2)
+
+
+#define DECLARE_F32_GAVGPOOL_SPCHW_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t elements, \
+ size_t channels, \
+ const float* input, \
+ float* output, \
+ const union xnn_f32_gavgpool_params* params);
+
+DECLARE_F32_GAVGPOOL_SPCHW_UKERNEL_FUNCTION(xnn_f32_gavgpool_spchw_ukernel__neon_x4)
+DECLARE_F32_GAVGPOOL_SPCHW_UKERNEL_FUNCTION(xnn_f32_gavgpool_spchw_ukernel__sse_x4)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h
new file mode 100644
index 0000000..27f591d
--- /dev/null
+++ b/src/xnnpack/gemm.h
@@ -0,0 +1,189 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_GEMM_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t mr, \
+ size_t nr, \
+ size_t k, \
+ const float* a, \
+ size_t a_stride, \
+ const float* w, \
+ float* c, \
+ size_t cm_stride, \
+ size_t cn_stride, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x4__scalar)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8__neon_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8__neonfma_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8__psimd_loadsplat)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8__psimd_splat)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8__sse_dup)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8__sse_load1)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8s4__psimd)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_1x8s4__sse)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_2x4__scalar)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x12__neon_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x12__neonfma_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x2__neon_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x2__neonfma_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x2__scalar)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x4__scalar)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a57)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__neon_ld128)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__neon_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__neonfma_ld128)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__neonfma_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__psimd_loadsplat)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__psimd_splat)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__sse_dup)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8__sse_load1)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8s4__psimd)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_4x8s4__sse)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_5x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_5x8__neon_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_5x8__neonfma_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a57)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a73)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8__neon_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8__neonfma_ld64)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8__psimd_loadsplat)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8__psimd_splat)
+DECLARE_F32_GEMM_UKERNEL_FUNCTION(xnn_f32_gemm_ukernel_6x8s4__psimd)
+
+#define DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t mr, \
+ size_t nr, \
+ size_t k, \
+ const float* a, \
+ size_t a_stride, \
+ const float* w, \
+ float* c, \
+ size_t cm_stride, \
+ size_t cn_stride, \
+ const float* acc, \
+ const union xnn_f32_output_params* params);
+
+
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x12__aarch64_neonfma_cortex_a53)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x4__scalar)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8__aarch64_neonfma_cortex_a57)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8__neon_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8__neonfma_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8__psimd_loadsplat)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8__psimd_splat)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8__sse_dup)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8__sse_load1)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8s4__psimd)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_1x8s4__sse)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_2x4__scalar)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x12__aarch64_neonfma_cortex_a53)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x12__neon_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x12__neonfma_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x4__scalar)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a57)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld128)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__aarch64_neonfma_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__neon_ld128)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__neon_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__neonfma_ld128)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__neonfma_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__psimd_loadsplat)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__psimd_splat)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__sse_dup)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8__sse_load1)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8s4__psimd)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_4x8s4__sse)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_5x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_5x8__neon_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_5x8__neonfma_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a57)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a73)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8__aarch64_neonfma_ld128)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8__neon_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8__neonfma_ld64)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8__psimd_loadsplat)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8__psimd_splat)
+DECLARE_F32_GEMMINC_UKERNEL_FUNCTION(xnn_f32_gemminc_ukernel_6x8s4__psimd)
+
+
+#define DECLARE_F16_GEMM_UKERNEL_FUNCTION(fn_name) \
+ void fn_name( \
+ size_t mr, \
+ size_t nr, \
+ size_t k, \
+ const void* a, \
+ size_t a_stride, \
+ const void* w, \
+ void* c, \
+ size_t cm_stride, \
+ size_t cn_stride, \
+ const struct xnn_f16_output_params* params);
+
+DECLARE_F16_GEMM_UKERNEL_FUNCTION(xnn_f16_gemm_ukernel_4x8__neonfp16arith_ld64)
+DECLARE_F16_GEMM_UKERNEL_FUNCTION(xnn_f16_gemm_ukernel_6x8__neonfp16arith_ld64)
+DECLARE_F16_GEMM_UKERNEL_FUNCTION(xnn_f16_gemm_ukernel_8x8__neonfp16arith_ld64)
+
+
+#define DECLARE_Q8_GEMM_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t mr, \
+ size_t nr, \
+ size_t k, \
+ const uint8_t* a, \
+ size_t a_stride, \
+ const void* w, \
+ uint8_t* c, \
+ size_t cm_stride, \
+ size_t cn_stride, \
+ const union xnn_q8_gemm_params* params);
+
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_2x2__scalar)
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_2x4c8__neon)
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_2x4c8__sse2)
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_3x3c8__neon)
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_4x4c2__sse2)
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_4x8__aarch32_neon)
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_4x8__neon)
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_6x4__neon)
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_8x8__aarch64_neon)
+DECLARE_Q8_GEMM_UKERNEL_FUNCTION(xnn_q8_gemm_ukernel_8x8__neon)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/hswish.h b/src/xnnpack/hswish.h
new file mode 100644
index 0000000..8d0ab93
--- /dev/null
+++ b/src/xnnpack/hswish.h
@@ -0,0 +1,35 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_HSWISH_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const float* x, \
+ float* y, \
+ const union xnn_f32_hswish_params* params);
+
+DECLARE_F32_HSWISH_UKERNEL_FUNCTION(xnn_f32_hswish_ukernel__psimd)
+DECLARE_F32_HSWISH_UKERNEL_FUNCTION(xnn_f32_hswish_ukernel__neon)
+DECLARE_F32_HSWISH_UKERNEL_FUNCTION(xnn_f32_hswish_ukernel__neonfma)
+DECLARE_F32_HSWISH_UKERNEL_FUNCTION(xnn_f32_hswish_ukernel__sse)
+DECLARE_F32_HSWISH_UKERNEL_FUNCTION(xnn_f32_hswish_ukernel__scalar)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/igemm.h b/src/xnnpack/igemm.h
new file mode 100644
index 0000000..4d30c6f
--- /dev/null
+++ b/src/xnnpack/igemm.h
@@ -0,0 +1,105 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_IGEMM_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t mr, \
+ size_t nr, \
+ size_t kc, \
+ size_t ks, \
+ const float** a, \
+ const float* w, \
+ float* c, \
+ size_t cm_stride, \
+ size_t cn_stride, \
+ size_t a_offset, \
+ const float* zero, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x12__aarch64_neonfma_cortex_a53)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x4__scalar)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8__neon_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8__psimd_loadsplat)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8__psimd_splat)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8__sse_dup)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8__sse_load1)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8s4__psimd)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_1x8s4__sse)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_2x4__scalar)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x12__aarch64_neonfma_cortex_a53)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x12__neon_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x12__neonfma_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x2__neon_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x2__neonfma_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x2__scalar)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x2c4__psimd)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x2c4__sse)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x4__neon_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x4__neonfma_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x4__scalar)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__neon_ld128)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__neon_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__neonfma_ld128)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__neonfma_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__psimd_loadsplat)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__psimd_splat)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__sse_dup)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8__sse_load1)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8s4__psimd)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_4x8s4__sse)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_5x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a57)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a73)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a75)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8__neon_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8__neonfma_ld64)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8__psimd_loadsplat)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8__psimd_splat)
+DECLARE_F32_IGEMM_UKERNEL_FUNCTION(xnn_f32_igemm_ukernel_6x8s4__psimd)
+
+
+#define DECLARE_Q8_IGEMM_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t mr, \
+ size_t nr, \
+ size_t kc, \
+ size_t ks, \
+ const uint8_t** a, \
+ const void* w, \
+ uint8_t* c, \
+ size_t cm_stride, \
+ size_t cn_stride, \
+ size_t a_offset, \
+ const uint8_t* zero, \
+ const union xnn_q8_gemm_params* params);
+
+DECLARE_Q8_IGEMM_UKERNEL_FUNCTION(xnn_q8_igemm_ukernel_2x2__scalar)
+DECLARE_Q8_IGEMM_UKERNEL_FUNCTION(xnn_q8_igemm_ukernel_4x4c2__sse2)
+DECLARE_Q8_IGEMM_UKERNEL_FUNCTION(xnn_q8_igemm_ukernel_4x8__neon)
+DECLARE_Q8_IGEMM_UKERNEL_FUNCTION(xnn_q8_igemm_ukernel_8x8__neon)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/im2col.h b/src/xnnpack/im2col.h
new file mode 100644
index 0000000..07323e3
--- /dev/null
+++ b/src/xnnpack/im2col.h
@@ -0,0 +1,37 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+XNN_INTERNAL void xnn_im2col_conv2d(
+ size_t output_height,
+ size_t output_width,
+ size_t kernel_height,
+ size_t kernel_width,
+ size_t subsampling_height,
+ size_t subsampling_width,
+ size_t dilation_height,
+ size_t dilation_width,
+ size_t input_width,
+ size_t input_padding_top,
+ size_t input_padding_left,
+ size_t group_input_channels_in_bytes,
+ size_t input_pixel_stride_in_bytes,
+ const void* input,
+ void* output);
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/indirection.h b/src/xnnpack/indirection.h
new file mode 100644
index 0000000..60be1f6
--- /dev/null
+++ b/src/xnnpack/indirection.h
@@ -0,0 +1,57 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+XNN_INTERNAL void xnn_indirection_init_conv2d(
+ xnn_operator_t op,
+ size_t output_tile_size,
+ uint32_t log2_element_size);
+
+XNN_INTERNAL void xnn_indirection_init_dwconv2d(
+ xnn_operator_t op,
+ size_t batch_start,
+ size_t step_height,
+ size_t step_width,
+ uint32_t log2_element_size);
+
+XNN_INTERNAL void xnn_indirection_init_deconv2d(
+ xnn_operator_t op,
+ size_t output_tile_size,
+ uint32_t log2_element_size);
+
+XNN_INTERNAL void xnn_indirection_init_subconv2d(
+ xnn_operator_t op,
+ size_t output_tile_size,
+ uint32_t log2_element_size);
+
+XNN_INTERNAL void xnn_indirection_init_maxpool2d(
+ xnn_operator_t op,
+ size_t batch_start,
+ size_t step_height,
+ size_t step_width,
+ uint32_t log2_element_size);
+
+XNN_INTERNAL void xnn_indirection_init_unpool2d(
+ xnn_operator_t op,
+ size_t batch_start,
+ uint32_t log2_element_size);
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/isa-checks.h b/src/xnnpack/isa-checks.h
new file mode 100644
index 0000000..0bdf97c
--- /dev/null
+++ b/src/xnnpack/isa-checks.h
@@ -0,0 +1,79 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <cpuinfo.h>
+
+
+#if CPUINFO_ARCH_PNACL || CPUINFO_ARCH_WASMSIMD
+ #define TEST_REQUIRES_PSIMD
+#else
+ #define TEST_REQUIRES_PSIMD \
+ do { \
+ if (!cpuinfo_initialize() || !(cpuinfo_has_arm_neon() || cpuinfo_has_x86_sse2())) { \
+ GTEST_SKIP(); \
+ } \
+ } while (0)
+#endif
+
+#define TEST_REQUIRES_X86_SSE \
+ do { \
+ if (!cpuinfo_initialize() || !cpuinfo_has_x86_sse()) { \
+ GTEST_SKIP(); \
+ } \
+ } while (0)
+
+#define TEST_REQUIRES_X86_SSE2 \
+ do { \
+ if (!cpuinfo_initialize() || !cpuinfo_has_x86_sse2()) { \
+ GTEST_SKIP(); \
+ } \
+ } while (0)
+
+#define TEST_REQUIRES_X86_AVX \
+ do { \
+ if (!cpuinfo_initialize() || !cpuinfo_has_x86_avx()) { \
+ GTEST_SKIP(); \
+ } \
+ } while (0)
+
+#define TEST_REQUIRES_X86_AVX2 \
+ do { \
+ if (!cpuinfo_initialize() || !cpuinfo_has_x86_avx2()) { \
+ GTEST_SKIP(); \
+ } \
+ } while (0)
+
+#define TEST_REQUIRES_X86_AVX512F \
+ do { \
+ if (!cpuinfo_initialize() || !cpuinfo_has_x86_avx512f()) { \
+ GTEST_SKIP(); \
+ } \
+ } while (0)
+
+#define TEST_REQUIRES_ARM_NEON \
+ do { \
+ if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon()) { \
+ GTEST_SKIP(); \
+ } \
+ } while (0)
+
+#define TEST_REQUIRES_ARM_NEON_FMA \
+ do { \
+ if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon_fma()) { \
+ GTEST_SKIP(); \
+ } \
+ } while (0)
+
+#define TEST_REQUIRES_ARM_NEON_FP16_ARITH \
+ do { \
+ if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon_fp16_arith()) { \
+ GTEST_SKIP(); \
+ } \
+ } while (0)
diff --git a/src/xnnpack/log.h b/src/xnnpack/log.h
new file mode 100644
index 0000000..9eb5abf
--- /dev/null
+++ b/src/xnnpack/log.h
@@ -0,0 +1,23 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <inttypes.h>
+
+#include <clog.h>
+
+#ifndef XNN_LOG_LEVEL
+#define XNN_LOG_LEVEL CLOG_DEBUG
+#endif
+
+CLOG_DEFINE_LOG_DEBUG(xnn_log_debug, "XNNPACK", XNN_LOG_LEVEL);
+CLOG_DEFINE_LOG_INFO(xnn_log_info, "XNNPACK", XNN_LOG_LEVEL);
+CLOG_DEFINE_LOG_WARNING(xnn_log_warning, "XNNPACK", XNN_LOG_LEVEL);
+CLOG_DEFINE_LOG_ERROR(xnn_log_error, "XNNPACK", XNN_LOG_LEVEL);
+CLOG_DEFINE_LOG_FATAL(xnn_log_fatal, "XNNPACK", XNN_LOG_LEVEL);
diff --git a/src/xnnpack/lut.h b/src/xnnpack/lut.h
new file mode 100644
index 0000000..49b0ec4
--- /dev/null
+++ b/src/xnnpack/lut.h
@@ -0,0 +1,44 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_X8_LUT_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const uint8_t* x, \
+ const uint8_t* t, \
+ uint8_t* y);
+
+DECLARE_X8_LUT_UKERNEL_FUNCTION(xnn_x8_lut_ukernel__scalar)
+
+
+#define DECLARE_U8_LUT32NORM_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const uint8_t* x, \
+ const uint32_t* t, \
+ uint8_t* y);
+
+DECLARE_U8_LUT32NORM_UKERNEL_FUNCTION(xnn_u8_lut32norm_ukernel__scalar)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/math.h b/src/xnnpack/math.h
new file mode 100644
index 0000000..60e46dc
--- /dev/null
+++ b/src/xnnpack/math.h
@@ -0,0 +1,64 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <assert.h>
+
+inline static size_t min(size_t a, size_t b) {
+ return a < b ? a : b;
+}
+
+inline static size_t max(size_t a, size_t b) {
+ return a > b ? a : b;
+}
+
+inline static size_t doz(size_t a, size_t b) {
+ return a >= b ? a - b : 0;
+}
+
+inline static size_t divide_round_up(size_t n, size_t q) {
+ return n % q == 0 ? n / q : n / q + 1;
+}
+
+inline static size_t round_up(size_t n, size_t q) {
+ return divide_round_up(n, q) * q;
+}
+
+inline static size_t round_down_po2(size_t n, size_t q) {
+ assert(q != 0);
+ assert((q & (q - 1)) == 0);
+ return n & -q;
+}
+
+inline static size_t round_up_po2(size_t n, size_t q) {
+ return round_down_po2(n + q - 1, q);
+}
+
+inline static size_t subtract_modulo(size_t a, size_t b, size_t m) {
+ assert(a < m);
+ assert(b < m);
+ return a >= b ? a - b : a - b + m;
+}
+
+inline static float math_min_f32(float a, float b) {
+ #if defined(__wasm__)
+ return __builtin_wasm_min_f32(a, b);
+ #else
+ return a < b ? a : b;
+ #endif
+}
+
+inline static float math_max_f32(float a, float b) {
+ #if defined(__wasm__)
+ return __builtin_wasm_max_f32(a, b);
+ #else
+ return a > b ? a : b;
+ #endif
+}
diff --git a/src/xnnpack/maxpool.h b/src/xnnpack/maxpool.h
new file mode 100644
index 0000000..1c134d7
--- /dev/null
+++ b/src/xnnpack/maxpool.h
@@ -0,0 +1,56 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_MAXPOOL_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const float** x, \
+ float* y, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_MAXPOOL_UKERNEL_FUNCTION(xnn_f32_maxpool_ukernel_9p8q__psimd)
+DECLARE_F32_MAXPOOL_UKERNEL_FUNCTION(xnn_f32_maxpool_ukernel_9p8q__scalar)
+DECLARE_F32_MAXPOOL_UKERNEL_FUNCTION(xnn_f32_maxpool_ukernel_9p8q__sse)
+
+
+#define DECLARE_U8_MAXPOOL_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const uint8_t** x, \
+ uint8_t* y, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_u8_output_params* params);
+
+DECLARE_U8_MAXPOOL_UKERNEL_FUNCTION(xnn_u8_maxpool_ukernel_9p8q__neon)
+DECLARE_U8_MAXPOOL_UKERNEL_FUNCTION(xnn_u8_maxpool_ukernel_9p8q__sse2)
+DECLARE_U8_MAXPOOL_UKERNEL_FUNCTION(xnn_u8_maxpool_ukernel_9p8q__scalar)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
new file mode 100644
index 0000000..a34d6fd
--- /dev/null
+++ b/src/xnnpack/operator.h
@@ -0,0 +1,275 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <pthreadpool.h>
+
+#include <xnnpack/requantization.h>
+#include <xnnpack/compute.h>
+
+
+enum xnn_ukernel_type {
+ xnn_ukernel_type_none = 0,
+ xnn_ukernel_type_add,
+ xnn_ukernel_type_argmax_pooling,
+ xnn_ukernel_type_average_pooling,
+ xnn_ukernel_type_channel_shuffle,
+ xnn_ukernel_type_clamp,
+ xnn_ukernel_type_igemm,
+ xnn_ukernel_type_dconv2d_hwc2spchw,
+ xnn_ukernel_type_dwconv,
+ xnn_ukernel_type_gemm,
+ xnn_ukernel_type_global_average_pooling,
+ xnn_ukernel_type_hswish,
+ xnn_ukernel_type_lut,
+ xnn_ukernel_type_max_pooling,
+ xnn_ukernel_type_pad,
+ xnn_ukernel_type_pixelwise_average_pooling,
+ xnn_ukernel_type_prelu,
+ xnn_ukernel_type_softargmax,
+ xnn_ukernel_type_spmm,
+ xnn_ukernel_type_subconv2d,
+ xnn_ukernel_type_unpooling,
+ xnn_ukernel_type_vmulcaddc,
+};
+
+enum xnn_operator_type {
+ xnn_operator_type_none = 0,
+ xnn_operator_type_add_f32,
+ xnn_operator_type_add_q8,
+ xnn_operator_type_argmax_pooling_f32,
+ xnn_operator_type_average_pooling_f32,
+ xnn_operator_type_average_pooling_q8,
+ xnn_operator_type_channel_pad_x32,
+ xnn_operator_type_channel_shuffle_x8,
+ xnn_operator_type_channel_shuffle_x32,
+ xnn_operator_type_clamp_f32,
+ xnn_operator_type_clamp_u8,
+ xnn_operator_type_convolution_f32,
+ xnn_operator_type_convolution_spnchw_f32,
+ xnn_operator_type_convolution_q8,
+ xnn_operator_type_deconvolution_f32,
+ xnn_operator_type_deconvolution_q8,
+ xnn_operator_type_fully_connected_f32,
+ xnn_operator_type_fully_connected_q8,
+ xnn_operator_type_global_average_pooling_f32,
+ xnn_operator_type_global_average_pooling_q8,
+ xnn_operator_type_global_average_pooling_spnchw_f32,
+ xnn_operator_type_hswish_f32,
+ xnn_operator_type_leaky_relu_q8,
+ xnn_operator_type_max_pooling_f32,
+ xnn_operator_type_max_pooling_u8,
+ xnn_operator_type_prelu_f32,
+ xnn_operator_type_sigmoid_q8,
+ xnn_operator_type_softargmax_q8,
+ xnn_operator_type_unpooling_x32,
+};
+
+struct xnn_ukernel_dconv2d {
+ union {
+ xnn_conv_hwc2spchw_ukernel_function hwc2spchw_function;
+ xnn_conv_hwc_ukernel_function hwc_function;
+ };
+ uint8_t output_height_tile;
+ uint8_t output_channel_tile;
+};
+
+struct xnn_ukernel_dwconv {
+ union {
+ xnn_dwconv_up_ukernel_function unipass_function;
+ xnn_dwconv_mp_ukernel_function multipass_function;
+ };
+ uint8_t mr;
+ uint8_t qr;
+};
+
+// Direct 2D Depthwise Convolution
+struct xnn_ukernel_dwconv2d {
+ union {
+ xnn_dwconv_spchw_ukernel_function spchw_function;
+ };
+ uint8_t input_width_tile;
+ uint8_t output_width_tile;
+};
+
+struct xnn_ukernel_gemm {
+ xnn_gemm_ukernel_function default_function;
+ xnn_gemm_ukernel_function mr1_function;
+ uint8_t mr;
+ uint8_t nr;
+ uint8_t kr;
+};
+
+struct xnn_ukernel_igemm {
+ xnn_igemm_ukernel_function default_function;
+ xnn_igemm_ukernel_function mr1_function;
+ uint8_t mr;
+ uint8_t nr;
+ uint8_t kr;
+};
+
+struct xnn_ukernel_spmm {
+ xnn_spmm_ukernel_function function;
+ uint8_t mr;
+};
+
+struct xnn_ukernel_vmulcaddc {
+ xnn_vmulcaddc_ukernel_function function;
+ uint8_t mr;
+};
+
+struct xnn_ukernel {
+ enum xnn_ukernel_type type;
+ union {
+ struct xnn_ukernel_dconv2d dconv2d;
+ struct xnn_ukernel_dwconv dwconv;
+ struct xnn_ukernel_dwconv2d dwconv2d;
+ struct xnn_ukernel_gemm gemm;
+ struct xnn_ukernel_igemm igemm;
+ struct xnn_ukernel_spmm spmm;
+ struct xnn_ukernel_vmulcaddc vmulcaddc;
+ };
+};
+
+enum xnn_run_state {
+ xnn_run_state_invalid = 0,
+ xnn_run_state_ready,
+ xnn_run_state_skip,
+};
+
+struct subconvolution_params {
+ void* weights;
+ size_t w_stride;
+ const void** indirection_buffer;
+ void* output;
+ size_t slice_width;
+ size_t slice_height;
+ size_t indirection_y_stride;
+ size_t indirection_x_stride;
+ /* kernel_size * mr * sizeof(void*) */
+ size_t scaled_kernel_size;
+};
+
+struct xnn_operator {
+ size_t batch_size;
+ uint32_t padding_top;
+ uint32_t padding_right;
+ uint32_t padding_bottom;
+ uint32_t padding_left;
+ uint32_t adjustment_height;
+ uint32_t adjustment_width;
+ uint32_t kernel_height;
+ uint32_t kernel_width;
+ uint32_t stride_height;
+ uint32_t stride_width;
+ uint32_t dilation_height;
+ uint32_t dilation_width;
+ uint32_t groups;
+ size_t group_channels;
+ size_t group_input_channels;
+ size_t group_output_channels;
+ size_t channels;
+
+ size_t pad_before_channels;
+ size_t pad_after_channels;
+ uint32_t pad_value;
+
+ size_t input_height;
+ size_t input_width;
+ size_t input_pixel_stride;
+ const void* input;
+ const void** indirection_buffer;
+ void* a_sum;
+
+ size_t input2_pixel_stride;
+ const void* input2;
+
+ size_t output_height;
+ size_t output_width;
+ size_t output_pixel_stride;
+ void* output;
+
+ void* packed_weights;
+ // Total number of non-zero kernel elements when weights use sparse representation.
+ size_t num_nonzero_values;
+ // Total number of non-zero kernel blocks when weights use sparse representation.
+ size_t num_nonzero_blocks;
+ // Total number of output channel blocks when weights use sparse representation.
+ size_t num_output_channel_blocks;
+ // Input channel corresponding to the first non-zero kernel element.
+ size_t first_input_channel;
+
+ float input_scale;
+ float output_scale;
+ uint8_t input_zero_point;
+ uint8_t kernel_zero_point;
+ uint8_t output_zero_point;
+ uint8_t output_min;
+ uint8_t output_max;
+
+ size_t valid_batch_size;
+ size_t last_input_height;
+ size_t last_input_width;
+ const void* last_input;
+ void* last_output;
+
+ void* zero_buffer;
+ void* lookup_table;
+ void* pixelwise_buffer;
+ struct subconvolution_params* subconvolution_buffer;
+
+ union {
+ union xnn_f32_avgpool_params f32_avgpool_params;
+ union xnn_f32_gavgpool_params f32_gavgpool_params;
+ union xnn_f32_hswish_params f32_hswish_params;
+ union xnn_f32_output_params f32_output_params;
+ union xnn_f32_spchw_params f32_spchw_params;
+ union xnn_q8_add_params q8_add_params;
+ union xnn_q8_avgpool_params q8_avgpool_params;
+ union xnn_q8_gemm_params q8_gemm_params;
+ union xnn_u8_output_params u8_output_params;
+ };
+ enum xnn_operator_type type;
+ struct xnn_ukernel ukernel;
+
+ struct compute_parameters compute;
+ struct compute_parameters compute2;
+ union {
+ struct add_contiguous_context add_contiguous;
+ struct add_strided_context add_strided;
+ struct argmax_pooling_context argmax_pooling;
+ struct average_pooling_context average_pooling;
+ struct channel_pad_context channel_pad;
+ struct channel_shuffle_context channel_shuffle;
+ struct dconv2d_context dconv2d;
+ struct dwconv2d_context dwconv2d;
+ struct dwconv_context dwconv;
+ struct gemm_context gemm;
+ struct global_average_pooling_context global_average_pooling;
+ struct global_average_pooling_spnchw_context global_average_pooling_spnchw;
+ struct igemm_context igemm;
+ struct lut_contiguous_context lut_contiguous;
+ struct lut_strided_context lut_strided;
+ struct max_pooling_context max_pooling;
+ struct pixelwise_average_pooling_context pixelwise_average_pooling;
+ struct prelu_context prelu;
+ struct spmm_context spmm;
+ struct subconv_context subconv;
+ struct u8_softargmax_context u8_softargmax;
+ struct univector_contiguous_context univector_contiguous;
+ struct univector_strided_context univector_strided;
+ struct unpooling_context unpooling;
+ struct vmulcaddc_context vmulcaddc;
+ } context;
+
+ enum xnn_run_state state;
+};
diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h
new file mode 100644
index 0000000..4bc31c2
--- /dev/null
+++ b/src/xnnpack/pack.h
@@ -0,0 +1,646 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stdint.h>
+#include <xnnpack/math.h>
+#include <xnnpack/operator.h>
+
+
+static inline void xnn_pack_q8_gemm_goi_w(
+ size_t g,
+ size_t nc,
+ size_t kc,
+ uint32_t nr,
+ uint32_t kr,
+ uint8_t izp,
+ uint8_t kzp,
+ const uint8_t* k,
+ const int32_t* b,
+ void* packed_w)
+{
+ const int32_t boff = (int32_t) kc * (int32_t) izp * (int32_t) kzp;
+ do {
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ int32_t* packed_b = (int32_t*) packed_w;
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
+ for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ int32_t ksum = 0;
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ const uint8_t kv = k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
+ ksum += (int32_t) kv;
+ *((uint8_t*) packed_w) = kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
+ }
+ packed_b[nr_block_offset] -= ksum * (int32_t) izp;
+ packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
+ }
+ }
+ k += nc * kc;
+ b += nc;
+ } while (--g != 0);
+}
+
+static inline void xnn_pack_q8_conv_goki_w(
+ size_t g,
+ size_t nc,
+ size_t ks,
+ size_t kc,
+ uint32_t nr,
+ uint32_t kr,
+ uint8_t izp,
+ uint8_t kzp,
+ const uint8_t* k,
+ const int32_t* b,
+ void* packed_w)
+{
+ const int32_t boff = (int32_t) ks * (int32_t) kc * (int32_t) izp * (int32_t) kzp;
+ do {
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ int32_t* packed_b = (int32_t*) packed_w;
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
+ for (size_t ki = 0; ki < ks; ki++) {
+ for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ int32_t ksum = 0;
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ const uint8_t kv =
+ k[((nr_block_start + nr_block_offset) * ks + ki) * kc + (kr_block_start + kr_block_offset)];
+ ksum += (int32_t) kv;
+ *((uint8_t*) packed_w) = kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
+ }
+ packed_b[nr_block_offset] -= ksum * (int32_t) izp;
+ packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
+ }
+ }
+ }
+ k += ks * kc * nc;
+ b += nc;
+ } while (--g != 0);
+}
+
+static inline void xnn_pack_q8_conv_kgo_w(
+ size_t g,
+ size_t nc,
+ size_t ks,
+ uint32_t nr,
+ uint32_t kr,
+ uint8_t izp,
+ uint8_t kzp,
+ const uint8_t* k,
+ const int32_t* b,
+ void* packed_w)
+{
+ const int32_t boff = (int32_t) ks * (int32_t) izp * (int32_t) kzp;
+ for (size_t i = 0; i < g; i++) {
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ int32_t* packed_b = (int32_t*) packed_w;
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
+ for (size_t ki = 0; ki < ks; ki++) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ const uint8_t kv =
+ k[ki * g * nc + (nr_block_start + nr_block_offset)];
+ *((uint8_t*) packed_w) = kv;
+ packed_b[nr_block_offset] -= (int32_t) kv * (int32_t) izp;
+ packed_w = (void*) ((uintptr_t) packed_w + kr * sizeof(uint8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
+ }
+ }
+ k += nc;
+ b += nc;
+ }
+}
+
+static inline void xnn_pack_q8_deconv_goki_w(
+ size_t g,
+ size_t nc,
+ size_t kh,
+ size_t kw,
+ size_t kc,
+ size_t sh,
+ size_t sw,
+ size_t nr,
+ size_t kr,
+ uint8_t izp,
+ uint8_t kzp,
+ const uint8_t* k,
+ const int32_t* b,
+ void* packed_w,
+ struct subconvolution_params* params)
+{
+ for (size_t i = 0; i < g; i++) {
+ for (size_t oy = 0; oy < sh; oy++) {
+ for (size_t ox = 0; ox < sw; ox++) {
+ if (i == 0) {
+ (*params++).weights = packed_w;
+ }
+ const int32_t boff = (int32_t) divide_round_up(kh - oy, sh) * (int32_t) divide_round_up(kw - ox, sw) * (int32_t) kc * (int32_t) izp * (int32_t) kzp;
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ int32_t* packed_b = (int32_t*) packed_w;
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
+ for (size_t ky = oy; ky < kh; ky += sh) {
+ for (size_t kx = ox; kx < kw; kx += sw) {
+ for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ int32_t ksum = 0;
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ const uint8_t kv =
+ k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (kr_block_start + kr_block_offset)];
+ ksum += (int32_t) kv;
+ *((uint8_t*) packed_w) = kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
+ }
+ packed_b[nr_block_offset] -= ksum * (int32_t) izp;
+ packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
+ }
+ }
+ }
+ }
+ }
+ }
+ k += kh * kw * kc * nc;
+ b += nc;
+ }
+}
+
+static inline void xnn_pack_q8_dwconv_ghw_w(
+ size_t h,
+ size_t w,
+ size_t c,
+ size_t cr,
+ uint8_t izp,
+ uint8_t kzp,
+ const uint8_t* k,
+ const int32_t* b,
+ void* packed_w)
+{
+ const int32_t boff = (int32_t) h * (int32_t) w * (int32_t) izp * (int32_t) kzp;
+ for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
+ const size_t cr_block_size = min(c - cr_block_start, cr);
+ int32_t* packed_b = (int32_t*) packed_w;
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ *((int32_t*) packed_w) = b[cr_block_start + cr_block_offset] + boff;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(int32_t));
+ for (size_t x = 0; x < w; x++) {
+ for (size_t y = 0; y < h; y++) {
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ const uint8_t kv = k[((cr_block_start + cr_block_offset) * h + y) * w + x];
+ packed_b[cr_block_offset] -= (int32_t) kv * (int32_t) izp;
+ *((uint8_t*) packed_w) = kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(uint8_t));
+ }
+ }
+ }
+}
+
+static inline void xnn_pack_q8_dwconv_hwg_w(
+ size_t h,
+ size_t w,
+ size_t c,
+ size_t cr,
+ uint8_t izp,
+ uint8_t kzp,
+ const uint8_t* k,
+ const int32_t* b,
+ void* packed_w)
+{
+ const int32_t boff = (int32_t) h * (int32_t) w * (int32_t) izp * (int32_t) kzp;
+ for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
+ const size_t cr_block_size = min(c - cr_block_start, cr);
+ int32_t* packed_b = (int32_t*) packed_w;
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ *((int32_t*) packed_w) = b[cr_block_start + cr_block_offset] + boff;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(int32_t));
+ for (size_t x = 0; x < w; x++) {
+ for (size_t y = 0; y < h; y++) {
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ const uint8_t kv = k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
+ packed_b[cr_block_offset] -= (int32_t) kv * (int32_t) izp;
+ *((uint8_t*) packed_w) = kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(uint8_t));
+ }
+ }
+ }
+}
+
+static inline void xnn_pack_f16_gemm_goi_w(
+ size_t g,
+ size_t nc,
+ size_t kc,
+ size_t nr,
+ size_t kr,
+ const uint16_t* k,
+ const uint16_t* b,
+ uint16_t* packed_w)
+{
+ do {
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *packed_w++ = b[nr_block_start + nr_block_offset];
+ }
+ packed_w += nr - nr_block_size;
+ for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ *packed_w++ =
+ k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
+ }
+ packed_w += kr - kr_block_size;
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+ }
+ k += nc * kc;
+ b += nc;
+ } while (--g != 0);
+}
+
+static inline void xnn_pack_f32_gemm_goi_w(
+ size_t g,
+ size_t nc,
+ size_t kc,
+ size_t nr,
+ size_t kr,
+ size_t sr,
+ const float* k,
+ const float* b,
+ float* packed_w)
+{
+ const size_t skr = sr * kr;
+ const size_t skc = round_down_po2(kc, skr);
+ const size_t sr_mask = (sr - 1) * kr;
+ do {
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *packed_w++ = b[nr_block_start + nr_block_offset];
+ }
+ packed_w += nr - nr_block_size;
+
+ for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
+ *packed_w++ =
+ k[(nr_block_start + nr_block_offset) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
+ }
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+
+ for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ *packed_w++ =
+ k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
+ }
+ packed_w += kr - kr_block_size;
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+ }
+ k += nc * kc;
+ b += nc;
+ } while (--g != 0);
+}
+
+static inline void xnn_pack_f32_gemminc_goi_w(
+ size_t g,
+ size_t nc,
+ size_t kc,
+ size_t nr,
+ size_t kr,
+ size_t sr,
+ const float* k,
+ float* packed_w)
+{
+ const size_t skr = sr * kr;
+ const size_t skc = round_down_po2(kc, skr);
+ const size_t sr_mask = (sr - 1) * kr;
+ do {
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+
+ for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
+ *packed_w++ =
+ k[(nr_block_start + nr_block_offset) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
+ }
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+
+ for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ *packed_w++ =
+ k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
+ }
+ packed_w += kr - kr_block_size;
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+ }
+ k += nc * kc;
+ } while (--g != 0);
+}
+
+static inline void xnn_pack_f32_conv_goki_w(
+ size_t g,
+ size_t nc,
+ size_t ks,
+ size_t kc,
+ size_t nr,
+ size_t kr,
+ size_t sr,
+ const float* k,
+ const float* b,
+ float* packed_w)
+{
+ const size_t skr = sr * kr;
+ const size_t skc = round_down_po2(kc, skr);
+ const size_t sr_mask = (sr - 1) * kr;
+ do {
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *packed_w++ = b[nr_block_start + nr_block_offset];
+ }
+ packed_w += nr - nr_block_size;
+
+ for (size_t ki = 0; ki < ks; ki++) {
+ for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
+ *packed_w++ =
+ k[((nr_block_start + nr_block_offset) * ks + ki) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
+ }
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+
+ for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ *packed_w++ =
+ k[((nr_block_start + nr_block_offset) * ks + ki) * kc + (kr_block_start + kr_block_offset)];
+ }
+ packed_w += kr - kr_block_size;
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+ }
+ }
+ k += ks * kc * nc;
+ b += nc;
+ } while (--g != 0);
+}
+
+static inline void xnn_pack_f32_conv_kgo_w(
+ size_t g,
+ size_t nc,
+ size_t ks,
+ size_t nr,
+ size_t kr,
+ const float* k,
+ const float* b,
+ float* packed_w)
+{
+ for (size_t i = 0; i < g; i++) {
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *packed_w++ = b[nr_block_start + nr_block_offset];
+ }
+ packed_w += nr - nr_block_size;
+ for (size_t ki = 0; ki < ks; ki++) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *packed_w =
+ k[ki * g * nc + (nr_block_start + nr_block_offset)];
+ packed_w += kr;
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+ }
+ k += nc;
+ b += nc;
+ }
+}
+
+static inline void xnn_pack_f32_dconv_oki_w(
+ size_t nc,
+ size_t kc,
+ size_t nr,
+ size_t kh,
+ size_t kw,
+ const float* k,
+ const float* b,
+ float* packed_w)
+{
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr; nr_block_offset++) {
+ *packed_w++ = b[nr_block_start + min(nr_block_offset, nr_block_size - 1)];
+ }
+
+ for (size_t kx = 0; kx < kw; kx++) {
+ for (size_t c = 0; c < kc; c++) {
+ for (size_t ky = 0; ky < kh; ky++) {
+ for (size_t nr_block_offset = 0; nr_block_offset < nr; nr_block_offset++) {
+ *packed_w++ = k[(((nr_block_start + min(nr_block_offset, nr_block_size - 1)) * kh + ky) * kw + kx) * kc + c];
+ }
+ }
+ }
+ }
+ }
+}
+
+static inline void xnn_pack_f32_deconv_goki_w(
+ size_t g,
+ size_t nc,
+ size_t kh,
+ size_t kw,
+ size_t kc,
+ size_t sh,
+ size_t sw,
+ size_t nr,
+ size_t kr,
+ const float* k,
+ const float* b,
+ float* packed_w,
+ struct subconvolution_params* params)
+{
+ for (size_t i = 0; i < g; i++) {
+ for (size_t oy = 0; oy < sh; oy++) {
+ for (size_t ox = 0; ox < sw; ox++) {
+ if (i == 0) {
+ (*params++).weights = packed_w;
+ }
+ for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
+ const size_t nr_block_size = min(nc - nr_block_start, nr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ *packed_w++ = b[nr_block_start + nr_block_offset];
+ }
+ packed_w += nr - nr_block_size;
+ for (size_t ky = oy; ky < kh; ky += sh) {
+ for (size_t kx = ox; kx < kw; kx += sw) {
+ for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ *packed_w++ =
+ k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (kr_block_start + kr_block_offset)];
+ }
+ packed_w += kr - kr_block_size;
+ }
+ packed_w += (nr - nr_block_size) * kr;
+ }
+ }
+ }
+ }
+ }
+ }
+ k += kh * kw * kc * nc;
+ b += nc;
+ }
+}
+
+static inline void xnn_pack_f32_dwconv_ghw_w(
+ size_t h,
+ size_t w,
+ size_t c,
+ size_t cr,
+ const float* k,
+ const float* b,
+ float* packed_w)
+{
+ for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
+ const size_t cr_block_size = min(c - cr_block_start, cr);
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ *packed_w++ = b[cr_block_start + cr_block_offset];
+ }
+ packed_w += cr - cr_block_size;
+ for (size_t x = 0; x < w; x++) {
+ for (size_t y = 0; y < h; y++) {
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ const float kv = k[((cr_block_start + cr_block_offset) * h + y) * w + x];
+ *packed_w++ = kv;
+ }
+ packed_w += cr - cr_block_size;
+ }
+ }
+ }
+}
+
+static inline void xnn_pack_f32_dwconv_hwg_w(
+ size_t h,
+ size_t w,
+ size_t c,
+ size_t cr,
+ const float* k,
+ const float* b,
+ float* packed_w)
+{
+ for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
+ const size_t cr_block_size = min(c - cr_block_start, cr);
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ *packed_w++ = b[cr_block_start + cr_block_offset];
+ }
+ packed_w += cr - cr_block_size;
+ for (size_t x = 0; x < w; x++) {
+ for (size_t y = 0; y < h; y++) {
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ const float kv = k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
+ *packed_w++ = kv;
+ }
+ packed_w += cr - cr_block_size;
+ }
+ }
+ }
+}
+
+static inline void xnn_pack_f32_spchw_dwconv_ghw_w(
+ size_t kernel_size,
+ size_t groups,
+ const float* kernel,
+ const float* bias,
+ float* packed_weights)
+{
+ for (size_t g = 0; g < groups; g++) {
+ *packed_weights++ = *bias++;
+ for (size_t i = 0; i < kernel_size; i++) {
+ *packed_weights++ = kernel[g * kernel_size + i];
+ }
+ }
+}
+
+static inline void xnn_pack_f32_vmulcaddc_w(
+ size_t c,
+ size_t cr,
+ const float* s,
+ const float* b,
+ float* packed_w)
+{
+ for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
+ const size_t cr_block_size = min(c - cr_block_start, cr);
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ packed_w[cr_block_offset] = s[cr_block_start + cr_block_offset];
+ }
+ packed_w += cr;
+ for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
+ packed_w[cr_block_offset] = b[cr_block_start + cr_block_offset];
+ }
+ packed_w += cr;
+ }
+}
diff --git a/src/xnnpack/packx.h b/src/xnnpack/packx.h
new file mode 100644
index 0000000..20b3bc1
--- /dev/null
+++ b/src/xnnpack/packx.h
@@ -0,0 +1,36 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_X32_PACKX_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t m, \
+ size_t k, \
+ const uint32_t* x, \
+ size_t x_stride, \
+ uint32_t* y);
+
+DECLARE_X32_PACKX_UKERNEL_FUNCTION(xnn_x32_packx_ukernel_2x__scalar)
+DECLARE_X32_PACKX_UKERNEL_FUNCTION(xnn_x32_packx_ukernel_3x__scalar)
+DECLARE_X32_PACKX_UKERNEL_FUNCTION(xnn_x32_packx_ukernel_4x__neon_st4)
+DECLARE_X32_PACKX_UKERNEL_FUNCTION(xnn_x32_packx_ukernel_4x__psimd)
+DECLARE_X32_PACKX_UKERNEL_FUNCTION(xnn_x32_packx_ukernel_4x__scalar)
+DECLARE_X32_PACKX_UKERNEL_FUNCTION(xnn_x32_packx_ukernel_4x__sse)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/pad.h b/src/xnnpack/pad.h
new file mode 100644
index 0000000..3cb8103
--- /dev/null
+++ b/src/xnnpack/pad.h
@@ -0,0 +1,39 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_PAD_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t m, \
+ size_t n, \
+ size_t l, \
+ size_t r, \
+ uint32_t c, \
+ const void* input, \
+ size_t input_stride, \
+ void* output, \
+ size_t output_stride);
+
+DECLARE_PAD_UKERNEL_FUNCTION(xnn_x32_pad_x2__neon)
+DECLARE_PAD_UKERNEL_FUNCTION(xnn_x32_pad_x2__psimd)
+DECLARE_PAD_UKERNEL_FUNCTION(xnn_x32_pad_x2__scalar)
+DECLARE_PAD_UKERNEL_FUNCTION(xnn_x32_pad_x2__sse2)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
new file mode 100644
index 0000000..30e8393
--- /dev/null
+++ b/src/xnnpack/params.h
@@ -0,0 +1,1304 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include <cpuinfo.h>
+
+#include <xnnpack/common.h>
+
+#define XNN_INTERNAL_EXTRA_BYTES 32
+
+struct xnn_f16_output_params {
+ uint16_t scale;
+ uint16_t max;
+ uint16_t min;
+};
+
+union xnn_f32_output_params {
+ struct {
+ float max;
+ float min;
+ } scalar;
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) float max[4];
+ XNN_ALIGN(16) float min[4];
+ } sse;
+#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */
+};
+
+union xnn_f32_spchw_params {
+ struct {
+ float max;
+ float min;
+ } scalar;
+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ struct {
+ XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels
+ XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels
+ XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels
+ float min;
+ float max;
+ } neon;
+#elif CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels
+ XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels
+ XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels
+ XNN_ALIGN(16) float max[4];
+ XNN_ALIGN(16) float min[4];
+ } sse;
+#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */
+};
+
+union xnn_u8_output_params {
+ struct {
+ int32_t max;
+ int32_t min;
+ } scalar;
+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ struct {
+ uint8_t max;
+ uint8_t min;
+ } neon;
+#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) uint8_t max[16];
+ XNN_ALIGN(16) uint8_t min[16];
+ } sse2;
+#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */
+};
+
+union xnn_f32_avgpool_params {
+ struct {
+ float multiplier;
+ float output_min;
+ float output_max;
+ } scalar;
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) float multiplier[4];
+ XNN_ALIGN(16) float output_max[4];
+ XNN_ALIGN(16) float output_min[4];
+ } sse2;
+#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */
+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ struct {
+ XNN_ALIGN(16) float multiplier;
+ XNN_ALIGN(16) float output_max;
+ XNN_ALIGN(16) float output_min;
+ } neon;
+#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */
+};
+
+union xnn_f32_gavgpool_params {
+ struct {
+ float multiplier;
+ float output_min;
+ float output_max;
+ } scalar;
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) float multiplier[4];
+ XNN_ALIGN(16) float output_max[4];
+ XNN_ALIGN(16) float output_min[4];
+ XNN_ALIGN(16) uint32_t mask[4];
+ } sse;
+#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */
+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ struct {
+ XNN_ALIGN(16) float multiplier;
+ XNN_ALIGN(16) float output_max;
+ XNN_ALIGN(16) float output_min;
+ XNN_ALIGN(16) uint32_t mask[4];
+ } neon;
+#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */
+};
+
+union xnn_f32_hswish_params {
+ struct {
+ float sixth;
+ float half;
+ float one;
+ } scalar;
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) float sixth[4];
+ XNN_ALIGN(16) float half[4];
+ XNN_ALIGN(16) float one[4];
+ } sse;
+#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */
+};
+
+union xnn_q8_gemm_params {
+ struct {
+ int32_t kernel_zero_point;
+ int32_t input_zero_point;
+ int32_t multiplier;
+ int32_t remainder_mask;
+ int32_t remainder_threshold;
+ uint32_t shift;
+ int32_t output_min_less_zero_point;
+ int32_t output_max_less_zero_point;
+ int32_t output_zero_point;
+ } scalar;
+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ struct {
+ int16_t kernel_zero_point;
+ int16_t input_zero_point;
+ int32_t multiplier;
+ int32_t right_shift;
+ int16_t output_zero_point;
+ uint8_t output_max;
+ uint8_t output_min;
+ } neon;
+#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) int16_t kernel_zero_point[8];
+ XNN_ALIGN(16) int16_t input_zero_point[8];
+ XNN_ALIGN(16) uint32_t multiplier[4];
+ XNN_ALIGN(16) uint64_t rounding[2];
+ XNN_ALIGN(16) int32_t remainder_mask[4];
+ XNN_ALIGN(16) int32_t remainder_threshold[4];
+ XNN_ALIGN(16) uint64_t shift[2];
+ XNN_ALIGN(16) int16_t output_zero_point[8];
+ XNN_ALIGN(16) uint8_t output_max[16];
+ XNN_ALIGN(16) uint8_t output_min[16];
+ } sse2;
+#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */
+};
+
+union xnn_q8_add_params {
+ struct {
+ int32_t zero_point_product;
+ uint32_t a_multiplier;
+ uint32_t b_multiplier;
+ uint32_t shift;
+ int32_t remainder_mask;
+ int32_t remainder_threshold;
+ int32_t y_zero_point;
+ int32_t y_max;
+ int32_t y_min;
+ } scalar;
+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ struct {
+ uint8_t a_zero_point;
+ uint8_t b_zero_point;
+ int16_t y_zero_point;
+ int32_t a_multiplier;
+ int32_t b_multiplier;
+ int32_t right_shift;
+ uint8_t y_max;
+ uint8_t y_min;
+ } neon;
+#endif
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) int32_t zero_point_product[4];
+ XNN_ALIGN(16) uint16_t a_multiplier_lo[8];
+ XNN_ALIGN(16) uint16_t a_multiplier_hi[8];
+ XNN_ALIGN(16) uint16_t b_multiplier_lo[8];
+ XNN_ALIGN(16) uint16_t b_multiplier_hi[8];
+ XNN_ALIGN(16) int32_t remainder_mask[4];
+ XNN_ALIGN(16) int32_t remainder_threshold[4];
+ XNN_ALIGN(16) int16_t y_zero_point[8];
+ XNN_ALIGN(16) uint8_t y_max[16];
+ XNN_ALIGN(16) uint8_t y_min[16];
+ uint32_t shift;
+ uint32_t a_multiplier;
+ uint32_t b_multiplier;
+ } sse2;
+#endif
+};
+
+union xnn_q8_avgpool_params {
+ struct {
+ int32_t bias;
+ int32_t multiplier;
+ int64_t rounding;
+ uint32_t right_shift;
+ int32_t output_min_less_zero_point;
+ int32_t output_max_less_zero_point;
+ int32_t output_zero_point;
+ } scalar;
+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ struct {
+ int32_t bias;
+ int32_t multiplier;
+ int64_t left_shift;
+ int16_t output_zero_point;
+ uint8_t output_max;
+ uint8_t output_min;
+ } neon;
+#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) int32_t bias[4];
+ XNN_ALIGN(16) uint32_t multiplier[4];
+ XNN_ALIGN(16) uint64_t rounding[2];
+ XNN_ALIGN(16) uint64_t right_shift[2];
+ XNN_ALIGN(16) int16_t output_zero_point[8];
+ XNN_ALIGN(16) uint8_t output_max[16];
+ XNN_ALIGN(16) uint8_t output_min[16];
+ } sse2;
+#endif
+};
+
+union xnn_fp32_requantization_params {
+ struct {
+ float scale;
+ float min_less_zero_point;
+ float max_less_zero_point;
+ float magic;
+ int32_t magic_less_zero_point;
+ } scalar;
+ struct {
+ float scale;
+ float max;
+ float min;
+ float magic;
+ int32_t magic_less_zero_point;
+ } neon;
+ struct {
+ float scale;
+ int16_t zero_point;
+ uint8_t max;
+ uint8_t min;
+ } neonv8;
+ struct {
+ XNN_ALIGN(16) float scale[4];
+ XNN_ALIGN(16) int16_t zero_point[8];
+ XNN_ALIGN(16) uint8_t max[16];
+ XNN_ALIGN(16) uint8_t min[16];
+ } sse2;
+ struct {
+ XNN_ALIGN(16) float scale[4];
+ XNN_ALIGN(16) float min_less_zero_point[4];
+ XNN_ALIGN(16) float max_less_zero_point[4];
+ XNN_ALIGN(16) float magic[4];
+ XNN_ALIGN(16) int32_t magic_less_zero_point[4];
+ } psimd;
+};
+
+union xnn_precise_requantization_params {
+ struct {
+ uint32_t multiplier;
+ uint32_t rounding_lo;
+ uint32_t rounding_hi;
+ uint32_t shift_less_32;
+ int32_t min_less_zero_point;
+ int32_t max_less_zero_point;
+ int32_t zero_point;
+ } scalar;
+ struct {
+ int32_t multiplier;
+ int32_t right_shift;
+ int16_t zero_point;
+ uint8_t max;
+ uint8_t min;
+ } neon;
+ struct {
+ XNN_ALIGN(16) uint32_t multiplier[4];
+ XNN_ALIGN(16) uint64_t rounding[2];
+ XNN_ALIGN(16) uint32_t shift[4];
+ XNN_ALIGN(16) int16_t zero_point[8];
+ XNN_ALIGN(16) uint8_t max[16];
+ XNN_ALIGN(16) uint8_t min[16];
+ } sse2;
+};
+
+union xnn_q31_requantization_params {
+ struct {
+ int32_t multiplier;
+ int32_t remainder_mask;
+ int32_t remainder_threshold;
+ uint32_t shift;
+ int32_t min_less_zero_point;
+ int32_t max_less_zero_point;
+ int32_t zero_point;
+ } scalar;
+#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ struct {
+ int32_t multiplier;
+ int32_t right_shift;
+ int16_t zero_point;
+ uint8_t max;
+ uint8_t min;
+ } neon;
+#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ struct {
+ XNN_ALIGN(16) uint32_t multiplier[4];
+ XNN_ALIGN(16) uint64_t rounding[2];
+ XNN_ALIGN(16) int32_t remainder_mask[4];
+ XNN_ALIGN(16) int32_t remainder_threshold[4];
+ XNN_ALIGN(16) uint64_t shift[2];
+ XNN_ALIGN(16) int16_t zero_point[8];
+ XNN_ALIGN(16) uint8_t max[16];
+ XNN_ALIGN(16) uint8_t min[16];
+ } sse2;
+#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */
+};
+
+union xnn_requantization_params {
+ union xnn_precise_requantization_params precise;
+ union xnn_fp32_requantization_params fp32;
+ union xnn_q31_requantization_params q31;
+};
+
+typedef void (*xnn_ppmm_ukernel_function)(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const void* a,
+ const void* w,
+ void* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const void* params);
+
+typedef void (*xnn_f32_ppmm_ukernel_function)(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const float* a,
+ const float* w,
+ float* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_f16_ppmm_ukernel_function)(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const void* a,
+ const void* w,
+ void* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const struct xnn_f16_output_params* params);
+
+typedef void (*xnn_gemm_ukernel_function)(
+ size_t mr,
+ size_t nr,
+ size_t k,
+ const void* a,
+ size_t a_stride,
+ const void* w,
+ void* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const void* params);
+
+typedef void (*xnn_f32_gemm_ukernel_function)(
+ size_t mr,
+ size_t nr,
+ size_t k,
+ const float* a,
+ size_t a_stride,
+ const float* w,
+ float* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_f32_gemminc_ukernel_function)(
+ size_t mr,
+ size_t nr,
+ size_t k,
+ const float* a,
+ size_t a_stride,
+ const float* w,
+ float* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const float* acc,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_f16_gemm_ukernel_function)(
+ size_t mr,
+ size_t nr,
+ size_t k,
+ const void* a,
+ size_t a_stride,
+ const void* w,
+ void* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const struct xnn_f16_output_params* params);
+
+typedef void (*xnn_q8_gemm_ukernel_function)(
+ size_t mr,
+ size_t nr,
+ size_t k,
+ const uint8_t* a,
+ size_t a_stride,
+ const void* w,
+ uint8_t* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_q8_gemm_params* params);
+
+typedef void (*xnn_igemm_ukernel_function)(
+ size_t mr,
+ size_t nr,
+ size_t kc,
+ size_t ks,
+ const void** a,
+ const void* w,
+ void* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const void* zero,
+ const void* params);
+
+typedef void (*xnn_f32_igemm_ukernel_function)(
+ size_t mr,
+ size_t nr,
+ size_t kc,
+ size_t ks,
+ const float** a,
+ const float* w,
+ float* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const float* zero,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_q8_igemm_ukernel_function)(
+ size_t mr,
+ size_t nr,
+ size_t kc,
+ size_t ks,
+ const uint8_t** a,
+ const void* w,
+ uint8_t* c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const uint8_t* zero,
+ const union xnn_q8_gemm_params* params);
+
+typedef void (*xnn_conv_hwc_ukernel_function)(
+ size_t input_height,
+ size_t input_width,
+ size_t output_y_start,
+ size_t output_y_end,
+ const void* input,
+ const void* zero,
+ const void* weights,
+ void* output,
+ size_t input_padding_top,
+ size_t output_channels,
+ size_t output_height_stride,
+ size_t output_width_stride,
+ const void* params);
+
+typedef void (*xnn_f32_conv_hwc_ukernel_function)(
+ size_t input_height,
+ size_t input_width,
+ size_t output_y_start,
+ size_t output_y_end,
+ const float* input,
+ const float* zero,
+ const float* weights,
+ float* output,
+ size_t input_padding_top,
+ size_t output_channels,
+ size_t output_height_stride,
+ size_t output_width_stride,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_conv_hwc2spchw_ukernel_function)(
+ size_t input_height,
+ size_t input_width,
+ size_t output_y_start,
+ size_t output_y_end,
+ const void* input,
+ const void* zero,
+ const void* weights,
+ void* output,
+ size_t input_padding_top,
+ size_t output_channels,
+ size_t output_height_stride,
+ size_t output_channel_stride,
+ const void* params);
+
+typedef void (*xnn_f32_conv_hwc2spchw_ukernel_function)(
+ size_t input_height,
+ size_t input_width,
+ size_t output_y_start,
+ size_t output_y_end,
+ const float* input,
+ const float* zero,
+ const float* weights,
+ float* output,
+ size_t input_padding_top,
+ size_t output_channels,
+ size_t output_height_stride,
+ size_t output_channel_stride,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_spmm_ukernel_function)(
+ uint32_t m,
+ uint32_t n,
+ const void* a,
+ const void* w,
+ const int32_t* dmap,
+ const uint32_t* nmap,
+ void* c,
+ const void* params);
+
+typedef void (*xnn_f32_spmm_ukernel_function)(
+ uint32_t m,
+ uint32_t n,
+ const float* a,
+ const float* w,
+ const int32_t* dmap,
+ const uint32_t* nmap,
+ float* c,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_packx_ukernel_function)(
+ size_t m,
+ size_t k,
+ const void* x,
+ size_t x_stride,
+ void* y);
+
+typedef void (*xnn_x32_packx_ukernel_function)(
+ size_t m,
+ size_t k,
+ const uint32_t* x,
+ size_t x_stride,
+ uint32_t* y);
+
+typedef void (*xnn_pad_ukernel_function)(
+ size_t m,
+ size_t n,
+ size_t l,
+ size_t r,
+ uint32_t c,
+ const void* x,
+ size_t x_stride,
+ void* y,
+ size_t y_stride);
+
+typedef void (*xnn_unpool_ukernel_function)(
+ size_t p,
+ size_t c,
+ uint32_t f,
+ const void* input,
+ const uint32_t* index,
+ void** output);
+
+typedef void (*xnn_x32_unpool_ukernel_function)(
+ size_t p,
+ size_t c,
+ uint32_t f,
+ const uint32_t* input,
+ const uint32_t* index,
+ uint32_t** output);
+
+typedef void (*xnn_zipc_ukernel_function)(
+ size_t n,
+ const void* x,
+ void* y);
+
+typedef void (*xnn_x8_zipc_ukernel_function)(
+ size_t n,
+ const uint8_t* x,
+ uint8_t* y);
+
+typedef void (*xnn_x32_zipc_ukernel_function)(
+ size_t n,
+ const uint32_t* x,
+ uint32_t* y);
+
+typedef void (*xnn_zipv_ukernel_function)(
+ size_t n,
+ size_t m,
+ const void* x,
+ void* y);
+
+typedef void (*xnn_x8_zipv_ukernel_function)(
+ size_t n,
+ size_t m,
+ const uint8_t* x,
+ uint8_t* y);
+
+typedef void (*xnn_x32_zipv_ukernel_function)(
+ size_t n,
+ size_t m,
+ const uint32_t* x,
+ uint32_t* y);
+
+typedef void (*xnn_x8_lut_ukernel_function)(
+ size_t n,
+ const uint8_t* x,
+ const uint8_t* t,
+ uint8_t* y);
+
+typedef void (*xnn_dwconv_spchw_ukernel_function)(
+ size_t output_height,
+ size_t input_width,
+ const void* input,
+ const void* weights,
+ void* output,
+ size_t input_tuple_stride,
+ size_t output_tuple_stride,
+ size_t input_height_stride,
+ size_t output_height_stride,
+ const void* params);
+
+typedef void (*xnn_f32_dwconv_spchw_ukernel_function)(
+ size_t output_height,
+ size_t input_width,
+ const float* input,
+ const float* weights,
+ float* output,
+ size_t input_tuple_stride,
+ size_t output_tuple_stride,
+ size_t input_height_stride,
+ size_t output_height_stride,
+ const union xnn_f32_spchw_params* params);
+
+typedef void (*xnn_dwconv_up_ukernel_function)(
+ size_t channels,
+ size_t output_width,
+ const void** input,
+ const void* weights,
+ void* output,
+ size_t input_stride,
+ size_t output_increment,
+ const void* params);
+
+typedef void (*xnn_f32_dwconv_up_ukernel_function)(
+ size_t channels,
+ size_t output_width,
+ const float** input,
+ const float* weights,
+ float* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_q8_dwconv_up_ukernel_function)(
+ size_t channels,
+ size_t output_width,
+ const uint8_t** input,
+ const void* weights,
+ uint8_t* output,
+ size_t input_stride,
+ size_t output_increment,
+ const union xnn_q8_gemm_params* params);
+
+typedef void (*xnn_dwconv_mp_ukernel_function)(
+ size_t channels,
+ size_t output_width,
+ const void** input,
+ const void* weights,
+ void* buffer,
+ void* output,
+ size_t input_stride,
+ size_t output_increment,
+ const void* params);
+
+typedef void (*xnn_gavgpool_up_ukernel_function)(
+ size_t m,
+ size_t n,
+ const void* x,
+ size_t x_stride,
+ const void* zero,
+ void* y,
+ const void* params);
+
+typedef void (*xnn_f32_gavgpool_up_ukernel_function)(
+ size_t m,
+ size_t n,
+ const float* x,
+ size_t x_stride,
+ const float* zero,
+ float* y,
+ const union xnn_f32_avgpool_params* params);
+
+typedef void (*xnn_gavgpool_spchw_ukernel_function)(
+ size_t elements,
+ size_t channels,
+ const float* input,
+ float* output,
+ const void* params);
+
+typedef void (*xnn_f32_gavgpool_spchw_ukernel_function)(
+ size_t elements,
+ size_t channels,
+ const float* input,
+ float* output,
+ const union xnn_f32_gavgpool_params* params);
+
+typedef void (*xnn_q8_gavgpool_up_ukernel_function)(
+ size_t m,
+ size_t n,
+ const uint8_t* x,
+ size_t x_stride,
+ const uint8_t* zero,
+ uint8_t* y,
+ const union xnn_q8_avgpool_params* params);
+
+typedef void (*xnn_gavgpool_mp_ukernel_function)(
+ size_t m,
+ size_t n,
+ const void* x,
+ size_t x_stride,
+ const void* zero,
+ void* buffer,
+ void* y,
+ const void* params);
+
+typedef void (*xnn_f32_gavgpool_mp_ukernel_function)(
+ size_t m,
+ size_t n,
+ const float* x,
+ size_t x_stride,
+ const float* zero,
+ float* buffer,
+ float* y,
+ const union xnn_f32_avgpool_params* params);
+
+typedef void (*xnn_q8_gavgpool_mp_ukernel_function)(
+ size_t m,
+ size_t n,
+ const uint8_t* x,
+ size_t x_stride,
+ const uint8_t* zero,
+ int32_t* buffer,
+ uint8_t* y,
+ const union xnn_q8_avgpool_params* params);
+
+typedef void (*xnn_avgpool_up_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const void** x,
+ const void* zero,
+ void* y,
+ size_t x_increment,
+ size_t y_increment,
+ const void* params);
+
+typedef void (*xnn_f32_avgpool_up_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** x,
+ const float* zero,
+ float* y,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_f32_avgpool_params* params);
+
+typedef void (*xnn_q8_avgpool_up_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** x,
+ const uint8_t* zero,
+ uint8_t* y,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_q8_avgpool_params* params);
+
+typedef void (*xnn_avgpool_mp_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const void** x,
+ const void* zero,
+ void* buffer,
+ void* y,
+ size_t x_increment,
+ size_t y_increment,
+ const void* params);
+
+typedef void (*xnn_f32_avgpool_mp_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** x,
+ const float* zero,
+ float* buffer,
+ float* y,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_f32_avgpool_params* params);
+
+typedef void (*xnn_q8_avgpool_mp_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** x,
+ const uint8_t* zero,
+ int32_t* buffer,
+ uint8_t* y,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_q8_avgpool_params* params);
+
+typedef void (*xnn_pavgpool_up_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const void** x,
+ const void* zero,
+ const void* multiplier,
+ void* y,
+ size_t x_increment,
+ size_t y_increment,
+ const void* params);
+
+typedef void (*xnn_f32_pavgpool_up_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** x,
+ const float* zero,
+ const float* multiplier,
+ float* y,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_pavgpool_mp_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const void** x,
+ const void* zero,
+ const void* multiplier,
+ void* buffer,
+ void* y,
+ size_t x_increment,
+ size_t y_increment,
+ const void* params);
+
+typedef void (*xnn_f32_pavgpool_mp_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** x,
+ const float* zero,
+ const float* multiplier,
+ float* buffer,
+ float* y,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_maxpool_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const void** x,
+ void* y,
+ size_t x_increment,
+ size_t y_increment,
+ const void* params);
+
+typedef void (*xnn_f32_maxpool_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** x,
+ float* y,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_u8_maxpool_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const uint8_t** x,
+ uint8_t* y,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_u8_output_params* params);
+
+typedef void (*xnn_argmaxpool_up_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const void** x,
+ void* y,
+ uint32_t* i,
+ size_t x_increment,
+ size_t y_increment,
+ const void* params);
+
+typedef void (*xnn_f32_argmaxpool_up_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** x,
+ float* y,
+ uint32_t* i,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_argmaxpool_mp_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const void** x,
+ void* ab,
+ uint32_t* ib,
+ void* y,
+ uint32_t* i,
+ size_t x_increment,
+ size_t y_increment,
+ const void* params);
+
+typedef void (*xnn_f32_argmaxpool_mp_ukernel_function)(
+ size_t n,
+ size_t ks,
+ size_t kc,
+ const float** x,
+ float* ab,
+ uint32_t* ib,
+ float* y,
+ uint32_t* i,
+ size_t x_increment,
+ size_t y_increment,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_univector_ukernel_function)(
+ size_t n,
+ const void* x,
+ void* y,
+ const void* params);
+
+typedef void (*xnn_f32_clamp_ukernel_function)(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_u8_clamp_ukernel_function)(
+ size_t n,
+ const uint8_t* x,
+ uint8_t* y,
+ const union xnn_u8_output_params* params);
+
+typedef void (*xnn_f32_hswish_ukernel_function)(
+ size_t n,
+ const float* x,
+ float* y,
+ const union xnn_f32_hswish_params* params);
+
+typedef void (*xnn_rmax_ukernel_function)(
+ size_t n,
+ const void* x,
+ void* y);
+
+typedef void (*xnn_u8_rmax_ukernel_function)(
+ size_t n,
+ const uint8_t* x,
+ uint8_t* y);
+
+typedef void (*xnn_f32_rmax_ukernel_function)(
+ size_t n,
+ const float* x,
+ float* y);
+
+typedef void (*xnn_u8_lut32norm_ukernel_function)(
+ size_t n,
+ const uint8_t* x,
+ const uint32_t* t,
+ uint8_t* y);
+
+typedef void (*xnn_vadd_ukernel_function)(
+ size_t n,
+ const void* a,
+ const void* b,
+ void* y,
+ const void* params);
+
+typedef void (*xnn_f32_vadd_ukernel_function)(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_q8_vadd_ukernel_function)(
+ size_t n,
+ const uint8_t* a,
+ const uint8_t* b,
+ uint8_t* y,
+ const union xnn_q8_add_params* params);
+
+typedef void (*xnn_vmul_ukernel_function)(
+ size_t n,
+ const void* a,
+ const void* b,
+ void* y,
+ const void* params);
+
+typedef void (*xnn_f32_vmul_ukernel_function)(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_vsub_ukernel_function)(
+ size_t n,
+ const void* a,
+ const void* b,
+ void* y,
+ const void* params);
+
+typedef void (*xnn_f32_vsub_ukernel_function)(
+ size_t n,
+ const float* a,
+ const float* b,
+ float* y,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_vmulcaddc_ukernel_function)(
+ size_t m,
+ size_t c,
+ const void* x,
+ size_t x_stride,
+ const void* w,
+ void* y,
+ size_t y_stride,
+ const void* params);
+
+typedef void (*xnn_f32_vmulcaddc_ukernel_function)(
+ size_t m,
+ size_t c,
+ const float* x,
+ size_t x_stride,
+ const float* w,
+ float* y,
+ size_t y_stride,
+ const union xnn_f32_output_params* params);
+
+typedef void (*xnn_prelu_ukernel_function)(
+ size_t mr,
+ size_t n,
+ const void* x,
+ size_t x_stride,
+ const void* w,
+ void* y,
+ size_t y_stride,
+ const void* params);
+
+typedef void (*xnn_f32_prelu_ukernel_function)(
+ size_t mr,
+ size_t n,
+ const float* x,
+ size_t x_stride,
+ const float* w,
+ float* y,
+ size_t y_stride,
+ const union xnn_f32_output_params* params);
+
+
+struct gemm_parameters {
+ xnn_gemm_ukernel_function gemm;
+ xnn_igemm_ukernel_function igemm;
+ /* Optional GEMM and IGEMM micro-kernels with MR=1 and the same NR and KR parameters */
+ xnn_gemm_ukernel_function gemm1;
+ xnn_igemm_ukernel_function igemm1;
+ uint8_t mr;
+ uint8_t nr;
+ uint8_t log2_kr;
+ uint8_t log2_sr;
+};
+
+struct spmm_parameters {
+ xnn_spmm_ukernel_function ukernel;
+ // Number of M-dimension elements in a tile.
+ // Corresponds to a block of pixels in 1x1 Convolution and a block of batch size in Fully Connected operator.
+ uint8_t mr;
+ // Number of N-dimension elements in a tile.
+ // Corresponds to a block of output channels/features in 1x1 Convolution and Fully Connected operator.
+ uint8_t nr;
+};
+
+struct hwc2spchw_dconv_parameters {
+ xnn_conv_hwc2spchw_ukernel_function ukernel_with_symm_padding;
+ // Number of output channels in a tile.
+ // This parameter must be passed as is to weight packing function.
+ uint8_t output_channel_tile;
+ // Number of output height pixels in a tile.
+ // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
+ uint8_t output_height_tile;
+ // Number of output width pixes in a tile.
+ uint8_t output_width_tile;
+};
+
+struct spchw_dwconv_parameters {
+ xnn_dwconv_spchw_ukernel_function ukernel;
+ // Number of input width pixels in a tile.
+ uint8_t input_width_tile;
+ // Number of output width pixels in a tile.
+ uint8_t output_width_tile;
+ // Number of output height pixels in a tile.
+ // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
+ uint8_t output_height_tile;
+};
+
+struct spchw_gavgpool_parameters {
+ xnn_gavgpool_spchw_ukernel_function ukernel;
+ // Number of channels in a tile.
+ // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
+ uint8_t channel_tile;
+};
+
+struct dwconv_parameters {
+ union {
+ xnn_dwconv_up_ukernel_function up;
+ xnn_dwconv_mp_ukernel_function mp;
+ };
+ uint8_t cr;
+ uint8_t mr;
+ uint8_t qr;
+};
+
+struct gavgpool_parameters {
+ xnn_gavgpool_up_ukernel_function up;
+ xnn_gavgpool_mp_ukernel_function mp;
+ uint8_t mr;
+};
+
+struct avgpool_parameters {
+ xnn_avgpool_up_ukernel_function up;
+ xnn_avgpool_mp_ukernel_function mp;
+ uint8_t mr;
+ uint8_t qr;
+};
+
+struct pavgpool_parameters {
+ xnn_pavgpool_up_ukernel_function up;
+ xnn_pavgpool_mp_ukernel_function mp;
+ uint8_t mr;
+ uint8_t qr;
+};
+
+struct argmaxpool_parameters {
+ union {
+ xnn_argmaxpool_up_ukernel_function up;
+ xnn_argmaxpool_mp_ukernel_function mp;
+ };
+ uint8_t mr;
+ uint8_t qr;
+};
+
+struct maxpool_parameters {
+ xnn_maxpool_ukernel_function ukernel;
+ uint8_t mr;
+ uint8_t qr;
+};
+
+struct zip_parameters {
+ xnn_zipc_ukernel_function x2;
+ xnn_zipc_ukernel_function x3;
+ xnn_zipc_ukernel_function x4;
+ xnn_zipv_ukernel_function xm;
+};
+
+struct prelu_parameters {
+ xnn_prelu_ukernel_function ukernel;
+ uint8_t mr;
+};
+
+struct pad_parameters {
+ xnn_pad_ukernel_function ukernel;
+ uint8_t mr;
+};
+
+struct vmulcaddc_parameters {
+ xnn_vmulcaddc_ukernel_function ukernel;
+ uint8_t cr;
+ uint8_t mr;
+};
+
+#define XNN_MAX_Q8_DWCONV_UKERNELS 1
+#define XNN_MAX_F32_DWCONV_UKERNELS 3
+#define XNN_MAX_F32_ARGMAXPOOL_UKERNELS 3
+
+struct xnn_parameters {
+ bool initialized;
+ struct {
+ struct gemm_parameters gemm;
+ struct dwconv_parameters dwconv[XNN_MAX_Q8_DWCONV_UKERNELS];
+ struct avgpool_parameters avgpool;
+ struct gavgpool_parameters gavgpool;
+ xnn_vadd_ukernel_function vadd;
+ } q8;
+ struct {
+ struct maxpool_parameters maxpool;
+ xnn_univector_ukernel_function clamp;
+ xnn_u8_lut32norm_ukernel_function lut32norm;
+ xnn_u8_rmax_ukernel_function rmax;
+ } u8;
+ struct {
+ xnn_x8_lut_ukernel_function lut;
+ struct zip_parameters zip;
+ } x8;
+ struct {
+ struct gemm_parameters gemm;
+ struct gemm_parameters gemm2;
+ struct dwconv_parameters dwconv[XNN_MAX_F32_DWCONV_UKERNELS];
+ struct avgpool_parameters avgpool;
+ struct pavgpool_parameters pavgpool;
+ struct gavgpool_parameters gavgpool;
+ struct maxpool_parameters maxpool;
+ struct argmaxpool_parameters argmaxpool[XNN_MAX_F32_ARGMAXPOOL_UKERNELS];
+ xnn_univector_ukernel_function clamp;
+ xnn_univector_ukernel_function hswish;
+ struct prelu_parameters prelu;
+ xnn_vadd_ukernel_function vadd;
+ struct vmulcaddc_parameters vmulcaddc;
+ // Sparse Matrix-Dense Matrix Multiplication (NR=1 block).
+ struct spmm_parameters spmm;
+ // Sparse Matrix-Dense Matrix Multiplication (NR=2 block).
+ struct spmm_parameters spmm2;
+ // Sparse Matrix-Dense Matrix Multiplication (NR=4 block).
+ struct spmm_parameters spmm4;
+ // Direct 3x3 stride-2 Convolution with 3 input channels and HWC->SpCHW layout conversion.
+ struct hwc2spchw_dconv_parameters hwc2spchw_dconv3x3c3s2;
+ // Direct 3x3 stride-1 Convolution with padding 1 on left and right in SpCHW layout.
+ struct spchw_dwconv_parameters spchw_dwconv3x3;
+ // Direct 3x3 stride-2 Convolution with padding 1 on left and right in SpCHW layout.
+ struct spchw_dwconv_parameters spchw_dwconv3x3s2;
+ // Global Average Pooling in SpCHW layout.
+ struct spchw_gavgpool_parameters spchw_gavgpool;
+ } f32;
+ struct {
+ struct pad_parameters pad;
+ xnn_unpool_ukernel_function unpool;
+ struct zip_parameters zip;
+ } x32;
+};
+
+extern XNN_INTERNAL struct xnn_parameters xnn_params;
diff --git a/src/xnnpack/pavgpool.h b/src/xnnpack/pavgpool.h
new file mode 100644
index 0000000..f124519
--- /dev/null
+++ b/src/xnnpack/pavgpool.h
@@ -0,0 +1,60 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_PAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const float** x, \
+ const float* zero, \
+ const float* multiplier, \
+ float* buffer, \
+ float* y, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_PAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_pavgpool_ukernel_mp9p8q__neon)
+DECLARE_F32_PAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_pavgpool_ukernel_mp9p8q__psimd)
+DECLARE_F32_PAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_pavgpool_ukernel_mp9p8q__scalar)
+DECLARE_F32_PAVGPOOL_MULTIPASS_UKERNEL_FUNCTION(xnn_f32_pavgpool_ukernel_mp9p8q__sse)
+
+
+#define DECLARE_F32_PAVGPOOL_UNIPASS_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t ks, \
+ size_t kc, \
+ const float** x, \
+ const float* zero, \
+ const float* multiplier, \
+ float* y, \
+ size_t x_increment, \
+ size_t y_increment, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_PAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_pavgpool_ukernel_up9__neon)
+DECLARE_F32_PAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_pavgpool_ukernel_up9__psimd)
+DECLARE_F32_PAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_pavgpool_ukernel_up9__scalar)
+DECLARE_F32_PAVGPOOL_UNIPASS_UKERNEL_FUNCTION(xnn_f32_pavgpool_ukernel_up9__sse)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/ppmm.h b/src/xnnpack/ppmm.h
new file mode 100644
index 0000000..1bf6941
--- /dev/null
+++ b/src/xnnpack/ppmm.h
@@ -0,0 +1,45 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_PPMM_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t mr, \
+ size_t nc, \
+ size_t kc, \
+ const float* a, \
+ const float* w, \
+ float* c, \
+ size_t cm_stride, \
+ size_t cn_stride, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_2x4__scalar)
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_3x3__scalar)
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_4x2__scalar)
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_4x4__scalar)
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_4x8__neon)
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_4x8__neonfma)
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_4x8__psimd)
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_4x8__sse)
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_8x8__neon)
+DECLARE_F32_PPMM_UKERNEL_FUNCTION(xnn_f32_ppmm_ukernel_8x8__neonfma)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/prelu.h b/src/xnnpack/prelu.h
new file mode 100644
index 0000000..2a882a7
--- /dev/null
+++ b/src/xnnpack/prelu.h
@@ -0,0 +1,38 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_PRELU_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t mr, \
+ size_t n, \
+ const float* x, \
+ size_t x_stride, \
+ const float* w, \
+ float* y, \
+ size_t y_stride, \
+ const union xnn_f32_output_params* clamping_params);
+
+
+DECLARE_F32_PRELU_UKERNEL_FUNCTION(xnn_f32_prelu_ukernel_x4__psimd)
+DECLARE_F32_PRELU_UKERNEL_FUNCTION(xnn_f32_prelu_ukernel_x4__scalar)
+DECLARE_F32_PRELU_UKERNEL_FUNCTION(xnn_f32_prelu_ukernel_x4__sse)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/requantization-stubs.h b/src/xnnpack/requantization-stubs.h
new file mode 100644
index 0000000..ee6e86d
--- /dev/null
+++ b/src/xnnpack/requantization-stubs.h
@@ -0,0 +1,69 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stdint.h>
+#include <stddef.h>
+
+#include <xnnpack/params.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef void (*requantization_function)(
+ size_t n,
+ const int32_t* input,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax,
+ uint8_t* output);
+
+#define DECLARE_REQUANTIZATION_FUNCTION(fn_name) \
+ void fn_name( \
+ size_t n, \
+ const int32_t* input, \
+ float scale, \
+ uint8_t zero_point, \
+ uint8_t qmin, \
+ uint8_t qmax, \
+ uint8_t* output);
+
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_precise__scalar_unsigned32)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_precise__scalar_unsigned64)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_precise__scalar_signed64)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_precise__sse2)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_precise__ssse3)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_precise__sse4)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_precise__neon)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_precise__psimd)
+
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_fp32__scalar_lrintf)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_fp32__scalar_magic)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_fp32__sse2)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_fp32__neon)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_fp32__psimd)
+
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_q31__scalar)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_q31__sse2)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_q31__ssse3)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_q31__sse4)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_q31__neon)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_q31__psimd)
+
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_gemmlowp__scalar)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_gemmlowp__sse2)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_gemmlowp__ssse3)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_gemmlowp__sse4)
+DECLARE_REQUANTIZATION_FUNCTION(xnn_requantize_gemmlowp__neon)
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/requantization.h b/src/xnnpack/requantization.h
new file mode 100644
index 0000000..bf3e100
--- /dev/null
+++ b/src/xnnpack/requantization.h
@@ -0,0 +1,1307 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#if defined(__cplusplus) && (__cplusplus >= 201103L)
+ #include <cstdint>
+ #include <cstddef>
+ #include <cassert>
+ #include <cmath>
+#else
+ #include <stdint.h>
+ #include <stddef.h>
+ #include <assert.h>
+ #include <math.h>
+#endif
+
+#include <fp16.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/scalar-utils.h>
+
+
+static inline union xnn_q8_gemm_params xnn_compute_scalar_q8_gemm_params(
+ uint8_t input_zero_point,
+ uint8_t kernel_zero_point,
+ float scale,
+ uint8_t output_zero_point,
+ uint8_t output_min,
+ uint8_t output_max)
+{
+ /* Compute requantization parameters */
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
+ const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
+ assert(multiplier >= INT32_C(0x40000000));
+ assert(multiplier <= INT32_C(0x7FFFFF80));
+
+ /* Shift is in [0, 31] range */
+ const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
+ assert(shift >= 0);
+ assert(shift < 32);
+
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const uint32_t remainder_threshold = remainder_mask >> 1;
+
+ union xnn_q8_gemm_params params;
+ params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
+ params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
+ params.scalar.multiplier = multiplier;
+ params.scalar.remainder_mask = (int32_t) remainder_mask;
+ params.scalar.remainder_threshold = (int32_t) remainder_threshold;
+ params.scalar.shift = (uint32_t) shift;
+ params.scalar.output_min_less_zero_point =
+ (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
+ params.scalar.output_max_less_zero_point =
+ (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
+ params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
+ return params;
+}
+
+static inline union xnn_q8_gemm_params xnn_compute_q8_gemm_params(
+ uint8_t input_zero_point,
+ uint8_t kernel_zero_point,
+ float scale,
+ uint8_t output_zero_point,
+ uint8_t output_min,
+ uint8_t output_max)
+{
+ /* Compute requantization parameters */
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
+ const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
+ assert(multiplier >= INT32_C(0x40000000));
+ assert(multiplier <= INT32_C(0x7FFFFF80));
+
+ /* Shift is in [0, 31] range */
+ const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
+ assert(shift >= 0);
+ assert(shift < 32);
+
+ union xnn_q8_gemm_params params;
+ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const uint32_t remainder_threshold = remainder_mask >> 1;
+ for (uint32_t i = 0; i < 8; i++) {
+ params.sse2.input_zero_point[i] = (int16_t) (uint16_t) input_zero_point;
+ params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
+ }
+ params.sse2.multiplier[0] = multiplier;
+ params.sse2.multiplier[1] = multiplier;
+ params.sse2.multiplier[2] = multiplier;
+ params.sse2.multiplier[3] = multiplier;
+ params.sse2.rounding[0] = UINT64_C(0x40000000);
+ params.sse2.rounding[1] = UINT64_C(0x40000000);
+ params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
+ params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
+ params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
+ params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
+ params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
+ params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
+ params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
+ params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
+ params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
+ params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
+ for (uint32_t i = 0; i < 8; i++) {
+ params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
+ }
+ for (uint32_t i = 0; i < 16; i++) {
+ params.sse2.output_max[i] = output_max;
+ params.sse2.output_min[i] = output_min;
+ }
+ #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ params.neon.input_zero_point = (int16_t) (uint16_t) input_zero_point;
+ params.neon.kernel_zero_point = (int16_t) (uint16_t) kernel_zero_point;
+ params.neon.multiplier = multiplier;
+ params.neon.right_shift = -shift;
+ params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
+ params.neon.output_max = output_max;
+ params.neon.output_min = output_min;
+ #else
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const uint32_t remainder_threshold = remainder_mask >> 1;
+ params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
+ params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
+ params.scalar.multiplier = multiplier;
+ params.scalar.remainder_mask = (int32_t) remainder_mask;
+ params.scalar.remainder_threshold = (int32_t) remainder_threshold;
+ params.scalar.shift = (uint32_t) shift;
+ params.scalar.output_min_less_zero_point =
+ (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
+ params.scalar.output_max_less_zero_point =
+ (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
+ params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
+ #endif
+ return params;
+}
+
+static inline union xnn_q8_avgpool_params xnn_compute_q8_avgpool_params(
+ int32_t bias,
+ float scale,
+ uint8_t output_zero_point,
+ uint8_t output_min,
+ uint8_t output_max)
+{
+ /* Compute requantization parameters */
+ assert(scale >= 0x1.0p-32f);
+ assert(scale < 256.0f);
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x00800000, 0x00FFFFFF] range */
+ const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
+ assert(multiplier >= INT32_C(0x00800000));
+ assert(multiplier <= INT32_C(0x00FFFFFF));
+
+ /* Shift is in [16, 55] range */
+ const int32_t shift = 127 + 23 - (scale_bits >> 23);
+ assert(shift >= 16);
+ assert(shift < 64);
+
+ union xnn_q8_avgpool_params params;
+ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ const uint32_t right_shift = (uint32_t) shift;
+ const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
+ params.sse2.bias[0] = bias;
+ params.sse2.bias[1] = bias;
+ params.sse2.bias[2] = bias;
+ params.sse2.bias[3] = bias;
+ params.sse2.multiplier[0] = (uint32_t) multiplier;
+ params.sse2.multiplier[1] = (uint32_t) multiplier;
+ params.sse2.multiplier[2] = (uint32_t) multiplier;
+ params.sse2.multiplier[3] = (uint32_t) multiplier;
+ params.sse2.rounding[0] = rounding;
+ params.sse2.rounding[1] = rounding;
+ params.sse2.right_shift[0] = (uint64_t) right_shift;
+ params.sse2.right_shift[1] = (uint64_t) right_shift;
+ for (uint32_t i = 0; i < 8; i++) {
+ params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
+ }
+ for (uint32_t i = 0; i < 16; i++) {
+ params.sse2.output_max[i] = output_max;
+ params.sse2.output_min[i] = output_min;
+ }
+ #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ params.neon.bias = bias;
+ params.neon.multiplier = multiplier;
+ params.neon.left_shift = (int64_t) -shift;
+ params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
+ params.neon.output_max = output_max;
+ params.neon.output_min = output_min;
+ #else
+ const uint32_t right_shift = (uint32_t) shift;
+ const int64_t rounding = INT64_C(1) << (right_shift - 1);
+ params.scalar.bias = bias;
+ params.scalar.multiplier = multiplier;
+ params.scalar.rounding = rounding;
+ params.scalar.right_shift = right_shift;
+ params.scalar.output_min_less_zero_point =
+ (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
+ params.scalar.output_max_less_zero_point =
+ (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
+ params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
+ #endif
+ return params;
+}
+
+static inline union xnn_q8_avgpool_params xnn_compute_scalar_q8_avgpool_params(
+ int32_t bias,
+ float scale,
+ uint8_t output_zero_point,
+ uint8_t output_min,
+ uint8_t output_max)
+{
+ /* Compute requantization parameters */
+ assert(scale >= 0x1.0p-32f);
+ assert(scale < 256.0f);
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x00800000, 0x00FFFFFF] range */
+ const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
+ assert(multiplier >= INT32_C(0x00800000));
+ assert(multiplier <= INT32_C(0x00FFFFFF));
+
+ /* Shift is in [16, 55] range */
+ const int32_t shift = 127 + 23 - (scale_bits >> 23);
+ assert(shift >= 16);
+ assert(shift < 64);
+
+ union xnn_q8_avgpool_params params;
+ const uint32_t right_shift = (uint32_t) shift;
+ const int64_t rounding = INT64_C(1) << (right_shift - 1);
+ params.scalar.bias = bias;
+ params.scalar.rounding = rounding;
+ params.scalar.multiplier = multiplier;
+ params.scalar.right_shift = right_shift;
+ params.scalar.output_min_less_zero_point =
+ (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
+ params.scalar.output_max_less_zero_point =
+ (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
+ params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
+ return params;
+}
+
+static inline void xnn_update_f32_avgpool_params(
+ union xnn_f32_avgpool_params* params,
+ float multiplier)
+{
+ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ for (uint32_t i = 0; i < 4; i++) {
+ params->sse2.multiplier[i] = multiplier;
+ }
+ #else
+ params->scalar.multiplier = multiplier;
+ #endif
+}
+
+static inline union xnn_f32_avgpool_params xnn_compute_f32_avgpool_params(
+ float multiplier,
+ float output_min,
+ float output_max)
+{
+ union xnn_f32_avgpool_params params;
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ for (uint32_t i = 0; i < 4; i++) {
+ params.sse2.multiplier[i] = multiplier;
+ params.sse2.output_min[i] = output_min;
+ params.sse2.output_max[i] = output_max;
+ }
+#else
+ params.scalar.multiplier = multiplier;
+ params.scalar.output_min = output_min;
+ params.scalar.output_max = output_max;
+#endif
+return params;
+}
+
+static inline union xnn_f32_gavgpool_params xnn_compute_f32_gavgpool_params(
+ float multiplier,
+ float output_min,
+ float output_max,
+ uint32_t width)
+{
+ union xnn_f32_gavgpool_params params;
+ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ for (uint32_t i = 0; i < 4; i++) {
+ params.sse.multiplier[i] = multiplier;
+ params.sse.output_min[i] = output_min;
+ params.sse.output_max[i] = output_max;
+ }
+ switch (width % 4) {
+ case 0:
+ params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[1] = 0;
+ params.sse.mask[2] = 0;
+ params.sse.mask[3] = 0;
+ break;
+ case 2:
+ params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[2] = 0;
+ params.sse.mask[3] = 0;
+ break;
+ case 3:
+ params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[3] = 0;
+ break;
+ }
+#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ switch (width % 4) {
+ case 0:
+ params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[1] = 0;
+ params.neon.mask[2] = 0;
+ params.neon.mask[3] = 0;
+ break;
+ case 2:
+ params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[2] = 0;
+ params.neon.mask[3] = 0;
+ break;
+ case 3:
+ params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[3] = 0;
+ break;
+ }
+ params.neon.multiplier = multiplier;
+ params.neon.output_min = output_min;
+ params.neon.output_max = output_max;
+ #else
+ params.scalar.multiplier = multiplier;
+ params.scalar.output_min = output_min;
+ params.scalar.output_max = output_max;
+ #endif
+ return params;
+}
+
+static inline void xnn_update_f32_gavgpool_params(
+ union xnn_f32_gavgpool_params* params,
+ float multiplier,
+ uint32_t width)
+{
+ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ for (uint32_t i = 0; i < 4; i++) {
+ params->sse.multiplier[i] = multiplier;
+ }
+ switch (width % 4) {
+ case 0:
+ params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[1] = 0;
+ params->sse.mask[2] = 0;
+ params->sse.mask[3] = 0;
+ break;
+ case 2:
+ params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[2] = 0;
+ params->sse.mask[3] = 0;
+ break;
+ case 3:
+ params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[3] = 0;
+ break;
+ }
+ #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ params->neon.multiplier = multiplier;
+ switch (width % 4) {
+ case 0:
+ params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[1] = 0;
+ params->neon.mask[2] = 0;
+ params->neon.mask[3] = 0;
+ break;
+ case 2:
+ params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[2] = 0;
+ params->neon.mask[3] = 0;
+ break;
+ case 3:
+ params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[3] = 0;
+ break;
+ }
+ #endif
+}
+
+static inline union xnn_f32_avgpool_params xnn_compute_scalar_f32_avgpool_params(
+ float multiplier,
+ float output_min,
+ float output_max)
+{
+ union xnn_f32_avgpool_params params;
+ params.scalar.multiplier = multiplier;
+ params.scalar.output_min = output_min;
+ params.scalar.output_max = output_max;
+ return params;
+}
+
+static inline union xnn_f32_gavgpool_params xnn_compute_scalar_f32_gavgpool_params(
+ float multiplier,
+ float output_min,
+ float output_max,
+ uint32_t width)
+{
+ union xnn_f32_gavgpool_params params;
+ params.scalar.multiplier = multiplier;
+ params.scalar.output_min = output_min;
+ params.scalar.output_max = output_max;
+ return params;
+}
+
+static inline union xnn_f32_output_params xnn_compute_f32_output_params(
+ float output_min,
+ float output_max)
+{
+ union xnn_f32_output_params params;
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ for (uint32_t i = 0; i < 4; i++) {
+ params.sse.min[i] = output_min;
+ params.sse.max[i] = output_max;
+ }
+#else
+ params.scalar.min = output_min;
+ params.scalar.max = output_max;
+#endif
+ return params;
+}
+
+static inline union xnn_f32_output_params xnn_compute_scalar_f32_output_params(
+ float output_min,
+ float output_max)
+{
+ union xnn_f32_output_params params;
+ params.scalar.min = output_min;
+ params.scalar.max = output_max;
+ return params;
+}
+
+static inline union xnn_f32_hswish_params xnn_compute_f32_hswish_params(void)
+{
+ union xnn_f32_hswish_params params;
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ for (uint32_t i = 0; i < 4; i++) {
+ params.sse.sixth[i] = 0x1.555556p-3f;
+ params.sse.half[i] = 0.5f;
+ params.sse.one[i] = 1.0f;
+ }
+#else
+ params.scalar.sixth = 0x1.555556p-3f;
+ params.scalar.half = 0.5f;
+ params.scalar.one = 1.0f;
+#endif
+ return params;
+}
+
+static inline union xnn_f32_hswish_params xnn_compute_scalar_f32_hswish_params(void)
+{
+ union xnn_f32_hswish_params params;
+ params.scalar.sixth = 0x1.555556p-3f;
+ params.scalar.half = 0.5f;
+ params.scalar.one = 1.0f;
+ return params;
+}
+
+static inline union xnn_f32_spchw_params xnn_compute_f32_spchw_params(
+ uint32_t width,
+ float output_min,
+ float output_max)
+{
+ union xnn_f32_spchw_params params;
+#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ switch (width % 4) {
+ case 0:
+ params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[1] = 0;
+ params.sse.mask[2] = 0;
+ params.sse.mask[3] = 0;
+ break;
+ case 2:
+ params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[2] = 0;
+ params.sse.mask[3] = 0;
+ break;
+ case 3:
+ params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask[3] = 0;
+ break;
+ }
+ switch (width % 8) {
+ case 0:
+ params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[3] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[1] = 0;
+ params.sse.mask_even[2] = 0;
+ params.sse.mask_even[3] = 0;
+ params.sse.mask_odd[0] = 0;
+ params.sse.mask_odd[1] = 0;
+ params.sse.mask_odd[2] = 0;
+ params.sse.mask_odd[3] = 0;
+ break;
+ case 2:
+ params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[1] = 0;
+ params.sse.mask_even[2] = 0;
+ params.sse.mask_even[3] = 0;
+ params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[1] = 0;
+ params.sse.mask_odd[2] = 0;
+ params.sse.mask_odd[3] = 0;
+ break;
+ case 3:
+ params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[2] = 0;
+ params.sse.mask_even[3] = 0;
+ params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[1] = 0;
+ params.sse.mask_odd[2] = 0;
+ params.sse.mask_odd[3] = 0;
+ break;
+ case 4:
+ params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[2] = 0;
+ params.sse.mask_even[3] = 0;
+ params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[2] = 0;
+ params.sse.mask_odd[3] = 0;
+ break;
+ case 5:
+ params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[3] = 0;
+ params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[2] = 0;
+ params.sse.mask_odd[3] = 0;
+ break;
+ case 6:
+ params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[3] = 0;
+ params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[3] = 0;
+ break;
+ case 7:
+ params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_even[3] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params.sse.mask_odd[3] = 0;
+ break;
+ }
+ for (uint32_t i = 0; i < 4; i++) {
+ params.sse.max[i] = output_max;
+ params.sse.min[i] = output_min;
+ }
+#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ switch (width % 4) {
+ case 0:
+ params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[1] = 0;
+ params.neon.mask[2] = 0;
+ params.neon.mask[3] = 0;
+ break;
+ case 2:
+ params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[2] = 0;
+ params.neon.mask[3] = 0;
+ break;
+ case 3:
+ params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask[3] = 0;
+ break;
+ }
+ switch (width % 8) {
+ case 0:
+ params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[3] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[1] = 0;
+ params.neon.mask_even[2] = 0;
+ params.neon.mask_even[3] = 0;
+ params.neon.mask_odd[0] = 0;
+ params.neon.mask_odd[1] = 0;
+ params.neon.mask_odd[2] = 0;
+ params.neon.mask_odd[3] = 0;
+ break;
+ case 2:
+ params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[1] = 0;
+ params.neon.mask_even[2] = 0;
+ params.neon.mask_even[3] = 0;
+ params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[1] = 0;
+ params.neon.mask_odd[2] = 0;
+ params.neon.mask_odd[3] = 0;
+ break;
+ case 3:
+ params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[2] = 0;
+ params.neon.mask_even[3] = 0;
+ params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[1] = 0;
+ params.neon.mask_odd[2] = 0;
+ params.neon.mask_odd[3] = 0;
+ break;
+ case 4:
+ params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[2] = 0;
+ params.neon.mask_even[3] = 0;
+ params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[2] = 0;
+ params.neon.mask_odd[3] = 0;
+ break;
+ case 5:
+ params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[3] = 0;
+ params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[2] = 0;
+ params.neon.mask_odd[3] = 0;
+ break;
+ case 6:
+ params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[3] = 0;
+ params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[3] = 0;
+ break;
+ case 7:
+ params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_even[3] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params.neon.mask_odd[3] = 0;
+ break;
+ }
+ params.neon.max = output_max;
+ params.neon.min = output_min;
+#else
+ params.scalar.max = output_max;
+ params.scalar.min = output_min;
+#endif
+ return params;
+}
+
+static inline void xnn_update_f32_spchw_params(
+ union xnn_f32_spchw_params* params,
+ uint32_t width)
+{
+ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ switch (width % 4) {
+ case 0:
+ params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[1] = 0;
+ params->sse.mask[2] = 0;
+ params->sse.mask[3] = 0;
+ break;
+ case 2:
+ params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[2] = 0;
+ params->sse.mask[3] = 0;
+ break;
+ case 3:
+ params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask[3] = 0;
+ break;
+ }
+ switch (width % 8) {
+ case 0:
+ params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[3] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[1] = 0;
+ params->sse.mask_even[2] = 0;
+ params->sse.mask_even[3] = 0;
+ params->sse.mask_odd[0] = 0;
+ params->sse.mask_odd[1] = 0;
+ params->sse.mask_odd[2] = 0;
+ params->sse.mask_odd[3] = 0;
+ break;
+ case 2:
+ params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[1] = 0;
+ params->sse.mask_even[2] = 0;
+ params->sse.mask_even[3] = 0;
+ params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[1] = 0;
+ params->sse.mask_odd[2] = 0;
+ params->sse.mask_odd[3] = 0;
+ break;
+ case 3:
+ params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[2] = 0;
+ params->sse.mask_even[3] = 0;
+ params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[1] = 0;
+ params->sse.mask_odd[2] = 0;
+ params->sse.mask_odd[3] = 0;
+ break;
+ case 4:
+ params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[2] = 0;
+ params->sse.mask_even[3] = 0;
+ params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[2] = 0;
+ params->sse.mask_odd[3] = 0;
+ break;
+ case 5:
+ params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[3] = 0;
+ params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[2] = 0;
+ params->sse.mask_odd[3] = 0;
+ break;
+ case 6:
+ params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[3] = 0;
+ params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[3] = 0;
+ break;
+ case 7:
+ params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_even[3] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params->sse.mask_odd[3] = 0;
+ break;
+ }
+ #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ switch (width % 4) {
+ case 0:
+ params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[1] = 0;
+ params->neon.mask[2] = 0;
+ params->neon.mask[3] = 0;
+ break;
+ case 2:
+ params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[2] = 0;
+ params->neon.mask[3] = 0;
+ break;
+ case 3:
+ params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask[3] = 0;
+ break;
+ }
+ switch (width % 8) {
+ case 0:
+ params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[3] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[3] = UINT32_C(0xFFFFFFFF);
+ break;
+ case 1:
+ params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[1] = 0;
+ params->neon.mask_even[2] = 0;
+ params->neon.mask_even[3] = 0;
+ params->neon.mask_odd[0] = 0;
+ params->neon.mask_odd[1] = 0;
+ params->neon.mask_odd[2] = 0;
+ params->neon.mask_odd[3] = 0;
+ break;
+ case 2:
+ params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[1] = 0;
+ params->neon.mask_even[2] = 0;
+ params->neon.mask_even[3] = 0;
+ params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[1] = 0;
+ params->neon.mask_odd[2] = 0;
+ params->neon.mask_odd[3] = 0;
+ break;
+ case 3:
+ params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[2] = 0;
+ params->neon.mask_even[3] = 0;
+ params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[1] = 0;
+ params->neon.mask_odd[2] = 0;
+ params->neon.mask_odd[3] = 0;
+ break;
+ case 4:
+ params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[2] = 0;
+ params->neon.mask_even[3] = 0;
+ params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[2] = 0;
+ params->neon.mask_odd[3] = 0;
+ break;
+ case 5:
+ params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[3] = 0;
+ params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[2] = 0;
+ params->neon.mask_odd[3] = 0;
+ break;
+ case 6:
+ params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[3] = 0;
+ params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[3] = 0;
+ break;
+ case 7:
+ params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_even[3] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
+ params->neon.mask_odd[3] = 0;
+ break;
+ }
+ #endif
+}
+
+static inline union xnn_f32_spchw_params xnn_compute_scalar_f32_spchw_params(
+ uint32_t width,
+ float output_min,
+ float output_max)
+{
+ union xnn_f32_spchw_params params;
+ params.scalar.max = output_max;
+ params.scalar.min = output_min;
+ return params;
+}
+
+static inline union xnn_u8_output_params xnn_compute_u8_output_params(
+ uint8_t output_min,
+ uint8_t output_max)
+{
+ assert(output_min < output_max);
+
+ union xnn_u8_output_params params;
+ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ for (uint32_t i = 0; i < 16; i++) {
+ params.sse2.max[i] = output_max;
+ params.sse2.min[i] = output_min;
+ }
+ #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ params.neon.max = output_max;
+ params.neon.min = output_min;
+ #else
+ params.scalar.min = (int32_t) (uint32_t) output_min;
+ params.scalar.max = (int32_t) (uint32_t) output_max;
+ #endif
+ return params;
+}
+
+static inline union xnn_u8_output_params xnn_compute_scalar_u8_output_params(
+ uint8_t output_min,
+ uint8_t output_max)
+{
+ assert(output_min < output_max);
+
+ union xnn_u8_output_params params;
+ params.scalar.min = (int32_t) (uint32_t) output_min;
+ params.scalar.max = (int32_t) (uint32_t) output_max;
+ return params;
+}
+
+static inline union xnn_q8_add_params xnn_compute_q8_add_params(
+ uint8_t a_zero_point,
+ uint8_t b_zero_point,
+ uint8_t output_zero_point,
+ float a_output_scale,
+ float b_output_scale,
+ uint8_t output_min,
+ uint8_t output_max)
+{
+ assert(a_output_scale >= 0x1.0p-14f);
+ assert(b_output_scale >= 0x1.0p-14f);
+ assert(a_output_scale < 0x1.0p+8f);
+ assert(b_output_scale < 0x1.0p+8f);
+
+ /* Compute requantization parameters */
+ const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
+ assert(max_output_scale >= 0x1.0p-14f);
+ assert(max_output_scale < 0x1.0p+8f);
+ const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
+ const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
+ /* Shift is in [13, 31] range */
+ const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
+ assert(shift < 32);
+ assert(shift >= 13);
+
+ const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
+
+ /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range */
+ const uint32_t a_multiplier = (uint32_t) (int32_t) __builtin_lrintf(a_output_scale * scale_multiplier);
+ const uint32_t b_multiplier = (uint32_t) (int32_t) __builtin_lrintf(b_output_scale * scale_multiplier);
+ assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
+ assert(a_multiplier < UINT32_C(0x00400000));
+ assert(b_multiplier < UINT32_C(0x00400000));
+
+ union xnn_q8_add_params params;
+ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const uint32_t remainder_threshold = remainder_mask >> 1;
+ const int32_t zero_point_product =
+ (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
+ for (uint32_t i = 0; i < 4; i++) {
+ params.sse2.zero_point_product[i] = zero_point_product;
+ }
+ for (uint32_t i = 0; i < 8; i++) {
+ params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
+ }
+ for (uint32_t i = 0; i < 8; i++) {
+ params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
+ params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
+ params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
+ params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
+ }
+ params.sse2.a_multiplier = a_multiplier;
+ params.sse2.b_multiplier = b_multiplier;
+ for (uint32_t i = 0; i < 4; i++) {
+ params.sse2.remainder_mask[i] = remainder_mask;
+ params.sse2.remainder_threshold[i] = remainder_threshold;
+ }
+ params.sse2.shift = shift;
+ for (uint32_t i = 0; i < 16; i++) {
+ params.sse2.y_max[i] = output_max;
+ params.sse2.y_min[i] = output_min;
+ }
+ #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ params.neon.a_zero_point = a_zero_point;
+ params.neon.b_zero_point = b_zero_point;
+ params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
+ params.neon.a_multiplier = (int32_t) a_multiplier;
+ params.neon.b_multiplier = (int32_t) b_multiplier;
+ params.neon.right_shift = (int32_t) -shift;
+ params.neon.y_max = output_max;
+ params.neon.y_min = output_min;
+ #else
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const uint32_t remainder_threshold = remainder_mask >> 1;
+ params.scalar.zero_point_product =
+ (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
+ params.scalar.a_multiplier = a_multiplier;
+ params.scalar.b_multiplier = b_multiplier;
+ params.scalar.remainder_mask = (int32_t) remainder_mask;
+ params.scalar.remainder_threshold = (int32_t) remainder_threshold;
+ params.scalar.shift = shift;
+ params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
+ params.scalar.y_max = (int32_t) (uint32_t) output_max;
+ params.scalar.y_min = (int32_t) (uint32_t) output_min;
+ #endif
+ return params;
+}
+
+static inline union xnn_q8_add_params xnn_compute_scalar_q8_add_params(
+ uint8_t a_zero_point,
+ uint8_t b_zero_point,
+ uint8_t output_zero_point,
+ float a_output_scale,
+ float b_output_scale,
+ uint8_t output_min,
+ uint8_t output_max)
+{
+ assert(a_output_scale >= 0x1.0p-10f);
+ assert(b_output_scale >= 0x1.0p-10f);
+ assert(a_output_scale < 0x1.0p+8f);
+ assert(b_output_scale < 0x1.0p+8f);
+
+ /* Compute requantization parameters */
+ const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
+ assert(max_output_scale >= 0x1.0p-10f);
+ assert(max_output_scale < 0x1.0p+8f);
+ const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
+ const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
+ /* Shift is in [13, 31] range */
+ const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
+ assert(shift < 32);
+ assert(shift >= 13);
+
+ /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range */
+ const uint32_t a_multiplier = (uint32_t) (int32_t) __builtin_lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
+ const uint32_t b_multiplier = (uint32_t) (int32_t) __builtin_lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
+ assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
+ assert(a_multiplier < UINT32_C(0x00400000));
+ assert(b_multiplier < UINT32_C(0x00400000));
+
+ union xnn_q8_add_params params;
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const uint32_t remainder_threshold = remainder_mask >> 1;
+ params.scalar.zero_point_product =
+ (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
+ params.scalar.a_multiplier = a_multiplier;
+ params.scalar.b_multiplier = b_multiplier;
+ params.scalar.remainder_mask = (int32_t) remainder_mask;
+ params.scalar.remainder_threshold = (int32_t) remainder_threshold;
+ params.scalar.shift = shift;
+ params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
+ params.scalar.y_max = (int32_t) (uint32_t) output_max;
+ params.scalar.y_min = (int32_t) (uint32_t) output_min;
+ return params;
+}
+
+static inline union xnn_q31_requantization_params xnn_compute_scalar_requantization_params(
+ float scale,
+ uint8_t zero_point,
+ uint8_t min,
+ uint8_t max)
+{
+ /* Compute requantization parameters */
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
+ const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
+ assert(multiplier >= INT32_C(0x40000000));
+ assert(multiplier <= INT32_C(0x7FFFFF80));
+
+ /* Shift is in [0, 31] range */
+ const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
+ assert(shift >= 0);
+ assert(shift < 32);
+
+ union xnn_q31_requantization_params params;
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const uint32_t remainder_threshold = remainder_mask >> 1;
+ params.scalar.multiplier = multiplier;
+ params.scalar.remainder_mask = (int32_t) remainder_mask;
+ params.scalar.remainder_threshold = (int32_t) remainder_threshold;
+ params.scalar.shift = (uint32_t) shift;
+ params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
+ params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
+ params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
+ return params;
+}
+
+static inline union xnn_q31_requantization_params xnn_compute_requantization_params(
+ float scale,
+ uint8_t zero_point,
+ uint8_t min,
+ uint8_t max)
+{
+ /* Compute requantization parameters */
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
+ const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
+ assert(multiplier >= INT32_C(0x40000000));
+ assert(multiplier <= INT32_C(0x7FFFFF80));
+
+ /* Shift is in [0, 31] range */
+ const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
+ assert(shift >= 0);
+ assert(shift < 32);
+
+ union xnn_q31_requantization_params params;
+ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const uint32_t remainder_threshold = remainder_mask >> 1;
+ params.sse2.multiplier[0] = multiplier;
+ params.sse2.multiplier[1] = multiplier;
+ params.sse2.multiplier[2] = multiplier;
+ params.sse2.multiplier[3] = multiplier;
+ params.sse2.rounding[0] = UINT64_C(0x40000000);
+ params.sse2.rounding[1] = UINT64_C(0x40000000);
+ params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
+ params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
+ params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
+ params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
+ params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
+ params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
+ params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
+ params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
+ params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
+ params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
+ for (uint32_t i = 0; i < 8; i++) {
+ params.sse2.zero_point[i] = (int16_t) (uint16_t) zero_point;
+ }
+ for (uint32_t i = 0; i < 16; i++) {
+ params.sse2.max[i] = max;
+ params.sse2.min[i] = min;
+ }
+ #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
+ params.neon.multiplier = multiplier;
+ params.neon.right_shift = -shift;
+ params.neon.zero_point = (int16_t) (uint16_t) zero_point;
+ params.neon.max = max;
+ params.neon.min = min;
+ #else
+ const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
+ const uint32_t remainder_threshold = remainder_mask >> 1;
+ params.scalar.multiplier = multiplier;
+ params.scalar.remainder_mask = (int32_t) remainder_mask;
+ params.scalar.remainder_threshold = (int32_t) remainder_threshold;
+ params.scalar.shift = (uint32_t) shift;
+ params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
+ params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
+ params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
+ #endif
+ return params;
+}
+
+static inline uint8_t xnn_q31_requantize(
+ int32_t n,
+ union xnn_q31_requantization_params params)
+{
+ const int64_t product = (int64_t) n * (int64_t) params.scalar.multiplier;
+ const int32_t q31product = (int32_t) (uint32_t) ((uint64_t) (product + INT64_C(0x40000000)) >> 31);
+ const int32_t remainder = (q31product & params.scalar.remainder_mask) - (int32_t) (n < 0);
+ n = asr_s32(q31product, params.scalar.shift) + (int32_t) (remainder > params.scalar.remainder_threshold);
+ if (n < params.scalar.min_less_zero_point) {
+ n = params.scalar.min_less_zero_point;
+ }
+ if (n > params.scalar.max_less_zero_point) {
+ n = params.scalar.max_less_zero_point;
+ }
+
+ return (uint8_t) (n + params.scalar.zero_point);
+}
+
+static inline uint8_t xnn_avgpool_quantize(
+ int32_t n,
+ union xnn_q8_avgpool_params params)
+{
+ const int64_t product = (int64_t) n * (int64_t) params.scalar.multiplier;
+ const int64_t adjusted_product = product - (int64_t) (n < 0);
+
+ n = (int32_t) asr_s64(adjusted_product + params.scalar.rounding, params.scalar.right_shift);
+ if (n < params.scalar.output_min_less_zero_point) {
+ n = params.scalar.output_min_less_zero_point;
+ }
+ if (n > params.scalar.output_max_less_zero_point) {
+ n = params.scalar.output_max_less_zero_point;
+ }
+
+ return (uint8_t) (n + params.scalar.output_zero_point);
+}
+
+static inline uint8_t xnn_add_quantize(
+ uint8_t a, uint8_t b,
+ union xnn_q8_add_params params)
+{
+ /* Multiply by factors and accumulate products */
+ int32_t acc = params.scalar.zero_point_product +
+ (int32_t) ((uint32_t) a * params.scalar.a_multiplier) +
+ (int32_t) ((uint32_t) b * params.scalar.b_multiplier);
+
+ /* Shift right and round */
+ const int32_t rem = (acc & params.scalar.remainder_mask) - (int32_t) (acc < 0);
+ acc = asr_s32(acc, params.scalar.shift) + (int32_t) (rem > params.scalar.remainder_threshold);
+
+ /* Clamp and add output zero point */
+ int32_t y = acc + params.scalar.y_zero_point;
+ if (y >= params.scalar.y_max) {
+ y = params.scalar.y_max;
+ }
+ if (y <= params.scalar.y_min) {
+ y = params.scalar.y_min;
+ }
+ return (uint8_t) y;
+}
diff --git a/src/xnnpack/rmax.h b/src/xnnpack/rmax.h
new file mode 100644
index 0000000..25f6e32
--- /dev/null
+++ b/src/xnnpack/rmax.h
@@ -0,0 +1,47 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_RMAX_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const float* x, \
+ float* y);
+
+DECLARE_F32_RMAX_UKERNEL_FUNCTION(xnn_f32_rmax_ukernel__avx)
+DECLARE_F32_RMAX_UKERNEL_FUNCTION(xnn_f32_rmax_ukernel__avx512f)
+DECLARE_F32_RMAX_UKERNEL_FUNCTION(xnn_f32_rmax_ukernel__neon)
+DECLARE_F32_RMAX_UKERNEL_FUNCTION(xnn_f32_rmax_ukernel__scalar)
+DECLARE_F32_RMAX_UKERNEL_FUNCTION(xnn_f32_rmax_ukernel__sse)
+
+
+#define DECLARE_U8_RMAX_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const uint8_t* x, \
+ uint8_t* y);
+
+DECLARE_U8_RMAX_UKERNEL_FUNCTION(xnn_u8_rmax_ukernel__neon)
+DECLARE_U8_RMAX_UKERNEL_FUNCTION(xnn_u8_rmax_ukernel__scalar)
+DECLARE_U8_RMAX_UKERNEL_FUNCTION(xnn_u8_rmax_ukernel__sse2)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/scalar-utils.h b/src/xnnpack/scalar-utils.h
new file mode 100644
index 0000000..88d30c8
--- /dev/null
+++ b/src/xnnpack/scalar-utils.h
@@ -0,0 +1,121 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#if defined(__cplusplus) && (__cplusplus >= 201103L)
+ #include <climits>
+ #include <cstdint>
+ #include <cstdbool>
+ #include <cassert>
+#else
+ #include <limits.h>
+ #include <stdint.h>
+ #include <stdbool.h>
+ #include <assert.h>
+#endif
+
+#include <fp16.h>
+
+#if defined(__clang__) && !defined(__pnacl__)
+ #if __clang_major__ == 3 && __clang_minor__ >= 7 || __clang_major__ > 3
+ #define XNN_IGNORE_SHIFT_BASE_UB __attribute__((__no_sanitize__("shift-base")))
+ #else
+ #define XNN_IGNORE_SHIFT_BASE_UB
+ #endif
+#elif defined(__GNUC__)
+ #if __GNUC__ >= 8
+ #define XNN_IGNORE_SHIFT_BASE_UB __attribute__((__no_sanitize__("shift-base")))
+ #elif __GNUC__ == 4 && __GNUC_MINOR__ >= 9 || __GNUC__ > 4
+ /* 4.9 <= gcc < 8 support ubsan, but doesn't support no_sanitize attribute */
+ #define XNN_IGNORE_SHIFT_BASE_UB
+ #ifndef XNN_USE_SHIFT_BASE_UB_WORKAROUND
+ #define XNN_USE_SHIFT_BASE_UB_WORKAROUND 1
+ #endif
+ #else
+ #define XNN_IGNORE_SHIFT_BASE_UB
+ #endif
+#else
+ #define XNN_IGNORE_SHIFT_BASE_UB
+#endif
+
+XNN_IGNORE_SHIFT_BASE_UB
+inline static int32_t asr_s32(int32_t x, uint32_t n) {
+ #ifdef XNN_USE_SHIFT_BASE_UB_WORKAROUND
+ #if defined(__x86_64__) || defined(__aarch64__)
+ return (int32_t) ((uint64_t) (int64_t) x >> n);
+ #else
+ return x >= 0 ? x >> n : ~(~x >> n);
+ #endif
+ #else
+ return x >> n;
+ #endif
+}
+
+XNN_IGNORE_SHIFT_BASE_UB
+inline static int64_t asr_s64(int64_t x, uint32_t n) {
+ #ifdef XNN_USE_SHIFT_BASE_UB_WORKAROUND
+ return x >= 0 ? x >> n : ~(~x >> n);
+ #else
+ return x >> n;
+ #endif
+}
+
+inline static uint8_t scalar_requantize_precise(
+ int32_t value,
+ float scale,
+ uint8_t zero_point,
+ uint8_t qmin,
+ uint8_t qmax)
+{
+ assert(scale < 1.0f);
+ assert(scale >= 0x1.0p-32f);
+
+ const uint32_t scale_bits = fp32_to_bits(scale);
+ const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
+ const uint32_t shift = 127 + 23 - (scale_bits >> 23);
+ assert(shift >= 24);
+ assert(shift < 56);
+
+ /*
+ * Compute absolute value of input as unsigned 32-bit int.
+ * All further computations will work with unsigned values to avoid undefined behaviour on signed operations.
+ */
+ const uint32_t abs_value = (value >= 0) ? (uint32_t) value : -(uint32_t) value;
+
+ /* Compute full 64-bit product of 32-bit factors */
+ const uint64_t product = (uint64_t) abs_value * (uint64_t) multiplier;
+
+ /*
+ * Shift the full 64-bit product right with rounding.
+ * Rounding is performed towards closest integer, with midpoints rounded up (same as away from zero).
+ */
+ const uint64_t rounding = UINT64_C(1) << (shift - 1);
+ const uint32_t abs_scaled_value = (uint32_t) ((product + rounding) >> shift);
+
+ /*
+ * Copy the sign of input to scaled absolute input value.
+ */
+ const int32_t scaled_value = (int32_t) (value >= 0 ? abs_scaled_value : -abs_scaled_value);
+
+ /* Clamp scaled value with zero point between smin and smax */
+ int32_t clamped_value = scaled_value;
+ const int32_t smin = (int32_t) (uint32_t) qmin - (int32_t) (uint32_t) zero_point;
+ if (clamped_value < smin) {
+ clamped_value = smin;
+ }
+ const int32_t smax = (int32_t) (uint32_t) qmax - (int32_t) (uint32_t) zero_point;
+ if (clamped_value > smax) {
+ clamped_value = smax;
+ }
+
+ /* Add zero point to clamped value */
+ const int32_t biased_value = clamped_value + (int32_t) (uint32_t) zero_point;
+
+ return biased_value;
+}
diff --git a/src/xnnpack/spmm.h b/src/xnnpack/spmm.h
new file mode 100644
index 0000000..7ea16bf
--- /dev/null
+++ b/src/xnnpack/spmm.h
@@ -0,0 +1,66 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_SPMM_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ uint32_t m, \
+ uint32_t n, \
+ const float* a, \
+ const float* w, \
+ const int32_t* dmap, \
+ const uint32_t* nmap, \
+ float* c, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_12x1__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_12x2__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_12x4__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_16x1__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_16x2__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_16x4__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_16x1__neonfma_pipelined)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_16x1__neonfma_unroll2)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_1x1__scalar)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_1x1__scalar_pipelined)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_1x1__scalar_unroll2)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_2x1__scalar)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_2x1__scalar_pipelined)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_2x1__scalar_unroll2)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_4x1__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_4x2__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_4x4__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_4x1__neonfma_pipelined)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_4x1__neonfma_unroll2)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_4x1__scalar)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_4x1__scalar_pipelined)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_4x1__scalar_unroll2)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_4x1__sse)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x1__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x2__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x4__neonfma)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x1__neonfma_pipelined)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x1__neonfma_unroll2)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x1__scalar)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x1__scalar_pipelined)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x1__scalar_unroll2)
+DECLARE_F32_SPMM_UKERNEL_FUNCTION(xnn_f32_spmm_ukernel_8x1__sse)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/unpool.h b/src/xnnpack/unpool.h
new file mode 100644
index 0000000..c02457a
--- /dev/null
+++ b/src/xnnpack/unpool.h
@@ -0,0 +1,34 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_X32_UNPOOL_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t p, \
+ size_t c, \
+ uint32_t f, \
+ const uint32_t* input, \
+ const uint32_t* index, \
+ uint32_t** output);
+
+DECLARE_X32_UNPOOL_UKERNEL_FUNCTION(xnn_x32_unpool_ukernel__psimd)
+DECLARE_X32_UNPOOL_UKERNEL_FUNCTION(xnn_x32_unpool_ukernel__scalar)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/vadd.h b/src/xnnpack/vadd.h
new file mode 100644
index 0000000..a66d171
--- /dev/null
+++ b/src/xnnpack/vadd.h
@@ -0,0 +1,51 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_VADD_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const float* a, \
+ const float* b, \
+ float* y, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_VADD_UKERNEL_FUNCTION(xnn_f32_vadd_ukernel__neon)
+DECLARE_F32_VADD_UKERNEL_FUNCTION(xnn_f32_vadd_ukernel__psimd)
+DECLARE_F32_VADD_UKERNEL_FUNCTION(xnn_f32_vadd_ukernel__scalar)
+DECLARE_F32_VADD_UKERNEL_FUNCTION(xnn_f32_vadd_ukernel__sse)
+
+
+#define DECLARE_Q8_VADD_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const uint8_t* a, \
+ const uint8_t* b, \
+ uint8_t* y, \
+ const union xnn_q8_add_params* params);
+
+DECLARE_Q8_VADD_UKERNEL_FUNCTION(xnn_q8_vadd_ukernel__neon)
+DECLARE_Q8_VADD_UKERNEL_FUNCTION(xnn_q8_vadd_ukernel__scalar)
+DECLARE_Q8_VADD_UKERNEL_FUNCTION(xnn_q8_vadd_ukernel__sse2)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/vmul.h b/src/xnnpack/vmul.h
new file mode 100644
index 0000000..9747de8
--- /dev/null
+++ b/src/xnnpack/vmul.h
@@ -0,0 +1,35 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_VMUL_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const float* a, \
+ const float* b, \
+ float* y, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_VMUL_UKERNEL_FUNCTION(xnn_f32_vmul_ukernel__neon)
+DECLARE_F32_VMUL_UKERNEL_FUNCTION(xnn_f32_vmul_ukernel__psimd)
+DECLARE_F32_VMUL_UKERNEL_FUNCTION(xnn_f32_vmul_ukernel__scalar)
+DECLARE_F32_VMUL_UKERNEL_FUNCTION(xnn_f32_vmul_ukernel__sse)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/vmulcaddc.h b/src/xnnpack/vmulcaddc.h
new file mode 100644
index 0000000..a37e747
--- /dev/null
+++ b/src/xnnpack/vmulcaddc.h
@@ -0,0 +1,39 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_VMULCADDC_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t m, \
+ size_t c, \
+ const float* x, \
+ size_t x_stride, \
+ const float* w, \
+ float* y, \
+ size_t y_stride, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_VMULCADDC_UKERNEL_FUNCTION(xnn_f32_vmulcaddc_ukernel_c1__scalar_x2)
+DECLARE_F32_VMULCADDC_UKERNEL_FUNCTION(xnn_f32_vmulcaddc_ukernel_c4__neon_x2)
+DECLARE_F32_VMULCADDC_UKERNEL_FUNCTION(xnn_f32_vmulcaddc_ukernel_c4__neonfma_x2)
+DECLARE_F32_VMULCADDC_UKERNEL_FUNCTION(xnn_f32_vmulcaddc_ukernel_c4__psimd_x2)
+DECLARE_F32_VMULCADDC_UKERNEL_FUNCTION(xnn_f32_vmulcaddc_ukernel_c4__sse_x2)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/vsub.h b/src/xnnpack/vsub.h
new file mode 100644
index 0000000..e444eb6
--- /dev/null
+++ b/src/xnnpack/vsub.h
@@ -0,0 +1,35 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_F32_VSUB_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const float* a, \
+ const float* b, \
+ float* y, \
+ const union xnn_f32_output_params* params);
+
+DECLARE_F32_VSUB_UKERNEL_FUNCTION(xnn_f32_vsub_ukernel__neon)
+DECLARE_F32_VSUB_UKERNEL_FUNCTION(xnn_f32_vsub_ukernel__psimd)
+DECLARE_F32_VSUB_UKERNEL_FUNCTION(xnn_f32_vsub_ukernel__scalar)
+DECLARE_F32_VSUB_UKERNEL_FUNCTION(xnn_f32_vsub_ukernel__sse)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif
diff --git a/src/xnnpack/zip.h b/src/xnnpack/zip.h
new file mode 100644
index 0000000..48b164e
--- /dev/null
+++ b/src/xnnpack/zip.h
@@ -0,0 +1,86 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <xnnpack/params.h>
+#include <xnnpack/common.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#define DECLARE_X8_ZIPC_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const uint8_t* x, \
+ uint8_t* y);
+
+DECLARE_X8_ZIPC_UKERNEL_FUNCTION(xnn_x8_zip_x2_ukernel__neon)
+DECLARE_X8_ZIPC_UKERNEL_FUNCTION(xnn_x8_zip_x2_ukernel__sse2)
+DECLARE_X8_ZIPC_UKERNEL_FUNCTION(xnn_x8_zip_x2_ukernel__scalar)
+DECLARE_X8_ZIPC_UKERNEL_FUNCTION(xnn_x8_zip_x3_ukernel__neon)
+DECLARE_X8_ZIPC_UKERNEL_FUNCTION(xnn_x8_zip_x3_ukernel__sse2)
+DECLARE_X8_ZIPC_UKERNEL_FUNCTION(xnn_x8_zip_x3_ukernel__scalar)
+DECLARE_X8_ZIPC_UKERNEL_FUNCTION(xnn_x8_zip_x4_ukernel__neon)
+DECLARE_X8_ZIPC_UKERNEL_FUNCTION(xnn_x8_zip_x4_ukernel__sse2)
+DECLARE_X8_ZIPC_UKERNEL_FUNCTION(xnn_x8_zip_x4_ukernel__scalar)
+
+
+#define DECLARE_X32_ZIPC_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ const uint32_t* x, \
+ uint32_t* y);
+
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x2_ukernel__neon)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x2_ukernel__psimd)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x2_ukernel__scalar)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x2_ukernel__sse2)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x3_ukernel__neon)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x3_ukernel__psimd)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x3_ukernel__scalar)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x3_ukernel__sse2)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x4_ukernel__neon)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x4_ukernel__psimd)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x4_ukernel__scalar)
+DECLARE_X32_ZIPC_UKERNEL_FUNCTION(xnn_x32_zip_x4_ukernel__sse2)
+
+
+#define DECLARE_X8_ZIPV_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t m, \
+ const uint8_t* x, \
+ uint8_t* y);
+
+DECLARE_X8_ZIPV_UKERNEL_FUNCTION(xnn_x8_zip_xm_ukernel__neon)
+DECLARE_X8_ZIPV_UKERNEL_FUNCTION(xnn_x8_zip_xm_ukernel__sse2)
+DECLARE_X8_ZIPV_UKERNEL_FUNCTION(xnn_x8_zip_xm_ukernel__scalar)
+
+
+#define DECLARE_X32_ZIPV_UKERNEL_FUNCTION(fn_name) \
+ XNN_INTERNAL void fn_name( \
+ size_t n, \
+ size_t m, \
+ const uint32_t* x, \
+ uint32_t* y);
+
+DECLARE_X32_ZIPV_UKERNEL_FUNCTION(xnn_x32_zip_xm_ukernel__neon)
+DECLARE_X32_ZIPV_UKERNEL_FUNCTION(xnn_x32_zip_xm_ukernel__psimd)
+DECLARE_X32_ZIPV_UKERNEL_FUNCTION(xnn_x32_zip_xm_ukernel__scalar)
+DECLARE_X32_ZIPV_UKERNEL_FUNCTION(xnn_x32_zip_xm_ukernel__sse2)
+
+
+#ifdef __cplusplus
+} /* extern "C" */
+#endif