Add ND operator with broadcasting
- Generalize Multiply implementation to arbitrary binary elementwise operators.
- The legacy Add NC operator will be maintained until Add ND gets support for
strides.
PiperOrigin-RevId: 283466005
diff --git a/src/init.c b/src/init.c
index 60190bc..a82716a 100644
--- a/src/init.c
+++ b/src/init.c
@@ -218,7 +218,12 @@
.row_tile = 2,
.channel_tile = 8,
};
- xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__neon_x8;
+ xnn_params.f32.vadd = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__neon_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
+ .element_tile = 8,
+ };
xnn_params.f32.vmul = (struct vbinary_parameters) {
.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__neon_x8,
.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__neon_x8,
@@ -493,7 +498,12 @@
.row_tile = 2,
.channel_tile = 8,
};
- xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__neon_x8;
+ xnn_params.f32.vadd = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__neon_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
+ .element_tile = 8,
+ };
xnn_params.f32.vmul = (struct vbinary_parameters) {
.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__neon_x8,
.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__neon_x8,
@@ -796,7 +806,12 @@
.row_tile = 2,
.channel_tile = 8,
};
- xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__sse_x8;
+ xnn_params.f32.vadd = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__sse_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__sse_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__sse_x8,
+ .element_tile = 8,
+ };
xnn_params.f32.vmul = (struct vbinary_parameters) {
.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__sse_x8,
.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__sse_x8,
@@ -997,7 +1012,12 @@
.row_tile = 2,
.channel_tile = 8,
};
- xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__psimd_x8;
+ xnn_params.f32.vadd = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__psimd_x8,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__psimd_x8,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__psimd_x8,
+ .element_tile = 8,
+ };
xnn_params.f32.vmul = (struct vbinary_parameters) {
.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__psimd_x8,
.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__psimd_x8,
@@ -1173,7 +1193,12 @@
.row_tile = 4,
.channel_tile = 4,
};
- xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__scalar_x4;
+ xnn_params.f32.vadd = (struct vbinary_parameters) {
+ .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__scalar_x4,
+ .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__scalar_x4,
+ .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__scalar_x4,
+ .element_tile = 8,
+ };
xnn_params.f32.vmul = (struct vbinary_parameters) {
.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__scalar_x4,
.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__scalar_x4,