Refactor HardSwish micro-kernels
- Code-generate HSWISH micro-kernels
- Support unrolling in HSWISH implementation
- Add HSWISH micro-kernels for AVX, FMA3, and AVX512F
- Code-generate HSWISH unit tests
- Switch all platforms to newer versions of the micro-kernels
PiperOrigin-RevId: 284705773
diff --git a/src/init.c b/src/init.c
index 8c4424c..044472d 100644
--- a/src/init.c
+++ b/src/init.c
@@ -212,7 +212,7 @@
.channel_tile = 8,
};
xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__neon;
- xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__neon;
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__neon_x8;
xnn_params.f32.prelu = (struct prelu_parameters) {
.ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel__neon_2x8,
.row_tile = 2,
@@ -515,7 +515,7 @@
.channel_tile = 8,
};
xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__neon;
- xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__neonfma;
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__neonfma_x8;
xnn_params.f32.sigmoid = (xnn_univector_ukernel_function) xnn_f32_sigmoid_ukernel__neon_frac_p9_p10_nr1recps_x16;
xnn_params.f32.prelu = (struct prelu_parameters) {
.ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel__neon_2x8,
@@ -862,7 +862,15 @@
} else {
xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__sse;
}
- xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__sse;
+ if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx512f()) {
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__avx512f_x32;
+ } else if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_fma3()) {
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__fma3_x16;
+ } else if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx()) {
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__avx_x16;
+ } else {
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__sse_x8;
+ }
xnn_params.f32.sigmoid = (xnn_univector_ukernel_function) xnn_f32_sigmoid_ukernel__sse2_p5_div_x16;
xnn_params.f32.prelu = (struct prelu_parameters) {
.ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel__sse2_2x8,
@@ -1093,7 +1101,7 @@
.channel_tile = 8,
};
xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__psimd;
- xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__psimd;
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__psimd_x8;
xnn_params.f32.prelu = (struct prelu_parameters) {
.ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel__psimd_2x8,
.row_tile = 2,
@@ -1298,7 +1306,7 @@
.channel_tile = 2,
};
xnn_params.f32.clamp = (xnn_univector_ukernel_function) xnn_f32_clamp_ukernel__wasm;
- xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__wasm;
+ xnn_params.f32.hswish = (xnn_univector_ukernel_function) xnn_f32_hswish_ukernel__wasm_x4;
xnn_params.f32.prelu = (struct prelu_parameters) {
.ukernel = (xnn_prelu_ukernel_function) xnn_f32_prelu_ukernel__wasm_2x4,
.row_tile = 4,