F16 PReLU operator

PiperOrigin-RevId: 426323096
diff --git a/BUILD.bazel b/BUILD.bazel
index 33d191e..6100289 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -4425,6 +4425,7 @@
     "src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c",
     "src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c",
     "src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c",
+    "src/f16-prelu/gen/neonfp16arith-2x16.c",
     "src/f16-vbinary/gen/vadd-minmax-neonfp16arith-x16.c",
     "src/f16-vbinary/gen/vaddc-minmax-neonfp16arith-x16.c",
     "src/f16-vbinary/gen/vmul-minmax-neonfp16arith-x16.c",
@@ -6016,6 +6017,7 @@
     "src/f16-f32-vcvt/gen/vcvt-f16c-x16.c",
     "src/f16-gavgpool/gen/7p7x-minmax-f16c-c8.c",
     "src/f16-gavgpool/gen/7x-minmax-f16c-c8.c",
+    "src/f16-prelu/gen/f16c-2x16.c",
     "src/f16-vbinary/gen/vadd-minmax-f16c-x16.c",
     "src/f16-vbinary/gen/vaddc-minmax-f16c-x16.c",
     "src/f16-vbinary/gen/vmul-minmax-f16c-x16.c",
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 56f02c6..2c233f8 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -3181,6 +3181,7 @@
   src/f16-gemm/gen/6x16-minmax-neonfp16arith-ld64.c
   src/f16-igemm/gen/1x16-minmax-neonfp16arith-ld64.c
   src/f16-igemm/gen/6x16-minmax-neonfp16arith-ld64.c
+  src/f16-prelu/gen/neonfp16arith-2x16.c
   src/f16-vbinary/gen/vadd-minmax-neonfp16arith-x16.c
   src/f16-vbinary/gen/vaddc-minmax-neonfp16arith-x16.c
   src/f16-vbinary/gen/vmul-minmax-neonfp16arith-x16.c
@@ -4756,6 +4757,7 @@
   src/f16-f32-vcvt/gen/vcvt-f16c-x16.c
   src/f16-gavgpool/gen/7p7x-minmax-f16c-c8.c
   src/f16-gavgpool/gen/7x-minmax-f16c-c8.c
+  src/f16-prelu/gen/neonfp16arith-2x16.c
   src/f16-vbinary/gen/vadd-minmax-f16c-x16.c
   src/f16-vbinary/gen/vaddc-minmax-f16c-x16.c
   src/f16-vbinary/gen/vmul-minmax-f16c-x16.c
@@ -6706,7 +6708,7 @@
     CXX_STANDARD_REQUIRED YES
     CXX_EXTENSIONS NO)
   TARGET_INCLUDE_DIRECTORIES(prelu-nc-test PRIVATE src test)
-  TARGET_LINK_LIBRARIES(prelu-nc-test PRIVATE XNNPACK gtest gtest_main)
+  TARGET_LINK_LIBRARIES(prelu-nc-test PRIVATE XNNPACK fp16 gtest gtest_main)
   ADD_TEST(prelu-nc-test prelu-nc-test)
 
   ADD_EXECUTABLE(resize-bilinear-nhwc-test test/resize-bilinear-nhwc.cc)
diff --git a/include/xnnpack.h b/include/xnnpack.h
index 7647039..c29899a 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -2013,6 +2013,21 @@
   void* output,
   pthreadpool_t threadpool);
 
+enum xnn_status xnn_create_prelu_nc_f16(
+  size_t channels,
+  size_t input_stride,
+  size_t output_stride,
+  const void* negative_slope,
+  uint32_t flags,
+  xnn_operator_t* prelu_op_out);
+
+enum xnn_status xnn_setup_prelu_nc_f16(
+  xnn_operator_t prelu_op,
+  size_t batch_size,
+  const void* input,
+  void* output,
+  pthreadpool_t threadpool);
+
 #endif  // XNN_NO_F16_OPERATORS
 
 #ifndef XNN_NO_X16_OPERATORS
