F16 PReLU operator

PiperOrigin-RevId: 426323096
diff --git a/test/prelu-operator-tester.h b/test/prelu-operator-tester.h
index 47090a6..53565ce 100644
--- a/test/prelu-operator-tester.h
+++ b/test/prelu-operator-tester.h
@@ -7,6 +7,8 @@
 
 #include <gtest/gtest.h>
 
+#include <fp16.h>
+
 #include <algorithm>
 #include <cmath>
 #include <cstddef>
@@ -79,6 +81,69 @@
     return this->iterations_;
   }
 
+  void TestF16() const {
+    std::random_device random_device;
+    auto rng = std::mt19937(random_device());
+    auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
+    auto f16irng = std::bind(fp16_ieee_from_fp32_value, f32irng);
+    auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng);
+    auto f16wrng = std::bind(fp16_ieee_from_fp32_value, f32wrng);
+
+    std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
+    std::vector<uint16_t> w(channels());
+    std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
+    std::vector<float> y_ref(batch_size() * channels());
+    for (size_t iteration = 0; iteration < iterations(); iteration++) {
+      std::generate(x.begin(), x.end(), std::ref(f16irng));
+      std::generate(w.begin(), w.end(), std::ref(f16wrng));
+      std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
+
+      // Compute reference results, without clamping.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]);
+          const float w_value = fp16_ieee_to_fp32_value(w[c]);
+          y_ref[i * channels() + c] = signbit(x_value) ? x_value * w_value : x_value;
+        }
+      }
+
+      // Create, setup, run, and destroy PReLU operator.
+      ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+      xnn_operator_t prelu_op = nullptr;
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_create_prelu_nc_f16(
+          channels(), x_stride(), y_stride(),
+          w.data(),
+          0, &prelu_op));
+      ASSERT_NE(nullptr, prelu_op);
+
+      // Smart pointer to automatically delete prelu_op.
+      std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_setup_prelu_nc_f16(
+          prelu_op,
+          batch_size(),
+          x.data(), y.data(),
+          nullptr /* thread pool */));
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_run_operator(prelu_op, nullptr /* thread pool */));
+
+      // Verify results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          ASSERT_NEAR(
+              fp16_ieee_to_fp32_value(y[i * y_stride() + c]),
+              y_ref[i * channels() + c],
+              std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-4f))
+            << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
+        }
+      }
+    }
+  }
+
   void TestF32() const {
     std::random_device random_device;
     auto rng = std::mt19937(random_device());
@@ -128,8 +193,11 @@
       // Verify results.
       for (size_t i = 0; i < batch_size(); i++) {
         for (size_t c = 0; c < channels(); c++) {
-          ASSERT_NEAR(y[i * y_stride() + c], y_ref[i * channels() + c], 1.0e-6f * std::abs(y_ref[i * channels() + c]))
-            << "i = " << i << ", c = " << c;
+          ASSERT_NEAR(
+              y[i * y_stride() + c],
+              y_ref[i * channels() + c],
+              std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f))
+            << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
         }
       }
     }