ND Maximum and Minimum operators (with broadcasting support)
PiperOrigin-RevId: 284056046
diff --git a/test/binary-elementwise-operator-tester.h b/test/binary-elementwise-operator-tester.h
index 33c24b0..9a45854 100644
--- a/test/binary-elementwise-operator-tester.h
+++ b/test/binary-elementwise-operator-tester.h
@@ -26,6 +26,8 @@
enum class OperationType {
Unknown,
Add,
+ Maximum,
+ Minimum,
Multiply,
Subtract,
};
@@ -116,6 +118,10 @@
switch (operation_type()) {
case OperationType::Add:
return a + b;
+ case OperationType::Maximum:
+ return std::max<float>(a, b);
+ case OperationType::Minimum:
+ return std::min<float>(a, b);
case OperationType::Multiply:
return a * b;
case OperationType::Subtract:
@@ -211,6 +217,16 @@
output_min, output_max,
0, &binary_elementwise_op));
break;
+ case OperationType::Maximum:
+ ASSERT_EQ(xnn_status_success,
+ xnn_create_maximum_nd_f32(
+ 0, &binary_elementwise_op));
+ break;
+ case OperationType::Minimum:
+ ASSERT_EQ(xnn_status_success,
+ xnn_create_minimum_nd_f32(
+ 0, &binary_elementwise_op));
+ break;
case OperationType::Multiply:
ASSERT_EQ(xnn_status_success,
xnn_create_multiply_nd_f32(
@@ -243,6 +259,28 @@
input1.data(), input2.data(), output.data(),
nullptr /* thread pool */));
break;
+ case OperationType::Maximum:
+ ASSERT_EQ(xnn_status_success,
+ xnn_setup_maximum_nd_f32(
+ binary_elementwise_op,
+ num_input1_dims(),
+ input1_shape().data(),
+ num_input2_dims(),
+ input2_shape().data(),
+ input1.data(), input2.data(), output.data(),
+ nullptr /* thread pool */));
+ break;
+ case OperationType::Minimum:
+ ASSERT_EQ(xnn_status_success,
+ xnn_setup_minimum_nd_f32(
+ binary_elementwise_op,
+ num_input1_dims(),
+ input1_shape().data(),
+ num_input2_dims(),
+ input2_shape().data(),
+ input1.data(), input2.data(), output.data(),
+ nullptr /* thread pool */));
+ break;
case OperationType::Multiply:
ASSERT_EQ(xnn_status_success,
xnn_setup_multiply_nd_f32(