SoftArgMax operator

SoftArgMax operator implementation using Three-Pass algorithm with reloading of
computed exponentials.

PiperOrigin-RevId: 291778753
diff --git a/include/xnnpack.h b/include/xnnpack.h
index 4fdc300..ae924c0 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -497,6 +497,20 @@
     float* output,
     pthreadpool_t threadpool);
 
+enum xnn_status xnn_create_softargmax_nc_f32(
+    size_t channels,
+    size_t input_stride,
+    size_t output_stride,
+    uint32_t flags,
+    xnn_operator_t* softargmax_op_out);
+
+enum xnn_status xnn_setup_softargmax_nc_f32(
+    xnn_operator_t softargmax_op,
+    size_t batch_size,
+    const float* input,
+    float* output,
+    pthreadpool_t threadpool);
+
 enum xnn_status xnn_create_subtract_nd_f32(
     float output_min,
     float output_max,
diff --git a/src/init.c b/src/init.c
index 2f9c728..2f7b608 100644
--- a/src/init.c
+++ b/src/init.c
@@ -37,6 +37,7 @@
 #include <xnnpack/params.h>
 #include <xnnpack/pavgpool.h>
 #include <xnnpack/prelu.h>
+#include <xnnpack/raddstoreexpminusmax.h>
 #include <xnnpack/rmax.h>
 #include <xnnpack/spmm.h>
 #include <xnnpack/unpool.h>
@@ -248,6 +249,8 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__neon_lut64_p2_x8;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__neon;
     xnn_params.f32.vadd = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__neon_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
@@ -552,6 +555,8 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__neonfma_lut64_p2_x16;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__neon;
     xnn_params.f32.vadd = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__neon_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
@@ -911,6 +916,8 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__sse2_p5_x20_acc2;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__sse;
     if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx512f()) {
       xnn_params.f32.vadd = (struct vbinary_parameters) {
         .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__avx512f_x32,
@@ -1218,6 +1225,8 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__psimd_p5_x16_acc2;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__psimd;
     xnn_params.f32.vadd = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__psimd_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__psimd_x8,
@@ -1424,6 +1433,8 @@
       .row_tile = 4,
       .channel_tile = 4,
     };
+    xnn_params.f32.raddstoreexpminusmax = xnn_f32_raddstoreexpminusmax_ukernel__scalar_p5_x4_acc2;
+    xnn_params.f32.rmax = xnn_f32_rmax_ukernel__scalar;
     xnn_params.f32.vadd = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__wasm_x4,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__wasm_x4,
diff --git a/src/operator-run.c b/src/operator-run.c
index 18ec2a5..f85d1ce 100644
--- a/src/operator-run.c
+++ b/src/operator-run.c
@@ -657,6 +657,27 @@
   context->lut_norm_ukernel(n, x, t, y);
 }
 
