Use GEMM/IGEMM micro-kernels with Swizzle on WAsm SIMD
PiperOrigin-RevId: 273783514
diff --git a/src/init.c b/src/init.c
index 1296eb6..da4ff1e 100644
--- a/src/init.c
+++ b/src/init.c
@@ -653,6 +653,13 @@
};
#elif XNN_ARCH_PNACL || XNN_ARCH_WASMSIMD
+ // Unlike most other architectures, on x86/x86-64 when floating-point instructions
+ // have no NaN arguments, but produce NaN output, the output NaN has sign bit set.
+ // We use it to distinguish x86/x86-64 from other architectures, by doing subtraction
+ // of two infinities (must produce NaN per IEEE 754 standard).
+ static volatile uint32_t minus_inf = UINT32_C(0xFF800000);
+ const bool is_wasm_x86 = (int32_t) xnn_stub_wasm_f32_sub(minus_inf, minus_inf) < 0;
+
/**************************** Q8 micro-kernels ****************************/
xnn_params.q8.gemm = (struct gemm_parameters) {
.gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_2x2__scalar,
@@ -698,14 +705,27 @@
};
/**************************** F32 micro-kernels ****************************/
- xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__psimd_splat,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__psimd_splat,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__psimd_loadsplat,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__psimd_loadsplat,
- .mr = 4,
- .nr = 8,
- };
+ if (is_wasm_x86) {
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8s4__psimd,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8s4__psimd,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8s4__psimd,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8s4__psimd,
+ .mr = 4,
+ .nr = 8,
+ .log2_sr = 2,
+ };
+ } else {
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8s4__psimd,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8s4__psimd,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_igemm_ukernel_1x8s4__psimd,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8s4__psimd,
+ .mr = 6,
+ .nr = 8,
+ .log2_sr = 2,
+ };
+ }
xnn_params.f32.gemm2 = (struct gemm_parameters) {
.gemm = NULL,
.igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2c4__psimd,