arm_compute v18.11
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index 871a761..2f103a6 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -39,24 +39,39 @@
namespace
{
template <typename T>
-struct square
+T reduce_operation(T *ptr, int reduce_elements, ReductionOperation op, int stride)
{
- T operator()(const T &lhs, const T &rhs) const
- {
- return (lhs + rhs * rhs);
- }
-};
+ using type = typename std::remove_cv<T>::type;
+ auto res = type(0);
-template <typename T>
-T reduce_operation(T *ptr, int reduce_elements, ReductionOperation op)
-{
- switch(op)
+ if(std::is_integral<type>::value)
{
- case ReductionOperation::SUM_SQUARE:
- return std::accumulate(ptr, ptr + reduce_elements, static_cast<T>(0), square<T>());
- default:
- ARM_COMPUTE_ERROR("Unsupported reduction operation");
+ uint32_t int_res = 0;
+ for(int i = 0; i < reduce_elements; ++i)
+ {
+ auto elem = static_cast<uint32_t>(*(ptr + stride * i));
+ int_res += (op == ReductionOperation::SUM_SQUARE) ? elem * elem : elem;
+ }
+ if(op == ReductionOperation::MEAN_SUM && reduce_elements > 0)
+ {
+ int_res /= reduce_elements;
+ }
+ res = saturate_cast<type>(int_res);
}
+ else
+ {
+ for(int i = 0; i < reduce_elements; ++i)
+ {
+ auto elem = *(ptr + stride * i);
+ res += (op == ReductionOperation::SUM_SQUARE) ? elem * elem : elem;
+ }
+ if(op == ReductionOperation::MEAN_SUM && reduce_elements > 0)
+ {
+ res /= reduce_elements;
+ }
+ }
+
+ return res;
}
} // namespace
@@ -64,23 +79,85 @@
SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
{
// Create reference
- SimpleTensor<T> dst{ dst_shape, src.data_type() };
+ SimpleTensor<T> dst{ dst_shape, src.data_type(), 1, src.quantization_info() };
+ const unsigned int src_width = src.shape().x();
+ const unsigned int src_height = src.shape().y();
+ const unsigned int src_depth = src.shape().z();
+ const unsigned int src_batch = src.shape()[3];
+ const int reduce_elems = src.shape()[axis];
- // Compute reference
- const int reduce_elems = src.shape()[axis];
- const int upper_dims = src.shape().total_size_upper(axis + 1);
-
- for(int du = 0; du < upper_dims; ++du)
+ switch(axis)
{
- if(axis == 0)
+ case 0:
{
- const T *src_row_ptr = src.data() + du * reduce_elems;
- dst[du] = reduce_operation(src_row_ptr, reduce_elems, op);
+ const unsigned int upper_dims = src.shape().total_size_upper(1);
+ for(unsigned int du = 0; du < upper_dims; ++du)
+ {
+ const T *src_row_ptr = src.data() + du * reduce_elems;
+ auto res = reduce_operation(src_row_ptr, reduce_elems, op, 1);
+ dst[du] = res;
+ }
}
- else
+ break;
+ case 1:
{
+ const unsigned int upper_dims = src.shape().total_size_upper(2);
+ for(unsigned int du = 0; du < upper_dims; ++du)
+ {
+ for(unsigned int x = 0; x < src_width; ++x)
+ {
+ const int in_offset = du * src_height * src_width + x;
+ const int out_offset = du * src_width + x;
+ const T *src_row_ptr = src.data() + in_offset;
+ auto res = reduce_operation(src_row_ptr, reduce_elems, op, src_width);
+ dst[out_offset] = res;
+ }
+ }
+ }
+ break;
+ case 2:
+ {
+ const unsigned int upper_dims = src.shape().total_size_upper(3);
+ for(unsigned int du = 0; du < upper_dims; ++du)
+ {
+ for(unsigned int x = 0; x < src_width; ++x)
+ {
+ for(unsigned int y = 0; y < src_height; ++y)
+ {
+ const int in_offset = du * src_depth * src_height * src_width + y * src_width + x;
+ const int out_offset = du * src_width * src_height + y * src_width + x;
+ const T *src_row_ptr = src.data() + in_offset;
+ auto res = reduce_operation(src_row_ptr, reduce_elems, op, src_height * src_width);
+ dst[out_offset] = res;
+ }
+ }
+ }
+ }
+ break;
+ case 3:
+ {
+ const unsigned int upper_dims = src.shape().total_size_upper(4);
+ for(unsigned int du = 0; du < upper_dims; ++du)
+ {
+ for(unsigned int z = 0; z < src_depth; ++z)
+ {
+ for(unsigned int y = 0; y < src_height; ++y)
+ {
+ for(unsigned int x = 0; x < src_width; ++x)
+ {
+ const int in_offset = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
+ const int out_offset = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
+ const T *src_row_ptr = src.data() + in_offset;
+ auto res = reduce_operation(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
+ dst[out_offset] = res;
+ }
+ }
+ }
+ }
+ }
+ break;
+ default:
ARM_COMPUTE_ERROR("Unsupported reduction axis");
- }
}
return dst;
@@ -88,6 +165,7 @@
template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+template SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
} // namespace reference
} // namespace validation
} // namespace test