diff --git a/src/amalgam/avx2.c b/src/amalgam/avx2.c
index ab30cc7..42bc2b0 100644
--- a/src/amalgam/avx2.c
+++ b/src/amalgam/avx2.c
@@ -226,83 +226,83 @@
     vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
 
     if XNN_LIKELY(nc >= 16) {
-      _mm_storeu_si128((__m128i*) c3, _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC));
-      _mm_storeu_si128((__m128i*) (c3 + 8), _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC));
-      c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
-      _mm_storeu_si128((__m128i*) c2, _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC));
-      _mm_storeu_si128((__m128i*) (c2 + 8), _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC));
-      c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
-      _mm_storeu_si128((__m128i*) c1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC));
-      _mm_storeu_si128((__m128i*) (c1 + 8), _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC));
-      c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
       _mm_storeu_si128((__m128i*) c0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC));
       _mm_storeu_si128((__m128i*) (c0 + 8), _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC));
       c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
+      _mm_storeu_si128((__m128i*) c1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC));
+      _mm_storeu_si128((__m128i*) (c1 + 8), _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC));
+      c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
+      _mm_storeu_si128((__m128i*) c2, _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC));
+      _mm_storeu_si128((__m128i*) (c2 + 8), _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC));
+      c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
+      _mm_storeu_si128((__m128i*) c3, _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC));
+      _mm_storeu_si128((__m128i*) (c3 + 8), _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC));
+      c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
 
-      a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
-      a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
-      a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
       a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
+      a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
+      a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
+      a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
 
       nc -= 16;
     } else {
-      __m128i vh3x01234567 = _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC);
-      __m128i vh2x01234567 = _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC);
-      __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC);
       __m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC);
+      __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC);
+      __m128i vh2x01234567 = _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC);
+      __m128i vh3x01234567 = _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC);
       if (nc & 8) {
-        _mm_storeu_si128((__m128i*) c3, vh3x01234567);
-        _mm_storeu_si128((__m128i*) c2, vh2x01234567);
-        _mm_storeu_si128((__m128i*) c1, vh1x01234567);
         _mm_storeu_si128((__m128i*) c0, vh0x01234567);
+        _mm_storeu_si128((__m128i*) c1, vh1x01234567);
+        _mm_storeu_si128((__m128i*) c2, vh2x01234567);
+        _mm_storeu_si128((__m128i*) c3, vh3x01234567);
 
-        vh3x01234567 = _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC);
-        vh2x01234567 = _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC);
-        vh1x01234567 = _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC);
         vh0x01234567 = _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC);
+        vh1x01234567 = _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC);
+        vh2x01234567 = _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC);
+        vh3x01234567 = _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC);
 
-        c3 += 8;
-        c2 += 8;
-        c1 += 8;
         c0 += 8;
+        c1 += 8;
+        c2 += 8;
+        c3 += 8;
       }
       if (nc & 4) {
-        _mm_storel_epi64((__m128i*) c3, vh3x01234567);
-        _mm_storel_epi64((__m128i*) c2, vh2x01234567);
-        _mm_storel_epi64((__m128i*) c1, vh1x01234567);
         _mm_storel_epi64((__m128i*) c0, vh0x01234567);
+        _mm_storel_epi64((__m128i*) c1, vh1x01234567);
+        _mm_storel_epi64((__m128i*) c2, vh2x01234567);
+        _mm_storel_epi64((__m128i*) c3, vh3x01234567);
 
-        vh3x01234567 = _mm_unpackhi_epi64(vh3x01234567, vh3x01234567);
-        vh2x01234567 = _mm_unpackhi_epi64(vh2x01234567, vh2x01234567);
-        vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567);
         vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567);
+        vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567);
+        vh2x01234567 = _mm_unpackhi_epi64(vh2x01234567, vh2x01234567);
+        vh3x01234567 = _mm_unpackhi_epi64(vh3x01234567, vh3x01234567);
 
-        c3 += 4;
-        c2 += 4;
-        c1 += 4;
         c0 += 4;
+        c1 += 4;
+        c2 += 4;
+        c3 += 4;
       }
       if (nc & 2) {
-        _mm_storeu_si32(c3, vh3x01234567);
-        _mm_storeu_si32(c2, vh2x01234567);
-        _mm_storeu_si32(c1, vh1x01234567);
         _mm_storeu_si32(c0, vh0x01234567);
+        _mm_storeu_si32(c1, vh1x01234567);
+        _mm_storeu_si32(c2, vh2x01234567);
+        _mm_storeu_si32(c3, vh3x01234567);
 
-        vh3x01234567 = _mm_srli_epi64(vh3x01234567, 32);
-        vh2x01234567 = _mm_srli_epi64(vh2x01234567, 32);
-        vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32);
         vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32);
