SoftArgMax operator
SoftArgMax operator implementation using Three-Pass algorithm with reloading of
computed exponentials.
PiperOrigin-RevId: 291778753
diff --git a/test/softargmax-operator-tester.h b/test/softargmax-operator-tester.h
index 906822c..becfc4c 100644
--- a/test/softargmax-operator-tester.h
+++ b/test/softargmax-operator-tester.h
@@ -177,6 +177,68 @@
}
}
+ void TestF32() const {
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
+
+ std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
+ std::vector<float> output((batch_size() - 1) * output_stride() + channels());
+ std::vector<double> output_ref(batch_size() * channels());
+ for (size_t iteration = 0; iteration < iterations(); iteration++) {
+ std::generate(input.begin(), input.end(), std::ref(f32rng));
+ std::fill(output.begin(), output.end(), std::nanf(""));
+
+ // Compute reference results.
+ for (size_t i = 0; i < batch_size(); i++) {
+ const double max_input = *std::max_element(
+ input.data() + i * input_stride(),
+ input.data() + i * input_stride() + channels());
+ double sum_exp = 0.0;
+ for (size_t c = 0; c < channels(); c++) {
+ sum_exp += std::exp(double(input[i * input_stride() + c]) - max_input);
+ }
+ for (size_t c = 0; c < channels(); c++) {
+ output_ref[i * channels() + c] =
+ std::exp(double(input[i * input_stride() + c]) - max_input) / sum_exp;
+ }
+ }
+
+ // Create, setup, run, and destroy SoftArgMax operator.
+ ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+ xnn_operator_t soft_arg_max_op = nullptr;
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_create_softargmax_nc_f32(
+ channels(), input_stride(), output_stride(),
+ 0, &soft_arg_max_op));
+ ASSERT_NE(nullptr, soft_arg_max_op);
+
+ // Smart pointer to automatically delete soft_arg_max_op.
+ std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_soft_arg_max_op(soft_arg_max_op, xnn_delete_operator);
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_setup_softargmax_nc_f32(
+ soft_arg_max_op,
+ batch_size(),
+ input.data(), output.data(),
+ nullptr /* thread pool */));
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_run_operator(soft_arg_max_op, nullptr /* thread pool */));
+
+ // Verify results.
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t c = 0; c < channels(); c++) {
+ ASSERT_NEAR(
+ double(output[i * output_stride() + c]),
+ output_ref[i * channels() + c],
+ output_ref[i * channels() + c] * 1.0e-4);
+ }
+ }
+ }
+ }
+
private:
size_t batch_size_{1};
size_t channels_{1};