Add ND operator with broadcasting

- Generalize Multiply implementation to arbitrary binary elementwise operators.
- The legacy Add NC operator will be maintained until Add ND gets support for
  strides.

PiperOrigin-RevId: 283466005
diff --git a/src/add-nc.c b/src/add-nc.c
index 8a86617..277369b 100644
--- a/src/add-nc.c
+++ b/src/add-nc.c
@@ -345,7 +345,7 @@
       .b = b,
       .y = sum,
       .params.f32 = add_op->f32_output_params,
-      .ukernel = xnn_params.f32.vadd,
+      .ukernel = xnn_params.f32.vadd.op_ukernel,
     };
     add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
     add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_contiguous;
@@ -361,7 +361,7 @@
       .y_stride = sum_stride * sizeof(float),
       .n = channels * sizeof(float),
       .params.f32 = add_op->f32_output_params,
-      .ukernel = xnn_params.f32.vadd,
+      .ukernel = xnn_params.f32.vadd.op_ukernel,
     };
     add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
     add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_strided;
diff --git a/src/binary-elementwise-nd.c b/src/binary-elementwise-nd.c
new file mode 100644
index 0000000..65fce5f
--- /dev/null
+++ b/src/binary-elementwise-nd.c
@@ -0,0 +1,304 @@
+// Copyright 2019 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <xnnpack.h>
+#include <xnnpack/allocator.h>
+#include <xnnpack/log.h>
+#include <xnnpack/operator.h>
+#include <xnnpack/params-init.h>
+#include <xnnpack/params.h>
+
+
+static enum xnn_status create_binary_elementwise_nd_f32(
+    float output_min,
+    float output_max,
+    uint32_t flags,
+    enum xnn_operator_type operator_type,
+    xnn_operator_t* binary_elementwise_op_out)
+{
+  xnn_operator_t binary_elementwise_op = NULL;
+  enum xnn_status status = xnn_status_uninitialized;
+
+  if (!xnn_params.initialized) {
+    xnn_log_error("failed to create Add/Multiply operator: XNNPACK is not initialized");
+    goto error;
+  }
+
+  status = xnn_status_invalid_parameter;
+
+  if (isnan(output_min)) {
+    xnn_log_error(
+      "failed to create Add/Multiply operator with NaN output lower bound: lower bound must be non-NaN");
+    goto error;
+  }
+
+  if (isnan(output_max)) {
+    xnn_log_error(
+      "failed to create Add/Multiply operator with NaN output upper bound: upper bound must be non-NaN");
+    goto error;
+  }
+
+  if (output_min >= output_max) {
+    xnn_log_error(
+      "failed to create Add/Multiply operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
+      output_min, output_max);
+    goto error;
+  }
+
+  status = xnn_status_out_of_memory;
+
+  binary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
+  if (binary_elementwise_op == NULL) {
+    xnn_log_error("failed to allocate %zu bytes for Add/Multiply operator descriptor", sizeof(struct xnn_operator));
+    goto error;
+  }
+
+  binary_elementwise_op->f32_output_params = xnn_init_f32_output_params(output_min, output_max);
+
+  binary_elementwise_op->type = operator_type;
+  binary_elementwise_op->ukernel.type = xnn_ukernel_type_binary_elementwise;
+
+  binary_elementwise_op->state = xnn_run_state_invalid;
+
+  *binary_elementwise_op_out = binary_elementwise_op;
+  return xnn_status_success;
+
+error:
+  xnn_delete_operator(binary_elementwise_op);
+  return status;
+}
+
+enum xnn_status xnn_create_add_nd_f32(
+    float output_min,
+    float output_max,
+    uint32_t flags,
+    xnn_operator_t* add_op_out)
+{
+  return create_binary_elementwise_nd_f32(
+    output_min, output_max, flags, xnn_operator_type_add_nd_f32, add_op_out);
+}
+
+enum xnn_status xnn_create_multiply_nd_f32(
+    float output_min,
+    float output_max,
+    uint32_t flags,
+    xnn_operator_t* multiply_op_out)
+{
+  return create_binary_elementwise_nd_f32(
+    output_min, output_max, flags, xnn_operator_type_multiply_nd_f32, multiply_op_out);
+}
+
+static enum xnn_status setup_binary_elementwise_nd_f32(
+    xnn_operator_t binary_elementwise_op,
+    enum xnn_operator_type expected_operator_type,
+    size_t num_input1_dims,
+    const size_t* input1_shape,
+    size_t num_input2_dims,
+    const size_t* input2_shape,
+    const float* input1,
+    const float* input2,
+    float* output,
+    const struct vbinary_parameters vbinary[restrict static 1],
+    size_t num_threads)
+{
+  if (binary_elementwise_op->type != expected_operator_type) {
+    xnn_log_error("failed to setup Add/Multiply (ND, F32) operator: operator type mismatch");
+    return xnn_status_invalid_parameter;
+  }
+  binary_elementwise_op->state = xnn_run_state_invalid;
+
+  if (!xnn_params.initialized) {
+    xnn_log_error("failed to setup Add/Multiply operator: XNNPACK is not initialized");
+    return xnn_status_uninitialized;
+  }
+
+  if (max(num_input1_dims, num_input2_dims) > 4) {
+    xnn_log_error(
+      "failed to setup Add/Multiply operator with %zu and %zu dimensions in input shapes: "
+      "the number of input dimensions must not exceed 4",
+      num_input1_dims, num_input2_dims);
+    return xnn_status_unsupported_parameter;
+  }
+
+  for (size_t i = 0; i < num_input1_dims; i++) {
+    if (input1_shape[i] == 0) {
+      xnn_log_error("failed to setup Add/Multiply operator: shape dimension #%zu of input #1 is zero", i);
+      return xnn_status_invalid_parameter;
+    }
+  }
+
+  for (size_t i = 0; i < num_input2_dims; i++) {
+    if (input2_shape[i] == 0) {
+      xnn_log_error("failed to setup Add/Multiply operator: shape dimension #%zu of input #2 is zero", i);
+      return xnn_status_invalid_parameter;
+    }
+  }
+
+  size_t num_compressed_dims = 0;
+  size_t compressed_input1_shape[XNN_MAX_TENSOR_DIMS];
+  size_t compressed_input2_shape[XNN_MAX_TENSOR_DIMS];
+  size_t compressed_output_shape[XNN_MAX_TENSOR_DIMS];
+  for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
+    compressed_input1_shape[i] = 1;
+    compressed_input2_shape[i] = 1;
+    compressed_output_shape[i] = 1;
+  }
+  bool broadcast_input1 = false;
+  bool broadcast_input2 = false;
+  bool first_nonunit = true;
+  const size_t num_common_dims = min(num_input1_dims, num_input2_dims);
+  for (size_t i = 1; i <= num_common_dims; i++) {
+    const size_t input1_dim = input1_shape[num_input1_dims - i];
+    const size_t input2_dim = input2_shape[num_input2_dims - i];
+    if (input1_dim == 1 && input2_dim == 1) {
+      continue;
+    }
+    assert(!broadcast_input1 || !broadcast_input2);
+
+    if (input1_dim == 1) {
+      if (!broadcast_input1) {
+        broadcast_input1 = true;
+        broadcast_input2 = false;
+        num_compressed_dims++;
+      }
+      compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
+      compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
+    } else if (input2_dim == 1) {
+      if (!broadcast_input2) {
+        broadcast_input1 = false;
+        broadcast_input2 = true;
+        num_compressed_dims++;
+      }
+      compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
+      compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
+    } else if (input1_dim == input2_dim) {
+      if (broadcast_input1 || broadcast_input2 || first_nonunit) {
+        broadcast_input1 = false;
+        broadcast_input2 = false;
+        num_compressed_dims++;
+      }
+      compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
+      compressed_input2_shape[num_compressed_dims - 1] *= input1_dim;
+      compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
+    } else {
+      xnn_log_error("failed to setup Add/Multiply operator: "
+        "shape dimension #%zu of input1 (%zu) does not match shape dimension #%zu of input2 (%zu)",
+        num_input1_dims - i, input1_dim, num_input2_dims - i, input2_dim);
+      return xnn_status_invalid_parameter;
+    }
+    first_nonunit = false;
+  }
+  if (num_input1_dims > num_input2_dims) {
+    if (!broadcast_input2) {
+      num_compressed_dims++;
+    }
+    for (size_t i = 0; i < num_input1_dims - num_input2_dims; i++) {
+      const size_t input1_dim = input1_shape[i];
+      compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
+      compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
+    }
+  } else if (num_input2_dims > num_input1_dims) {
+    if (!broadcast_input1) {
+      num_compressed_dims++;
+    }
+    for (size_t i = 0; i < num_input2_dims - num_input1_dims; i++) {
+      const size_t input2_dim = input2_shape[i];
+      compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
+      compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
+    }
+  }
+  num_compressed_dims = max(num_compressed_dims, 1);
+
+  binary_elementwise_op->context.elementwise_binary = (struct elementwise_binary_context) {
+    .a = input1,
+    .b = input2,
+    .y = output,
+    .elements = compressed_output_shape[0] * sizeof(float),
+    .params.f32 = binary_elementwise_op->f32_output_params,
+  };
+  const size_t* compressed_a_shape = compressed_input1_shape;
+  const size_t* compressed_b_shape = compressed_input2_shape;
+  if (compressed_input1_shape[0] == 1) {
+    binary_elementwise_op->context.elementwise_binary.ukernel = vbinary->ropc_ukernel;
+    binary_elementwise_op->context.elementwise_binary.a = input2;
+    binary_elementwise_op->context.elementwise_binary.b = input1;
+    compressed_a_shape = compressed_input2_shape;
+    compressed_b_shape = compressed_input1_shape;
+  } else if (compressed_input2_shape[0] == 1) {
+    binary_elementwise_op->context.elementwise_binary.ukernel = vbinary->opc_ukernel;
+  } else if (compressed_input1_shape[0] == compressed_input2_shape[0]) {
+    binary_elementwise_op->context.elementwise_binary.ukernel = vbinary->op_ukernel;
+  }
+  size_t a_stride = compressed_a_shape[0], b_stride = compressed_b_shape[0], y_stride = compressed_output_shape[0];
+  for (size_t i = 1; i < num_compressed_dims; i++) {
+    if (compressed_a_shape[i] != 1) {
+      binary_elementwise_op->context.elementwise_binary.a_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = a_stride * sizeof(float);
+    }
+    if (compressed_b_shape[i] != 1) {
+      binary_elementwise_op->context.elementwise_binary.b_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = b_stride * sizeof(float);
+    }
+    binary_elementwise_op->context.elementwise_binary.y_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = y_stride * sizeof(float);
+    a_stride *= compressed_a_shape[i];
+    b_stride *= compressed_b_shape[i];
+    y_stride *= compressed_output_shape[i];
+  }
+
+  binary_elementwise_op->compute.type = xnn_parallelization_type_3d_tile_2d;
+  binary_elementwise_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_elementwise_binary_3d;
+  binary_elementwise_op->compute.range[0] = compressed_output_shape[3];
+  binary_elementwise_op->compute.range[1] = compressed_output_shape[2];
+  binary_elementwise_op->compute.range[2] = compressed_output_shape[1];
+  binary_elementwise_op->compute.tile[0] = 1;
+  binary_elementwise_op->compute.tile[1] = 1;
+  binary_elementwise_op->state = xnn_run_state_ready;
+
+  return xnn_status_success;
+}
+
+enum xnn_status xnn_setup_add_nd_f32(
+    xnn_operator_t add_op,
+    size_t num_input1_dims,
+    const size_t* input1_shape,
+    size_t num_input2_dims,
+    const size_t* input2_shape,
+    const float* input1,
+    const float* input2,
+    float* output,
+    pthreadpool_t threadpool)
+{
+  return setup_binary_elementwise_nd_f32(
+    add_op, xnn_operator_type_add_nd_f32,
+    num_input1_dims, input1_shape,
+    num_input2_dims, input2_shape,
+    input1, input2, output,
+    &xnn_params.f32.vadd,
+    pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_multiply_nd_f32(
+    xnn_operator_t multiply_op,
+    size_t num_input1_dims,
+    const size_t* input1_shape,
+    size_t num_input2_dims,
+    const size_t* input2_shape,
+    const float* input1,
+    const float* input2,
+    float* output,
+    pthreadpool_t threadpool)
+{
+  return setup_binary_elementwise_nd_f32(
+    multiply_op, xnn_operator_type_multiply_nd_f32,
+    num_input1_dims, input1_shape,
+    num_input2_dims, input2_shape,
+    input1, input2, output,
+    &xnn_params.f32.vmul,
+    pthreadpool_get_threads_count(threadpool));
+}
diff --git a/src/init.c b/src/init.c
index 60190bc..a82716a 100644
--- a/src/init.c
+++ b/src/init.c
@@ -218,7 +218,12 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
-    xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__neon_x8;
+    xnn_params.f32.vadd = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__neon_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__neon_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__neon_x8,
@@ -493,7 +498,12 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
-    xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__neon_x8;
+    xnn_params.f32.vadd = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__neon_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__neon_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__neon_x8,
@@ -796,7 +806,12 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
-    xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__sse_x8;
+    xnn_params.f32.vadd = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__sse_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__sse_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__sse_x8,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__sse_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__sse_x8,
@@ -997,7 +1012,12 @@
       .row_tile = 2,
       .channel_tile = 8,
     };
-    xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__psimd_x8;
+    xnn_params.f32.vadd = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__psimd_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__psimd_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__psimd_x8,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__psimd_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__psimd_x8,
@@ -1173,7 +1193,12 @@
       .row_tile = 4,
       .channel_tile = 4,
     };