+        vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32);
+        vh2x01234567 = _mm_srli_epi64(vh2x01234567, 32);
+        vh3x01234567 = _mm_srli_epi64(vh3x01234567, 32);
 
-        c3 += 2;
-        c2 += 2;
-        c1 += 2;
         c0 += 2;
+        c1 += 2;
+        c2 += 2;
+        c3 += 2;
       }
       if (nc & 1) {
-        *c3 = (uint16_t) _mm_extract_epi16(vh3x01234567, 0);
-        *c2 = (uint16_t) _mm_extract_epi16(vh2x01234567, 0);
-        *c1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0);
         *c0 = (uint16_t) _mm_extract_epi16(vh0x01234567, 0);
+        *c1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0);
+        *c2 = (uint16_t) _mm_extract_epi16(vh2x01234567, 0);
+        *c3 = (uint16_t) _mm_extract_epi16(vh3x01234567, 0);
       }
 
       nc = 0;
diff --git a/src/amalgam/f16c.c b/src/amalgam/f16c.c
index 7cbaef9..1e41f3a 100644
--- a/src/amalgam/f16c.c
+++ b/src/amalgam/f16c.c
@@ -11,6 +11,7 @@
 #include <xnnpack/gavgpool.h>
 #include <xnnpack/intrinsics-polyfill.h>
 #include <xnnpack/math.h>
+#include <xnnpack/prelu.h>
 #include <xnnpack/vbinary.h>
 #include <xnnpack/vcvt.h>
 #include <xnnpack/vunary.h>
@@ -357,6 +358,138 @@
   }
 }
 
