GEMM/IGEMM microkernels with alternative activations in WAsm SIMD
PiperOrigin-RevId: 316715937
diff --git a/src/f32-igemm/wasmsimd-splat.c.in b/src/f32-igemm/wasmsimd-splat.c.in
index 56891d8..e94bf7b 100644
--- a/src/f32-igemm/wasmsimd-splat.c.in
+++ b/src/f32-igemm/wasmsimd-splat.c.in
@@ -12,7 +12,11 @@
#include <xnnpack/igemm.h>
-void xnn_f32_igemm_minmax_ukernel_${MR}x${NR}__wasmsimd_splat_${"x86" if X86 else "arm"}(
+$assert ACTIVATION in ["LINEAR", "RELU", "MINMAX"]
+$ACTIVATION_SUFFIX = {"LINEAR": ""}.get(ACTIVATION, "_" + ACTIVATION.lower())
+$ARCH_SUFFIX = "" if ACTIVATION == "LINEAR" else "_x86" if X86 else "_arm"
+$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
+void xnn_f32_igemm${ACTIVATION_SUFFIX}_ukernel_${MR}x${NR}__wasmsimd_splat${ARCH_SUFFIX}(
size_t mr,
size_t nc,
size_t kc,
@@ -24,7 +28,7 @@
size_t cn_stride,
size_t a_offset,
const float* zero,
- const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
+ const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(mr != 0);
assert(mr <= ${MR});
@@ -54,6 +58,9 @@
c${M} = c${M-1};
}
+ $if ACTIVATION == "MINMAX" and not X86:
+ const v128_t vmin = wasm_v32x4_load_splat(¶ms->scalar.min);
+ const v128_t vmax = wasm_v32x4_load_splat(¶ms->scalar.max);
do {
v128_t vacc0x${ABC[0:4]} = wasm_v128_load(w);
$for N in range(4, NR, 4):
@@ -113,21 +120,35 @@
p -= ${MR} * sizeof(void*);
} while (p != 0);
- const v128_t vmin = wasm_v32x4_load_splat(¶ms->scalar.min);
- $for N in range(0, NR, 4):
- $for M in range(MR):
- $if X86:
- vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vmin, vacc${M}x${ABC[N:N+4]}, wasm_f32x4_lt(vacc${M}x${ABC[N:N+4]}, vmin));
- $else:
- vacc${M}x${ABC[N:N+4]} = wasm_f32x4_max(vacc${M}x${ABC[N:N+4]}, vmin);
+ $if ACTIVATION == "MINMAX":
+ $if X86:
+ const v128_t vmin = wasm_v32x4_load_splat(¶ms->scalar.min);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vmin, vacc${M}x${ABC[N:N+4]}, wasm_f32x4_lt(vacc${M}x${ABC[N:N+4]}, vmin));
- const v128_t vmax = wasm_v32x4_load_splat(¶ms->scalar.max);
- $for N in range(0, NR, 4):
- $for M in range(MR):
- $if X86:
- vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vacc${M}x${ABC[N:N+4]}, vmax, wasm_f32x4_le(vacc${M}x${ABC[N:N+4]}, vmax));
- $else:
- vacc${M}x${ABC[N:N+4]} = wasm_f32x4_min(vacc${M}x${ABC[N:N+4]}, vmax);
+ const v128_t vmax = wasm_v32x4_load_splat(¶ms->scalar.max);
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = wasm_v128_bitselect(vacc${M}x${ABC[N:N+4]}, vmax, wasm_f32x4_le(vacc${M}x${ABC[N:N+4]}, vmax));
+ $else:
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = wasm_f32x4_max(vacc${M}x${ABC[N:N+4]}, vmin);
+
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = wasm_f32x4_min(vacc${M}x${ABC[N:N+4]}, vmax);
+ $elif ACTIVATION == "RELU":
+ const v128_t vzero = wasm_f32x4_splat(0.0f);
+ $if X86:
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = wasm_v128_andnot(vacc${M}x${ABC[N:N+4]}, wasm_f32x4_le(vacc${M}x${ABC[N:N+4]}, vzero));
+ $else:
+ $for N in range(0, NR, 4):
+ $for M in range(MR):
+ vacc${M}x${ABC[N:N+4]} = wasm_f32x4_max(vacc${M}x${ABC[N:N+4]}, vzero);
if XNN_LIKELY(nc >= ${NR}) {
$for M in reversed(range(MR)):