-    xnn_params.f32.vadd = (xnn_vadd_ukernel_function) xnn_f32_vadd_ukernel__scalar_x4;
+    xnn_params.f32.vadd = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vadd_ukernel__scalar_x4,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__scalar_x4,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__scalar_x4,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__scalar_x4,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__scalar_x4,
diff --git a/src/multiply-nd.c b/src/multiply-nd.c
deleted file mode 100644
index 27dc4cf..0000000
--- a/src/multiply-nd.c
+++ /dev/null
@@ -1,241 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// This source code is licensed under the BSD-style license found in the
-// LICENSE file in the root directory of this source tree.
-
-#include <assert.h>
-#include <math.h>
-#include <stddef.h>
-#include <stdint.h>
-#include <stdlib.h>
-
-#include <xnnpack.h>
-#include <xnnpack/allocator.h>
-#include <xnnpack/log.h>
-#include <xnnpack/operator.h>
-#include <xnnpack/params-init.h>
-#include <xnnpack/params.h>
-
-
-enum xnn_status xnn_create_multiply_nd_f32(
-    float output_min,
-    float output_max,
-    uint32_t flags,
-    xnn_operator_t* multiply_op_out)
-{
-  xnn_operator_t multiply_op = NULL;
-  enum xnn_status status = xnn_status_uninitialized;
-
-  if (!xnn_params.initialized) {
-    xnn_log_error("failed to create Multiply operator: XNNPACK is not initialized");
-    goto error;
-  }
-
-  status = xnn_status_invalid_parameter;
-
-  if (isnan(output_min)) {
-    xnn_log_error(
-      "failed to create Multiply operator with NaN output lower bound: lower bound must be non-NaN");
-    goto error;
-  }
-
-  if (isnan(output_max)) {
-    xnn_log_error(
-      "failed to create Multiply operator with NaN output upper bound: upper bound must be non-NaN");
-    goto error;
-  }
-
-  if (output_min >= output_max) {
-    xnn_log_error(
-      "failed to create Multiply operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
-      output_min, output_max);
-    goto error;
-  }
-
-  status = xnn_status_out_of_memory;
-
-  multiply_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
-  if (multiply_op == NULL) {
-    xnn_log_error("failed to allocate %zu bytes for Multiply operator descriptor", sizeof(struct xnn_operator));
-    goto error;
-  }
-
-  multiply_op->f32_output_params = xnn_init_f32_output_params(output_min, output_max);
-
-  multiply_op->type = xnn_operator_type_multiply_nd_f32;
-  multiply_op->ukernel.type = xnn_ukernel_type_multiply;
-
-  multiply_op->state = xnn_run_state_invalid;
-
-  *multiply_op_out = multiply_op;
-  return xnn_status_success;
-
-error:
-  xnn_delete_operator(multiply_op);
-  return status;
-}
-
-enum xnn_status xnn_setup_multiply_nd_f32(
-    xnn_operator_t multiply_op,
-    size_t num_input1_dims,
-    const size_t* input1_shape,
-    size_t num_input2_dims,
-    const size_t* input2_shape,
-    const float* input1,
-    const float* input2,
-    float* output,
-    pthreadpool_t threadpool)
-{
-  if (multiply_op->type != xnn_operator_type_multiply_nd_f32) {
-    xnn_log_error("failed to setup Multiply (ND, F32) operator: operator type mismatch");
-    return xnn_status_invalid_parameter;
-  }
-  multiply_op->state = xnn_run_state_invalid;
-
-  if (!xnn_params.initialized) {
-    xnn_log_error("failed to setup Multiply operator: XNNPACK is not initialized");
-    return xnn_status_uninitialized;
-  }
-
-  if (max(num_input1_dims, num_input2_dims) > 4) {
-    xnn_log_error(
-      "failed to setup Multiply operator with %zu and %zu dimensions in input shapes: "
-      "the number of input dimensions must not exceed 4",
-      num_input1_dims, num_input2_dims);
-    return xnn_status_unsupported_parameter;
-  }
-
-  for (size_t i = 0; i < num_input1_dims; i++) {
-    if (input1_shape[i] == 0) {
-      xnn_log_error("failed to setup Multiply operator: shape dimension #%zu of input #1 is zero", i);
-      return xnn_status_invalid_parameter;
-    }
-  }
-
-  for (size_t i = 0; i < num_input2_dims; i++) {
-    if (input2_shape[i] == 0) {
-      xnn_log_error("failed to setup Multiply operator: shape dimension #%zu of input #2 is zero", i);
-      return xnn_status_invalid_parameter;
-    }
-  }
-
-  size_t num_compressed_dims = 0;
-  size_t compressed_input1_shape[XNN_MAX_TENSOR_DIMS];
-  size_t compressed_input2_shape[XNN_MAX_TENSOR_DIMS];
-  size_t compressed_output_shape[XNN_MAX_TENSOR_DIMS];
-  for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
-    compressed_input1_shape[i] = 1;
-    compressed_input2_shape[i] = 1;
-    compressed_output_shape[i] = 1;
-  }
-  bool broadcast_input1 = false;
-  bool broadcast_input2 = false;
-  bool first_nonunit = true;
-  const size_t num_common_dims = min(num_input1_dims, num_input2_dims);
-  for (size_t i = 1; i <= num_common_dims; i++) {
-    const size_t input1_dim = input1_shape[num_input1_dims - i];
-    const size_t input2_dim = input2_shape[num_input2_dims - i];
-    if (input1_dim == 1 && input2_dim == 1) {
-      continue;
-    }
-    assert(!broadcast_input1 || !broadcast_input2);
-
-    if (input1_dim == 1) {
-      if (!broadcast_input1) {
-        broadcast_input1 = true;
-        broadcast_input2 = false;
-        num_compressed_dims++;
-      }
-      compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
-      compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
-    } else if (input2_dim == 1) {
-      if (!broadcast_input2) {
-        broadcast_input1 = false;
-        broadcast_input2 = true;
-        num_compressed_dims++;
-      }
-      compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
-      compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
-    } else if (input1_dim == input2_dim) {
-      if (broadcast_input1 || broadcast_input2 || first_nonunit) {
-        broadcast_input1 = false;
-        broadcast_input2 = false;
-        num_compressed_dims++;
-      }
-      compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
-      compressed_input2_shape[num_compressed_dims - 1] *= input1_dim;
-      compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
-    } else {
-      xnn_log_error("failed to setup Multiply operator: "
-        "shape dimension #%zu of input1 (%zu) does not match shape dimension #%zu of input2 (%zu)",
-        num_input1_dims - i, input1_dim, num_input2_dims - i, input2_dim);
-      return xnn_status_invalid_parameter;
-    }
-    first_nonunit = false;
-  }
-  if (num_input1_dims > num_input2_dims) {
-    if (!broadcast_input2) {
-      num_compressed_dims++;
-    }
-    for (size_t i = 0; i < num_input1_dims - num_input2_dims; i++) {
-      const size_t input1_dim = input1_shape[i];
-      compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
-      compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
-    }
-  } else if (num_input2_dims > num_input1_dims) {
-    if (!broadcast_input1) {
-      num_compressed_dims++;
-    }
-    for (size_t i = 0; i < num_input2_dims - num_input1_dims; i++) {
-      const size_t input2_dim = input2_shape[i];
-      compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
-      compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
-    }
-  }
-  num_compressed_dims = max(num_compressed_dims, 1);
-
-  multiply_op->context.elementwise_binary = (struct elementwise_binary_context) {
-    .a = input1,
-    .b = input2,
-    .y = output,
-    .elements = compressed_output_shape[0] * sizeof(float),
-    .params.f32 = multiply_op->f32_output_params,
-  };
-  const size_t* compressed_a_shape = compressed_input1_shape;
-  const size_t* compressed_b_shape = compressed_input2_shape;
-  if (compressed_input1_shape[0] == 1) {
-    multiply_op->context.elementwise_binary.ukernel = xnn_params.f32.vmul.ropc_ukernel;
-    multiply_op->context.elementwise_binary.a = input2;
-    multiply_op->context.elementwise_binary.b = input1;
-    compressed_a_shape = compressed_input2_shape;
-    compressed_b_shape = compressed_input1_shape;
-  } else if (compressed_input2_shape[0] == 1) {
-    multiply_op->context.elementwise_binary.ukernel = xnn_params.f32.vmul.opc_ukernel;
-  } else if (compressed_input1_shape[0] == compressed_input2_shape[0]) {
-    multiply_op->context.elementwise_binary.ukernel = xnn_params.f32.vmul.op_ukernel;
-  }
-  size_t a_stride = compressed_a_shape[0], b_stride = compressed_b_shape[0], y_stride = compressed_output_shape[0];
-  for (size_t i = 1; i < num_compressed_dims; i++) {
-    if (compressed_a_shape[i] != 1) {
-      multiply_op->context.elementwise_binary.a_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = a_stride * sizeof(float);
-    }
-    if (compressed_b_shape[i] != 1) {
-      multiply_op->context.elementwise_binary.b_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = b_stride * sizeof(float);
-    }
-    multiply_op->context.elementwise_binary.y_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = y_stride * sizeof(float);
-    a_stride *= compressed_a_shape[i];
-    b_stride *= compressed_b_shape[i];
-    y_stride *= compressed_output_shape[i];
-  }
-
-  multiply_op->compute.type = xnn_parallelization_type_3d_tile_2d;
-  multiply_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_elementwise_binary_3d;
-  multiply_op->compute.range[0] = compressed_output_shape[3];
-  multiply_op->compute.range[1] = compressed_output_shape[2];
-  multiply_op->compute.range[2] = compressed_output_shape[1];
-  multiply_op->compute.tile[0] = 1;
-  multiply_op->compute.tile[1] = 1;
-  multiply_op->state = xnn_run_state_ready;
-
-  return xnn_status_success;
-}
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index c65af56..b1b9901 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -22,6 +22,7 @@
   xnn_ukernel_type_add,
   xnn_ukernel_type_argmax_pooling,
   xnn_ukernel_type_average_pooling,