+void xnn_f16_prelu_ukernel__f16c_2x16(
+    size_t rows,
+    size_t channels,
+    const void* restrict input,
+    size_t input_stride,
+    const void* restrict weights,
+    void* restrict output,
+    size_t output_stride) XNN_OOB_READS
+{
+  assert(rows != 0);
+  assert(channels != 0);
+  assert(channels % sizeof(uint16_t) == 0);
+
+  const uint16_t* i0 = (const uint16_t*) input;
+  uint16_t* o0 = (uint16_t*) output;
+  const uint16_t* i1 = (const uint16_t*) ((uintptr_t) i0 + input_stride);
+  uint16_t* o1 = (uint16_t*) ((uintptr_t) o0 + output_stride);
+
+  const size_t input_increment = input_stride * 2 - channels;
+  const size_t output_increment = output_stride * 2 - channels;
+
+  do {
+    if XNN_UNPREDICTABLE(rows < 2) {
+      i1 = i0;
+      o1 = o0;
+    }
+
+    const uint16_t* w = (const uint16_t*) weights;
+    size_t c = channels;
+    for (; c >= 16 * sizeof(uint16_t); c -= 16 * sizeof(uint16_t)) {
+      const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
+      const __m256 vw89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 8)));
+      w += 16;
+
+      const __m256 vi0x001234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
+      const __m256 vi0x089ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i0 + 8)));
+      i0 += 16;
+      const __m256 vi1x001234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
+      const __m256 vi1x089ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i1 + 8)));
+      i1 += 16;
+
+      __m256 vacc0x001234567 = _mm256_mul_ps(vi0x001234567, vw01234567);
+      __m256 vacc0x089ABCDEF = _mm256_mul_ps(vi0x089ABCDEF, vw89ABCDEF);
+      __m256 vacc1x001234567 = _mm256_mul_ps(vi1x001234567, vw01234567);
+      __m256 vacc1x089ABCDEF = _mm256_mul_ps(vi1x089ABCDEF, vw89ABCDEF);
+
+      vacc0x001234567 = _mm256_blendv_ps(vi0x001234567, vacc0x001234567, vi0x001234567);
+      vacc0x089ABCDEF = _mm256_blendv_ps(vi0x089ABCDEF, vacc0x089ABCDEF, vi0x089ABCDEF);
+      vacc1x001234567 = _mm256_blendv_ps(vi1x001234567, vacc1x001234567, vi1x001234567);
+      vacc1x089ABCDEF = _mm256_blendv_ps(vi1x089ABCDEF, vacc1x089ABCDEF, vi1x089ABCDEF);
+
+      _mm_storeu_si128((__m128i*) o0, _mm256_cvtps_ph(vacc0x089ABCDEF, _MM_FROUND_NO_EXC));
+      _mm_storeu_si128((__m128i*) (o0 + 0), _mm256_cvtps_ph(vacc0x001234567, _MM_FROUND_NO_EXC));
+      _mm_storeu_si128((__m128i*) (o0 + 8), _mm256_cvtps_ph(vacc0x089ABCDEF, _MM_FROUND_NO_EXC));
+      o0 += 16;
+      _mm_storeu_si128((__m128i*) o1, _mm256_cvtps_ph(vacc1x089ABCDEF, _MM_FROUND_NO_EXC));
+      _mm_storeu_si128((__m128i*) (o1 + 0), _mm256_cvtps_ph(vacc1x001234567, _MM_FROUND_NO_EXC));
+      _mm_storeu_si128((__m128i*) (o1 + 8), _mm256_cvtps_ph(vacc1x089ABCDEF, _MM_FROUND_NO_EXC));
+      o1 += 16;
+    }
+    for (; c >= 8 * sizeof(uint16_t); c -= 8 * sizeof(uint16_t)) {
+      const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
+      w += 8;
+
+      const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
+      i0 += 8;
+      const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
+      i1 += 8;
+
+      __m256 vacc0x01234567 = _mm256_mul_ps(vi0x01234567, vw01234567);
+      __m256 vacc1x01234567 = _mm256_mul_ps(vi1x01234567, vw01234567);
+
+      vacc0x01234567 = _mm256_blendv_ps(vi0x01234567, vacc0x01234567, vi0x01234567);
+      vacc1x01234567 = _mm256_blendv_ps(vi1x01234567, vacc1x01234567, vi1x01234567);
+
+      _mm_storeu_si128((__m128i*) o0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC));
+      o0 += 8;
+      _mm_storeu_si128((__m128i*) o1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC));
+      o1 += 8;
+    }
+    if XNN_UNLIKELY(c != 0) {
+      const __m256 vw01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
+
+      const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
+      i0 = (const uint16_t*) ((uintptr_t) i0 + c);
+      const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
+      i1 = (const uint16_t*) ((uintptr_t) i1 + c);
+
+      __m256 vacc0x01234567 = _mm256_mul_ps(vi0x01234567, vw01234567);
+      __m256 vacc1x01234567 = _mm256_mul_ps(vi1x01234567, vw01234567);
+
+      vacc0x01234567 = _mm256_blendv_ps(vi0x01234567, vacc0x01234567, vi0x01234567);
+      vacc1x01234567 = _mm256_blendv_ps(vi1x01234567, vacc1x01234567, vi1x01234567);
+
+      __m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC);
+      __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC);
+      if (c & (4 * sizeof(uint16_t))) {
+        _mm_storel_epi64((__m128i*) o0, vh0x01234567);
+        _mm_storel_epi64((__m128i*) o1, vh1x01234567);
+
+        vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567);
+        vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567);
+
+        o0 += 4;
+        o1 += 4;
+      }
+      if (c & (2 * sizeof(uint16_t))) {
+        *((uint32_t*) o0) = (uint32_t) _mm_cvtsi128_si32(vh0x01234567);
+        *((uint32_t*) o1) = (uint32_t) _mm_cvtsi128_si32(vh1x01234567);
+
+        vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32);
+        vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32);
+
+        o0 += 2;
+        o1 += 2;
+      }
+      if (c & (1 * sizeof(uint16_t))) {
+        *o0 = (uint16_t) _mm_extract_epi16(vh0x01234567, 0);
+        *o1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0);
+
+        o0 += 1;
+        o1 += 1;
+      }
+    }
+    i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment);
+    o0 = (uint16_t*) ((uintptr_t) o0 + output_increment);
+    i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment);
+    o1 = (uint16_t*) ((uintptr_t) o1 + output_increment);
+    rows = doz(rows, 2);
+  } while (rows != 0);
+}
+
 void xnn_f16_vadd_minmax_ukernel__f16c_x16(
     size_t n,
     const void* restrict a_ptr,
diff --git a/src/init.c b/src/init.c
index da6d2cc..fc167e8 100644
--- a/src/init.c
+++ b/src/init.c
@@ -2439,6 +2439,13 @@
         .row_tile = 7,
         .channel_tile = 8,
       };
+
+      xnn_params.f16.prelu = (struct prelu_parameters) {
+        .ukernel = (xnn_prelu_ukernel_function) xnn_f16_prelu_ukernel__neonfp16arith_2x16,
+        .row_tile = 2,
+        .channel_tile = 16,
+      };
+
       xnn_params.f16.vadd = (struct vbinary_parameters) {
         .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vadd_minmax_ukernel__neonfp16arith_x16,
         .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vaddc_minmax_ukernel__neonfp16arith_x16,
@@ -2459,6 +2466,7 @@
         .channel_tile = 8,
         .row_tile = 2,
       };
