arm_compute v19.02
Change-Id: I853a3ecf38f206da13c1b03640c8adf73c20477c
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index 2f103a6..fb7a6d6 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,19 +38,46 @@
{
namespace
{
-template <typename T>
-T reduce_operation(T *ptr, int reduce_elements, ReductionOperation op, int stride)
+template <typename T, typename OT>
+OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, int stride)
{
- using type = typename std::remove_cv<T>::type;
- auto res = type(0);
+ using type = typename std::remove_cv<OT>::type;
+ auto res = (op == ReductionOperation::PROD) ? type(1) : type(0);
if(std::is_integral<type>::value)
{
- uint32_t int_res = 0;
+ auto int_res = static_cast<uint32_t>(res);
for(int i = 0; i < reduce_elements; ++i)
{
- auto elem = static_cast<uint32_t>(*(ptr + stride * i));
- int_res += (op == ReductionOperation::SUM_SQUARE) ? elem * elem : elem;
+ auto elem = *(ptr + stride * i);
+
+ switch(op)
+ {
+ case ReductionOperation::ARG_IDX_MIN:
+ if(*(ptr + stride * static_cast<uint32_t>(int_res)) > elem)
+ {
+ int_res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::ARG_IDX_MAX:
+ if(*(ptr + stride * static_cast<uint32_t>(int_res)) < elem)
+ {
+ int_res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::SUM_SQUARE:
+ int_res += elem * elem;
+ break;
+ case ReductionOperation::MEAN_SUM:
+ case ReductionOperation::SUM:
+ int_res += elem;
+ break;
+ case ReductionOperation::PROD:
+ int_res *= elem;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Operation not supported");
+ }
}
if(op == ReductionOperation::MEAN_SUM && reduce_elements > 0)
{
@@ -63,23 +90,50 @@
for(int i = 0; i < reduce_elements; ++i)
{
auto elem = *(ptr + stride * i);
- res += (op == ReductionOperation::SUM_SQUARE) ? elem * elem : elem;
+ switch(op)
+ {
+ case ReductionOperation::ARG_IDX_MIN:
+ if(*(ptr + stride * static_cast<uint32_t>(res)) > elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::ARG_IDX_MAX:
+ if(*(ptr + stride * static_cast<uint32_t>(res)) < elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::SUM_SQUARE:
+ res += elem * elem;
+ break;
+ case ReductionOperation::MEAN_SUM:
+ case ReductionOperation::SUM:
+ res += elem;
+ break;
+ case ReductionOperation::PROD:
+ res *= elem;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Operation not supported");
+ }
}
if(op == ReductionOperation::MEAN_SUM && reduce_elements > 0)
{
res /= reduce_elements;
}
}
-
return res;
}
} // namespace
-template <typename T>
-SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
+template <typename T, typename OT>
+SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
{
// Create reference
- SimpleTensor<T> dst{ dst_shape, src.data_type(), 1, src.quantization_info() };
+ const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
+ DataType output_data_type = is_arg_min_max ? DataType::U32 : src.data_type();
+ SimpleTensor<OT> dst{ dst_shape, output_data_type, 1, src.quantization_info() };
const unsigned int src_width = src.shape().x();
const unsigned int src_height = src.shape().y();
const unsigned int src_depth = src.shape().z();
@@ -94,8 +148,7 @@
for(unsigned int du = 0; du < upper_dims; ++du)
{
const T *src_row_ptr = src.data() + du * reduce_elems;
- auto res = reduce_operation(src_row_ptr, reduce_elems, op, 1);
- dst[du] = res;
+ dst[du] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1);
}
}
break;
@@ -109,8 +162,7 @@
const int in_offset = du * src_height * src_width + x;
const int out_offset = du * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- auto res = reduce_operation(src_row_ptr, reduce_elems, op, src_width);
- dst[out_offset] = res;
+ dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width);
}
}
}
@@ -127,8 +179,7 @@
const int in_offset = du * src_depth * src_height * src_width + y * src_width + x;
const int out_offset = du * src_width * src_height + y * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- auto res = reduce_operation(src_row_ptr, reduce_elems, op, src_height * src_width);
- dst[out_offset] = res;
+ dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_height * src_width);
}
}
}
@@ -148,8 +199,7 @@
const int in_offset = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
const int out_offset = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- auto res = reduce_operation(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
- dst[out_offset] = res;
+ dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
}
}
}
@@ -163,9 +213,34 @@
return dst;
}
+template <typename T, typename OT>
+SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
+{
+ return compute_reduction_operation<T, OT>(src, dst_shape, axis, op);
+}
+
+template <>
+SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
+{
+ if(src.data_type() == DataType::QASYMM8 && op != ReductionOperation::MEAN_SUM)
+ {
+ SimpleTensor<float> src_f = convert_from_asymmetric(src);
+ SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op);
+ return convert_to_asymmetric(dst_f, src.quantization_info());
+ }
+ else
+ {
+ return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op);
+ }
+}
+
template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
-template SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+
+template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+
} // namespace reference
} // namespace validation
} // namespace test