Auto-switch to LINEAR GEMM/IGEMM/DWCONV micro-kernels
- Initialize pointers to LINEAR-activation micro-kernels on WebAssembly
- Automatically detect and use LINEAR-activation micro-kernels
- Update unit tests to test both MINMAX- and LINEAR-activation micro-kernels
PiperOrigin-RevId: 305792844
diff --git a/src/convolution-nhwc.c b/src/convolution-nhwc.c
index f8c73ce..daf5137 100644
--- a/src/convolution-nhwc.c
+++ b/src/convolution-nhwc.c
@@ -552,6 +552,7 @@
} else {
ukernel_type = xnn_ukernel_type_igemm;
}
+ const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
size_t zero_size = 0;
switch (ukernel_type) {
@@ -600,8 +601,12 @@
kernel, bias, convolution_op->packed_weights);
}
+ const union dwconv_fused_ukernels* ukernels = &dwconv_parameters->minmax;
+ if (linear_activation && dwconv_parameters->linear.unipass != NULL) {
+ ukernels = &dwconv_parameters->linear;
+ }
convolution_op->ukernel.dwconv = (struct xnn_ukernel_dwconv) {
- .unipass_function = dwconv_parameters->minmax.unipass,
+ .unipass_function = ukernels->unipass,
.primary_tile = dwconv_parameters->primary_tile,
.incremental_tile = dwconv_parameters->incremental_tile,
};
@@ -626,6 +631,10 @@
}
memset(convolution_op->packed_weights, 0, packed_group_weights_size * groups);
+ const struct gemm_fused_ukernels* ukernels = &xnn_params.f32.gemm.minmax;
+ if (linear_activation && xnn_params.f32.gemm.linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
+ ukernels = &xnn_params.f32.gemm.linear;
+ }
switch (ukernel_type) {
case xnn_ukernel_type_gemm:
xnn_pack_f32_gemm_goi_w(
@@ -636,8 +645,8 @@
.mr = xnn_params.f32.gemm.mr,
.nr = nr,
.kr = kr,
- .general_case = xnn_params.f32.gemm.minmax.gemm,
- .mr1_case = xnn_params.f32.gemm.minmax.gemm1,
+ .general_case = ukernels->gemm,
+ .mr1_case = ukernels->gemm1,
};
break;
case xnn_ukernel_type_igemm:
@@ -656,8 +665,8 @@
.mr = xnn_params.f32.gemm.mr,
.nr = nr,
.kr = kr,
- .general_case = xnn_params.f32.gemm.minmax.igemm,
- .mr1_case = xnn_params.f32.gemm.minmax.igemm1,
+ .general_case = ukernels->igemm,
+ .mr1_case = ukernels->igemm1,
};
break;
default: