SQRDIFF (Squared Difference) microkernels

PiperOrigin-RevId: 314881351
diff --git a/src/f32-vbinary/vop-avx512f.c.in b/src/f32-vbinary/vop-avx512f.c.in
index f36b303..cafb5ec 100644
--- a/src/f32-vbinary/vop-avx512f.c.in
+++ b/src/f32-vbinary/vop-avx512f.c.in
@@ -6,7 +6,7 @@
 $assert BATCH_TILE % 16 == 0
 $assert BATCH_TILE >= 16
 $ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
-$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB"]
+$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB", "SQRDIFF"]
 $assert ACTIVATION in ["LINEAR", "MINMAX"]
 #include <assert.h>
 
@@ -24,6 +24,7 @@
 $  "MIN": lambda x, y: "_mm512_min_ps(%s, %s)" % (x, y),
 $  "MUL": lambda x, y: "_mm512_mul_ps(%s, %s)" % (x, y),
 $  "SUB": lambda x, y: "_mm512_sub_ps(%s, %s)" % (x, y),
+$  "SQRDIFF": lambda x, y: "_mm512_sub_ps(%s, %s)" % (x, y),
 $}[OP]
 $SUFFIX = {"LINEAR": "", "MINMAX": "_minmax"}[ACTIVATION]
 $PARAMS = {"LINEAR": "xnn_f32_default_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
@@ -55,6 +56,10 @@
     $for N in range(0, BATCH_TILE, 16):
       __m512 vy${ABC[N:N+16]} = ${_MM512_OP_PS("va" + ABC[N:N+16], "vb" + ABC[N:N+16])};
 
+    $if OP == "SQRDIFF":
+      $for N in range(0, BATCH_TILE, 16):
+        vy${ABC[N:N+16]} = _mm512_mul_ps(vy${ABC[N:N+16]}, vy${ABC[N:N+16]});
+
     $if ACTIVATION == "MINMAX":
       $for N in range(0, BATCH_TILE, 16):
         vy${ABC[N:N+16]} = _mm512_max_ps(vy${ABC[N:N+16]}, vy_min);
@@ -76,6 +81,8 @@
       b += 16;
 
       __m512 vy = ${_MM512_OP_PS("va", "vb")};
+      $if OP == "SQRDIFF":
+        vy = _mm512_mul_ps(vy, vy);
       $if ACTIVATION == "MINMAX":
         vy = _mm512_max_ps(vy, vy_min);
         vy = _mm512_min_ps(vy, vy_max);
@@ -93,6 +100,8 @@
     const __m512 vb = _mm512_maskz_loadu_ps(vmask, b);
 
     __m512 vy = ${_MM512_OP_PS("va", "vb")};
+    $if OP == "SQRDIFF":
+      vy = _mm512_mul_ps(vy, vy);
     $if ACTIVATION == "MINMAX":
       vy = _mm512_max_ps(vy, vy_min);
       vy = _mm512_min_ps(vy, vy_max);