+
       xnn_params.f16.hswish = (struct vunary_parameters) {
         .ukernel = (xnn_univector_ukernel_function) xnn_f16_vhswish_ukernel__neonfp16arith_x16,
         .init.f16_hswish = xnn_init_f16_hswish_neon_params,
@@ -3656,6 +3664,13 @@
         .row_tile = 7,
         .channel_tile = 8,
       };
+
+      xnn_params.f16.prelu = (struct prelu_parameters) {
+        .ukernel = (xnn_prelu_ukernel_function) xnn_f16_prelu_ukernel__f16c_2x16,
+        .row_tile = 2,
+        .channel_tile = 16,
+      };
+
       xnn_params.f16.vadd = (struct vbinary_parameters) {
         .minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vadd_minmax_ukernel__f16c_x16,
         .minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f16_vaddc_minmax_ukernel__f16c_x16,
diff --git a/src/operator-strings.c b/src/operator-strings.c
index c3eac26..9bb524b 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -144,6 +144,8 @@
       return "Multiply (ND, QU8)";
     case xnn_operator_type_negate_nc_f32:
       return "Negate (NC, F32)";
+    case xnn_operator_type_prelu_nc_f16:
+      return "PReLU (NC, F16)";
     case xnn_operator_type_prelu_nc_f32:
       return "PReLU (NC, F32)";
     case xnn_operator_type_resize_bilinear_nhwc_f32:
diff --git a/src/operators/prelu-nc.c b/src/operators/prelu-nc.c
index 3e77aaf..a2e46c0 100644
--- a/src/operators/prelu-nc.c
+++ b/src/operators/prelu-nc.c
@@ -17,20 +17,32 @@
 #include <xnnpack/params.h>
 
 
-enum xnn_status xnn_create_prelu_nc_f32(
+static enum xnn_status create_prelu_nc(
     size_t channels,
     size_t input_stride,
     size_t output_stride,
-    const float* negative_slope,
+    const void* negative_slope,
     uint32_t flags,
+    uint32_t datatype_init_flags,
+    enum xnn_operator_type operator_type,
+    uint32_t log2_weights_element_size,
     xnn_operator_t* prelu_op_out)
 {
   xnn_operator_t prelu_op = NULL;
   enum xnn_status status = xnn_status_uninitialized;
 
   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
-    xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
-      xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32));
+    xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
+      xnn_operator_type_to_string(operator_type));
+    return xnn_status_uninitialized;
+  }
+
+  status = xnn_status_unsupported_hardware;
+
+  if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
+    xnn_log_error(
+      "failed to create %s operator: operations on data type are not supported",
+      xnn_operator_type_to_string(operator_type));
     goto error;
   }
 
@@ -39,7 +51,7 @@
   if (channels == 0) {
     xnn_log_error(
       "failed to create %s operator with %zu channels: number of channels must be non-zero",
-      xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), channels);
+      xnn_operator_type_to_string(operator_type), channels);
     goto error;
   }
 
@@ -47,7 +59,7 @@
     xnn_log_error(
       "failed to create %s operator with input element stride of %zu: "
       "stride must be at least as large as the number of channels (%zu)",
-      xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), input_stride, channels);
+      xnn_operator_type_to_string(operator_type), input_stride, channels);
     goto error;
   }
 
@@ -55,7 +67,7 @@
     xnn_log_error(
       "failed to create %s operator with output element stride of %zu: "
       "stride must be at least as large as the number of channels (%zu)",
-      xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32), output_stride, channels);
+      xnn_operator_type_to_string(operator_type), output_stride, channels);
     goto error;
   }
 
@@ -65,25 +77,25 @@
   if (prelu_op == NULL) {
     xnn_log_error(
       "failed to allocate %zu bytes for %s operator descriptor",
-      sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32));
+      sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
     goto error;
   }
 
-  const size_t packed_weights_size = channels * sizeof(float) + XNN_EXTRA_BYTES;
+  const size_t packed_weights_size = (channels << log2_weights_element_size) + XNN_EXTRA_BYTES;
   prelu_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
   if (prelu_op->packed_weights == NULL) {
     xnn_log_error(
       "failed to allocate %zu bytes for %s operator packed weights",
-      packed_weights_size, xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32));
+      packed_weights_size, xnn_operator_type_to_string(operator_type));
     goto error;
   }