+  xnn_ukernel_type_binary_elementwise,
   xnn_ukernel_type_channel_shuffle,
   xnn_ukernel_type_clamp,
   xnn_ukernel_type_dconv2d_hwc2spchw,
@@ -32,7 +33,6 @@
   xnn_ukernel_type_igemm,
   xnn_ukernel_type_lut,
   xnn_ukernel_type_max_pooling,
-  xnn_ukernel_type_multiply,
   xnn_ukernel_type_pad,
   xnn_ukernel_type_pixelwise_average_pooling,
   xnn_ukernel_type_prelu,
@@ -47,6 +47,7 @@
 enum xnn_operator_type {
   xnn_operator_type_none = 0,
   xnn_operator_type_add_nc_f32,
+  xnn_operator_type_add_nd_f32,
   xnn_operator_type_add_nc_q8,
   xnn_operator_type_argmax_pooling_nhwc_f32,
   xnn_operator_type_average_pooling_nhwc_f32,
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index 87625eb..d7dad8e 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -1362,7 +1362,7 @@
     xnn_univector_ukernel_function hswish;
     xnn_univector_ukernel_function sigmoid;
     struct prelu_parameters prelu;
-    xnn_vadd_ukernel_function vadd;
+    struct vbinary_parameters vadd;
     struct vbinary_parameters vmul;
     struct vmulcaddc_parameters vmulcaddc;
     // Sparse Matrix-Dense Matrix Multiplication (NR=1 block).