blob: 7aa771428de6a2950aa73854f834354dffc4492e [file] [log] [blame]
Anthony Barbierf45d5a92018-01-24 16:23:15 +00001/*
Jenkins514be652019-02-28 12:25:18 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbierf45d5a92018-01-24 16:23:15 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLDeconvolutionLayer.h"
25
Anthony Barbierf45d5a92018-01-24 16:23:15 +000026#include "arm_compute/core/Utils.h"
27#include "arm_compute/core/Validate.h"
28#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Jenkins4ba87db2019-05-23 17:11:51 +010029#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
Jenkinsb9abeae2018-11-22 11:58:08 +000030#include "arm_compute/runtime/CL/CLScheduler.h"
Anthony Barbierf45d5a92018-01-24 16:23:15 +000031
Jenkins4ba87db2019-05-23 17:11:51 +010032#include <cmath>
Anthony Barbierf45d5a92018-01-24 16:23:15 +000033#include <memory>
34#include <tuple>
35
36using namespace arm_compute;
37using namespace arm_compute::misc::shape_calculator;
38
Jenkins4ba87db2019-05-23 17:11:51 +010039CLDeconvolutionLayer::CLDeconvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
40 : _memory_manager(std::move(memory_manager)), _function()
Anthony Barbierf45d5a92018-01-24 16:23:15 +000041{
42}
43
Jenkins4ba87db2019-05-23 17:11:51 +010044void CLDeconvolutionLayer::configure(ICLTensor *input, ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const PadStrideInfo &deconv_info,
Jenkins975dfe12019-09-02 11:47:54 +010045 const WeightsInfo &weights_info)
Jenkins4ba87db2019-05-23 17:11:51 +010046{
47 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
Jenkins4ba87db2019-05-23 17:11:51 +010048
49 switch(CLDeconvolutionLayer::get_deconvolution_method(input->info(), weights->info(), nullptr, output->info(), deconv_info, weights_info))
50 {
51 case DeconvolutionMethod::DIRECT:
52 {
53 auto f = arm_compute::support::cpp14::make_unique<CLDirectDeconvolutionLayer>();
54 f->configure(input, weights, bias, output, deconv_info, weights_info);
55 _function = std::move(f);
56 break;
57 }
58 case DeconvolutionMethod::GEMM:
59 {
60 auto f = arm_compute::support::cpp14::make_unique<CLGEMMDeconvolutionLayer>(_memory_manager);
61 f->configure(input, weights, bias, output, deconv_info);
62 _function = std::move(f);
63 break;
64 }
65 default:
66 ARM_COMPUTE_ERROR("Not supported.");
67 break;
68 }
69}
70
71Status CLDeconvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *output, const PadStrideInfo &deconv_info,
Jenkins975dfe12019-09-02 11:47:54 +010072 const WeightsInfo &weights_info)
Anthony Barbierf45d5a92018-01-24 16:23:15 +000073{
74 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
Jenkins4ba87db2019-05-23 17:11:51 +010075 switch(CLDeconvolutionLayer::get_deconvolution_method(input, weights, bias, output, deconv_info, weights_info))
76 {
77 case DeconvolutionMethod::DIRECT:
78 {
79 // Validate direct convolution layer
80 ARM_COMPUTE_RETURN_ON_ERROR(CLDirectDeconvolutionLayer::validate(input, weights, bias, output, deconv_info, weights_info));
81 break;
82 }
83 case DeconvolutionMethod::GEMM:
84 {
85 // Validate gemm-based convolution layer
86 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMDeconvolutionLayer::validate(input, weights, bias, output, deconv_info));
87 break;
88 }
89 default:
90 ARM_COMPUTE_ERROR("Not supported.");
91 break;
92 }
93
94 return Status{};
95}
96
97DeconvolutionMethod CLDeconvolutionLayer::get_deconvolution_method(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *output, const PadStrideInfo &deconv_info,
98 const WeightsInfo &weights_info)
99{
100 ARM_COMPUTE_UNUSED(output, bias, weights_info);
Jenkinsb9abeae2018-11-22 11:58:08 +0000101
102 const DataLayout data_layout = input->data_layout();
103
104 const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
105 const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
Jenkinsb9abeae2018-11-22 11:58:08 +0000106
Jenkins4ba87db2019-05-23 17:11:51 +0100107 if(weights->dimension(idx_w) != deconv_info.stride().first || weights->dimension(idx_h) != deconv_info.stride().second)
Anthony Barbierf45d5a92018-01-24 16:23:15 +0000108 {
Jenkins4ba87db2019-05-23 17:11:51 +0100109 return DeconvolutionMethod::DIRECT;
Anthony Barbierf45d5a92018-01-24 16:23:15 +0000110 }
111
Jenkins4ba87db2019-05-23 17:11:51 +0100112 return DeconvolutionMethod::GEMM;
Anthony Barbierf45d5a92018-01-24 16:23:15 +0000113}
114
Anthony Barbierf45d5a92018-01-24 16:23:15 +0000115void CLDeconvolutionLayer::run()
116{
Jenkins52ba29e2018-08-29 15:32:11 +0000117 prepare();
Jenkins4ba87db2019-05-23 17:11:51 +0100118 _function->run();
Anthony Barbierf45d5a92018-01-24 16:23:15 +0000119}
Jenkins52ba29e2018-08-29 15:32:11 +0000120
121void CLDeconvolutionLayer::prepare()
122{
Jenkins4ba87db2019-05-23 17:11:51 +0100123 _function->prepare();
Jenkins52ba29e2018-08-29 15:32:11 +0000124}