Enable QS8 4x8 dot product GEMM AArch32 microkernel for Cortex A55
PiperOrigin-RevId: 423875189
diff --git a/src/init.c b/src/init.c
index 96fced3..bc1ea37 100644
--- a/src/init.c
+++ b/src/init.c
@@ -213,14 +213,28 @@
#if XNN_ENABLE_ASSEMBLY
if (!XNN_PLATFORM_IOS && cpuinfo_has_arm_neon_dot()) {
- xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_4x8c4__aarch32_neondot_ld64);
- xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_4x8c4__aarch32_neondot_ld64);
- xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8c4__neondot);
- xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8c4__neondot);
- xnn_params.qs8.gemm.init.qs8 = xnn_init_qs8_conv_minmax_rndnu_neon_params;
- xnn_params.qs8.gemm.mr = 4;
- xnn_params.qs8.gemm.nr = 8;
- xnn_params.qs8.gemm.log2_kr = 2;
+ switch (cpuinfo_get_uarch(0)->uarch) {
+ case cpuinfo_uarch_cortex_a55:
+ xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_4x8c4__aarch32_neondot_cortex_a55);
+ xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_4x8c4__aarch32_neondot_ld64);
+ xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8c4__neondot);
+ xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8c4__neondot);
+ xnn_params.qs8.gemm.init.qs8 = xnn_init_qs8_conv_minmax_rndnu_neon_params;
+ xnn_params.qs8.gemm.mr = 4;
+ xnn_params.qs8.gemm.nr = 8;
+ xnn_params.qs8.gemm.log2_kr = 2;
+ break;
+ default:
+ xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_4x8c4__aarch32_neondot_ld64);
+ xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_4x8c4__aarch32_neondot_ld64);
+ xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8c4__neondot);
+ xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8c4__neondot);
+ xnn_params.qs8.gemm.init.qs8 = xnn_init_qs8_conv_minmax_rndnu_neon_params;
+ xnn_params.qs8.gemm.mr = 4;
+ xnn_params.qs8.gemm.nr = 8;
+ xnn_params.qs8.gemm.log2_kr = 2;
+ break;
+ }
} else {
switch (cpuinfo_get_uarch(0)->uarch) {
case cpuinfo_uarch_cortex_a53: