FMA3 microkernels with 4-wide shuffle
- Use the new microkernels by default on 1-gen Zen
PiperOrigin-RevId: 284641473
diff --git a/src/init.c b/src/init.c
index 009749f..8c4424c 100644
--- a/src/init.c
+++ b/src/init.c
@@ -701,14 +701,29 @@
.nr = 16,
};
} else if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_fma3()) {
- xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_5x16__fma3_broadcast,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_5x16__fma3_broadcast,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16__fma3_broadcast,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16__fma3_broadcast,
- .mr = 5,
- .nr = 16,
- };
+ switch (cpuinfo_get_core(0)->uarch) {
+ case cpuinfo_uarch_zen:
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x16s4__fma3_broadcast,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x16s4__fma3_broadcast,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16s4__fma3_broadcast,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16s4__fma3_broadcast,
+ .mr = 4,
+ .nr = 16,
+ .log2_sr = 2,
+ };
+ break;
+ default:
+ xnn_params.f32.gemm = (struct gemm_parameters) {
+ .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_5x16__fma3_broadcast,
+ .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_5x16__fma3_broadcast,
+ .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16__fma3_broadcast,
+ .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16__fma3_broadcast,
+ .mr = 5,
+ .nr = 16,
+ };
+ break;
+ }
} else if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx()) {
xnn_params.f32.gemm = (struct gemm_parameters) {
.gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_5x16__avx_broadcast,