+void xnn_compute_f32_three_pass_softargmax(
+    const struct f32_three_pass_softargmax_context context[restrict static 1],
+    size_t batch_index)
+{
+  const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index);
+  float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index);
+  const size_t n = context->n;
+
+  // First pass: reduce-max
+  float x_max;
+  context->rmax_ukernel(n, x, &x_max);
+
+  // Second pass: reduce-add & store exp(x-x_max)
+  float y_sum;
+  context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max);
+
+  // Third pass: scale y
+  const float y_scale = 1.0f / y_sum;
+  context->vmulc_ukernel(n, y, &y_scale, y, &context->params);
+}
+
 void xnn_compute_vmulcaddc(
     const struct vmulcaddc_context context[restrict static 1],
     size_t batch_start,
diff --git a/src/softargmax-nc.c b/src/softargmax-nc.c
index c84915b..0be43b1 100644
--- a/src/softargmax-nc.c
+++ b/src/softargmax-nc.c
@@ -16,6 +16,7 @@
 #include <xnnpack/allocator.h>
 #include <xnnpack/operator.h>
 #include <xnnpack/log.h>
+#include <xnnpack/params-init.h>
 
 
 enum xnn_status xnn_create_softargmax_nc_q8(
@@ -46,7 +47,7 @@
 
   if (input_stride < channels) {
     xnn_log_error(
-      "failed to create Sigmoid operator with input element stride of %zu: "
+      "failed to create SoftArgMax operator with input element stride of %zu: "
       "stride must be at least as large as the number of channels (%zu)",
       input_stride, channels);
     goto error;
@@ -54,7 +55,7 @@
 
   if (output_stride < channels) {
     xnn_log_error(
-      "failed to create Sigmoid operator with output element stride of %zu: "
+      "failed to create SoftArgMax operator with output element stride of %zu: "
       "stride must be at least as large as the number of channels (%zu)",
       output_stride, channels);
     goto error;
@@ -173,3 +174,113 @@
 
   return xnn_status_success;
 }
+
+enum xnn_status xnn_create_softargmax_nc_f32(
+    size_t channels,
+    size_t input_stride,
+    size_t output_stride,
+    uint32_t flags,
+    xnn_operator_t* softargmax_op_out)
+{
+  xnn_operator_t softargmax_op = NULL;
+  enum xnn_status status = xnn_status_uninitialized;
+
+  if (!xnn_params.initialized) {
+    xnn_log_error("failed to create SoftArgMax operator: XNNPACK is not initialized");
+    goto error;
+  }
+
+  status = xnn_status_invalid_parameter;
+
+  if (channels == 0) {
+    xnn_log_error(
+      "failed to create SoftArgMax operator with %zu channels: number of channels must be non-zero", channels);
+    goto error;
+  }
+
+  if (input_stride < channels) {
+    xnn_log_error(
+      "failed to create SoftArgMax operator with input element stride of %zu: "
+      "stride must be at least as large as the number of channels (%zu)",
+      input_stride, channels);
+    goto error;
+  }
+
+  if (output_stride < channels) {
+    xnn_log_error(
+      "failed to create SoftArgMax operator with output element stride of %zu: "
+      "stride must be at least as large as the number of channels (%zu)",
+      output_stride, channels);
+    goto error;
+  }
+
+  status = xnn_status_out_of_memory;
+
+  softargmax_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
+  if (softargmax_op == NULL) {
+    xnn_log_error("failed to allocate %zu bytes for SoftArgMax operator descriptor", sizeof(struct xnn_operator));
+    goto error;
+  }
+
+  softargmax_op->channels = channels;
+  softargmax_op->input_pixel_stride = input_stride;
+  softargmax_op->output_pixel_stride = output_stride;
+
+  softargmax_op->type = xnn_operator_type_softargmax_nc_f32;
+  softargmax_op->ukernel.type = xnn_ukernel_type_softargmax;
+
+  softargmax_op->state = xnn_run_state_invalid;
+
+  *softargmax_op_out = softargmax_op;
+  return xnn_status_success;
+
+error:
+  xnn_delete_operator(softargmax_op);
+  return status;
+}
+
+enum xnn_status xnn_setup_softargmax_nc_f32(
+    xnn_operator_t softargmax_op,
+    size_t batch_size,
+    const float* input,
+    float* output,
+    pthreadpool_t threadpool)
+{
+  if (softargmax_op->type != xnn_operator_type_softargmax_nc_f32) {
+    xnn_log_error("failed to setup SoftArgMax (NC, F32) operator: operator type mismatch");
+    return xnn_status_invalid_parameter;
+  }
+  softargmax_op->state = xnn_run_state_invalid;
+
+  if (!xnn_params.initialized) {
+    xnn_log_error("failed to setup SoftArgMax operator: XNNPACK is not initialized");
+    return xnn_status_uninitialized;
+  }
+
+  if (batch_size == 0) {
+    softargmax_op->state = xnn_run_state_skip;
+    return xnn_status_success;
+  }
+
+  softargmax_op->batch_size = batch_size;
+  softargmax_op->input = input;
+  softargmax_op->output = output;
+
+  softargmax_op->context.f32_three_pass_softargmax = (struct f32_three_pass_softargmax_context) {
+    .n = softargmax_op->channels * sizeof(float),
+    .x = input,
+    .x_stride = softargmax_op->input_pixel_stride * sizeof(float),
+    .y = output,
+    .y_stride = softargmax_op->output_pixel_stride * sizeof(float),
+    .rmax_ukernel = xnn_params.f32.rmax,
+    .raddstoreexpminusmax_ukernel = xnn_params.f32.raddstoreexpminusmax,
+    .vmulc_ukernel = xnn_params.f32.vmul.opc_ukernel,
+    .params = xnn_init_f32_output_params(-INFINITY, INFINITY),
+  };
+  softargmax_op->compute.type = xnn_parallelization_type_1d;
+  softargmax_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_f32_three_pass_softargmax;
+  softargmax_op->compute.range[0] = batch_size;
+  softargmax_op->state = xnn_run_state_ready;
+
+  return xnn_status_success;
+}
diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h
index 58d3d9b..337ac36 100644
--- a/src/xnnpack/compute.h
+++ b/src/xnnpack/compute.h
@@ -761,3 +761,21 @@
       const struct u8_softargmax_context context[restrict static 1],
       size_t batch_index);
 #endif
+
+struct f32_three_pass_softargmax_context {
+  size_t n;
+  const void* x;
+  size_t x_stride;
+  void* y;
+  size_t y_stride;
+  xnn_f32_rmax_ukernel_function rmax_ukernel;
+  xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax_ukernel;
+  xnn_vbinary_ukernel_function vmulc_ukernel;
+  union xnn_f32_output_params params;
+};
+
+#ifndef __cplusplus
+  XNN_PRIVATE void xnn_compute_f32_three_pass_softargmax(
+      const struct f32_three_pass_softargmax_context context[restrict static 1],
+      size_t batch_index);
+#endif
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index 9ff2564..494444f 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -79,6 +79,7 @@
   xnn_operator_type_resize_bilinear_nhwc_f32,
   xnn_operator_type_sigmoid_nc_f32,
   xnn_operator_type_sigmoid_nc_q8,
+  xnn_operator_type_softargmax_nc_f32,
   xnn_operator_type_softargmax_nc_q8,
   xnn_operator_type_subtract_nd_f32,
   xnn_operator_type_unpooling_nhwc_x32,
@@ -276,6 +277,7 @@
     struct resize_bilinear_context resize_bilinear;
     struct spmm_context spmm;
     struct subconv_context subconv;
+    struct f32_three_pass_softargmax_context f32_three_pass_softargmax;
     struct u8_softargmax_context u8_softargmax;
     struct univector_contiguous_context univector_contiguous;
     struct univector_strided_context univector_strided;
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index 8de5adc..8259291 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -1369,6 +1369,8 @@
     struct vbinary_parameters vmul;
     struct vbinary_parameters vsub;
     struct vmulcaddc_parameters vmulcaddc;
+    xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax;
+    xnn_f32_rmax_ukernel_function rmax;
     // Sparse Matrix-Dense Matrix Multiplication (NR=1 block).
     struct spmm_parameters spmm;
     // Sparse Matrix-Dense Matrix Multiplication (NR=2 block).
diff --git a/test/softargmax-nc.cc b/test/softargmax-nc.cc
index a6857b0..288fea6 100644
--- a/test/softargmax-nc.cc
+++ b/test/softargmax-nc.cc
@@ -142,3 +142,109 @@
       .TestQ8();
   }
 }
