Add binary op microkernels with RELU activation
PiperOrigin-RevId: 325607697
diff --git a/test/vbinaryc-microkernel-tester.h b/test/vbinaryc-microkernel-tester.h
index 0f9efa8..ac2729a 100644
--- a/test/vbinaryc-microkernel-tester.h
+++ b/test/vbinaryc-microkernel-tester.h
@@ -387,6 +387,79 @@
}
}
+ void Test(xnn_f32_vbinary_relu_ukernel_function vbinaryc_relu, OpType op_type, Variant variant = Variant::Native) const {
+ std::random_device random_device;
+ auto rng = std::mt19937(random_device());
+ auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
+
+ std::vector<float> a(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
+ const float b = f32rng();
+ std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
+ std::vector<float> y_ref(batch_size());
+ for (size_t iteration = 0; iteration < iterations(); iteration++) {
+ std::generate(a.begin(), a.end(), std::ref(f32rng));
+ if (inplace()) {
+ std::generate(y.begin(), y.end(), std::ref(f32rng));
+ } else {
+ std::fill(y.begin(), y.end(), nanf(""));
+ }
+ const float* a_data = inplace() ? y.data() : a.data();
+
+ // Compute reference results.
+ for (size_t i = 0; i < batch_size(); i++) {
+ switch (op_type) {
+ case OpType::AddC:
+ y_ref[i] = a_data[i] + b;
+ break;
+ case OpType::DivC:
+ y_ref[i] = a_data[i] / b;
+ break;
+ case OpType::RDivC:
+ y_ref[i] = b / a_data[i];
+ break;
+ case OpType::MaxC:
+ y_ref[i] = std::max<float>(a_data[i], b);
+ break;
+ case OpType::MinC:
+ y_ref[i] = std::min<float>(a_data[i], b);
+ break;
+ case OpType::MulC:
+ y_ref[i] = a_data[i] * b;
+ break;
+ case OpType::SqrDiffC:
+ {
+ const float diff = a_data[i] - b;
+ y_ref[i] = diff * diff;
+ break;
+ }
+ case OpType::SubC:
+ y_ref[i] = a_data[i] - b;
+ break;
+ case OpType::RSubC:
+ y_ref[i] = b - a_data[i];
+ break;
+ }
+ }
+ for (size_t i = 0; i < batch_size(); i++) {
+ y_ref[i] = std::max(y_ref[i], 0.0f);
+ }
+
+ // Prepare parameters.
+ xnn_f32_relu_params params = { };
+
+ // Call optimized micro-kernel.
+ vbinaryc_relu(batch_size() * sizeof(float), a_data, &b, y.data(), ¶ms);
+
+ // Verify results.
+ for (size_t i = 0; i < batch_size(); i++) {
+ ASSERT_GE(y[i], 0.0f)
+ << "at " << i << " / " << batch_size();
+ ASSERT_NEAR(y[i], y_ref[i], std::abs(y_ref[i]) * 1.0e-6f)
+ << "at " << i << " / " << batch_size();
+ }
+ }
+ }
+
private:
size_t batch_size_{1};
bool inplace_{false};