-  memcpy(prelu_op->packed_weights, negative_slope, channels * sizeof(float));
+  memcpy(prelu_op->packed_weights, negative_slope, channels << log2_weights_element_size);
 
   prelu_op->channels = channels;
   prelu_op->input_pixel_stride = input_stride;
   prelu_op->output_pixel_stride = output_stride;
 
-  prelu_op->type = xnn_operator_type_prelu_nc_f32;
+  prelu_op->type = operator_type;
   prelu_op->flags = flags;
 
   prelu_op->state = xnn_run_state_invalid;
@@ -96,16 +108,53 @@
   return status;
 }
 
-enum xnn_status xnn_setup_prelu_nc_f32(
+
+enum xnn_status xnn_create_prelu_nc_f16(
+    size_t channels,
+    size_t input_stride,
+    size_t output_stride,
+    const void* negative_slope,
+    uint32_t flags,
+    xnn_operator_t* prelu_op_out)
+{
+  return create_prelu_nc(
+    channels, input_stride, output_stride,
+    negative_slope, flags,
+    XNN_INIT_FLAG_F16, xnn_operator_type_prelu_nc_f16,
+    1 /* log2(sizeof(uint16_t)) */,
+    prelu_op_out);
+}
+
+enum xnn_status xnn_create_prelu_nc_f32(
+    size_t channels,
+    size_t input_stride,
+    size_t output_stride,
+    const float* negative_slope,
+    uint32_t flags,
+    xnn_operator_t* prelu_op_out)
+{
+  return create_prelu_nc(
+    channels, input_stride, output_stride,
+    negative_slope, flags,
+    XNN_INIT_FLAG_F32, xnn_operator_type_prelu_nc_f32,
+    2 /* log2(sizeof(float)) */,
+    prelu_op_out);
+}
+
+static enum xnn_status setup_prelu_nc(
     xnn_operator_t prelu_op,
+    enum xnn_operator_type expected_operator_type,
     size_t batch_size,
     const float* input,
     float* output,
-    pthreadpool_t threadpool)
+    uint32_t datatype_init_flags,
+    uint32_t log2_element_size,
+    const struct prelu_parameters prelu[restrict XNN_MIN_ELEMENTS(1)],
+    size_t num_threads)
 {
-  if (prelu_op->type != xnn_operator_type_prelu_nc_f32) {
+  if (prelu_op->type != expected_operator_type) {
     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
-      xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32),
+      xnn_operator_type_to_string(expected_operator_type),
       xnn_operator_type_to_string(prelu_op->type));
     return xnn_status_invalid_parameter;
   }
@@ -113,10 +162,16 @@
 
   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
-      xnn_operator_type_to_string(xnn_operator_type_prelu_nc_f32));
+      xnn_operator_type_to_string(expected_operator_type));
     return xnn_status_uninitialized;
   }
 
+  if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
+    xnn_log_error("failed to setup %s operator: operations on data type are not supported",
+      xnn_operator_type_to_string(expected_operator_type));
+    return xnn_status_unsupported_hardware;
+  }
+
   if (batch_size == 0) {
     prelu_op->state = xnn_run_state_skip;
     return xnn_status_success;
@@ -124,22 +179,21 @@
 
   const size_t channels = prelu_op->channels;
   prelu_op->context.prelu = (struct prelu_context) {
-    .n = channels * sizeof(float),
+    .n = channels << log2_element_size,
     .x = input,
-    .x_stride = prelu_op->input_pixel_stride * sizeof(float),
+    .x_stride = prelu_op->input_pixel_stride << log2_element_size,
     .w = prelu_op->packed_weights,
     .y = output,
-    .y_stride = prelu_op->output_pixel_stride * sizeof(float),
-    .ukernel = xnn_params.f32.prelu.ukernel,
+    .y_stride = prelu_op->output_pixel_stride << log2_element_size,
+    .ukernel = prelu->ukernel,
   };
 
   size_t batch_tile = batch_size;
-  const size_t num_threads = pthreadpool_get_threads_count(threadpool);
   if (num_threads > 1) {
     const size_t target_tiles_per_thread = 5;
     const size_t max_batch_tile = divide_round_up(batch_size, num_threads * target_tiles_per_thread);
     if (max_batch_tile < batch_tile) {
-      const uint32_t row_tile = xnn_params.f32.prelu.row_tile;
+      const uint32_t row_tile = prelu->row_tile;
       batch_tile = min(batch_tile, divide_round_up(batch_tile, max_batch_tile * row_tile) * row_tile);
     }
   }
@@ -151,3 +205,35 @@
 
   return xnn_status_success;
 }
