FP32 Leaky ReLU operator

PiperOrigin-RevId: 315806661
diff --git a/README.md b/README.md
index 9d57be2..5af9241 100644
--- a/README.md
+++ b/README.md
@@ -38,6 +38,7 @@
 - Copy
 - Floor (rounding to integer below)
 - HardSwish
+- Leaky ReLU
 - Negate
 - Sigmoid
 - Softmax
diff --git a/include/xnnpack.h b/include/xnnpack.h
index e0caab3..96a4386 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -1251,6 +1251,21 @@
   float* output,
   pthreadpool_t threadpool);
 
+enum xnn_status xnn_create_leaky_relu_nc_f32(
+  size_t channels,
+  size_t input_stride,
+  size_t output_stride,
+  float negative_slope,
+  uint32_t flags,
+  xnn_operator_t* leaky_relu_op_out);
+
+enum xnn_status xnn_setup_leaky_relu_nc_f32(
+  xnn_operator_t leaky_relu_op,
+  size_t batch_size,
+  const float* input,
+  float* output,
+  pthreadpool_t threadpool);
+
 enum xnn_status xnn_create_max_pooling2d_nhwc_f32(
   uint32_t input_padding_top,
   uint32_t input_padding_right,
diff --git a/src/init.c b/src/init.c
index 7c73a21..63f05c5 100644
--- a/src/init.c
+++ b/src/init.c
@@ -297,6 +297,7 @@
       xnn_params.f32.abs = (xnn_univector_ukernel_function) xnn_f32_vabs_ukernel__neon_x8;
       xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__neon_x8;
       xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__neon_x8;
+      xnn_params.f32.lrelu = (xnn_univector_ukernel_function) xnn_f32_vlrelu_ukernel__neon_x8;
       xnn_params.f32.neg = (xnn_univector_ukernel_function) xnn_f32_vneg_ukernel__neon_x8;
       if (cpuinfo_has_arm_neon_v8()) {
         xnn_params.f32.rndne = (xnn_univector_ukernel_function) xnn_f32_vrndne_ukernel__neonv8_x8;
@@ -516,6 +517,7 @@
       xnn_params.f32.abs = (xnn_univector_ukernel_function) xnn_f32_vabs_ukernel__scalar_x4;
       xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__scalar_x4;
       xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__scalar_x4;
+      xnn_params.f32.lrelu = (xnn_univector_ukernel_function) xnn_f32_vlrelu_ukernel__scalar_x4;
       xnn_params.f32.neg = (xnn_univector_ukernel_function) xnn_f32_vneg_ukernel__scalar_x4;
       xnn_params.f32.rndne = (xnn_univector_ukernel_function) xnn_f32_vrndne_ukernel__scalar_libm_x1;
       xnn_params.f32.rndz  = (xnn_univector_ukernel_function) xnn_f32_vrndz_ukernel__scalar_libm_x1;
@@ -936,6 +938,7 @@
     xnn_params.f32.abs = (xnn_univector_ukernel_function) xnn_f32_vabs_ukernel__neon_x8;
     xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__neon_x8;
     xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__neonfma_x8;
+    xnn_params.f32.lrelu = (xnn_univector_ukernel_function) xnn_f32_vlrelu_ukernel__neon_x8;
     xnn_params.f32.neg = (xnn_univector_ukernel_function) xnn_f32_vneg_ukernel__neon_x8;
     xnn_params.f32.rndne = (xnn_univector_ukernel_function) xnn_f32_vrndne_ukernel__neonv8_x8;
     xnn_params.f32.rndz = (xnn_univector_ukernel_function) xnn_f32_vrndz_ukernel__neonv8_x8;
@@ -1289,6 +1292,13 @@
       xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__sse_x8;
     }
     if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx512f()) {
+      xnn_params.f32.lrelu = (xnn_univector_ukernel_function) xnn_f32_vlrelu_ukernel__avx512f_x16;
+    } else if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx()) {
+      xnn_params.f32.lrelu = (xnn_univector_ukernel_function) xnn_f32_vlrelu_ukernel__avx_x16;
+    } else {
+      xnn_params.f32.lrelu = (xnn_univector_ukernel_function) xnn_f32_vlrelu_ukernel__sse_x8;
+    }
+    if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx512f()) {
       xnn_params.f32.neg = (xnn_univector_ukernel_function) xnn_f32_vneg_ukernel__avx512f_x16;
     } else if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx()) {
       xnn_params.f32.neg = (xnn_univector_ukernel_function) xnn_f32_vneg_ukernel__avx_x16;
@@ -1663,6 +1673,7 @@
     xnn_params.f32.abs = (xnn_univector_ukernel_function) xnn_f32_vabs_ukernel__psimd_x8;
     xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__psimd_x8;
     xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__psimd_x8;
+    xnn_params.f32.lrelu = (xnn_univector_ukernel_function) xnn_f32_vlrelu_ukernel__psimd_x8;
     xnn_params.f32.neg = (xnn_univector_ukernel_function) xnn_f32_vneg_ukernel__psimd_x8;
     xnn_params.f32.rndne = (xnn_univector_ukernel_function) xnn_f32_vrndne_ukernel__psimd_x8;
     xnn_params.f32.rndz  = (xnn_univector_ukernel_function) xnn_f32_vrndz_ukernel__psimd_x8;
@@ -1906,6 +1917,7 @@
     xnn_params.f32.abs = (xnn_univector_ukernel_function) xnn_f32_vabs_ukernel__scalar_x4;
     xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__wasm_x4;
     xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__wasm_x4;
+    xnn_params.f32.lrelu = (xnn_univector_ukernel_function) xnn_f32_vlrelu_ukernel__scalar_x4;
     xnn_params.f32.neg = (xnn_univector_ukernel_function) xnn_f32_vneg_ukernel__scalar_x4;
     xnn_params.f32.rndne = (xnn_univector_ukernel_function) xnn_f32_vrndne_ukernel__scalar_libm_x4;
     xnn_params.f32.rndz  = (xnn_univector_ukernel_function) xnn_f32_vrndz_ukernel__scalar_libm_x4;
diff --git a/src/operator-strings.c b/src/operator-strings.c
index 22a6d2e..581af83 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -70,6 +70,8 @@
       return "Global Average Pooling (NCW, F32)";
     case xnn_operator_type_hardswish_nc_f32:
       return "HardSwish (NC, F32)";
+    case xnn_operator_type_leaky_relu_nc_f32:
+      return "Leaky ReLU (NC, F32)";
     case xnn_operator_type_leaky_relu_nc_q8:
       return "Leaky ReLU (NC, Q8)";
     case xnn_operator_type_max_pooling_nhwc_f32:
diff --git a/src/operators/unary-elementwise-nc.c b/src/operators/unary-elementwise-nc.c
index 38ed68e..5fd0029 100644
--- a/src/operators/unary-elementwise-nc.c
+++ b/src/operators/unary-elementwise-nc.c
@@ -295,6 +295,30 @@
     hardswish_op_out);
 }
 
