ND Maximum and Minimum operators (with broadcasting support)

PiperOrigin-RevId: 284056046
diff --git a/src/binary-elementwise-nd.c b/src/binary-elementwise-nd.c
index 1425c46..959bce7 100644
--- a/src/binary-elementwise-nd.c
+++ b/src/binary-elementwise-nd.c
@@ -86,6 +86,24 @@
     output_min, output_max, flags, xnn_operator_type_add_nd_f32, add_op_out);
 }
 
+enum xnn_status xnn_create_maximum_nd_f32(
+    uint32_t flags,
+    xnn_operator_t* maximum_op_out)
+{
+  return create_binary_elementwise_nd_f32(
+    -INFINITY /* output_min */, INFINITY /* output_max */,
+    flags, xnn_operator_type_maximum_nd_f32, maximum_op_out);
+}
+
+enum xnn_status xnn_create_minimum_nd_f32(
+    uint32_t flags,
+    xnn_operator_t* minimum_op_out)
+{
+  return create_binary_elementwise_nd_f32(
+    -INFINITY /* output_min */, INFINITY /* output_max */,
+    flags, xnn_operator_type_minimum_nd_f32, minimum_op_out);
+}
+
 enum xnn_status xnn_create_multiply_nd_f32(
     float output_min,
     float output_max,
@@ -295,6 +313,46 @@
     pthreadpool_get_threads_count(threadpool));
 }
 
+enum xnn_status xnn_setup_maximum_nd_f32(
+    xnn_operator_t maximum_op,
+    size_t num_input1_dims,
+    const size_t* input1_shape,
+    size_t num_input2_dims,
+    const size_t* input2_shape,
+    const float* input1,
+    const float* input2,
+    float* output,
+    pthreadpool_t threadpool)
+{
+  return setup_binary_elementwise_nd_f32(
+    maximum_op, xnn_operator_type_maximum_nd_f32,
+    num_input1_dims, input1_shape,
+    num_input2_dims, input2_shape,
+    input1, input2, output,
+    &xnn_params.f32.vmax,
+    pthreadpool_get_threads_count(threadpool));
+}
+
+enum xnn_status xnn_setup_minimum_nd_f32(
+    xnn_operator_t minimum_op,
+    size_t num_input1_dims,
+    const size_t* input1_shape,
+    size_t num_input2_dims,
+    const size_t* input2_shape,
+    const float* input1,
+    const float* input2,
+    float* output,
+    pthreadpool_t threadpool)
+{
+  return setup_binary_elementwise_nd_f32(
+    minimum_op, xnn_operator_type_minimum_nd_f32,
+    num_input1_dims, input1_shape,
+    num_input2_dims, input2_shape,
+    input1, input2, output,
+    &xnn_params.f32.vmin,
+    pthreadpool_get_threads_count(threadpool));
+}
+
 enum xnn_status xnn_setup_multiply_nd_f32(
     xnn_operator_t multiply_op,
     size_t num_input1_dims,
diff --git a/src/init.c b/src/init.c
index e189b8b..36bbbed 100644
--- a/src/init.c
+++ b/src/init.c
@@ -224,6 +224,18 @@
       .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
       .element_tile = 8,
     };
+    xnn_params.f32.vmax = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmax_ukernel__neon_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__neon_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__neon_x8,
+      .element_tile = 8,
+    };
+    xnn_params.f32.vmin = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmin_ukernel__neon_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__neon_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__neon_x8,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__neon_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__neon_x8,
@@ -510,6 +522,18 @@
       .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__neon_x8,
       .element_tile = 8,
     };
+    xnn_params.f32.vmax = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmax_ukernel__neon_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__neon_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__neon_x8,
+      .element_tile = 8,
+    };
+    xnn_params.f32.vmin = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmin_ukernel__neon_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__neon_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__neon_x8,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__neon_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__neon_x8,
@@ -824,6 +848,18 @@
       .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__sse_x8,
       .element_tile = 8,
     };
+    xnn_params.f32.vmax = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmax_ukernel__sse_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__sse_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__sse_x8,
+      .element_tile = 8,
+    };
+    xnn_params.f32.vmin = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmin_ukernel__sse_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__sse_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__sse_x8,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__sse_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__sse_x8,
@@ -1036,6 +1072,18 @@
       .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__psimd_x8,
       .element_tile = 8,
     };
+    xnn_params.f32.vmax = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmax_ukernel__psimd_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__psimd_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__psimd_x8,
+      .element_tile = 8,
+    };
+    xnn_params.f32.vmin = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmin_ukernel__psimd_x8,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__psimd_x8,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__psimd_x8,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__psimd_x8,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__psimd_x8,
@@ -1223,6 +1271,18 @@
       .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vaddc_ukernel__wasm_x4,
       .element_tile = 8,
     };
+    xnn_params.f32.vmax = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmax_ukernel__wasm_x4,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__wasm_x4,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmaxc_ukernel__wasm_x4,
+      .element_tile = 8,
+    };
+    xnn_params.f32.vmin = (struct vbinary_parameters) {
+      .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmin_ukernel__wasm_x4,
+      .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__wasm_x4,
+      .ropc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vminc_ukernel__wasm_x4,
+      .element_tile = 8,
+    };
     xnn_params.f32.vmul = (struct vbinary_parameters) {
       .op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmul_ukernel__wasm_x4,
       .opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vmulc_ukernel__wasm_x4,
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index 26d7db1..fed1165 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -71,6 +71,8 @@
   xnn_operator_type_leaky_relu_nc_q8,
   xnn_operator_type_max_pooling_nhwc_f32,
   xnn_operator_type_max_pooling_nhwc_u8,
+  xnn_operator_type_maximum_nd_f32,
+  xnn_operator_type_minimum_nd_f32,
   xnn_operator_type_multiply_nd_f32,
   xnn_operator_type_prelu_nc_f32,
   xnn_operator_type_resize_bilinear_nhwc_f32,
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index d4404cf..a047d62 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -1363,6 +1363,8 @@
     xnn_univector_ukernel_function sigmoid;
     struct prelu_parameters prelu;
     struct vbinary_parameters vadd;
+    struct vbinary_parameters vmax;
+    struct vbinary_parameters vmin;
     struct vbinary_parameters vmul;
     struct vbinary_parameters vsub;
     struct vmulcaddc_parameters vmulcaddc;