+
+TEST(SOFTARGMAX_NC_F32, single_class) {
+  SoftArgMaxOperatorTester()
+    .batch_size(1)
+    .channels(1)
+    .iterations(100)
+    .TestF32();
+}
+
+TEST(SOFTARGMAX_NC_F32, two_classes) {
+  SoftArgMaxOperatorTester()
+    .batch_size(1)
+    .channels(2)
+    .iterations(100)
+    .TestF32();
+}
+
+TEST(SOFTARGMAX_NC_F32, many_classes) {
+  for (size_t channels = 3; channels < 100; channels++) {
+    SoftArgMaxOperatorTester()
+      .batch_size(1)
+      .channels(channels)
+      .iterations(1)
+      .TestF32();
+  }
+}
+
+TEST(SOFTARGMAX_NC_F32, cifar_classes) {
+  // CIFAR-10
+  SoftArgMaxOperatorTester()
+    .batch_size(1)
+    .channels(10)
+    .iterations(15)
+    .TestF32();
+  // CIFAR-100
+  SoftArgMaxOperatorTester()
+    .batch_size(1)
+    .channels(100)
+    .iterations(15)
+    .TestF32();
+}
+
+TEST(SOFTARGMAX_NC_F32, imagenet_classes) {
+  // ImageNet-1K
+  SoftArgMaxOperatorTester()
+    .batch_size(1)
+    .channels(1000)
+    .iterations(10)
+    .TestF32();
+  // ImageNet-1K+1
+  SoftArgMaxOperatorTester()
+    .batch_size(1)
+    .channels(1001)
+    .iterations(10)
+    .TestF32();
+  // ImageNet-22K
+  SoftArgMaxOperatorTester()
+    .batch_size(1)
+    .channels(21841)
+    .iterations(10)
+    .TestF32();
+}
+
+TEST(SOFTARGMAX_NC_F32, small_batch) {
+  for (size_t channels = 1; channels < 100; channels += 5) {
+    SoftArgMaxOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .iterations(3)
+      .TestF32();
+  }
+}
+
+TEST(SOFTARGMAX_NC_F32, small_batch_with_input_stride) {
+  for (size_t channels = 1; channels < 100; channels += 5) {
+    SoftArgMaxOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .iterations(3)
+      .TestF32();
+  }
+}
+
+TEST(SOFTARGMAX_NC_F32, small_batch_with_output_stride) {
+  for (size_t channels = 1; channels < 100; channels += 5) {
+    SoftArgMaxOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .output_stride(117)
+      .iterations(3)
+      .TestF32();
+  }
+}
+
+TEST(SOFTARGMAX_NC_F32, strided_batch_with_input_and_output_stride) {
+  for (size_t channels = 1; channels < 100; channels += 5) {
+    SoftArgMaxOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .output_stride(117)
+      .iterations(3)
+      .TestF32();
+  }
+}
diff --git a/test/softargmax-operator-tester.h b/test/softargmax-operator-tester.h
index 906822c..becfc4c 100644
--- a/test/softargmax-operator-tester.h
+++ b/test/softargmax-operator-tester.h
@@ -177,6 +177,68 @@
     }
   }
 