+enum xnn_status xnn_create_leaky_relu_nc_f32(
+  size_t channels,
+  size_t input_stride,
+  size_t output_stride,
+  float negative_slope,
+  uint32_t flags,
+  xnn_operator_t* leaky_relu_op_out)
+{
+  if (!isfinite(negative_slope)) {
+    xnn_log_error(
+      "failed to create %s operator with %f negative slope: finite number expected",
+      xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_f32),
+      negative_slope);
+    return xnn_status_invalid_parameter;
+  }
+
+  const union xnn_f32_lrelu_params params = xnn_init_f32_lrelu_params(negative_slope);
+  return create_unary_elementwise_nc(
+    channels, input_stride, output_stride, flags,
+    &params, sizeof(params),
+    xnn_operator_type_leaky_relu_nc_f32,
+    leaky_relu_op_out);
+}
+
 enum xnn_status xnn_create_negate_nc_f32(
     size_t channels,
     size_t input_stride,
@@ -541,6 +565,29 @@
     &hardswish_op->params.f32_hswish, sizeof(hardswish_op->params.f32_hswish));
 }
 
+enum xnn_status xnn_setup_leaky_relu_nc_f32(
+  xnn_operator_t leaky_relu_op,
+  size_t batch_size,
+  const float* input,
+  float* output,
+  pthreadpool_t threadpool)
+{
+  if (leaky_relu_op->type != xnn_operator_type_leaky_relu_nc_f32) {
+    xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
+      xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_f32),
+      xnn_operator_type_to_string(leaky_relu_op->type));
+    return xnn_status_invalid_parameter;
+  }
+  leaky_relu_op->state = xnn_run_state_invalid;
+
+  return setup_unary_elementwise_nc(
+    leaky_relu_op,
+    batch_size, input, output,
+    xnn_params.f32.lrelu,
+    2 /* log2(sizeof(float)) */,
+    &leaky_relu_op->params.f32_lrelu, sizeof(leaky_relu_op->params.f32_lrelu));
+}
+
 enum xnn_status xnn_setup_negate_nc_f32(
     xnn_operator_t negate_op,
     size_t batch_size,
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index cb1a9d8..af281e4 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -70,6 +70,7 @@
   xnn_operator_type_global_average_pooling_nwc_q8,
   xnn_operator_type_global_average_pooling_ncw_f32,
   xnn_operator_type_hardswish_nc_f32,
+  xnn_operator_type_leaky_relu_nc_f32,
   xnn_operator_type_leaky_relu_nc_q8,
   xnn_operator_type_max_pooling_nhwc_f32,
   xnn_operator_type_max_pooling_nhwc_u8,
@@ -246,6 +247,7 @@
 
   union {
     union xnn_f32_abs_params f32_abs;
+    union xnn_f32_lrelu_params f32_lrelu;
     union xnn_f32_neg_params f32_neg;
     union xnn_f32_rnd_params f32_rnd;
     // Parameters for Global Average Pooling in CHW layout
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index f6900eb..e945bab 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -1697,6 +1697,7 @@
     xnn_univector_ukernel_function abs;
     xnn_univector_ukernel_function clamp;
     xnn_univector_ukernel_function hswish;
+    xnn_univector_ukernel_function lrelu;
     xnn_univector_ukernel_function neg;
     xnn_univector_ukernel_function rndne;
     xnn_univector_ukernel_function rndz;
diff --git a/test/leaky-relu-nc.cc b/test/leaky-relu-nc.cc
index 1fd10c6..ac16818 100644
--- a/test/leaky-relu-nc.cc
+++ b/test/leaky-relu-nc.cc
@@ -151,3 +151,73 @@
       .TestQ8();
   }
 }