+
+enum xnn_status xnn_setup_prelu_nc_f16(
+    xnn_operator_t prelu_op,
+    size_t batch_size,
+    const void* input,
+    void* output,
+    pthreadpool_t threadpool)
+{
+  return setup_prelu_nc(
+    prelu_op, xnn_operator_type_prelu_nc_f16,
+    batch_size, input, output,
+    XNN_INIT_FLAG_F16,
+    1 /* log2(sizeof(uint16_t)) */,
+    &xnn_params.f16.prelu,
+    pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_prelu_nc_f32(
+    xnn_operator_t prelu_op,
+    size_t batch_size,
+    const float* input,
+    float* output,
+    pthreadpool_t threadpool)
+{
+  return setup_prelu_nc(
+    prelu_op, xnn_operator_type_prelu_nc_f32,
+    batch_size, input, output,
+    XNN_INIT_FLAG_F32,
+    2 /* log2(sizeof(float)) */,
+    &xnn_params.f32.prelu,
+    pthreadpool_get_threads_count(threadpool));
+}
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index cd00a24..b41a4eb 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -96,6 +96,7 @@
   xnn_operator_type_multiply_nd_qs8,
   xnn_operator_type_multiply_nd_qu8,
   xnn_operator_type_negate_nc_f32,
+  xnn_operator_type_prelu_nc_f16,
   xnn_operator_type_prelu_nc_f32,
   xnn_operator_type_resize_bilinear_nchw_f32,
   xnn_operator_type_resize_bilinear_nhwc_f32,
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index 70d4623..53383c4 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -4097,6 +4097,7 @@
     struct gemm_parameters gemm2;
     struct dwconv_parameters dwconv[XNN_MAX_F16_DWCONV_UKERNELS];
     struct vunary_parameters hswish;
+    struct prelu_parameters prelu;
     struct vbinary_parameters vadd;
     struct vbinary_parameters vmul;
     struct vmulcaddc_parameters vmulcaddc;
diff --git a/test/prelu-nc.cc b/test/prelu-nc.cc
index a337726..351e450 100644
--- a/test/prelu-nc.cc
+++ b/test/prelu-nc.cc
@@ -10,6 +10,105 @@
 #include "prelu-operator-tester.h"
 
 
+TEST(PRELU_NC_F16, unit_batch) {
+  for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+    PReLUOperatorTester()
+      .batch_size(1)
+      .channels(channels)
+      .iterations(3)
+      .TestF16();
+  }
+}
+
+TEST(PRELU_NC_F16, small_batch) {
+  for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+    PReLUOperatorTester()
+      .batch_size(xnn_params.f16.prelu.row_tile)
+      .channels(channels)
+      .iterations(3)
+      .TestF16();
+  }
+}
+
+TEST(PRELU_NC_F16, small_batch_with_x_stride) {
+  for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+    PReLUOperatorTester()
+      .batch_size(xnn_params.f16.prelu.row_tile)
+      .channels(channels)
+      .x_stride(123)
+      .iterations(3)
+      .TestF16();
+  }
+}
+
+TEST(PRELU_NC_F16, small_batch_with_y_stride) {
+  for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+    PReLUOperatorTester()
+      .batch_size(xnn_params.f16.prelu.row_tile)
+      .channels(channels)
+      .y_stride(117)
+      .iterations(3)
+      .TestF16();
+  }
+}
+
+TEST(PRELU_NC_F16, small_batch_with_x_stride_and_y_stride) {
+  for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+    PReLUOperatorTester()
+      .batch_size(xnn_params.f16.prelu.row_tile)
+      .channels(channels)
+      .x_stride(123)
+      .y_stride(117)
+      .iterations(3)
+      .TestF16();
+  }
+}
+
+TEST(PRELU_NC_F16, large_batch) {
+  for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+    PReLUOperatorTester()
+      .batch_size(3 * xnn_params.f16.prelu.row_tile + 1)
+      .channels(channels)
+      .iterations(1)
+      .TestF16();
+  }
+}
+
+TEST(PRELU_NC_F16, large_batch_with_x_stride) {
+  for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+    PReLUOperatorTester()
+      .batch_size(3 * xnn_params.f16.prelu.row_tile + 1)
+      .channels(channels)
+      .x_stride(123)
+      .iterations(1)
+      .TestF16();
+  }
+}
+
+TEST(PRELU_NC_F16, large_batch_with_y_stride) {
+  for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+    PReLUOperatorTester()
+      .batch_size(3 * xnn_params.f16.prelu.row_tile + 1)
+      .channels(channels)
+      .y_stride(117)
+      .iterations(1)
+      .TestF16();
+  }
+}
+
+TEST(PRELU_NC_F16, large_batch_with_x_stride_and_y_stride) {
+  for (size_t channels = 1; channels < xnn_params.f16.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f16.prelu.channel_tile - 1)) {
+    PReLUOperatorTester()
+      .batch_size(3 * xnn_params.f16.prelu.row_tile + 1)
+      .channels(channels)
+      .x_stride(123)
+      .y_stride(117)
+      .iterations(1)
+      .TestF16();
+  }
+}
+
+
 TEST(PRELU_NC_F32, unit_batch) {
   for (size_t channels = 1; channels < xnn_params.f32.prelu.channel_tile * 10; channels += std::max<size_t>(1, xnn_params.f32.prelu.channel_tile - 1)) {
     PReLUOperatorTester()
diff --git a/test/prelu-operator-tester.h b/test/prelu-operator-tester.h
index 47090a6..53565ce 100644
--- a/test/prelu-operator-tester.h
+++ b/test/prelu-operator-tester.h
@@ -7,6 +7,8 @@
 
 #include <gtest/gtest.h>
 
+#include <fp16.h>
+
 #include <algorithm>
 #include <cmath>
 #include <cstddef>
@@ -79,6 +81,69 @@
     return this->iterations_;
   }
 
+  void TestF16() const {
+    std::random_device random_device;
+    auto rng = std::mt19937(random_device());
+    auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
+    auto f16irng = std::bind(fp16_ieee_from_fp32_value, f32irng);
+    auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng);
+    auto f16wrng = std::bind(fp16_ieee_from_fp32_value, f32wrng);
+
+    std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
+    std::vector<uint16_t> w(channels());
+    std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
+    std::vector<float> y_ref(batch_size() * channels());
+    for (size_t iteration = 0; iteration < iterations(); iteration++) {
+      std::generate(x.begin(), x.end(), std::ref(f16irng));
+      std::generate(w.begin(), w.end(), std::ref(f16wrng));
+      std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
+
+      // Compute reference results, without clamping.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]);
+          const float w_value = fp16_ieee_to_fp32_value(w[c]);
+          y_ref[i * channels() + c] = signbit(x_value) ? x_value * w_value : x_value;
+        }
+      }
+
+      // Create, setup, run, and destroy PReLU operator.
+      ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+      xnn_operator_t prelu_op = nullptr;
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_create_prelu_nc_f16(
+          channels(), x_stride(), y_stride(),
+          w.data(),
+          0, &prelu_op));
+      ASSERT_NE(nullptr, prelu_op);
+
+      // Smart pointer to automatically delete prelu_op.
+      std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_setup_prelu_nc_f16(
+          prelu_op,
+          batch_size(),
+          x.data(), y.data(),
+          nullptr /* thread pool */));
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_run_operator(prelu_op, nullptr /* thread pool */));
+
+      // Verify results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          ASSERT_NEAR(
+              fp16_ieee_to_fp32_value(y[i * y_stride() + c]),
+              y_ref[i * channels() + c],
+              std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-4f))
+            << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
+        }
+      }
+    }
+  }
+
   void TestF32() const {
     std::random_device random_device;
     auto rng = std::mt19937(random_device());
@@ -128,8 +193,11 @@
       // Verify results.
       for (size_t i = 0; i < batch_size(); i++) {
         for (size_t c = 0; c < channels(); c++) {
-          ASSERT_NEAR(y[i * y_stride() + c], y_ref[i * channels() + c], 1.0e-6f * std::abs(y_ref[i * channels() + c]))
-            << "i = " << i << ", c = " << c;
+          ASSERT_NEAR(
+              y[i * y_stride() + c],
+              y_ref[i * channels() + c],
+              std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f))
+            << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
         }
       }
     }