Dynamically choose micro-kernel depending on active core
Bazel option xnn_enable_hmp=false disables these optimizations
PiperOrigin-RevId: 303209003
diff --git a/src/convolution-nhwc.c b/src/convolution-nhwc.c
index d40a1db..ad60b3b 100644
--- a/src/convolution-nhwc.c
+++ b/src/convolution-nhwc.c
@@ -307,7 +307,7 @@
.mr = xnn_params.q8.gemm.mr,
.nr = nr,
.kr = kr,
- .default_function = xnn_params.q8.gemm.gemm,
+ .general_case = xnn_params.q8.gemm.gemm,
};
break;
case xnn_ukernel_type_igemm:
@@ -328,7 +328,7 @@
.mr = xnn_params.q8.gemm.mr,
.nr = nr,
.kr = kr,
- .default_function = xnn_params.q8.gemm.igemm,
+ .general_case = xnn_params.q8.gemm.igemm,
};
break;
default:
@@ -636,8 +636,8 @@
.mr = xnn_params.f32.gemm.mr,
.nr = nr,
.kr = kr,
- .default_function = xnn_params.f32.gemm.gemm,
- .mr1_function = xnn_params.f32.gemm.gemm1,
+ .general_case = xnn_params.f32.gemm.gemm,
+ .mr1_case = xnn_params.f32.gemm.gemm1,
};
break;
case xnn_ukernel_type_igemm:
@@ -656,8 +656,8 @@
.mr = xnn_params.f32.gemm.mr,
.nr = nr,
.kr = kr,
- .default_function = xnn_params.f32.gemm.igemm,
- .mr1_function = xnn_params.f32.gemm.igemm1,
+ .general_case = xnn_params.f32.gemm.igemm,
+ .mr1_case = xnn_params.f32.gemm.igemm1,
};
break;
default:
@@ -801,10 +801,10 @@
uint32_t mr = convolution_op->ukernel.gemm.mr;
const uint32_t nr = convolution_op->ukernel.gemm.nr;
- xnn_gemm_ukernel_function gemm_ukernel = convolution_op->ukernel.gemm.default_function;
- if (batch_output_size == 1 && convolution_op->ukernel.gemm.mr1_function != NULL) {
+ struct xnn_hmp_gemm_ukernel gemm_ukernel = convolution_op->ukernel.gemm.general_case;
+ if (batch_output_size == 1 && convolution_op->ukernel.gemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
mr = 1;
- gemm_ukernel = convolution_op->ukernel.gemm.mr1_function;
+ gemm_ukernel = convolution_op->ukernel.gemm.mr1_case;
}
convolution_op->context.gemm = (struct gemm_context) {
@@ -833,15 +833,35 @@
}
}
if (groups == 1) {
- convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
- convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
+ #if XNN_MAX_UARCH_TYPES > 1
+ if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
+ convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d_with_uarch;
+ convolution_op->compute.task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
+ convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
+ }
+ #else
+ convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
+ convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
+ #endif
convolution_op->compute.range[0] = batch_output_size;
convolution_op->compute.range[1] = group_output_channels;
convolution_op->compute.tile[0] = mr;
convolution_op->compute.tile[1] = nc;
} else {
- convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
- convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
+ #if XNN_MAX_UARCH_TYPES > 1
+ if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
+ convolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_grouped_gemm;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
+ }
+ #else
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
+ #endif
convolution_op->compute.range[0] = groups;
convolution_op->compute.range[1] = batch_output_size;
convolution_op->compute.range[2] = group_output_channels;
@@ -864,10 +884,10 @@
uint32_t mr = convolution_op->ukernel.igemm.mr;
const uint32_t nr = convolution_op->ukernel.igemm.nr;
- xnn_igemm_ukernel_function igemm_ukernel = convolution_op->ukernel.igemm.default_function;
- if (output_size == 1 && convolution_op->ukernel.igemm.mr1_function != NULL) {
+ struct xnn_hmp_igemm_ukernel igemm_ukernel = convolution_op->ukernel.igemm.general_case;
+ if (output_size == 1 && convolution_op->ukernel.igemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
mr = 1;
- igemm_ukernel = convolution_op->ukernel.igemm.mr1_function;
+ igemm_ukernel = convolution_op->ukernel.igemm.mr1_case;
}
const size_t tiled_output_size = round_up(output_size, mr);
@@ -924,16 +944,36 @@
}
}
if (groups == 1) {
- convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
- convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_igemm;
+ #if XNN_MAX_UARCH_TYPES > 1
+ if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
+ convolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_igemm;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_igemm;
+ }
+ #else
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_igemm;
+ #endif
convolution_op->compute.range[0] = batch_size;
convolution_op->compute.range[1] = output_size;
convolution_op->compute.range[2] = group_output_channels;
convolution_op->compute.tile[0] = mr;
convolution_op->compute.tile[1] = nc;
} else {
- convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
- convolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_igemm;
+ #if XNN_MAX_UARCH_TYPES > 1
+ if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
+ convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d_with_uarch;
+ convolution_op->compute.task_4d_tile_2d_with_id = (pthreadpool_task_4d_tile_2d_with_id_t) xnn_compute_hmp_grouped_igemm;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
+ convolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_igemm;
+ }
+ #else
+ convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
+ convolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_igemm;
+ #endif
convolution_op->compute.range[0] = batch_size;
convolution_op->compute.range[1] = groups;
convolution_op->compute.range[2] = output_size;
diff --git a/src/deconvolution-nhwc.c b/src/deconvolution-nhwc.c
index f58832f..c23c0cb 100644
--- a/src/deconvolution-nhwc.c
+++ b/src/deconvolution-nhwc.c
@@ -190,8 +190,8 @@
const uint32_t mr = xnn_params.q8.gemm.mr;
const uint32_t nr = xnn_params.q8.gemm.nr;
const uint32_t kr = UINT32_C(1) << xnn_params.q8.gemm.log2_kr;
- const xnn_igemm_ukernel_function igemm_ukernel_function = xnn_params.q8.gemm.igemm;
- const xnn_gemm_ukernel_function gemm_ukernel_function = xnn_params.q8.gemm.gemm;
+ const struct xnn_hmp_igemm_ukernel igemm_ukernel = xnn_params.q8.gemm.igemm;
+ const struct xnn_hmp_gemm_ukernel gemm_ukernel = xnn_params.q8.gemm.gemm;
const uint32_t n_stride = round_up(group_output_channels, nr);
const uint32_t k_stride = round_up_po2(group_input_channels, kr);
@@ -287,8 +287,8 @@
deconvolution_op->type = xnn_operator_type_deconvolution_nhwc_q8;
deconvolution_op->ukernel.type = ukernel_type;
deconvolution_op->ukernel.igemm = (struct xnn_ukernel_igemm) {
- .default_function = igemm_ukernel_function,
- .gemm_function = gemm_ukernel_function,
+ .general_case = igemm_ukernel,
+ .gemm_case = gemm_ukernel,
.mr = mr,
.nr = nr,
.kr = kr,
@@ -431,17 +431,17 @@
uint32_t nr = xnn_params.f32.gemm.nr;
uint32_t kr = UINT32_C(1) << xnn_params.f32.gemm.log2_kr;
uint32_t sr = UINT32_C(1) << xnn_params.f32.gemm.log2_sr;
- xnn_igemm_ukernel_function igemm_ukernel_function = xnn_params.f32.gemm.igemm;
- xnn_gemm_ukernel_function gemm_ukernel_function = xnn_params.f32.gemm.gemm;
+ struct xnn_hmp_igemm_ukernel igemm_ukernel = xnn_params.f32.gemm.igemm;
+ struct xnn_hmp_gemm_ukernel gemm_ukernel = xnn_params.f32.gemm.gemm;
if (nr > group_output_channels) {
// Default micro-kernel is suboptimal. Try to find a better micro-kernel.
- if (xnn_params.f32.gemm2.igemm != NULL) {
+ if (xnn_params.f32.gemm2.igemm.function[XNN_UARCH_DEFAULT] != NULL) {
mr = xnn_params.f32.gemm2.mr;
nr = xnn_params.f32.gemm2.nr;
kr = UINT32_C(1) << xnn_params.f32.gemm2.log2_kr;
sr = UINT32_C(1) << xnn_params.f32.gemm2.log2_sr;
- igemm_ukernel_function = xnn_params.f32.gemm2.igemm;
- gemm_ukernel_function = xnn_params.f32.gemm2.gemm;
+ igemm_ukernel = xnn_params.f32.gemm2.igemm;
+ gemm_ukernel = xnn_params.f32.gemm2.gemm;
}
}
@@ -531,8 +531,8 @@
deconvolution_op->type = xnn_operator_type_deconvolution_nhwc_f32;
deconvolution_op->ukernel.type = ukernel_type;
deconvolution_op->ukernel.igemm = (struct xnn_ukernel_igemm) {
- .default_function = igemm_ukernel_function,
- .gemm_function = gemm_ukernel_function,
+ .general_case = igemm_ukernel,
+ .gemm_case = gemm_ukernel,
.mr = mr,
.nr = nr,
.kr = kr,
@@ -615,10 +615,10 @@
.ba_stride = input_height * input_width * deconvolution_op->input_pixel_stride << log2_input_element_size,
.bc_stride = output_size * deconvolution_op->output_pixel_stride << log2_output_element_size,
.log2_csize = log2_output_element_size,
- .ukernel = deconvolution_op->ukernel.igemm.default_function,
+ .ukernel = deconvolution_op->ukernel.igemm.general_case,
};
- if (output_size == 1 && deconvolution_op->ukernel.igemm.mr1_function != NULL) {
- deconvolution_op->context.igemm.ukernel = deconvolution_op->ukernel.igemm.mr1_function;
+ if (output_size == 1 && deconvolution_op->ukernel.igemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
+ deconvolution_op->context.igemm.ukernel = deconvolution_op->ukernel.igemm.mr1_case;
}
memcpy(&deconvolution_op->context.igemm.params, params, sizeof(deconvolution_op->context.igemm.params));
@@ -755,7 +755,7 @@
.ba_stride = input_height * input_width * input_pixel_stride,
.bc_stride = output_size * output_pixel_stride,
.log2_csize = log2_output_element_size,
- .ukernel = deconvolution_op->ukernel.igemm.gemm_function,
+ .ukernel = deconvolution_op->ukernel.igemm.gemm_case,
};
memcpy(&deconvolution_op->context.subgemm.params, params, sizeof(deconvolution_op->context.subgemm.params));
} else {
@@ -773,7 +773,7 @@
.ba_stride = input_height * input_width * input_pixel_stride,
.bc_stride = output_size * output_pixel_stride,
.log2_csize = log2_output_element_size,
- .ukernel = deconvolution_op->ukernel.igemm.default_function,
+ .ukernel = deconvolution_op->ukernel.igemm.general_case,
};
memcpy(&deconvolution_op->context.subconv.params, params, sizeof(deconvolution_op->context.subconv.params));
}
@@ -901,7 +901,7 @@
const bool use_gemm = no_padding && no_adjustment &&
deconvolution_op->kernel_height == deconvolution_op->stride_height &&
deconvolution_op->kernel_width == deconvolution_op->stride_width &&
- deconvolution_op->ukernel.igemm.gemm_function != NULL;
+ deconvolution_op->ukernel.igemm.gemm_case.function[XNN_UARCH_DEFAULT] != NULL;
return setup_subconv2d_path(
deconvolution_op,
batch_size,
diff --git a/src/fully-connected-nc.c b/src/fully-connected-nc.c
index 447ad17..159e448 100644
--- a/src/fully-connected-nc.c
+++ b/src/fully-connected-nc.c
@@ -175,7 +175,7 @@
fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
- .default_function = xnn_params.q8.gemm.gemm,
+ .general_case = xnn_params.q8.gemm.gemm,
.mr = xnn_params.q8.gemm.mr,
.nr = nr,
.kr = kr,
@@ -310,8 +310,8 @@
fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
- .default_function = xnn_params.f32.gemm.gemm,
- .mr1_function = xnn_params.f32.gemm.gemm1,
+ .general_case = xnn_params.f32.gemm.gemm,
+ .mr1_case = xnn_params.f32.gemm.gemm1,
.mr = xnn_params.f32.gemm.mr,
.nr = nr,
.kr = kr,
@@ -366,9 +366,9 @@
uint32_t mr = fully_connected_op->ukernel.gemm.mr;
const uint32_t nr = fully_connected_op->ukernel.gemm.nr;
- xnn_gemm_ukernel_function gemm_ukernel = fully_connected_op->ukernel.gemm.default_function;
- if (batch_size == 1 && fully_connected_op->ukernel.gemm.mr1_function != NULL) {
- gemm_ukernel = fully_connected_op->ukernel.gemm.mr1_function;
+ struct xnn_hmp_gemm_ukernel gemm_ukernel = fully_connected_op->ukernel.gemm.general_case;
+ if (batch_size == 1 && fully_connected_op->ukernel.gemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
+ gemm_ukernel = fully_connected_op->ukernel.gemm.mr1_case;
mr = 1;
}
diff --git a/src/init.c b/src/init.c
index 18d7522..ac509a7 100644
--- a/src/init.c
+++ b/src/init.c
@@ -74,8 +74,8 @@
/**************************** Q8 micro-kernels ****************************/
#ifndef XNN_NO_Q8_OPERATORS
xnn_params.q8.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_4x8__neon,
- .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_4x8__neon,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_4x8__neon),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_4x8__neon),
.mr = 4,
.nr = 8,
};
@@ -133,14 +133,14 @@
/**************************** F32 micro-kernels ****************************/
#ifndef XNN_NO_F32_OPERATORS
#if XNN_ENABLE_ASSEMBLY
- switch (cpuinfo_get_core(0)->uarch) {
+ switch (cpuinfo_get_uarch(0)->uarch) {
case cpuinfo_uarch_cortex_a5:
case cpuinfo_uarch_cortex_a7:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_ld64),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64),
.mr = 4,
.nr = 8,
};
@@ -149,10 +149,10 @@
case cpuinfo_uarch_cortex_a53:
case cpuinfo_uarch_cortex_a55r0:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a53,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a53,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a53),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a53),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64),
.mr = 4,
.nr = 8,
};
@@ -160,10 +160,10 @@
case cpuinfo_uarch_cortex_a55:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a55,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a55,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a55),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a55),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64),
.mr = 4,
.nr = 8,
};
@@ -173,10 +173,10 @@
case cpuinfo_uarch_cortex_a72:
case cpuinfo_uarch_cortex_a73:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_pld_cortex_a75,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_pld_cortex_a75),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_pld_cortex_a75),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64),
.mr = 4,
.nr = 8,
};
@@ -185,28 +185,65 @@
case cpuinfo_uarch_krait:
default:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64),
.mr = 4,
.nr = 8,
};
break;
}
+ #if XNN_MAX_UARCH_TYPES > 1
+ {
+ /* Choose micro-kernels for little cores according to micro-kernel specification for the big core */
+ const uint32_t mr = xnn_params.f32.gemm.mr;
+ const uint32_t nr = xnn_params.f32.gemm.nr;
+ const uint32_t log2_sr = xnn_params.f32.gemm.log2_sr;
+ for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
+ const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(i);
+ if (uarch_info == NULL) {
+ /* No more microarchitectures in the system */
+ break;
+ }
+
+ switch (uarch_info->uarch) {
+ case cpuinfo_uarch_cortex_a53:
+ case cpuinfo_uarch_cortex_a55r0:
+ if (mr == 4 && nr == 8 && log2_sr == 0) {
+ xnn_params.f32.gemm.gemm.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a53;
+ xnn_params.f32.gemm.igemm.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a53;
+ xnn_params.f32.gemm.gemm1.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64;
+ xnn_params.f32.gemm.igemm1.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64;
+ }
+ break;
+ case cpuinfo_uarch_cortex_a55:
+ if (mr == 4 && nr == 8 && log2_sr == 0) {
+ xnn_params.f32.gemm.gemm.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a55;
+ xnn_params.f32.gemm.igemm.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a55;
+ xnn_params.f32.gemm.gemm1.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64;
+ xnn_params.f32.gemm.igemm1.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64;
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ #endif // XNN_MAX_UARCH_TYPES > 1
#else // XNN_ENABLE_ASSEMBLY
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__neon_lane_ld128,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neon_lane_ld128,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__neon_lane_ld128),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__neon_lane_ld128),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neon_lane_ld64),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neon_lane_ld64),
.mr = 4,
.nr = 8,
};
#endif // XNN_ENABLE_ASSEMBLY
xnn_params.f32.gemm2 = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2__neon_lane_ld64,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__neon_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2__neon_lane_ld64),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__neon_lane_ld64),
.mr = 4,
.nr = 2,
};
@@ -338,8 +375,8 @@
/**************************** Q8 micro-kernels ****************************/
#ifndef XNN_NO_Q8_OPERATORS
xnn_params.q8.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_8x8__neon,
- .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_8x8__neon,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_8x8__neon),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_8x8__neon),
.mr = 8,
.nr = 8,
};
@@ -390,19 +427,19 @@
#if XNN_PLATFORM_IOS
#if XNN_ENABLE_ASSEMBLY
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ios,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_ios,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ios),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_ios),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57),
.mr = 6,
.nr = 8,
};
#else // !XNN_ENABLE_ASSEMBLY
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__neonfma_lane_ld64,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__neonfma_lane_ld64,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neonfma_lane_ld64,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neonfma_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__neonfma_lane_ld64),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__neonfma_lane_ld64),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neonfma_lane_ld64),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neonfma_lane_ld64),
.mr = 6,
.nr = 8,
};
@@ -412,20 +449,20 @@
switch (cpuinfo_get_core(0)->uarch) {
case cpuinfo_uarch_cortex_a57:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a57,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a57,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a57),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a57),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57),
.mr = 6,
.nr = 8,
};
break;
case cpuinfo_uarch_cortex_a72:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75),
.mr = 4,
.nr = 8,
};
@@ -435,10 +472,10 @@
case cpuinfo_uarch_exynos_m3:
case cpuinfo_uarch_exynos_m4:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a75,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a75),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75),
.mr = 6,
.nr = 8,
};
@@ -446,10 +483,10 @@
case cpuinfo_uarch_exynos_m1:
case cpuinfo_uarch_exynos_m2:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8s4__neonfma,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8s4__neonfma,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8s4__neonfma,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8s4__neonfma,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8s4__neonfma),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8s4__neonfma),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8s4__neonfma),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8s4__neonfma),
.mr = 6,
.nr = 8,
.log2_sr = 2,
@@ -459,30 +496,30 @@
case cpuinfo_uarch_cortex_a53:
case cpuinfo_uarch_cortex_a55r0:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53),
.mr = 6,
.nr = 8,
};
break;
case cpuinfo_uarch_cortex_a55:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a55,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a55,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a55),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a55),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53),
.mr = 6,
.nr = 8,
};
break;
case cpuinfo_uarch_cortex_a73:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a73,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a73,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a73),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a73),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75),
.mr = 6,
.nr = 8,
};
@@ -492,29 +529,76 @@
case cpuinfo_uarch_exynos_m5:
case cpuinfo_uarch_kryo:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a57,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a57,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a57),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a57),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a57),
.mr = 4,
.nr = 8,
};
break;
}
+ #if XNN_MAX_UARCH_TYPES > 1
+ {
+ /* Choose micro-kernels for little cores according to micro-kernel specification for the big core */
+ const uint32_t mr = xnn_params.f32.gemm.mr;
+ const uint32_t nr = xnn_params.f32.gemm.nr;
+ const uint32_t log2_sr = xnn_params.f32.gemm.log2_sr;
+ for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
+ const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(i);
+ if (uarch_info == NULL) {
+ /* No more microarchitectures in the system */
+ break;
+ }
+
+ switch (uarch_info->uarch) {
+ case cpuinfo_uarch_cortex_a53:
+ case cpuinfo_uarch_cortex_a55r0:
+ if (mr == 6 && nr == 8 && log2_sr == 0) {
+ xnn_params.f32.gemm.gemm.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53;
+ xnn_params.f32.gemm.igemm.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a53;
+ xnn_params.f32.gemm.gemm1.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53;
+ xnn_params.f32.gemm.igemm1.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53;
+ } else if (mr == 4 && nr == 8 && log2_sr == 0) {
+ xnn_params.f32.gemm.gemm.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53;
+ xnn_params.f32.gemm.igemm.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a53;
+ xnn_params.f32.gemm.gemm1.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53;
+ xnn_params.f32.gemm.igemm1.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53;
+ }
+ break;
+ case cpuinfo_uarch_cortex_a55:
+ if (mr == 6 && nr == 8 && log2_sr == 0) {
+ xnn_params.f32.gemm.gemm.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a55;
+ xnn_params.f32.gemm.igemm.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__aarch64_neonfma_cortex_a55;
+ xnn_params.f32.gemm.gemm1.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53;
+ xnn_params.f32.gemm.igemm1.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53;
+ } else if (mr == 4 && nr == 8 && log2_sr == 0) {
+ xnn_params.f32.gemm.gemm.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a55;
+ xnn_params.f32.gemm.igemm.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a55;
+ xnn_params.f32.gemm.gemm1.function[i] = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53;
+ xnn_params.f32.gemm.igemm1.function[i] = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a53;
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ #endif // XNN_MAX_UARCH_TYPES > 1
#else // !XNN_ENABLE_ASSEMBLY
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__neonfma_lane_ld64,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__neonfma_lane_ld64,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neonfma_lane_ld64,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neonfma_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8__neonfma_lane_ld64),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8__neonfma_lane_ld64),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__neonfma_lane_ld64),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__neonfma_lane_ld64),
.mr = 6,
.nr = 8,
};
#endif // XNN_ENABLE_ASSEMBLY
#endif // XNN_PLATFORM_IOS
xnn_params.f32.gemm2 = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2__neonfma_lane_ld64,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__neonfma_lane_ld64,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2__neonfma_lane_ld64),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__neonfma_lane_ld64),
.mr = 4,
.nr = 2,
};
@@ -732,8 +816,8 @@
/**************************** Q8 micro-kernels ****************************/
#ifndef XNN_NO_Q8_OPERATORS
xnn_params.q8.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_4x4c2__sse2,
- .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_4x4c2__sse2,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_4x4c2__sse2),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_4x4c2__sse2),
.mr = 4,
.nr = 4,
.log2_kr = 1,
@@ -784,10 +868,10 @@
#ifndef XNN_NO_F32_OPERATORS
if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx512f()) {
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_7x16__avx512f_broadcast,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_7x16__avx512f_broadcast,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16__avx512f_broadcast,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16__avx512f_broadcast,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_7x16__avx512f_broadcast),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_7x16__avx512f_broadcast),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16__avx512f_broadcast),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16__avx512f_broadcast),
.mr = 7,
.nr = 16,
};
@@ -796,10 +880,10 @@
case cpuinfo_uarch_zen:
case cpuinfo_uarch_dhyana:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x16s4__fma3_broadcast,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x16s4__fma3_broadcast,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16s4__fma3_broadcast,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16s4__fma3_broadcast,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x16s4__fma3_broadcast),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x16s4__fma3_broadcast),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16s4__fma3_broadcast),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16s4__fma3_broadcast),
.mr = 4,
.nr = 16,
.log2_sr = 2,
@@ -807,10 +891,10 @@
break;
default:
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_5x16__fma3_broadcast,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_5x16__fma3_broadcast,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16__fma3_broadcast,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16__fma3_broadcast,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_5x16__fma3_broadcast),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_5x16__fma3_broadcast),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16__fma3_broadcast),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16__fma3_broadcast),
.mr = 5,
.nr = 16,
};
@@ -818,26 +902,26 @@
}
} else if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx()) {
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_5x16__avx_broadcast,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_5x16__avx_broadcast,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16__avx_broadcast,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16__avx_broadcast,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_5x16__avx_broadcast),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_5x16__avx_broadcast),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x16__avx_broadcast),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x16__avx_broadcast),
.mr = 5,
.nr = 16,
};
} else {
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__sse_load1,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__sse_load1,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__sse_load1,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__sse_load1,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__sse_load1),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__sse_load1),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__sse_load1),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__sse_load1),
.mr = 4,
.nr = 8,
};
}
xnn_params.f32.gemm2 = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2c4__sse,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2c4__sse,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2c4__sse),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2c4__sse),
.mr = 4,
.nr = 2,
.log2_kr = 2,
@@ -1157,8 +1241,8 @@
/**************************** Q8 micro-kernels ****************************/
#ifndef XNN_NO_Q8_OPERATORS
xnn_params.q8.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_2x2__scalar,
- .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_2x2__scalar,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_2x2__scalar),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_2x2__scalar),
.mr = 2,
.nr = 2,
};
@@ -1208,27 +1292,27 @@
#ifndef XNN_NO_F32_OPERATORS
if (is_wasm_x86) {
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__psimd_splat,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__psimd_splat,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__psimd_splat,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__psimd_splat,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x8__psimd_splat),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x8__psimd_splat),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8__psimd_splat),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8__psimd_splat),
.mr = 4,
.nr = 8,
};
} else {
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8s4__psimd,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8s4__psimd,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8s4__psimd,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8s4__psimd,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_6x8s4__psimd),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_6x8s4__psimd),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x8s4__psimd),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x8s4__psimd),
.mr = 6,
.nr = 8,
.log2_sr = 2,
};
}
xnn_params.f32.gemm2 = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2c4__psimd,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2c4__psimd,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2c4__psimd),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2c4__psimd),
.mr = 4,
.nr = 2,
.log2_kr = 2,
@@ -1367,8 +1451,8 @@
/**************************** Q8 micro-kernels ****************************/
#ifndef XNN_NO_Q8_OPERATORS
xnn_params.q8.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_2x2__scalar,
- .igemm = (xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_2x2__scalar,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_q8_gemm_ukernel_2x2__scalar),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_q8_igemm_ukernel_2x2__scalar),
.mr = 2,
.nr = 2,
};
@@ -1418,26 +1502,26 @@
#ifndef XNN_NO_F32_OPERATORS
if (is_wasm_x86) {
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_2x4__scalar,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_2x4__scalar,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x4__wasm,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x4__wasm,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_2x4__scalar),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_2x4__scalar),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x4__wasm),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x4__wasm),
.mr = 2,
.nr = 4,
};
} else {
xnn_params.f32.gemm = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x4__wasm,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x4__wasm,
- .gemm1 = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x4__wasm,
- .igemm1 = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x4__wasm,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x4__wasm),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x4__wasm),
+ .gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_1x4__wasm),
+ .igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_1x4__wasm),
.mr = 4,
.nr = 4,
};
}
xnn_params.f32.gemm2 = (struct gemm_parameters) {
- .gemm = (xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2__wasm,
- .igemm = (xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__wasm,
+ .gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_f32_gemm_ukernel_4x2__wasm),
+ .igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_f32_igemm_ukernel_4x2__wasm),
.mr = 4,
.nr = 2,
};
diff --git a/src/operator-run.c b/src/operator-run.c
index 9305819..3f87633 100644
--- a/src/operator-run.c
+++ b/src/operator-run.c
@@ -32,7 +32,7 @@
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
- context->ukernel(
+ context->ukernel.function[XNN_UARCH_DEFAULT](
mr_block_size,
nr_block_size,
k_scaled,
@@ -55,7 +55,7 @@
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
- context->ukernel(
+ context->ukernel.function[XNN_UARCH_DEFAULT](
mr_block_size,
nr_block_size,
context->k_scaled,
@@ -97,7 +97,7 @@
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
- context->ukernel(
+ context->ukernel.function[XNN_UARCH_DEFAULT](
mr_block_size,
nr_block_size,
context->kc,
@@ -123,7 +123,7 @@
const size_t ks = context->ks;
const size_t cm_stride = context->cm_stride;
- context->ukernel(
+ context->ukernel.function[XNN_UARCH_DEFAULT](
mr_block_size,
nr_block_size,
context->kc,
@@ -163,7 +163,7 @@
const size_t ax_stride = context->ax_stride;
const size_t cx_stride = context->cx_stride;
- context->ukernel(
+ context->ukernel.function[XNN_UARCH_DEFAULT](
slice_x_size,
nc_block_size,
context->kc,
@@ -200,7 +200,7 @@
const size_t ax_stride = context->ax_stride;
const size_t cx_stride = context->cx_stride;
- context->ukernel(
+ context->ukernel.function[XNN_UARCH_DEFAULT](
slice_x_size,
nc_block_size,
context->kc,
@@ -237,7 +237,7 @@
const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
const size_t cx_stride = context->cx_stride;
- context->ukernel(
+ context->ukernel.function[XNN_UARCH_DEFAULT](
slice_x_size,
nc_block_size,
context->kc,
@@ -275,7 +275,7 @@
const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
const size_t cx_stride = context->cx_stride;
- context->ukernel(
+ context->ukernel.function[XNN_UARCH_DEFAULT](
slice_x_size,
nc_block_size,
context->kc,
@@ -773,6 +773,113 @@
&context->params);
}
+#if XNN_MAX_UARCH_TYPES > 1
+ void xnn_compute_hmp_grouped_gemm(
+ const struct gemm_context context[restrict static 1],
+ uint32_t uarch_index,
+ size_t group_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size)
+ {
+ const size_t k_scaled = context->k_scaled;
+ const size_t a_stride = context->a_stride;
+ const size_t cm_stride = context->cm_stride;
+
+ context->ukernel.function[uarch_index](
+ mr_block_size,
+ nr_block_size,
+ k_scaled,
+ (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
+ a_stride,
+ (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
+ (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
+ cm_stride,
+ context->cn_stride,
+ &context->params);
+ }
+
+ void xnn_compute_hmp_gemm(
+ const struct gemm_context context[restrict static 1],
+ uint32_t uarch_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size)
+ {
+ const size_t a_stride = context->a_stride;
+ const size_t cm_stride = context->cm_stride;
+
+ context->ukernel.function[uarch_index](
+ mr_block_size,
+ nr_block_size,
+ context->k_scaled,
+ (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
+ a_stride,
+ (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
+ (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
+ cm_stride,
+ context->cn_stride,
+ &context->params);
+ }
+
+ void xnn_compute_hmp_grouped_igemm(
+ const struct igemm_context context[restrict static 1],
+ uint32_t uarch_index,
+ size_t batch_index,
+ size_t group_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size)
+ {
+ const size_t ks = context->ks;
+ const size_t cm_stride = context->cm_stride;
+
+ context->ukernel.function[uarch_index](
+ mr_block_size,
+ nr_block_size,
+ context->kc,
+ context->ks_scaled,
+ (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
+ (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
+ (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
+ cm_stride,
+ context->cn_stride,
+ context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
+ context->zero,
+ &context->params);
+ }
+
+ void xnn_compute_hmp_igemm(
+ const struct igemm_context context[restrict static 1],
+ uint32_t uarch_index,
+ size_t batch_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size)
+ {
+ const size_t ks = context->ks;
+ const size_t cm_stride = context->cm_stride;
+
+ context->ukernel.function[uarch_index](
+ mr_block_size,
+ nr_block_size,
+ context->kc,
+ context->ks_scaled,
+ (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
+ (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
+ (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
+ cm_stride,
+ context->cn_stride,
+ context->a_offset + batch_index * context->ba_stride,
+ context->zero,
+ &context->params);
+ }
+#endif // XNN_MAX_UARCH_TYPES > 1
+
enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
{
if (!xnn_params.initialized) {
@@ -909,6 +1016,53 @@
op->compute.tile[0], op->compute.tile[1],
PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
break;
+#if XNN_MAX_UARCH_TYPES > 1
+ case xnn_parallelization_type_2d_tile_2d_with_uarch:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ assert(op->compute.tile[0] != 0);
+ assert(op->compute.tile[1] != 0);
+ pthreadpool_parallelize_2d_tile_2d_with_uarch(
+ threadpool,
+ op->compute.task_2d_tile_2d_with_id,
+ &op->context,
+ 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
+ op->compute.range[0], op->compute.range[1],
+ op->compute.tile[0], op->compute.tile[1],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_3d_tile_2d_with_uarch:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ assert(op->compute.range[2] != 0);
+ assert(op->compute.tile[0] != 0);
+ assert(op->compute.tile[1] != 0);
+ pthreadpool_parallelize_3d_tile_2d_with_uarch(
+ threadpool,
+ op->compute.task_3d_tile_2d_with_id,
+ &op->context,
+ 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
+ op->compute.range[0], op->compute.range[1], op->compute.range[2],
+ op->compute.tile[0], op->compute.tile[1],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+ case xnn_parallelization_type_4d_tile_2d_with_uarch:
+ assert(op->compute.range[0] != 0);
+ assert(op->compute.range[1] != 0);
+ assert(op->compute.range[2] != 0);
+ assert(op->compute.range[3] != 0);
+ assert(op->compute.tile[0] != 0);
+ assert(op->compute.tile[1] != 0);
+ pthreadpool_parallelize_4d_tile_2d_with_uarch(
+ threadpool,
+ op->compute.task_4d_tile_2d_with_id,
+ &op->context,
+ 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
+ op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
+ op->compute.tile[0], op->compute.tile[1],
+ PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
+ break;
+#endif // XNN_MAX_UARCH_TYPES > 1
default:
XNN_UNREACHABLE;
}
diff --git a/src/xnnpack/common.h b/src/xnnpack/common.h
index f145bc8..f108db3 100644
--- a/src/xnnpack/common.h
+++ b/src/xnnpack/common.h
@@ -91,6 +91,16 @@
#define XNN_PLATFORM_WEB 0
#endif
+#ifndef XNN_MAX_UARCH_TYPES
+ #if (XNN_ARCH_ARM || XNN_ARCH_ARM64) && !XNN_PLATFORM_IOS
+ #define XNN_MAX_UARCH_TYPES 3
+ #else
+ #define XNN_MAX_UARCH_TYPES 1
+ #endif
+#endif
+
+#define XNN_UARCH_DEFAULT 0
+
#if defined(__GNUC__)
#if defined(__clang__) || (__GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ >= 5)
#define XNN_UNREACHABLE do { __builtin_unreachable(); } while (0)
diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h
index dd16210..e6b4223 100644
--- a/src/xnnpack/compute.h
+++ b/src/xnnpack/compute.h
@@ -26,6 +26,11 @@
xnn_parallelization_type_4d_tile_2d,
xnn_parallelization_type_5d_tile_2d,
xnn_parallelization_type_6d_tile_2d,
+#if XNN_MAX_UARCH_TYPES > 1
+ xnn_parallelization_type_2d_tile_2d_with_uarch,
+ xnn_parallelization_type_3d_tile_2d_with_uarch,
+ xnn_parallelization_type_4d_tile_2d_with_uarch,
+#endif // XNN_MAX_UARCH_TYPES > 1
};
struct compute_parameters {
@@ -40,6 +45,11 @@
pthreadpool_task_4d_tile_2d_t task_4d_tile_2d;
pthreadpool_task_5d_tile_2d_t task_5d_tile_2d;
pthreadpool_task_6d_tile_2d_t task_6d_tile_2d;
+#if XNN_MAX_UARCH_TYPES > 1
+ pthreadpool_task_2d_tile_2d_with_id_t task_2d_tile_2d_with_id;
+ pthreadpool_task_3d_tile_2d_with_id_t task_3d_tile_2d_with_id;
+ pthreadpool_task_4d_tile_2d_with_id_t task_4d_tile_2d_with_id;
+#endif // XNN_MAX_UARCH_TYPES > 1
};
size_t range[6];
size_t tile[2];
@@ -57,7 +67,7 @@
size_t cn_stride;
size_t cg_stride;
uint32_t log2_csize;
- xnn_gemm_ukernel_function ukernel;
+ struct xnn_hmp_gemm_ukernel ukernel;
union {
union xnn_q8_gemm_params q8;
union xnn_f32_output_params f32;
@@ -79,6 +89,25 @@
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
+
+ #if XNN_MAX_UARCH_TYPES > 1
+ XNN_PRIVATE void xnn_compute_hmp_grouped_gemm(
+ const struct gemm_context context[restrict static 1],
+ uint32_t uarch_index,
+ size_t group_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size);
+
+ XNN_PRIVATE void xnn_compute_hmp_gemm(
+ const struct gemm_context context[restrict static 1],
+ uint32_t uarch_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size);
+ #endif // XNN_MAX_UARCH_TYPES > 1
#endif
// Context for Sparse Matrix-Dense Matrix Multiplication.
@@ -136,7 +165,7 @@
size_t ba_stride;
size_t bc_stride;
uint32_t log2_csize;
- xnn_igemm_ukernel_function ukernel;
+ struct xnn_hmp_igemm_ukernel ukernel;
union {
union xnn_q8_gemm_params q8;
union xnn_f32_output_params f32;
@@ -160,6 +189,27 @@
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
+
+ #if XNN_MAX_UARCH_TYPES > 1
+ XNN_PRIVATE void xnn_compute_hmp_grouped_igemm(
+ const struct igemm_context context[restrict static 1],
+ uint32_t uarch_index,
+ size_t batch_index,
+ size_t group_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size);
+
+ XNN_PRIVATE void xnn_compute_hmp_igemm(
+ const struct igemm_context context[restrict static 1],
+ uint32_t uarch_index,
+ size_t batch_index,
+ size_t mr_block_start,
+ size_t nr_block_start,
+ size_t mr_block_size,
+ size_t nr_block_size);
+ #endif // XNN_MAX_UARCH_TYPES > 1
#endif
struct subgemm_context {
@@ -177,7 +227,7 @@
size_t ba_stride;
size_t bc_stride;
uint32_t log2_csize;
- xnn_gemm_ukernel_function ukernel;
+ struct xnn_hmp_gemm_ukernel ukernel;
union {
union xnn_q8_gemm_params q8;
union xnn_f32_output_params f32;
@@ -221,7 +271,7 @@
size_t ba_stride;
size_t bc_stride;
uint32_t log2_csize;
- xnn_igemm_ukernel_function ukernel;
+ struct xnn_hmp_igemm_ukernel ukernel;
union {
union xnn_q8_gemm_params q8;
union xnn_f32_output_params f32;
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index c416cbc..6ecf25c 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -113,17 +113,17 @@
};
struct xnn_ukernel_gemm {
- xnn_gemm_ukernel_function default_function;
- xnn_gemm_ukernel_function mr1_function;
+ struct xnn_hmp_gemm_ukernel general_case;
+ struct xnn_hmp_gemm_ukernel mr1_case;
uint8_t mr;
uint8_t nr;
uint8_t kr;
};
struct xnn_ukernel_igemm {
- xnn_igemm_ukernel_function default_function;
- xnn_igemm_ukernel_function mr1_function;
- xnn_gemm_ukernel_function gemm_function;
+ struct xnn_hmp_igemm_ukernel general_case;
+ struct xnn_hmp_igemm_ukernel mr1_case;
+ struct xnn_hmp_gemm_ukernel gemm_case;
uint8_t mr;
uint8_t nr;
uint8_t kr;
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index 458cb9d..dfd69cb 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -1200,13 +1200,62 @@
float scale_mantissa,
float scale_exponent);
+struct xnn_hmp_gemm_ukernel {
+ xnn_gemm_ukernel_function function[XNN_MAX_UARCH_TYPES];
+};
+
+static inline struct xnn_hmp_gemm_ukernel xnn_init_hmp_gemm_ukernel(xnn_gemm_ukernel_function function) {
+ struct xnn_hmp_gemm_ukernel ukernel = { function };
+ for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
+ ukernel.function[i] = function;
+ }
+ return ukernel;
+}
+
+static inline bool xnn_is_hmp_gemm_ukernel(struct xnn_hmp_gemm_ukernel ukernel) {
+#if XNN_MAX_UARCH_TYPES == 1
+ return false;
+#else
+ uintptr_t default_function = (uintptr_t) ukernel.function[XNN_UARCH_DEFAULT];
+ uintptr_t difference = 0;
+ for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
+ difference |= (default_function ^ (uintptr_t) ukernel.function[i]);
+ }
+ return difference != 0;
+#endif
+}
+
+struct xnn_hmp_igemm_ukernel {
+ xnn_igemm_ukernel_function function[XNN_MAX_UARCH_TYPES];
+};
+
+static inline struct xnn_hmp_igemm_ukernel xnn_init_hmp_igemm_ukernel(xnn_igemm_ukernel_function function) {
+ struct xnn_hmp_igemm_ukernel ukernel = { function };
+ for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
+ ukernel.function[i] = function;
+ }
+ return ukernel;
+}
+
+static inline bool xnn_is_hmp_igemm_ukernel(struct xnn_hmp_igemm_ukernel ukernel) {
+#if XNN_MAX_UARCH_TYPES == 1
+ return false;
+#else
+ uintptr_t default_function = (uintptr_t) ukernel.function[XNN_UARCH_DEFAULT];
+ uintptr_t difference = 0;
+ for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
+ difference |= (default_function ^ (uintptr_t) ukernel.function[i]);
+ }
+ return difference != 0;
+#endif
+}
struct gemm_parameters {
- xnn_gemm_ukernel_function gemm;
- xnn_igemm_ukernel_function igemm;
+ struct xnn_hmp_gemm_ukernel gemm;
+ struct xnn_hmp_igemm_ukernel igemm;
// Optional GEMM and IGEMM micro-kernels with MR=1 and the same NR and KR parameters.
- xnn_gemm_ukernel_function gemm1;
- xnn_igemm_ukernel_function igemm1;
+ struct xnn_hmp_gemm_ukernel gemm1;
+ struct xnn_hmp_igemm_ukernel igemm1;
uint8_t mr;
uint8_t nr;
uint8_t log2_kr;