+
+
+TEST(LEAKY_RELU_NC_F32, unit_batch) {
+  for (size_t channels = 1; channels < 100; channels++) {
+    LeakyReLUOperatorTester()
+      .batch_size(1)
+      .channels(channels)
+      .iterations(3)
+      .TestF32();
+  }
+}
+
+TEST(LEAKY_RELU_NC_F32, small_batch) {
+  for (size_t channels = 1; channels < 100; channels++) {
+    LeakyReLUOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .iterations(3)
+      .TestF32();
+  }
+}
+
+TEST(LEAKY_RELU_NC_F32, small_batch_with_input_stride) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    LeakyReLUOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .iterations(3)
+      .TestF32();
+  }
+}
+
+TEST(LEAKY_RELU_NC_F32, small_batch_with_output_stride) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    LeakyReLUOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .output_stride(117)
+      .iterations(3)
+      .TestF32();
+  }
+}
+
+TEST(LEAKY_RELU_NC_F32, small_batch_with_input_and_output_stride) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    LeakyReLUOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .output_stride(117)
+      .iterations(3)
+      .TestF32();
+  }
+}
+
+TEST(LEAKY_RELU_NC_F32, small_batch_with_negative_slope) {
+  for (size_t batch_size = 1; batch_size <= 3; batch_size += 2) {
+    for (size_t channels = 1; channels < 100; channels += 15) {
+      for (float negative_slope = 1.0e-4f; negative_slope < 1.0f; negative_slope *= 3.14159265f) {
+        LeakyReLUOperatorTester()
+          .batch_size(3)
+          .channels(channels)
+          .negative_slope(negative_slope)
+          .iterations(1)
+          .TestF32();
+      }
+    }
+  }
+}
diff --git a/test/leaky-relu-operator-tester.h b/test/leaky-relu-operator-tester.h
index 65e695d..9f4c6df 100644
--- a/test/leaky-relu-operator-tester.h
+++ b/test/leaky-relu-operator-tester.h
@@ -153,6 +153,62 @@
     return this->iterations_;
   }
 
+  void TestF32() const {
+    std::random_device random_device;
+    auto rng = std::mt19937(random_device());
+    auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), std::ref(rng));
+
+    std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * input_stride() + channels());
+    std::vector<float> output((batch_size() - 1) * output_stride() + channels());
+    std::vector<float> output_ref(batch_size() * channels());
+    for (size_t iteration = 0; iteration < iterations(); iteration++) {
+      std::generate(input.begin(), input.end(), std::ref(f32rng));
+      std::fill(output.begin(), output.end(), std::nanf(""));
+
+      // Compute reference results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          const float x = input[i * input_stride() + c];
+          const float y = std::signbit(x) ? x * negative_slope() : x;
+          output_ref[i * channels() + c] = y;
+        }
+      }
+
+      // Create, setup, run, and destroy Leaky ReLU operator.
+      ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+      xnn_operator_t leaky_relu_op = nullptr;
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_create_leaky_relu_nc_f32(
+          channels(), input_stride(), output_stride(),
+          negative_slope(),
+          0, &leaky_relu_op));
+      ASSERT_NE(nullptr, leaky_relu_op);
+
+      // Smart pointer to automatically delete leaky_relu_op.
+      std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_leaky_relu_op(leaky_relu_op, xnn_delete_operator);
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_setup_leaky_relu_nc_f32(
+          leaky_relu_op,
+          batch_size(),
+          input.data(), output.data(),
+          nullptr /* thread pool */));
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_run_operator(leaky_relu_op, nullptr /* thread pool */));
+
+      // Verify results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          ASSERT_EQ(output[i * output_stride() + c], output_ref[i * channels() + c])
+            << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels()
+            << ", input " << input[i * input_stride() + c] << ", negative slope " << negative_slope();
+        }
+      }
+    }
+  }
+
   void TestQ8() const {
     std::random_device random_device;
     auto rng = std::mt19937(random_device());
@@ -176,7 +232,7 @@
         }
       }
 
-      // Create, setup, run, and destroy LeakyReLU operator.
+      // Create, setup, run, and destroy Leaky ReLU operator.
       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
       xnn_operator_t leaky_relu_op = nullptr;