SoftArgMax operator

SoftArgMax operator implementation using Three-Pass algorithm with reloading of
computed exponentials.

PiperOrigin-RevId: 291778753
diff --git a/src/init.c b/src/init.c
index 2f9c728..2f7b608 100644
--- a/src/init.c
+++ b/src/init.c
@@ -37,6 +37,7 @@
 #include <xnnpack/params.h>
 #include <xnnpack/pavgpool.h>
 #include <xnnpack/prelu.h>
+#include <xnnpack/raddstoreexpminusmax.h>
 #include <xnnpack/rmax.h>
 #include <xnnpack/spmm.h>
 #include <xnnpack/unpool.h>
@@ -248,6 +249,8 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__neon_lut64_p2_x8;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__neon;
     xnn_params.f32.vadd = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__neon_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
@@ -552,6 +555,8 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__neonfma_lut64_p2_x16;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__neon;
     xnn_params.f32.vadd = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__neon_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
@@ -911,6 +916,8 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__sse2_p5_x20_acc2;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__sse;
     if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx512f()) {
       xnn_params.f32.vadd = (struct vbinary_parameters) {
         .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__avx512f_x32,
@@ -1218,6 +1225,8 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__psimd_p5_x16_acc2;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__psimd;
     xnn_params.f32.vadd = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__psimd_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__psimd_x8,
@@ -1424,6 +1433,8 @@
       .row_tile = 4,
       .channel_tile = 4,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__scalar_p5_x4_acc2;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__scalar;
     xnn_params.f32.vadd = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__wasm_x4,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__wasm_x4,
diff --git a/src/operator-run.c b/src/operator-run.c
index 18ec2a5..f85d1ce 100644
--- a/src/operator-run.c
+++ b/src/operator-run.c
@@ -657,6 +657,27 @@
   context->lut_norm_ukernel(n, x, t, y);
 }
 
+void xnn_compute_f32_three_pass_softargmax(
+    const struct f32_three_pass_softargmax_context context[restrict static 1],
+    size_t batch_index)
+{
+  const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index);
+  float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index);
+  const size_t n = context->n;
+
+  // First pass: reduce-max
+  float x_max;
+  context->rmax_ukernel(n, x, &x_max);
+
+  // Second pass: reduce-add & store exp(x-x_max)
+  float y_sum;
+  context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max);
+
+  // Third pass: scale y
+  const float y_scale = 1.0f / y_sum;
+  context->vmulc_ukernel(n, y, &y_scale, y, &context->params);
+}
+
 void xnn_compute_vmulcaddc(
     const struct vmulcaddc_context context[restrict static 1],
     size_t batch_start,
diff --git a/src/softargmax-nc.c b/src/softargmax-nc.c
index c84915b..0be43b1 100644
--- a/src/softargmax-nc.c
+++ b/src/softargmax-nc.c
@@ -16,6 +16,7 @@
 #include <xnnpack/allocator.h>
 #include <xnnpack/operator.h>
 #include <xnnpack/log.h>
+#include <xnnpack/params-init.h>
 
 
 enum xnn_status xnn_create_softargmax_nc_q8(
@@ -46,7 +47,7 @@
 
   if (input_stride < channels) {
     xnn_log_error(
-      "failed to create Sigmoid operator with input element stride of %zu: "
+      "failed to create SoftArgMax operator with input element stride of %zu: "
       "stride must be at least as large as the number of channels (%zu)",
       input_stride, channels);
     goto error;
@@ -54,7 +55,7 @@
 
   if (output_stride < channels) {
     xnn_log_error(
-      "failed to create Sigmoid operator with output element stride of %zu: "
+      "failed to create SoftArgMax operator with output element stride of %zu: "
       "stride must be at least as large as the number of channels (%zu)",
       output_stride, channels);
     goto error;
@@ -173,3 +174,113 @@
 
   return xnn_status_success;
 }
+
+enum xnn_status xnn_create_softargmax_nc_f32(
+    size_t channels,
+    size_t input_stride,
+    size_t output_stride,
+    uint32_t flags,
+    xnn_operator_t* softargmax_op_out)
+{
+  xnn_operator_t softargmax_op = NULL;
+  enum xnn_status status = xnn_status_uninitialized;
+
+  if (!xnn_params.initialized) {
+    xnn_log_error("failed to create SoftArgMax operator: XNNPACK is not initialized");
+    goto error;
+  }
+
+  status = xnn_status_invalid_parameter;
+
+  if (channels == 0) {
+    xnn_log_error(
+      "failed to create SoftArgMax operator with %zu channels: number of channels must be non-zero", channels);
+    goto error;
+  }
+
+  if (input_stride < channels) {
+    xnn_log_error(
+      "failed to create SoftArgMax operator with input element stride of %zu: "
+      "stride must be at least as large as the number of channels (%zu)",
+      input_stride, channels);
+    goto error;
+  }
+
+  if (output_stride < channels) {
+    xnn_log_error(
+      "failed to create SoftArgMax operator with output element stride of %zu: "
+      "stride must be at least as large as the number of channels (%zu)",
+      output_stride, channels);
+    goto error;
+  }
+
+  status = xnn_status_out_of_memory;
+
+  softargmax_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
+  if (softargmax_op == NULL) {
+    xnn_log_error("failed to allocate %zu bytes for SoftArgMax operator descriptor", sizeof(struct xnn_operator));
+    goto error;
+  }
+
+  softargmax_op->channels = channels;
+  softargmax_op->input_pixel_stride = input_stride;
+  softargmax_op->output_pixel_stride = output_stride;
+
+  softargmax_op->type = xnn_operator_type_softargmax_nc_f32;
+  softargmax_op->ukernel.type = xnn_ukernel_type_softargmax;
+
+  softargmax_op->state = xnn_run_state_invalid;
+
+  *softargmax_op_out = softargmax_op;
+  return xnn_status_success;
+
+error:
+  xnn_delete_operator(softargmax_op);
+  return status;
+}
+
+enum xnn_status xnn_setup_softargmax_nc_f32(
+    xnn_operator_t softargmax_op,
+    size_t batch_size,
+    const float* input,
+    float* output,
+    pthreadpool_t threadpool)
+{
+  if (softargmax_op->type != xnn_operator_type_softargmax_nc_f32) {
+    xnn_log_error("failed to setup SoftArgMax (NC, F32) operator: operator type mismatch");
+    return xnn_status_invalid_parameter;
+  }
+  softargmax_op->state = xnn_run_state_invalid;
+
+  if (!xnn_params.initialized) {
+    xnn_log_error("failed to setup SoftArgMax operator: XNNPACK is not initialized");
+    return xnn_status_uninitialized;
+  }
+
+  if (batch_size == 0) {
+    softargmax_op->state = xnn_run_state_skip;
+    return xnn_status_success;
+  }
+
+  softargmax_op->batch_size = batch_size;
+  softargmax_op->input = input;
+  softargmax_op->output = output;
+
+  softargmax_op->context.f32_three_pass_softargmax = (struct f32_three_pass_softargmax_context) {
+    .n = softargmax_op->channels * sizeof(float),
+    .x = input,
+    .x_stride = softargmax_op->input_pixel_stride * sizeof(float),
+    .y = output,
+    .y_stride = softargmax_op->output_pixel_stride * sizeof(float),
+    .rmax_ukernel = xnn_params.f32.rmax,
+    .raddstoreexpminusmax_ukernel = xnn_params.f32.raddstoreexpminusmax,
+    .vmulc_ukernel = xnn_params.f32.vmul.opc_ukernel,
+    .params = xnn_init_f32_output_params(-INFINITY, INFINITY),
+  };
+  softargmax_op->compute.type = xnn_parallelization_type_1d;
+  softargmax_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_f32_three_pass_softargmax;
+  softargmax_op->compute.range[0] = batch_size;
+  softargmax_op->state = xnn_run_state_ready;
+
+  return xnn_status_success;
+}
diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h
index 58d3d9b..337ac36 100644
--- a/src/xnnpack/compute.h
+++ b/src/xnnpack/compute.h
@@ -761,3 +761,21 @@
       const struct u8_softargmax_context context[restrict static 1],
       size_t batch_index);
 #endif
+
+struct f32_three_pass_softargmax_context {
+  size_t n;
+  const void* x;
+  size_t x_stride;
+  void* y;
+  size_t y_stride;
+  xnn_f32_rmax_ukernel_function rmax_ukernel;
+  xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax_ukernel;
+  xnn_vbinary_ukernel_function vmulc_ukernel;
+  union xnn_f32_output_params params;
+};
+
+#ifndef __cplusplus
+  XNN_PRIVATE void xnn_compute_f32_three_pass_softargmax(
+      const struct f32_three_pass_softargmax_context context[restrict static 1],
+      size_t batch_index);
+#endif
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index 9ff2564..494444f 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -79,6 +79,7 @@
   xnn_operator_type_resize_bilinear_nhwc_f32,
   xnn_operator_type_sigmoid_nc_f32,
   xnn_operator_type_sigmoid_nc_q8,
+  xnn_operator_type_softargmax_nc_f32,
   xnn_operator_type_softargmax_nc_q8,
   xnn_operator_type_subtract_nd_f32,
   xnn_operator_type_unpooling_nhwc_x32,
@@ -276,6 +277,7 @@
     struct resize_bilinear_context resize_bilinear;
     struct spmm_context spmm;
     struct subconv_context subconv;
+    struct f32_three_pass_softargmax_context f32_three_pass_softargmax;
     struct u8_softargmax_context u8_softargmax;
     struct univector_contiguous_context univector_contiguous;
     struct univector_strided_context univector_strided;
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index 8de5adc..8259291 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -1369,6 +1369,8 @@
     struct vbinary_parameters vmul;
     struct vbinary_parameters vsub;
     struct vmulcaddc_parameters vmulcaddc;
+    xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax;
+    xnn_f32_rmax_ukernel_function rmax;
     // Sparse Matrix-Dense Matrix Multiplication (NR=1 block).
     struct spmm_parameters spmm;
     // Sparse Matrix-Dense Matrix Multiplication (NR=2 block).