GEMM/IGEMM microkernels with alternative activations in WAsm SIMD

PiperOrigin-RevId: 316715937
diff --git a/src/f32-igemm/wasmsimd-splat.c.in b/src/f32-igemm/wasmsimd-splat.c.in
index 56891d8..e94bf7b 100644
--- a/src/f32-igemm/wasmsimd-splat.c.in
+++ b/src/f32-igemm/wasmsimd-splat.c.in
@@ -12,7 +12,11 @@
 #include <xnnpack/igemm.h>
 
 
-void xnn_f32_igemm_minmax_ukernel_${MR}x${NR}__wasmsimd_splat_${"x86" if X86 else "arm"}(
+$assert ACTIVATION in ["LINEAR", "RELU", "MINMAX"]
+$ACTIVATION_SUFFIX = {"LINEAR": ""}.get(ACTIVATION, "_" + ACTIVATION.lower())
+$ARCH_SUFFIX = "" if ACTIVATION == "LINEAR" else "_x86" if X86 else "_arm"
+$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
+void xnn_f32_igemm${ACTIVATION_SUFFIX}_ukernel_${MR}x${NR}__wasmsimd_splat${ARCH_SUFFIX}(
     size_t mr,
     size_t nc,
     size_t kc,
@@ -24,7 +28,7 @@
     size_t cn_stride,
     size_t a_offset,
     const float* zero,
-    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
+    const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)])
 {
   assert(mr != 0);
   assert(mr <= ${MR});
@@ -54,6 +58,9 @@
         c${M} = c${M-1};
       }
 
+  $if ACTIVATION == "MINMAX" and not X86:
+    const v128_t vmin = wasm_v32x4_load_splat(&params->scalar.min);
+    const v128_t vmax = wasm_v32x4_load_splat(&params->scalar.max);
   do {
     v128_t vacc0x${ABC[0:4]} = wasm_v128_load(w);
     $for N in range(4, NR, 4):
@@ -113,21 +120,35 @@
       p -= ${MR} * sizeof(void*);
     } while (p != 0);
 
-    const v128_t vmin = wasm_v32x4_load_splat(&params->scalar.min);
-    $for N in range(0, NR, 4):
-      $for M in range(MR):
-        $if X86:
-          vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vmin, vacc${M}x${ABC[N:N+4]}, wasm_f32x4_lt(vacc${M}x${ABC[N:N+4]}, vmin));
-        $else:
-          vacc${M}x${ABC[N:N+4]} = wasm_f32x4_max(vacc${M}x${ABC[N:N+4]}, vmin);
+    $if ACTIVATION == "MINMAX":
+      $if X86:
+        const v128_t vmin = wasm_v32x4_load_splat(&params->scalar.min);
+        $for N in range(0, NR, 4):
+          $for M in range(MR):
+            vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vmin, vacc${M}x${ABC[N:N+4]}, wasm_f32x4_lt(vacc${M}x${ABC[N:N+4]}, vmin));
 
-    const v128_t vmax = wasm_v32x4_load_splat(&params->scalar.max);
-    $for N in range(0, NR, 4):
-      $for M in range(MR):
-        $if X86:
-          vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vacc${M}x${ABC[N:N+4]}, vmax, wasm_f32x4_le(vacc${M}x${ABC[N:N+4]}, vmax));
-        $else:
-          vacc${M}x${ABC[N:N+4]} = wasm_f32x4_min(vacc${M}x${ABC[N:N+4]}, vmax);
+        const v128_t vmax = wasm_v32x4_load_splat(&params->scalar.max);
+        $for N in range(0, NR, 4):
+          $for M in range(MR):
+            vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vacc${M}x${ABC[N:N+4]}, vmax, wasm_f32x4_le(vacc${M}x${ABC[N:N+4]}, vmax));
+      $else:
+        $for N in range(0, NR, 4):
+          $for M in range(MR):
+            vacc${M}x${ABC[N:N+4]} = wasm_f32x4_max(vacc${M}x${ABC[N:N+4]}, vmin);
+
+        $for N in range(0, NR, 4):
+          $for M in range(MR):
+            vacc${M}x${ABC[N:N+4]} = wasm_f32x4_min(vacc${M}x${ABC[N:N+4]}, vmax);
+    $elif ACTIVATION == "RELU":
+      const v128_t vzero = wasm_f32x4_splat(0.0f);
+      $if X86:
+        $for N in range(0, NR, 4):
+          $for M in range(MR):
+            vacc${M}x${ABC[N:N+4]} = wasm_v128_andnot(vacc${M}x${ABC[N:N+4]}, wasm_f32x4_le(vacc${M}x${ABC[N:N+4]}, vzero));
+      $else:
+        $for N in range(0, NR, 4):
+          $for M in range(MR):
+            vacc${M}x${ABC[N:N+4]} = wasm_f32x4_max(vacc${M}x${ABC[N:N+4]}, vzero);
 
     if XNN_LIKELY(nc >= ${NR}) {
       $for M in reversed(range(MR)):