blob: 3aa5a813b6604c165ed7d09bcc4aa949f1bfb8c1 [file] [log] [blame]
Kaizen8938bd32017-09-28 14:38:23 +01001/*
Jenkins514be652019-02-28 12:25:18 +00002 * Copyright (c) 2017-2019 ARM Limited.
Kaizen8938bd32017-09-28 14:38:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLReductionOperation.h"
25
26#include "arm_compute/core/CL/ICLTensor.h"
27#include "arm_compute/core/CL/kernels/CLReductionOperationKernel.h"
28#include "arm_compute/core/Error.h"
Jenkins0e205f72019-11-28 16:53:35 +000029#include "arm_compute/core/Helpers.h"
Kaizen8938bd32017-09-28 14:38:23 +010030#include "arm_compute/core/PixelValue.h"
31#include "arm_compute/core/TensorInfo.h"
32#include "arm_compute/core/Validate.h"
Jenkins0e205f72019-11-28 16:53:35 +000033#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Kaizen8938bd32017-09-28 14:38:23 +010034#include "arm_compute/runtime/CL/CLScheduler.h"
35#include "arm_compute/runtime/Tensor.h"
36#include "support/ToolchainSupport.h"
37
Jenkins0e205f72019-11-28 16:53:35 +000038namespace arm_compute
39{
Jenkinsb3a371b2018-05-23 11:36:53 +010040namespace
41{
Jenkinsb9abeae2018-11-22 11:58:08 +000042unsigned int calculate_number_of_stages(const ITensorInfo *input, unsigned int axis)
Jenkinsb3a371b2018-05-23 11:36:53 +010043{
Jenkinsb9abeae2018-11-22 11:58:08 +000044 // We need only 1 stage for all axis except x-axis and x-axis for QASYMM8.
45 if(axis != 0 || (axis == 0 && is_data_type_quantized(input->data_type())))
46 {
47 return 1;
48 }
Jenkinsb3a371b2018-05-23 11:36:53 +010049 // Calculate number of WGs. 16 elements per thread, 8 threads per WG
50 const unsigned int num_of_wg = ceil(input->dimension(0) / 128.f);
51
52 // Calculate number of stages. First stage performs op and the rest reduction sum
53 // depending on the size of the input. Last stage should have only 1 WG.
54 const unsigned int num_of_stages = num_of_wg / 128 + 2;
55
56 return num_of_stages;
57}
58} // namespace
59
Kaizen8938bd32017-09-28 14:38:23 +010060CLReductionOperation::CLReductionOperation(std::shared_ptr<IMemoryManager> memory_manager)
Jenkins0e205f72019-11-28 16:53:35 +000061 : _memory_group(std::move(memory_manager)), _results_vector(), _reduction_kernels_vector(), _border_handlers_vector(), _reshape_kernel(), _op(), _num_of_stages(), _reduction_axis(), _is_serial(),
62 _is_reshape_required(false)
Kaizen8938bd32017-09-28 14:38:23 +010063{
64}
65
Jenkins0e205f72019-11-28 16:53:35 +000066Status CLReductionOperation::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op, bool keep_dims)
Jenkinsb3a371b2018-05-23 11:36:53 +010067{
Jenkins0e205f72019-11-28 16:53:35 +000068 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
69 ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
70
71 const unsigned int num_of_stages = calculate_number_of_stages(input, axis);
72 const bool is_serial = needs_serialized_reduction(op, input->data_type(), axis);
73 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX) || (op == ReductionOperation::ARG_IDX_MIN);
74 const bool is_reshape_required = !keep_dims || is_arg_min_max;
75
76 if(is_reshape_required)
77 {
78 const TensorInfo expected_output_shape = output->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis, keep_dims));
79 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output_shape, output);
80 }
81
82 auto *output_internal = output;
83
84 TensorInfo output_before_reshape;
85 const auto input_shape = input->tensor_shape();
86 const auto input_data_type = input->data_type();
87 const auto input_num_channles = input->num_channels();
88 const auto input_qinfo = input->quantization_info();
89 const auto output_data_type = is_arg_min_max ? DataType::S32 : output->data_type();
90
91 auto initialize_tensorinfo = [](TensorInfo & ti, TensorShape shape, DataType data_type, int num_channels, QuantizationInfo qinfo)
92 {
93 ti.set_data_type(data_type).set_tensor_shape(shape).set_num_channels(num_channels).set_quantization_info(qinfo);
94 };
95
96 if(is_reshape_required)
97 {
98 auto shape_before_reshape = input_shape;
99 shape_before_reshape.set(axis, 1);
100 initialize_tensorinfo(output_before_reshape, shape_before_reshape, output_data_type, input_num_channles, input_qinfo);
101 output_internal = &output_before_reshape;
102 }
103
Jenkins514be652019-02-28 12:25:18 +0000104 if(is_serial)
105 {
Jenkins0e205f72019-11-28 16:53:35 +0000106 ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, output_internal, axis, op));
Jenkins514be652019-02-28 12:25:18 +0000107 }
108 else
Jenkinsb3a371b2018-05-23 11:36:53 +0100109 {
Jenkinsb9abeae2018-11-22 11:58:08 +0000110 // Create temporary tensor infos
Jenkins4ba87db2019-05-23 17:11:51 +0100111 std::vector<TensorInfo> sums_vector(num_of_stages - 1);
Jenkinsb9abeae2018-11-22 11:58:08 +0000112
113 // Create intermediate tensor info
Jenkins0e205f72019-11-28 16:53:35 +0000114 TensorShape shape{ input_shape };
115
116 shape.set(0, ceil(shape.x() / 128.f));
Jenkinsb9abeae2018-11-22 11:58:08 +0000117
118 for(unsigned int i = 0; i < num_of_stages - 1; i++)
119 {
Jenkins0e205f72019-11-28 16:53:35 +0000120 initialize_tensorinfo(sums_vector[i], shape, input_data_type, input_num_channles, input_qinfo);
Jenkinsb9abeae2018-11-22 11:58:08 +0000121 }
122
123 ReductionOperation first_kernel_op;
Jenkins514be652019-02-28 12:25:18 +0000124 ReductionOperation intermediate_kernel_op;
Jenkinsb9abeae2018-11-22 11:58:08 +0000125 ReductionOperation last_kernel_op;
126 switch(op)
127 {
128 case ReductionOperation::SUM:
129 case ReductionOperation::MEAN_SUM:
Jenkins514be652019-02-28 12:25:18 +0000130 first_kernel_op = ReductionOperation::SUM;
131 intermediate_kernel_op = ReductionOperation::SUM;
132 last_kernel_op = op;
Jenkinsb9abeae2018-11-22 11:58:08 +0000133 break;
134 case ReductionOperation::SUM_SQUARE:
Jenkins514be652019-02-28 12:25:18 +0000135 first_kernel_op = ReductionOperation::SUM_SQUARE;
136 intermediate_kernel_op = ReductionOperation::SUM;
137 last_kernel_op = ReductionOperation::SUM;
138 break;
139 case ReductionOperation::PROD:
140 first_kernel_op = ReductionOperation::PROD;
141 intermediate_kernel_op = ReductionOperation::PROD;
142 last_kernel_op = ReductionOperation::PROD;
Jenkinsb9abeae2018-11-22 11:58:08 +0000143 break;
Jenkins975dfe12019-09-02 11:47:54 +0100144 case ReductionOperation::MIN:
145 first_kernel_op = ReductionOperation::MIN;
146 intermediate_kernel_op = ReductionOperation::MIN;
147 last_kernel_op = ReductionOperation::MIN;
148 break;
149 case ReductionOperation::MAX:
150 first_kernel_op = ReductionOperation::MAX;
151 intermediate_kernel_op = ReductionOperation::MAX;
152 last_kernel_op = ReductionOperation::MAX;
153 break;
Jenkinsb9abeae2018-11-22 11:58:08 +0000154 default:
155 ARM_COMPUTE_ERROR("Not supported");
156 }
157
158 // Validate ReductionOperation only on first kernel
Jenkins4ba87db2019-05-23 17:11:51 +0100159 ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, &sums_vector[0], axis, first_kernel_op));
Jenkinsb9abeae2018-11-22 11:58:08 +0000160
161 // Validate ReductionOperation on intermediate stages
162 for(unsigned int i = 1; i < num_of_stages - 1; ++i)
163 {
Jenkins4ba87db2019-05-23 17:11:51 +0100164 ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(&sums_vector[i - 1], &sums_vector[i], axis, intermediate_kernel_op));
Jenkinsb9abeae2018-11-22 11:58:08 +0000165 }
166
167 // Validate ReductionOperation on the last stage
168 const unsigned int last_stage = num_of_stages - 1;
Jenkins0e205f72019-11-28 16:53:35 +0000169 ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(&sums_vector[last_stage - 1], output_internal, axis, last_kernel_op, input->dimension(0)));
170 }
171
172 if(is_reshape_required)
173 {
174 ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayerKernel::validate(output_internal, output));
Jenkinsb3a371b2018-05-23 11:36:53 +0100175 }
Jenkinsb3a371b2018-05-23 11:36:53 +0100176
Jenkinsb3a371b2018-05-23 11:36:53 +0100177 return Status{};
178}
179
Jenkins0e205f72019-11-28 16:53:35 +0000180ICLTensor *CLReductionOperation::configure_intermediate_result_vector(ICLTensor *input, ICLTensor *output)
Kaizen8938bd32017-09-28 14:38:23 +0100181{
Jenkins0e205f72019-11-28 16:53:35 +0000182 if(!_is_reshape_required && _is_serial)
183 {
184 return output;
185 }
186
187 auto intermediate_result_vector_size = _is_serial ? 1 : _num_of_stages;
188 const auto is_arg_min_max = (_op == ReductionOperation::ARG_IDX_MAX || _op == ReductionOperation::ARG_IDX_MIN);
189
190 if(!_is_reshape_required)
191 {
192 --intermediate_result_vector_size;
193 }
194
195 _results_vector.resize(intermediate_result_vector_size);
196 auto shape = input->info()->tensor_shape();
197
198 shape.set(_reduction_axis, _is_serial ? 1 : ceil(shape.x() / 128.f));
199
200 for(auto &v : _results_vector)
201 {
202 if(&v == &_results_vector.back() && _is_reshape_required)
203 {
204 shape.set(_reduction_axis, 1);
205 }
206 v.allocator()->init(input->info()->clone()->set_tensor_shape(shape));
207 }
208
209 if(is_arg_min_max)
210 {
211 _results_vector.back().info()->set_data_type(DataType::S32).set_is_resizable(true).reset_padding();
212 }
213
214 return _is_reshape_required ? &_results_vector.back() : output;
215}
216
217void CLReductionOperation::configure(ICLTensor *input, ICLTensor *output, unsigned int axis, ReductionOperation op, bool keep_dims)
218{
219 _op = op;
220 _num_of_stages = calculate_number_of_stages(input->info(), axis);
221 _reduction_axis = axis;
222 _is_serial = needs_serialized_reduction(op, input->info()->data_type(), axis);
223 const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX) || (op == ReductionOperation::ARG_IDX_MIN);
224 _is_reshape_required = !keep_dims || is_arg_min_max;
225
226 auto *output_internal = configure_intermediate_result_vector(input, output);
227
228 // ArgMinMax might not give initialized output tensor, so initialize here.
229 if(_is_reshape_required)
230 {
231 const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis, false);
232 const auto output_data_type = is_arg_min_max ? DataType::S32 : input->info()->data_type();
233 auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
234 }
Kaizen8938bd32017-09-28 14:38:23 +0100235
236 // Configure reduction operation kernels
Jenkins4ba87db2019-05-23 17:11:51 +0100237 _reduction_kernels_vector.resize(_num_of_stages);
Kaizen8938bd32017-09-28 14:38:23 +0100238
Jenkinsb9abeae2018-11-22 11:58:08 +0000239 // Create temporary tensors
Jenkins514be652019-02-28 12:25:18 +0000240 if(_is_serial)
241 {
Jenkins0e205f72019-11-28 16:53:35 +0000242 if(_is_reshape_required)
243 {
244 _memory_group.manage(&_results_vector.back());
245 }
246
247 _reduction_kernels_vector[0].configure(input, output_internal, axis, op, 0);
Jenkins514be652019-02-28 12:25:18 +0000248 }
249 else
Kaizen8938bd32017-09-28 14:38:23 +0100250 {
Jenkins4ba87db2019-05-23 17:11:51 +0100251 _border_handlers_vector.resize(_num_of_stages);
Jenkins4ba87db2019-05-23 17:11:51 +0100252 _memory_group.manage(&_results_vector[0]);
Jenkinsb9abeae2018-11-22 11:58:08 +0000253
254 ReductionOperation first_kernel_op;
Jenkins514be652019-02-28 12:25:18 +0000255 ReductionOperation intermediate_kernel_op;
Jenkinsb9abeae2018-11-22 11:58:08 +0000256 ReductionOperation last_kernel_op;
Jenkins514be652019-02-28 12:25:18 +0000257 PixelValue pixelValue;
Jenkinsb9abeae2018-11-22 11:58:08 +0000258 switch(op)
259 {
260 case ReductionOperation::SUM:
261 case ReductionOperation::MEAN_SUM:
Jenkins514be652019-02-28 12:25:18 +0000262 first_kernel_op = ReductionOperation::SUM;
263 intermediate_kernel_op = ReductionOperation::SUM;
264 last_kernel_op = op;
265 pixelValue = PixelValue();
Jenkinsb9abeae2018-11-22 11:58:08 +0000266 break;
267 case ReductionOperation::SUM_SQUARE:
Jenkins514be652019-02-28 12:25:18 +0000268 first_kernel_op = ReductionOperation::SUM_SQUARE;
269 intermediate_kernel_op = ReductionOperation::SUM;
270 last_kernel_op = ReductionOperation::SUM;
271 pixelValue = PixelValue();
272 break;
273 case ReductionOperation::PROD:
274 first_kernel_op = ReductionOperation::PROD;
275 intermediate_kernel_op = ReductionOperation::PROD;
276 last_kernel_op = ReductionOperation::PROD;
277 pixelValue = PixelValue(1, input->info()->data_type());
Jenkinsb9abeae2018-11-22 11:58:08 +0000278 break;
Jenkins975dfe12019-09-02 11:47:54 +0100279 case ReductionOperation::MIN:
280 first_kernel_op = ReductionOperation::MIN;
281 intermediate_kernel_op = ReductionOperation::MIN;
282 last_kernel_op = ReductionOperation::MIN;
283 switch(input->info()->data_type())
284 {
285 case DataType::F32:
286 {
287 pixelValue = PixelValue(std::numeric_limits<float>::max());
288 break;
289 }
290 case DataType::F16:
291 {
292 pixelValue = PixelValue(static_cast<half>(65504.0f));
293 break;
294 }
295 case DataType::QASYMM8:
296 {
297 pixelValue = PixelValue(255, input->info()->data_type(), input->info()->quantization_info());
298 break;
299 }
300 default:
301 {
302 ARM_COMPUTE_ERROR("Unsupported DataType");
303 }
304 }
305 break;
306 case ReductionOperation::MAX:
307 first_kernel_op = ReductionOperation::MAX;
308 intermediate_kernel_op = ReductionOperation::MAX;
309 last_kernel_op = ReductionOperation::MAX;
310 switch(input->info()->data_type())
311 {
312 case DataType::F32:
313 {
314 pixelValue = PixelValue(-std::numeric_limits<float>::max());
315 break;
316 }
317 case DataType::F16:
318 {
319 pixelValue = PixelValue(static_cast<half>(-65504.0f));
320 break;
321 }
322 case DataType::QASYMM8:
323 {
324 pixelValue = PixelValue(0, input->info()->data_type(), input->info()->quantization_info());
325 break;
326 }
327 default:
328 {
329 ARM_COMPUTE_ERROR("Unsupported DataType");
330 }
331 }
332 break;
Jenkinsb9abeae2018-11-22 11:58:08 +0000333 default:
334 ARM_COMPUTE_ERROR("Not supported");
335 }
336
Jenkins4ba87db2019-05-23 17:11:51 +0100337 _reduction_kernels_vector[0].configure(input, &_results_vector[0], axis, first_kernel_op);
Jenkins514be652019-02-28 12:25:18 +0000338 _border_handlers_vector[0].configure(input, _reduction_kernels_vector[0].border_size(), BorderMode::CONSTANT, pixelValue);
Jenkinsb9abeae2018-11-22 11:58:08 +0000339
340 // Apply ReductionOperation on intermediate stages
341 for(unsigned int i = 1; i < _num_of_stages - 1; ++i)
342 {
Jenkins4ba87db2019-05-23 17:11:51 +0100343 _memory_group.manage(&_results_vector[i]);
344 _reduction_kernels_vector[i].configure(&_results_vector[i - 1], &_results_vector[i], axis, intermediate_kernel_op);
345 _border_handlers_vector[i].configure(&_results_vector[i - 1], _reduction_kernels_vector[i].border_size(), BorderMode::CONSTANT, pixelValue);
Jenkins514be652019-02-28 12:25:18 +0000346 _results_vector[i - 1].allocator()->allocate();
Jenkinsb9abeae2018-11-22 11:58:08 +0000347 }
348
349 // Apply ReductionOperation on the last stage
350 const unsigned int last_stage = _num_of_stages - 1;
351 const unsigned int input_width = input->info()->dimension(0);
Jenkins0e205f72019-11-28 16:53:35 +0000352
353 if(_is_reshape_required)
354 {
355 _memory_group.manage(&_results_vector.back());
356 }
357
358 _reduction_kernels_vector[last_stage].configure(&_results_vector[last_stage - 1], output_internal, axis, last_kernel_op, input_width);
Jenkins4ba87db2019-05-23 17:11:51 +0100359 _border_handlers_vector[last_stage].configure(&_results_vector[last_stage - 1], _reduction_kernels_vector[last_stage].border_size(), BorderMode::CONSTANT, pixelValue);
Jenkins514be652019-02-28 12:25:18 +0000360 _results_vector[last_stage - 1].allocator()->allocate();
Kaizen8938bd32017-09-28 14:38:23 +0100361 }
Jenkins0e205f72019-11-28 16:53:35 +0000362
363 if(_is_reshape_required)
364 {
365 _reshape_kernel.configure(&_results_vector.back(), output);
366 _results_vector.back().allocator()->allocate();
367 }
Kaizen8938bd32017-09-28 14:38:23 +0100368}
369
370void CLReductionOperation::run()
371{
Jenkins4ba87db2019-05-23 17:11:51 +0100372 MemoryGroupResourceScope scope_mg(_memory_group);
Kaizen8938bd32017-09-28 14:38:23 +0100373
Jenkins514be652019-02-28 12:25:18 +0000374 if(_is_serial)
375 {
376 CLScheduler::get().enqueue(_reduction_kernels_vector[0], false);
377 }
378 else
Kaizen8938bd32017-09-28 14:38:23 +0100379 {
Jenkinsb9abeae2018-11-22 11:58:08 +0000380 for(unsigned int i = 0; i < _num_of_stages; ++i)
381 {
382 CLScheduler::get().enqueue(_border_handlers_vector[i], false);
383 CLScheduler::get().enqueue(_reduction_kernels_vector[i], false);
384 }
385 }
Jenkins0e205f72019-11-28 16:53:35 +0000386
387 if(_is_reshape_required)
388 {
389 CLScheduler::get().enqueue(_reshape_kernel, false);
390 }
Jenkinsb3a371b2018-05-23 11:36:53 +0100391}
Jenkins0e205f72019-11-28 16:53:35 +0000392} // namespace arm_compute