AVX & AVX512F versions of binary elementwise micro-kernels
PiperOrigin-RevId: 284867789
diff --git a/src/init.c b/src/init.c
index 044472d..e8e8d69 100644
--- a/src/init.c
+++ b/src/init.c
@@ -877,42 +877,118 @@
.row_tile = 2,
.channel_tile = 8,
};
- xnn_params.f32.vadd = (struct vbinary_parameters) {
- .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__sse_x8,
- .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__sse_x8,
- .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__sse_x8,
- .element_tile = 8,
- };
- xnn_params.f32.vdiv = (struct vbinary_parameters) {
- .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdiv_ukernel__sse_x8,
- .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdivc_ukernel__sse_x8,
- .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vrdivc_ukernel__sse_x8,
- .element_tile = 8,
- };
- xnn_params.f32.vmax = (struct vbinary_parameters) {
- .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmax_ukernel__sse_x8,
- .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__sse_x8,
- .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__sse_x8,
- .element_tile = 8,
- };
- xnn_params.f32.vmin = (struct vbinary_parameters) {
- .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmin_ukernel__sse_x8,
- .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__sse_x8,
- .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__sse_x8,
- .element_tile = 8,
- };
- xnn_params.f32.vmul = (struct vbinary_parameters) {
- .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__sse_x8,
- .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__sse_x8,
- .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__sse_x8,
- .element_tile = 8,
- };
- xnn_params.f32.vsub = (struct vbinary_parameters) {
- .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vsub_ukernel__sse_x8,
- .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vsubc_ukernel__sse_x8,
- .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vrsubc_ukernel__sse_x8,
- .element_tile = 8,
- };
+ if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx512f()) {
+ xnn_params.f32.vadd = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__avx512f_x32,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__avx512f_x32,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__avx512f_x32,
+ .element_tile = 32,
+ };
+ xnn_params.f32.vdiv = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdiv_ukernel__avx512f_x32,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdivc_ukernel__avx512f_x32,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vrdivc_ukernel__avx512f_x32,
+ .element_tile = 32,
+ };
+ xnn_params.f32.vmax = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmax_ukernel__avx512f_x32,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__avx512f_x32,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__avx512f_x32,
+ .element_tile = 32,
+ };
+ xnn_params.f32.vmin = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmin_ukernel__avx512f_x32,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__avx512f_x32,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__avx512f_x32,
+ .element_tile = 32,
+ };
+ xnn_params.f32.vmul = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__avx512f_x32,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__avx512f_x32,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__avx512f_x32,
+ .element_tile = 32,
+ };
+ xnn_params.f32.vsub = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vsub_ukernel__avx512f_x32,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vsubc_ukernel__avx512f_x32,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vrsubc_ukernel__avx512f_x32,
+ .element_tile = 32,
+ };
+ } else if (!XNN_PLATFORM_MOBILE && cpuinfo_has_x86_avx()) {
+ xnn_params.f32.vadd = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__avx_x16,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__avx_x16,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__avx_x16,
+ .element_tile = 16,
+ };
+ xnn_params.f32.vdiv = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdiv_ukernel__avx_x16,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdivc_ukernel__avx_x16,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vrdivc_ukernel__avx_x16,
+ .element_tile = 16,
+ };
+ xnn_params.f32.vmax = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmax_ukernel__avx_x16,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__avx_x16,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__avx_x16,
+ .element_tile = 16,
+ };
+ xnn_params.f32.vmin = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmin_ukernel__avx_x16,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__avx_x16,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__avx_x16,
+ .element_tile = 16,
+ };
+ xnn_params.f32.vmul = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__avx_x16,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__avx_x16,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__avx_x16,
+ .element_tile = 16,
+ };
+ xnn_params.f32.vsub = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vsub_ukernel__avx_x16,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vsubc_ukernel__avx_x16,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vrsubc_ukernel__avx_x16,
+ .element_tile = 16,
+ };
+ } else {
+ xnn_params.f32.vadd = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__sse_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__sse_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__sse_x8,
+ .element_tile = 8,
+ };
+ xnn_params.f32.vdiv = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdiv_ukernel__sse_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdivc_ukernel__sse_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vrdivc_ukernel__sse_x8,
+ .element_tile = 8,
+ };
+ xnn_params.f32.vmax = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmax_ukernel__sse_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__sse_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__sse_x8,
+ .element_tile = 8,
+ };
+ xnn_params.f32.vmin = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmin_ukernel__sse_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__sse_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__sse_x8,
+ .element_tile = 8,
+ };
+ xnn_params.f32.vmul = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__sse_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__sse_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__sse_x8,
+ .element_tile = 8,
+ };
+ xnn_params.f32.vsub = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vsub_ukernel__sse_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vsubc_ukernel__sse_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vrsubc_ukernel__sse_x8,
+ .element_tile = 8,
+ };
+ }
xnn_params.f32.vmulcaddc = (struct vmulcaddc_parameters) {
.ukernel = (xnn_vmulcaddc_ukernel_function) xnn_f32_vmulcaddc_ukernel_c4__sse_2x,
.channel_tile = 4,