F32 CLAMP micro-kernels in AVX and AVX512F implementations

PiperOrigin-RevId: 282845725
diff --git a/src/f32-clamp/avx.c b/src/f32-clamp/avx.c
new file mode 100644
index 0000000..9f3414c
--- /dev/null
+++ b/src/f32-clamp/avx.c
@@ -0,0 +1,47 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/clamp.h>
+
+
+static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};
+
+void xnn_f32_clamp_ukernel__avx(
+    size_t n,
+    const float* x,
+    float* y,
+    const union xnn_f32_output_params params[restrict static 1])
+{
+  assert(n != 0);
+  assert(n % sizeof(float) == 0);
+
+  const __m256 voutput_max = _mm256_broadcast_ps((const __m128*) params->sse.max);
+  const __m256 voutput_min = _mm256_broadcast_ps((const __m128*) params->sse.min);
+
+  for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
+    const __m256 vx = _mm256_loadu_ps(x);
+    x += 8;
+
+    const __m256 vy = _mm256_min_ps(_mm256_max_ps(vx, voutput_min), voutput_max);
+
+    _mm256_storeu_ps(y, vy);
+    y += 8;
+  }
+  if (n != 0) {
+    assert(n >= 1);
+    assert(n <= 7);
+    __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - n));
+
+    const __m256 vx = _mm256_maskload_ps(x, vmask);
+
+    const __m256 vy = _mm256_min_ps(_mm256_max_ps(vx, voutput_min), voutput_max);
+
+    _mm256_maskstore_ps(y, vmask, vy);
+  }
+}
diff --git a/src/f32-clamp/avx512f.c b/src/f32-clamp/avx512f.c
new file mode 100644
index 0000000..d1df6e4
--- /dev/null
+++ b/src/f32-clamp/avx512f.c
@@ -0,0 +1,47 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/clamp.h>
+
+
+void xnn_f32_clamp_ukernel__avx512f(
+    size_t n,
+    const float* x,
+    float* y,
+    const union xnn_f32_output_params params[restrict static 1])
+{
+  assert(n != 0);
+  assert(n % sizeof(float) == 0);
+
+  const __m512 voutput_max = _mm512_broadcast_f32x4(_mm_load_ps(params->sse.max));
+  const __m512 voutput_min = _mm512_broadcast_f32x4(_mm_load_ps(params->sse.min));
+
+  for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
+    const __m512 vx = _mm512_loadu_ps(x);
+    x += 16;
+
+    const __m512 vy = _mm512_min_ps(_mm512_max_ps(vx, voutput_min), voutput_max);
+
+    _mm512_storeu_ps(y, vy);
+    y += 16;
+  }
+  if (n != 0) {
+    assert(n >= 1);
+    assert(n <= 15);
+    // Prepare mask for valid 32-bit elements (depends on n).
+    n >>= 2 /* log2(sizeof(float)) */;
+    const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
+
+    const __m512 vx = _mm512_maskz_loadu_ps(vmask, x);
+
+    const __m512 vy = _mm512_min_ps(_mm512_max_ps(vx, voutput_min), voutput_max);
+
+    _mm512_mask_storeu_ps(y, vmask, vy);
+  }
+}
diff --git a/src/init.c b/src/init.c
index 4b9e0ce..e7ce5b3 100644
--- a/src/init.c
+++ b/src/init.c
@@ -771,7 +771,13 @@
       .pixel_tile = 1,
       .channel_tile = 8,
     };
-    xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__sse;
+    if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx512f()) {
+      xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__avx512f;
+    } else if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx()) {
+      xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__avx;
+    } else {
+      xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__sse;
+    }
     xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__sse;
     xnn_params.f32.sigmoid = (xnn_univector_ukernel_function) xnn_f32_sigmoid_ukernel__sse2_p5_div_x16;
     xnn_params.f32.prelu = (struct prelu_parameters) {
diff --git a/src/xnnpack/clamp.h b/src/xnnpack/clamp.h
index 0cd59b4..7a9c67b 100644
--- a/src/xnnpack/clamp.h
+++ b/src/xnnpack/clamp.h
@@ -26,9 +26,11 @@
       float* y,                                       \
       const union xnn_f32_output_params* params);
 
-DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__psimd)
 DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__neon)
 DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__sse)
+DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__avx)
+DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__avx512f)
+DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__psimd)
 DECLARE_F32_CLAMP_UKERNEL_FUNCTION(xnn_f32_clamp_ukernel__scalar)