SoftArgMax operator

SoftArgMax operator implementation using Three-Pass algorithm with reloading of
computed exponentials.

PiperOrigin-RevId: 291778753
diff --git a/test/softargmax-operator-tester.h b/test/softargmax-operator-tester.h
index 906822c..becfc4c 100644
--- a/test/softargmax-operator-tester.h
+++ b/test/softargmax-operator-tester.h
@@ -177,6 +177,68 @@
     }
   }
 
+  void TestF32() const {
+    std::random_device random_device;
+    auto rng = std::mt19937(random_device());
+    auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
+
+    std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
+    std::vector<float> output((batch_size() - 1) * output_stride() + channels());
+    std::vector<double> output_ref(batch_size() * channels());
+    for (size_t iteration = 0; iteration < iterations(); iteration++) {
+      std::generate(input.begin(), input.end(), std::ref(f32rng));
+      std::fill(output.begin(), output.end(), std::nanf(""));
+
+      // Compute reference results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        const double max_input = *std::max_element(
+          input.data() + i * input_stride(),
+          input.data() + i * input_stride() + channels());
+        double sum_exp = 0.0;
+        for (size_t c = 0; c < channels(); c++) {
+          sum_exp += std::exp(double(input[i * input_stride() + c]) - max_input);
+        }
+        for (size_t c = 0; c < channels(); c++) {
+          output_ref[i * channels() + c] =
+              std::exp(double(input[i * input_stride() + c]) - max_input) / sum_exp;
+        }
+      }
+
+      // Create, setup, run, and destroy SoftArgMax operator.
+      ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+      xnn_operator_t soft_arg_max_op = nullptr;
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_create_softargmax_nc_f32(
+          channels(), input_stride(), output_stride(),
+          0, &soft_arg_max_op));
+      ASSERT_NE(nullptr, soft_arg_max_op);
+
+      // Smart pointer to automatically delete soft_arg_max_op.
+      std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_soft_arg_max_op(soft_arg_max_op, xnn_delete_operator);
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_setup_softargmax_nc_f32(
+          soft_arg_max_op,
+          batch_size(),
+          input.data(), output.data(),
+          nullptr /* thread pool */));
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_run_operator(soft_arg_max_op, nullptr /* thread pool */));
+
+      // Verify results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          ASSERT_NEAR(
+            double(output[i * output_stride() + c]),
+            output_ref[i * channels() + c],
+            output_ref[i * channels() + c] * 1.0e-4);
+        }
+      }
+    }
+  }
+
  private:
   size_t batch_size_{1};
   size_t channels_{1};