arm_compute v19.08
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index fb7a6d6..fe128cc 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -42,7 +42,25 @@
OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, int stride)
{
using type = typename std::remove_cv<OT>::type;
- auto res = (op == ReductionOperation::PROD) ? type(1) : type(0);
+ T res;
+ switch(op)
+ {
+ case ReductionOperation::PROD:
+ {
+ res = type(1);
+ }
+ break;
+ case ReductionOperation::MIN:
+ case ReductionOperation::MAX:
+ {
+ res = *ptr;
+ }
+ break;
+ default:
+ {
+ res = type(0);
+ }
+ }
if(std::is_integral<type>::value)
{
@@ -53,16 +71,16 @@
switch(op)
{
- case ReductionOperation::ARG_IDX_MIN:
- if(*(ptr + stride * static_cast<uint32_t>(int_res)) > elem)
+ case ReductionOperation::MIN:
+ if(static_cast<T>(int_res) > elem)
{
- int_res = static_cast<uint32_t>(i);
+ int_res = elem;
}
break;
- case ReductionOperation::ARG_IDX_MAX:
- if(*(ptr + stride * static_cast<uint32_t>(int_res)) < elem)
+ case ReductionOperation::MAX:
+ if(static_cast<T>(int_res) < elem)
{
- int_res = static_cast<uint32_t>(i);
+ int_res = elem;
}
break;
case ReductionOperation::SUM_SQUARE:
@@ -92,16 +110,16 @@
auto elem = *(ptr + stride * i);
switch(op)
{
- case ReductionOperation::ARG_IDX_MIN:
- if(*(ptr + stride * static_cast<uint32_t>(res)) > elem)
+ case ReductionOperation::MIN:
+ if(res > elem)
{
- res = static_cast<uint32_t>(i);
+ res = elem;
}
break;
- case ReductionOperation::ARG_IDX_MAX:
- if(*(ptr + stride * static_cast<uint32_t>(res)) < elem)
+ case ReductionOperation::MAX:
+ if(res < elem)
{
- res = static_cast<uint32_t>(i);
+ res = elem;
}
break;
case ReductionOperation::SUM_SQUARE:
@@ -125,6 +143,35 @@
}
return res;
}
+
+template <typename T, typename OT>
+OT reduce_operation_arg_min_max(const T *ptr, int reduce_elements, ReductionOperation op, int stride)
+{
+ uint32_t res = 0;
+ for(int i = 0; i < reduce_elements; ++i)
+ {
+ auto elem = *(ptr + stride * i);
+ switch(op)
+ {
+ case ReductionOperation::ARG_IDX_MIN:
+ if(*(ptr + stride * res) > elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::ARG_IDX_MAX:
+ if(*(ptr + stride * res) < elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Operation not supported");
+ }
+ }
+ return static_cast<OT>(res);
+}
+
} // namespace
template <typename T, typename OT>
@@ -148,7 +195,9 @@
for(unsigned int du = 0; du < upper_dims; ++du)
{
const T *src_row_ptr = src.data() + du * reduce_elems;
- dst[du] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1);
+ dst[du] = is_arg_min_max ?
+ reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, 1) :
+ reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1);
}
}
break;
@@ -162,7 +211,9 @@
const int in_offset = du * src_height * src_width + x;
const int out_offset = du * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width);
+ dst[out_offset] = is_arg_min_max ?
+ reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width) :
+ reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width);
}
}
}
@@ -179,7 +230,9 @@
const int in_offset = du * src_depth * src_height * src_width + y * src_width + x;
const int out_offset = du * src_width * src_height + y * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_height * src_width);
+ dst[out_offset] = is_arg_min_max ?
+ reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height) :
+ reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height);
}
}
}
@@ -199,7 +252,9 @@
const int in_offset = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
const int out_offset = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
+ dst[out_offset] = is_arg_min_max ?
+ reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth) :
+ reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
}
}
}
@@ -238,6 +293,7 @@
template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<int32_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);