Dynamically choose micro-kernel depending on active core
Bazel option xnn_enable_hmp=false disables these optimizations
PiperOrigin-RevId: 303209003
diff --git a/src/convolution-nhwc.c b/src/convolution-nhwc.c
index d40a1db..ad60b3b 100644
--- a/src/convolution-nhwc.c
+++ b/src/convolution-nhwc.c
@@ -307,7 +307,7 @@
.mr = xnn_params.q8.gemm.mr,
.nr = nr,
.kr = kr,
- .default_function = xnn_params.q8.gemm.gemm,
+ .general_case = xnn_params.q8.gemm.gemm,
};
break;
case xnn_ukernel_type_igemm:
@@ -328,7 +328,7 @@
.mr = xnn_params.q8.gemm.mr,
.nr = nr,
.kr = kr,
- .default_function = xnn_params.q8.gemm.igemm,
+ .general_case = xnn_params.q8.gemm.igemm,
};
break;
default:
@@ -636,8 +636,8 @@
.mr = xnn_params.f32.gemm.mr,
.nr = nr,
.kr = kr,
- .default_function = xnn_params.f32.gemm.gemm,
- .mr1_function = xnn_params.f32.gemm.gemm1,
+ .general_case = xnn_params.f32.gemm.gemm,
+ .mr1_case = xnn_params.f32.gemm.gemm1,
};
break;
case xnn_ukernel_type_igemm:
@@ -656,8 +656,8 @@
.mr = xnn_params.f32.gemm.mr,
.nr = nr,
.kr = kr,
- .default_function = xnn_params.f32.gemm.igemm,
- .mr1_function = xnn_params.f32.gemm.igemm1,
+ .general_case = xnn_params.f32.gemm.igemm,
+ .mr1_case = xnn_params.f32.gemm.igemm1,
};
break;
default:
@@ -801,10 +801,10 @@
uint32_t mr = convolution_op->ukernel.gemm.mr;
const uint32_t nr = convolution_op->ukernel.gemm.nr;
- xnn_gemm_ukernel_function gemm_ukernel = convolution_op->ukernel.gemm.default_function;
- if (batch_output_size == 1 && convolution_op->ukernel.gemm.mr1_function != NULL) {
+ struct xnn_hmp_gemm_ukernel gemm_ukernel = convolution_op->ukernel.gemm.general_case;
+ if (batch_output_size == 1 && convolution_op->ukernel.gemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
mr = 1;
- gemm_ukernel = convolution_op->ukernel.gemm.mr1_function;
+ gemm_ukernel = convolution_op->ukernel.gemm.mr1_case;
}
convolution_op->context.gemm = (struct gemm_context) {
@@ -833,15 +833,35 @@
}
}
if (groups == 1) {
- convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
- convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
+ #if XNN_MAX_UARCH_TYPES > 1
+ if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
+ convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d_with_uarch;
+ convolution_op->compute.task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
+ convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
+ }
+ #else
+ convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
+ convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
+ #endif
convolution_op->compute.range[0] = batch_output_size;
convolution_op->compute.range[1] = group_output_channels;
convolution_op->compute.tile[0] = mr;
convolution_op->compute.tile[1] = nc;
} else {
- convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
- convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
+ #if XNN_MAX_UARCH_TYPES > 1
+ if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
+ convolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_grouped_gemm;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
+ }
+ #else
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
+ #endif
convolution_op->compute.range[0] = groups;
convolution_op->compute.range[1] = batch_output_size;
convolution_op->compute.range[2] = group_output_channels;
@@ -864,10 +884,10 @@
uint32_t mr = convolution_op->ukernel.igemm.mr;
const uint32_t nr = convolution_op->ukernel.igemm.nr;
- xnn_igemm_ukernel_function igemm_ukernel = convolution_op->ukernel.igemm.default_function;
- if (output_size == 1 && convolution_op->ukernel.igemm.mr1_function != NULL) {
+ struct xnn_hmp_igemm_ukernel igemm_ukernel = convolution_op->ukernel.igemm.general_case;
+ if (output_size == 1 && convolution_op->ukernel.igemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
mr = 1;
- igemm_ukernel = convolution_op->ukernel.igemm.mr1_function;
+ igemm_ukernel = convolution_op->ukernel.igemm.mr1_case;
}
const size_t tiled_output_size = round_up(output_size, mr);
@@ -924,16 +944,36 @@
}
}
if (groups == 1) {
- convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
- convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_igemm;
+ #if XNN_MAX_UARCH_TYPES > 1
+ if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
+ convolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_igemm;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_igemm;
+ }
+ #else
+ convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+ convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_igemm;
+ #endif
convolution_op->compute.range[0] = batch_size;
convolution_op->compute.range[1] = output_size;
convolution_op->compute.range[2] = group_output_channels;
convolution_op->compute.tile[0] = mr;
convolution_op->compute.tile[1] = nc;
} else {
- convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
- convolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_igemm;
+ #if XNN_MAX_UARCH_TYPES > 1
+ if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
+ convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d_with_uarch;
+ convolution_op->compute.task_4d_tile_2d_with_id = (pthreadpool_task_4d_tile_2d_with_id_t) xnn_compute_hmp_grouped_igemm;
+ } else {
+ convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
+ convolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_igemm;
+ }
+ #else
+ convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
+ convolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_igemm;
+ #endif
convolution_op->compute.range[0] = batch_size;
convolution_op->compute.range[1] = groups;
convolution_op->compute.range[2] = output_size;