Generalize create_binary_elementwise_nd across data types
PiperOrigin-RevId: 316847561
diff --git a/src/operators/binary-elementwise-nd.c b/src/operators/binary-elementwise-nd.c
index 3f9ee7e..7b3b2df 100644
--- a/src/operators/binary-elementwise-nd.c
+++ b/src/operators/binary-elementwise-nd.c
@@ -17,10 +17,12 @@
#include <xnnpack/params.h>
-static enum xnn_status create_binary_elementwise_nd_f32(
+static enum xnn_status create_binary_elementwise_nd(
float output_min,
float output_max,
uint32_t flags,
+ const void* params,
+ size_t params_size,
enum xnn_operator_type operator_type,
xnn_operator_t* binary_elementwise_op_out)
{
@@ -66,7 +68,9 @@
goto error;
}
- binary_elementwise_op->params.f32_minmax = xnn_init_f32_minmax_params(output_min, output_max);
+ if (params_size != 0) {
+ memcpy(&binary_elementwise_op->params, params, params_size);
+ }
binary_elementwise_op->type = operator_type;
binary_elementwise_op->ukernel.type = xnn_ukernel_type_binary_elementwise;
@@ -87,8 +91,11 @@
uint32_t flags,
xnn_operator_t* add_op_out)
{
- return create_binary_elementwise_nd_f32(
- output_min, output_max, flags, xnn_operator_type_add_nd_f32, add_op_out);
+ const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
+
+ return create_binary_elementwise_nd(
+ output_min, output_max, flags, ¶ms, sizeof(params),
+ xnn_operator_type_add_nd_f32, add_op_out);
}
enum xnn_status xnn_create_divide_nd_f32(
@@ -97,26 +104,35 @@
uint32_t flags,
xnn_operator_t* divide_op_out)
{
- return create_binary_elementwise_nd_f32(
- output_min, output_max, flags, xnn_operator_type_divide_nd_f32, divide_op_out);
+ const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
+
+ return create_binary_elementwise_nd(
+ output_min, output_max, flags, ¶ms, sizeof(params),
+ xnn_operator_type_divide_nd_f32, divide_op_out);
}
enum xnn_status xnn_create_maximum_nd_f32(
uint32_t flags,
xnn_operator_t* maximum_op_out)
{
- return create_binary_elementwise_nd_f32(
+ const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(-INFINITY /* output_min */, INFINITY /* output_max */);
+
+ return create_binary_elementwise_nd(
-INFINITY /* output_min */, INFINITY /* output_max */,
- flags, xnn_operator_type_maximum_nd_f32, maximum_op_out);
+ flags, ¶ms, sizeof(params),
+ xnn_operator_type_maximum_nd_f32, maximum_op_out);
}
enum xnn_status xnn_create_minimum_nd_f32(
uint32_t flags,
xnn_operator_t* minimum_op_out)
{
- return create_binary_elementwise_nd_f32(
+ const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(-INFINITY /* output_min */, INFINITY /* output_max */);
+
+ return create_binary_elementwise_nd(
-INFINITY /* output_min */, INFINITY /* output_max */,
- flags, xnn_operator_type_minimum_nd_f32, minimum_op_out);
+ flags, ¶ms, sizeof(params),
+ xnn_operator_type_minimum_nd_f32, minimum_op_out);
}
enum xnn_status xnn_create_multiply_nd_f32(
@@ -125,17 +141,23 @@
uint32_t flags,
xnn_operator_t* multiply_op_out)
{
- return create_binary_elementwise_nd_f32(
- output_min, output_max, flags, xnn_operator_type_multiply_nd_f32, multiply_op_out);
+ const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
+
+ return create_binary_elementwise_nd(
+ output_min, output_max, flags, ¶ms, sizeof(params),
+ xnn_operator_type_multiply_nd_f32, multiply_op_out);
}
enum xnn_status xnn_create_squared_difference_nd_f32(
uint32_t flags,
xnn_operator_t* squared_difference_op_out)
{
- return create_binary_elementwise_nd_f32(
+ const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(-INFINITY /* output_min */, INFINITY /* output_max */);
+
+ return create_binary_elementwise_nd(
-INFINITY /* output_min */, INFINITY /* output_max */,
- flags, xnn_operator_type_squared_difference_nd_f32, squared_difference_op_out);
+ flags, ¶ms, sizeof(params),
+ xnn_operator_type_squared_difference_nd_f32, squared_difference_op_out);
}
enum xnn_status xnn_create_subtract_nd_f32(
@@ -144,8 +166,11 @@
uint32_t flags,
xnn_operator_t* subtract_op_out)
{
- return create_binary_elementwise_nd_f32(
- output_min, output_max, flags, xnn_operator_type_subtract_nd_f32, subtract_op_out);
+ const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
+
+ return create_binary_elementwise_nd(
+ output_min, output_max, flags, ¶ms, sizeof(params),
+ xnn_operator_type_subtract_nd_f32, subtract_op_out);
}
static enum xnn_status setup_binary_elementwise_nd_f32(