F16 PReLU operator
PiperOrigin-RevId: 426323096
diff --git a/test/prelu-operator-tester.h b/test/prelu-operator-tester.h
index 47090a6..53565ce 100644
--- a/test/prelu-operator-tester.h
+++ b/test/prelu-operator-tester.h
@@ -7,6 +7,8 @@
#include <gtest/gtest.h>
+#include <fp16.h>
+
#include <algorithm>
#include <cmath>
#include <cstddef>
@@ -79,6 +81,69 @@
return this->iterations_;
}
+ void TestF16() const {
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
+ auto f16irng = std::bind(fp16_ieee_from_fp32_value, f32irng);
+ auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng);
+ auto f16wrng = std::bind(fp16_ieee_from_fp32_value, f32wrng);
+
+ std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
+ std::vector<uint16_t> w(channels());
+ std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
+ std::vector<float> y_ref(batch_size() * channels());
+ for (size_t iteration = 0; iteration < iterations(); iteration++) {
+ std::generate(x.begin(), x.end(), std::ref(f16irng));
+ std::generate(w.begin(), w.end(), std::ref(f16wrng));
+ std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
+
+ // Compute reference results, without clamping.
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t c = 0; c < channels(); c++) {
+ const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]);
+ const float w_value = fp16_ieee_to_fp32_value(w[c]);
+ y_ref[i * channels() + c] = signbit(x_value) ? x_value * w_value : x_value;
+ }
+ }
+
+ // Create, setup, run, and destroy PReLU operator.
+ ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+ xnn_operator_t prelu_op = nullptr;
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_create_prelu_nc_f16(
+ channels(), x_stride(), y_stride(),
+ w.data(),
+ 0, &prelu_op));
+ ASSERT_NE(nullptr, prelu_op);
+
+ // Smart pointer to automatically delete prelu_op.
+ std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_setup_prelu_nc_f16(
+ prelu_op,
+ batch_size(),
+ x.data(), y.data(),
+ nullptr /* thread pool */));
+
+ ASSERT_EQ(xnn_status_success,
+ xnn_run_operator(prelu_op, nullptr /* thread pool */));
+
+ // Verify results.
+ for (size_t i = 0; i < batch_size(); i++) {
+ for (size_t c = 0; c < channels(); c++) {
+ ASSERT_NEAR(
+ fp16_ieee_to_fp32_value(y[i * y_stride() + c]),
+ y_ref[i * channels() + c],
+ std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-4f))
+ << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
+ }
+ }
+ }
+ }
+
void TestF32() const {
std::random_device random_device;
auto rng = std::mt19937(random_device());
@@ -128,8 +193,11 @@
// Verify results.
for (size_t i = 0; i < batch_size(); i++) {
for (size_t c = 0; c < channels(); c++) {
- ASSERT_NEAR(y[i * y_stride() + c], y_ref[i * channels() + c], 1.0e-6f * std::abs(y_ref[i * channels() + c]))
- << "i = " << i << ", c = " << c;
+ ASSERT_NEAR(
+ y[i * y_stride() + c],
+ y_ref[i * channels() + c],
+ std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f))
+ << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
}
}
}