+  void TestF32() const {
+    std::random_device random_device;
+    auto rng = std::mt19937(random_device());
+    auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
+
+    std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
+    std::vector<float> output((batch_size() - 1) * output_stride() + channels());
+    std::vector<double> output_ref(batch_size() * channels());
+    for (size_t iteration = 0; iteration < iterations(); iteration++) {
+      std::generate(input.begin(), input.end(), std::ref(f32rng));
+      std::fill(output.begin(), output.end(), std::nanf(""));
+
+      // Compute reference results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        const double max_input = *std::max_element(
+          input.data() + i * input_stride(),
+          input.data() + i * input_stride() + channels());
+        double sum_exp = 0.0;
+        for (size_t c = 0; c < channels(); c++) {
+          sum_exp += std::exp(double(input[i * input_stride() + c]) - max_input);
+        }
+        for (size_t c = 0; c < channels(); c++) {
+          output_ref[i * channels() + c] =
+              std::exp(double(input[i * input_stride() + c]) - max_input) / sum_exp;
+        }
+      }
+
+      // Create, setup, run, and destroy SoftArgMax operator.
+      ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+      xnn_operator_t soft_arg_max_op = nullptr;
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_create_softargmax_nc_f32(
+          channels(), input_stride(), output_stride(),
+          0, &soft_arg_max_op));
+      ASSERT_NE(nullptr, soft_arg_max_op);
+
+      // Smart pointer to automatically delete soft_arg_max_op.
+      std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_soft_arg_max_op(soft_arg_max_op, xnn_delete_operator);
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_setup_softargmax_nc_f32(
+          soft_arg_max_op,
+          batch_size(),
+          input.data(), output.data(),
+          nullptr /* thread pool */));
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_run_operator(soft_arg_max_op, nullptr /* thread pool */));
+
+      // Verify results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          ASSERT_NEAR(
+            double(output[i * output_stride() + c]),
+            output_ref[i * channels() + c],
+            output_ref[i * channels() + c] * 1.0e-4);
+        }
+      }
+    }
+  }
+
  private:
   size_t batch_size_{1};
   size_t channels_{1};