blob: 8d460142e571e3086dc8f4af2a2cca6a1b817021 [file] [log] [blame]
Anthony Barbier871448e2017-03-24 14:54:29 +00001/*
Jenkins514be652019-02-28 12:25:18 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier871448e2017-03-24 14:54:29 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLGEMM.h"
25
Jenkins4ba87db2019-05-23 17:11:51 +010026#include "arm_compute/core/CL/ICLGEMMKernelConfiguration.h"
Anthony Barbier871448e2017-03-24 14:54:29 +000027#include "arm_compute/core/CL/ICLTensor.h"
Jenkins4ba87db2019-05-23 17:11:51 +010028#include "arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
29#include "arm_compute/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
Anthony Barbier871448e2017-03-24 14:54:29 +000030#include "arm_compute/core/Error.h"
Jenkinsb3a371b2018-05-23 11:36:53 +010031#include "arm_compute/core/GPUTarget.h"
Anthony Barbier871448e2017-03-24 14:54:29 +000032#include "arm_compute/core/Helpers.h"
Jenkins975dfe12019-09-02 11:47:54 +010033#include "arm_compute/core/KernelDescriptors.h"
Anthony Barbier871448e2017-03-24 14:54:29 +000034#include "arm_compute/core/TensorInfo.h"
35#include "arm_compute/core/Types.h"
Jenkinsb3a371b2018-05-23 11:36:53 +010036#include "arm_compute/core/Utils.h"
Anthony Barbier871448e2017-03-24 14:54:29 +000037#include "arm_compute/core/Validate.h"
Jenkins975dfe12019-09-02 11:47:54 +010038#include "arm_compute/core/utils/helpers/float_ops.h"
Jenkins0e205f72019-11-28 16:53:35 +000039#include "arm_compute/core/utils/misc/Cast.h"
Jenkinsb3a371b2018-05-23 11:36:53 +010040#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier871448e2017-03-24 14:54:29 +000041#include "arm_compute/runtime/CL/CLScheduler.h"
42#include "arm_compute/runtime/ITensorAllocator.h"
43
Jenkins514be652019-02-28 12:25:18 +000044namespace arm_compute
45{
Jenkinsb3a371b2018-05-23 11:36:53 +010046using namespace arm_compute::misc::shape_calculator;
Jenkins514be652019-02-28 12:25:18 +000047using namespace arm_compute::cl_gemm;
Jenkins0e205f72019-11-28 16:53:35 +000048using namespace arm_compute::utils::cast;
Anthony Barbier871448e2017-03-24 14:54:29 +000049
Jenkins0e205f72019-11-28 16:53:35 +000050CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
Jenkinsb9abeae2018-11-22 11:58:08 +000051 : _memory_group(std::move(memory_manager)),
Jenkins0e205f72019-11-28 16:53:35 +000052 _weights_manager(weights_manager),
Jenkinsb9abeae2018-11-22 11:58:08 +000053 _mm_kernel(),
Jenkins514be652019-02-28 12:25:18 +000054 _reshape_lhs_kernel(),
55 _reshape_rhs_kernel(),
Jenkins0e205f72019-11-28 16:53:35 +000056 _reshape_rhs_kernel_managed(),
Jenkins514be652019-02-28 12:25:18 +000057 _mm_reshaped_kernel(),
Jenkins4ba87db2019-05-23 17:11:51 +010058 _mm_reshaped_only_rhs_kernel(),
Jenkinsb9abeae2018-11-22 11:58:08 +000059 _tmp_a(),
60 _tmp_b(),
61 _original_b(nullptr),
Jenkinsb9abeae2018-11-22 11:58:08 +000062 _reshape_b_only_on_first_run(false),
Jenkins514be652019-02-28 12:25:18 +000063 _is_prepared(false),
Jenkins4ba87db2019-05-23 17:11:51 +010064 _gemm_type(GEMMType::NATIVE)
Anthony Barbier871448e2017-03-24 14:54:29 +000065{
66}
67
Jenkins4ba87db2019-05-23 17:11:51 +010068CLGEMM::GEMMType CLGEMM::select_gemm_type(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run, GPUTarget gpu_target)
Anthony Barbier871448e2017-03-24 14:54:29 +000069{
Jenkins4ba87db2019-05-23 17:11:51 +010070 GEMMType gemm_type = GEMMType::RESHAPED_V1;
Anthony Barbier871448e2017-03-24 14:54:29 +000071
Jenkins0e205f72019-11-28 16:53:35 +000072 if(gpu_target_is_in(gpu_target, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT,
73 GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72,
74 GPUTarget::G76, GPUTarget::G77))
Jenkins4ba87db2019-05-23 17:11:51 +010075 {
Jenkins0e205f72019-11-28 16:53:35 +000076 if(data_type == DataType::F32)
Jenkins4ba87db2019-05-23 17:11:51 +010077 {
Jenkins0e205f72019-11-28 16:53:35 +000078 if((m > 1) && (n < 16))
Jenkins4ba87db2019-05-23 17:11:51 +010079 {
Jenkins0e205f72019-11-28 16:53:35 +000080 gemm_type = GEMMType::RESHAPED_V1;
81 }
82 else if(m == 1)
83 {
84 gemm_type = GEMMType::RESHAPED_ONLY_RHS;
Jenkins4ba87db2019-05-23 17:11:51 +010085 }
86 else
87 {
Jenkins0e205f72019-11-28 16:53:35 +000088 // COMPMID-852
89 if((k > 256) && (m > 4) && reshape_b_only_on_first_run)
90 {
91 constexpr float alpha = 3.2f;
92 constexpr float fact0 = 1.51f;
93 constexpr float fact1 = 1.66f;
94 constexpr float ops = 12.0f;
95 const float scale = k > 1024 ? 1.07f : 1.0f;
96 gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
97 }
98 else
99 {
100 gemm_type = GEMMType::NATIVE;
101 }
102 }
103
104 const auto workload = static_cast<float>((m * n) / 20.0f);
105
106 gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
107 }
108 else
109 {
110 if((m == 1) || (!reshape_b_only_on_first_run))
111 {
112 gemm_type = GEMMType::RESHAPED_ONLY_RHS;
113 }
114 else
115 {
116 gemm_type = GEMMType::RESHAPED_V2;
Jenkins4ba87db2019-05-23 17:11:51 +0100117 }
118 }
Jenkins4ba87db2019-05-23 17:11:51 +0100119 }
120 else
121 {
122 // We reshape the matrices only if we do not have the vector-by-matrix case and we reshape the matrix B only once
123 gemm_type = ((m != 1) && reshape_b_only_on_first_run) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
124 }
Kaizen8938bd32017-09-28 14:38:23 +0100125
Jenkins4ba87db2019-05-23 17:11:51 +0100126 return gemm_type;
127}
128
129void CLGEMM::configure_native(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
130{
131 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
132 const unsigned int n = b->info()->dimension(0);
133 const unsigned int k = a->info()->dimension(0);
134 const GPUTarget gpu_target = CLScheduler::get().target();
Anthony Barbier06ea0482018-02-22 15:45:35 +0000135
136 // Set the target for the kernels
Anthony Barbier06ea0482018-02-22 15:45:35 +0000137 _mm_kernel.set_target(gpu_target);
138
Jenkins975dfe12019-09-02 11:47:54 +0100139 GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias());
Jenkins4ba87db2019-05-23 17:11:51 +0100140
141 // Configure and tune matrix multiply kernel
Jenkins975dfe12019-09-02 11:47:54 +0100142 _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Jenkins4ba87db2019-05-23 17:11:51 +0100143
144 // Tune kernel statically
145 CLScheduler::get().tune_kernel_static(_mm_kernel);
146}
147
148void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
149{
Jenkins514be652019-02-28 12:25:18 +0000150 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
151 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
152 const unsigned int n = b->info()->dimension(0);
153 const unsigned int k = a->info()->dimension(0);
Jenkins514be652019-02-28 12:25:18 +0000154 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Jenkins4ba87db2019-05-23 17:11:51 +0100155 const GPUTarget gpu_target = CLScheduler::get().target();
Jenkins514be652019-02-28 12:25:18 +0000156 int mult_transpose1xW_width = 1;
157 int mult_interleave4x4_height = 1;
Anthony Barbier06ea0482018-02-22 15:45:35 +0000158
Jenkins4ba87db2019-05-23 17:11:51 +0100159 // Set the target for the kernels
160 _reshape_lhs_kernel.set_target(gpu_target);
161 _mm_kernel.set_target(gpu_target);
162
Jenkinsb3a371b2018-05-23 11:36:53 +0100163 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
Anthony Barbier06ea0482018-02-22 15:45:35 +0000164 {
165 mult_transpose1xW_width = 4;
166 mult_interleave4x4_height = 2;
167 }
Jenkins4ba87db2019-05-23 17:11:51 +0100168
Jenkins514be652019-02-28 12:25:18 +0000169 GEMMRHSMatrixInfo rhs_info;
170 rhs_info.n0 = 16 / b->info()->element_size();
171 rhs_info.k0 = 1;
172 rhs_info.h0 = mult_transpose1xW_width;
173 rhs_info.interleave = false;
174 rhs_info.transpose = false;
175
176 GEMMLHSMatrixInfo lhs_info;
177 lhs_info.m0 = 4;
178 lhs_info.k0 = 4;
179 lhs_info.v0 = mult_interleave4x4_height;
180 lhs_info.interleave = true;
181 lhs_info.transpose = true;
Anthony Barbier06ea0482018-02-22 15:45:35 +0000182
Jenkins975dfe12019-09-02 11:47:54 +0100183 GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Anthony Barbier8140e1e2017-12-14 23:48:46 +0000184
Jenkins0e205f72019-11-28 16:53:35 +0000185 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
186
187 // Manage intermediate buffers
Jenkins4ba87db2019-05-23 17:11:51 +0100188 _memory_group.manage(&_tmp_a);
Jenkins0e205f72019-11-28 16:53:35 +0000189
190 if(!_reshape_b_only_on_first_run && use_mm_b)
Jenkins52ba29e2018-08-29 15:32:11 +0000191 {
Jenkins4ba87db2019-05-23 17:11:51 +0100192 _memory_group.manage(&_tmp_b);
Kaizen8938bd32017-09-28 14:38:23 +0100193 }
Anthony Barbier871448e2017-03-24 14:54:29 +0000194
Jenkins4ba87db2019-05-23 17:11:51 +0100195 // Configure interleave kernel
196 _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
Kaizen8938bd32017-09-28 14:38:23 +0100197
Jenkins4ba87db2019-05-23 17:11:51 +0100198 // Configure transpose kernel
Jenkins0e205f72019-11-28 16:53:35 +0000199 ICLTensor *reshaped_rhs = &_tmp_b;
200 if(_weights_manager && _weights_manager->are_weights_managed(b))
201 {
202 _reshape_rhs_kernel_managed.configure(b, rhs_info);
203 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
204 }
205 else
206 {
207 _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
208 }
Anthony Barbier871448e2017-03-24 14:54:29 +0000209
Jenkins4ba87db2019-05-23 17:11:51 +0100210 // Configure and tune matrix multiply kernel
Jenkins0e205f72019-11-28 16:53:35 +0000211 _mm_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Jenkins4ba87db2019-05-23 17:11:51 +0100212
213 CLScheduler::get().tune_kernel_static(_mm_kernel);
214
215 // Allocate intermediate tensors
216 _tmp_a.allocator()->allocate();
Jenkins0e205f72019-11-28 16:53:35 +0000217
218 if(!_reshape_b_only_on_first_run && use_mm_b)
Anthony Barbier871448e2017-03-24 14:54:29 +0000219 {
Jenkins4ba87db2019-05-23 17:11:51 +0100220 _tmp_b.allocator()->allocate();
Anthony Barbier871448e2017-03-24 14:54:29 +0000221 }
222}
223
Jenkins4ba87db2019-05-23 17:11:51 +0100224void CLGEMM::configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
225{
Jenkins4ba87db2019-05-23 17:11:51 +0100226 DataType data_type = a->info()->data_type();
227 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
228 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
229 const unsigned int n = b->info()->dimension(0);
230 const unsigned int k = a->info()->dimension(0);
231 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
232 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
233 const GPUTarget gpu_target = CLScheduler::get().target();
Jenkins975dfe12019-09-02 11:47:54 +0100234 bool broadcast_bias = gemm_info.broadcast_bias();
235
236 GEMMKernelInfo kernel_info;
237 kernel_info.m = m;
238 kernel_info.n = n;
239 kernel_info.k = k;
240 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
241 kernel_info.reinterpret_input_as_3d = false;
242 kernel_info.broadcast_bias = broadcast_bias;
243 kernel_info.activation_info = gemm_info.activation_info();
Jenkins4ba87db2019-05-23 17:11:51 +0100244
245 // Set the target for the kernels
246 _reshape_lhs_kernel.set_target(gpu_target);
247 _mm_kernel.set_target(gpu_target);
248
Jenkins0e205f72019-11-28 16:53:35 +0000249 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
250
Jenkins4ba87db2019-05-23 17:11:51 +0100251 // Manage intermediate buffers
252 _memory_group.manage(&_tmp_a);
Jenkins0e205f72019-11-28 16:53:35 +0000253
254 if(!_reshape_b_only_on_first_run && use_mm_b)
Jenkins4ba87db2019-05-23 17:11:51 +0100255 {
256 _memory_group.manage(&_tmp_b);
257 }
Jenkins0e205f72019-11-28 16:53:35 +0000258
Jenkins4ba87db2019-05-23 17:11:51 +0100259 // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
260
261 GEMMLHSMatrixInfo lhs_info{};
262 GEMMRHSMatrixInfo rhs_info{};
263
264 // Pick up the GEMM configuration
265 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
266 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
267
268 // Configure lhs_info and rhs_info
269 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
270
271 _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
Jenkins0e205f72019-11-28 16:53:35 +0000272
273 ICLTensor *reshaped_rhs = &_tmp_b;
274 if(_weights_manager && _weights_manager->are_weights_managed(b))
275 {
276 _reshape_rhs_kernel_managed.configure(b, rhs_info);
277 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
278 }
279 else
280 {
281 _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
282 }
Jenkins4ba87db2019-05-23 17:11:51 +0100283
284 // Configure and tune matrix multiply kernel
Jenkins0e205f72019-11-28 16:53:35 +0000285 _mm_reshaped_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Jenkins4ba87db2019-05-23 17:11:51 +0100286
287 // Allocate intermediate tensors
288 _tmp_a.allocator()->allocate();
Jenkins0e205f72019-11-28 16:53:35 +0000289
290 if(!_reshape_b_only_on_first_run && use_mm_b)
Jenkins4ba87db2019-05-23 17:11:51 +0100291 {
292 _tmp_b.allocator()->allocate();
293 }
294}
295
296void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
297{
Jenkins4ba87db2019-05-23 17:11:51 +0100298 DataType data_type = a->info()->data_type();
299 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
300 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
301 const unsigned int n = b->info()->dimension(0);
302 const unsigned int k = a->info()->dimension(0);
303 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
304 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
305 const GPUTarget gpu_target = CLScheduler::get().target();
Jenkins975dfe12019-09-02 11:47:54 +0100306 bool broadcast_bias = gemm_info.broadcast_bias();
307
308 GEMMKernelInfo kernel_info;
309 kernel_info.m = m;
310 kernel_info.n = n;
311 kernel_info.k = k;
312 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
313 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
314 kernel_info.broadcast_bias = broadcast_bias;
315 kernel_info.activation_info = gemm_info.activation_info();
Jenkins4ba87db2019-05-23 17:11:51 +0100316
317 // Set the target for the kernels
318 _mm_kernel.set_target(gpu_target);
319
Jenkins0e205f72019-11-28 16:53:35 +0000320 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
321
Jenkins4ba87db2019-05-23 17:11:51 +0100322 // Manage intermediate buffers
Jenkins0e205f72019-11-28 16:53:35 +0000323 if(!_reshape_b_only_on_first_run && use_mm_b)
Jenkins4ba87db2019-05-23 17:11:51 +0100324 {
325 _memory_group.manage(&_tmp_b);
326 }
327
328 GEMMLHSMatrixInfo lhs_info{};
329 GEMMRHSMatrixInfo rhs_info{};
330
331 // Pick up the GEMM configuration
332 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
333 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
334
335 // Configure lhs_info and rhs_info
336 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
337
Jenkins0e205f72019-11-28 16:53:35 +0000338 ICLTensor *reshaped_rhs = &_tmp_b;
339 if(_weights_manager && _weights_manager->are_weights_managed(b))
340 {
341 _reshape_rhs_kernel_managed.configure(b, rhs_info);
342 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
343 }
344 else
345 {
346 _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
347 }
Jenkins4ba87db2019-05-23 17:11:51 +0100348
349 // Configure and tune matrix multiply kernel
Jenkins0e205f72019-11-28 16:53:35 +0000350 _mm_reshaped_only_rhs_kernel.configure(a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Jenkins4ba87db2019-05-23 17:11:51 +0100351
Jenkins0e205f72019-11-28 16:53:35 +0000352 if(!_reshape_b_only_on_first_run && use_mm_b)
Jenkins4ba87db2019-05-23 17:11:51 +0100353 {
354 _tmp_b.allocator()->allocate();
355 }
356}
357
358Status CLGEMM::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Anthony Barbier06ea0482018-02-22 15:45:35 +0000359{
Jenkinsb3a371b2018-05-23 11:36:53 +0100360 ARM_COMPUTE_UNUSED(alpha);
Jenkins52ba29e2018-08-29 15:32:11 +0000361 ARM_COMPUTE_UNUSED(output);
Jenkinsb3a371b2018-05-23 11:36:53 +0100362
Jenkins4ba87db2019-05-23 17:11:51 +0100363 // Get the GPU target
364 const GPUTarget gpu_target = CLScheduler::get().target();
365 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
366 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
367 const unsigned int n = b->dimension(0);
368 const unsigned int k = a->dimension(0);
369 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Jenkinsb3a371b2018-05-23 11:36:53 +0100370
Jenkins975dfe12019-09-02 11:47:54 +0100371 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias());
Jenkins4ba87db2019-05-23 17:11:51 +0100372
373 // Validate matrix multiply
Jenkins975dfe12019-09-02 11:47:54 +0100374 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta,
375 false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Jenkins4ba87db2019-05-23 17:11:51 +0100376
377 return Status{};
378}
379
380Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
381{
382 ARM_COMPUTE_UNUSED(alpha);
383 ARM_COMPUTE_UNUSED(output);
Jenkinsb3a371b2018-05-23 11:36:53 +0100384
385 TensorInfo tmp_a_info{};
386 TensorInfo tmp_b_info{};
Jenkinsb3a371b2018-05-23 11:36:53 +0100387
388 // Get the GPU target
Jenkins4ba87db2019-05-23 17:11:51 +0100389 const GPUTarget gpu_target = CLScheduler::get().target();
390 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
Jenkins514be652019-02-28 12:25:18 +0000391 const unsigned int n = b->dimension(0);
392 const unsigned int k = a->dimension(0);
Jenkins514be652019-02-28 12:25:18 +0000393 int mult_transpose1xW_width = 1;
394 int mult_interleave4x4_height = 1;
395 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Jenkinsb3a371b2018-05-23 11:36:53 +0100396
397 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
398 {
399 mult_transpose1xW_width = 4;
400 mult_interleave4x4_height = 2;
401 }
402
Jenkins514be652019-02-28 12:25:18 +0000403 GEMMRHSMatrixInfo rhs_info;
404 rhs_info.n0 = 16 / b->element_size();
405 rhs_info.k0 = 1;
406 rhs_info.h0 = mult_transpose1xW_width;
407 rhs_info.interleave = false;
408 rhs_info.transpose = false;
409
410 GEMMLHSMatrixInfo lhs_info;
411 lhs_info.m0 = 4;
412 lhs_info.k0 = 4;
413 lhs_info.v0 = mult_interleave4x4_height;
414 lhs_info.interleave = true;
415 lhs_info.transpose = true;
416
Jenkins975dfe12019-09-02 11:47:54 +0100417 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Jenkinsb3a371b2018-05-23 11:36:53 +0100418
Jenkins4ba87db2019-05-23 17:11:51 +0100419 // Validate interleave kernel
420 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
421 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
Jenkins514be652019-02-28 12:25:18 +0000422
Jenkins4ba87db2019-05-23 17:11:51 +0100423 // Validate transpose kernel
424 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
425 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
Jenkins514be652019-02-28 12:25:18 +0000426
Jenkins4ba87db2019-05-23 17:11:51 +0100427 // Validate matrix multiply
Jenkins975dfe12019-09-02 11:47:54 +0100428 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta,
429 true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Jenkinsb3a371b2018-05-23 11:36:53 +0100430
Anthony Barbier06ea0482018-02-22 15:45:35 +0000431 return Status{};
432}
433
Jenkins4ba87db2019-05-23 17:11:51 +0100434Status CLGEMM::validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
435{
436 ARM_COMPUTE_UNUSED(alpha);
437 ARM_COMPUTE_UNUSED(output);
438
439 TensorInfo tmp_a_info{};
440 TensorInfo tmp_b_info{};
441
442 // Get the GPU target
443 const GPUTarget gpu_target = CLScheduler::get().target();
444 DataType data_type = a->data_type();
445 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
446 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
447 const unsigned int n = b->dimension(0);
448 const unsigned int k = a->dimension(0);
449 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
450 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Jenkins975dfe12019-09-02 11:47:54 +0100451 const bool broadcast_bias = gemm_info.broadcast_bias();
Jenkins4ba87db2019-05-23 17:11:51 +0100452
Jenkins975dfe12019-09-02 11:47:54 +0100453 GEMMKernelInfo kernel_info;
454 kernel_info.m = m;
455 kernel_info.n = n;
456 kernel_info.k = k;
457 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
458 kernel_info.reinterpret_input_as_3d = false;
459 kernel_info.broadcast_bias = broadcast_bias;
460 kernel_info.activation_info = gemm_info.activation_info();
Jenkins4ba87db2019-05-23 17:11:51 +0100461
462 GEMMLHSMatrixInfo lhs_info;
463 GEMMRHSMatrixInfo rhs_info;
464
465 // Pick up the GEMM configuration
466 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
467 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
468
469 // Configure lhs_info and rhs_info
470 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
471
472 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
473 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
474
475 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
476 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
477
478 // Validate matrix multiply
Jenkins975dfe12019-09-02 11:47:54 +0100479 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Jenkins4ba87db2019-05-23 17:11:51 +0100480
481 return Status{};
482}
483
484Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
485{
486 ARM_COMPUTE_UNUSED(alpha);
487 ARM_COMPUTE_UNUSED(output);
488
489 TensorInfo tmp_b_info{};
490
491 // Get the GPU target
492 const GPUTarget gpu_target = CLScheduler::get().target();
493 const DataType data_type = a->data_type();
494 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
495 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
496 const unsigned int n = b->dimension(0);
497 const unsigned int k = a->dimension(0);
498 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
499 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Jenkins975dfe12019-09-02 11:47:54 +0100500 const bool broadcast_bias = gemm_info.broadcast_bias();
Jenkins4ba87db2019-05-23 17:11:51 +0100501
Jenkins975dfe12019-09-02 11:47:54 +0100502 GEMMKernelInfo kernel_info;
503 kernel_info.m = m;
504 kernel_info.n = n;
505 kernel_info.k = k;
506 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
507 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
508 kernel_info.broadcast_bias = broadcast_bias;
509 kernel_info.activation_info = gemm_info.activation_info();
Jenkins4ba87db2019-05-23 17:11:51 +0100510
511 GEMMLHSMatrixInfo lhs_info;
512 GEMMRHSMatrixInfo rhs_info;
513
514 // Pick up the GEMM configuration
515 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
516 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
517
518 // Configure lhs_info and rhs_info
519 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
520
521 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
522 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
523
524 // Validate matrix multiply
Jenkins975dfe12019-09-02 11:47:54 +0100525 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Jenkins4ba87db2019-05-23 17:11:51 +0100526
527 return Status{};
528}
529
530void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
531{
532 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
533
534 // Perform validation step
535 ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
536
537 // Check if we need to reshape the matrix B only on the first run
538 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
539 _is_prepared = gemm_info.retain_internal_weights();
540 _original_b = b;
541
542 // Get the GPU target
543 const GPUTarget gpu_target = CLScheduler::get().target();
544 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
545 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
546 const unsigned int n = b->info()->dimension(0);
547 const unsigned int k = a->info()->dimension(0);
548
549 // Select GEMMType
550 _gemm_type = select_gemm_type(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
551
Jenkins975dfe12019-09-02 11:47:54 +0100552 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
553
554 const ICLTensor *c_to_use = fuse_add_c ? c : nullptr;
Jenkins4ba87db2019-05-23 17:11:51 +0100555
556 switch(_gemm_type)
557 {
558 case GEMMType::NATIVE:
559 {
Jenkins975dfe12019-09-02 11:47:54 +0100560 configure_native(a, b, c_to_use, output, alpha, beta, gemm_info);
Jenkins4ba87db2019-05-23 17:11:51 +0100561 break;
562 }
563 case GEMMType::RESHAPED_V1:
564 {
Jenkins975dfe12019-09-02 11:47:54 +0100565 configure_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info);
Jenkins4ba87db2019-05-23 17:11:51 +0100566 break;
567 }
568 case GEMMType::RESHAPED_V2:
569 {
Jenkins975dfe12019-09-02 11:47:54 +0100570 configure_reshaped_v2(a, b, c_to_use, output, alpha, beta, gemm_info);
Jenkins4ba87db2019-05-23 17:11:51 +0100571 break;
572 }
573 case GEMMType::RESHAPED_ONLY_RHS:
574 {
Jenkins975dfe12019-09-02 11:47:54 +0100575 configure_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info);
Jenkins4ba87db2019-05-23 17:11:51 +0100576 break;
577 }
578 default:
579 {
580 ARM_COMPUTE_ERROR("GEMMType not supported");
581 }
582 }
Jenkins4ba87db2019-05-23 17:11:51 +0100583}
584
585Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
586{
587 // Get the GPU target
588 const GPUTarget gpu_target = CLScheduler::get().target();
589 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
590 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
591 const unsigned int n = b->dimension(0);
592 const unsigned int k = a->dimension(0);
593
594 // Select GEMMType
595 GEMMType gemm_type = select_gemm_type(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run(), gpu_target);
596
Jenkins975dfe12019-09-02 11:47:54 +0100597 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
598
599 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
600
Jenkins4ba87db2019-05-23 17:11:51 +0100601 switch(gemm_type)
602 {
603 case GEMMType::NATIVE:
604 {
Jenkins975dfe12019-09-02 11:47:54 +0100605 ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info));
Jenkins4ba87db2019-05-23 17:11:51 +0100606 break;
607 }
608 case GEMMType::RESHAPED_V1:
609 {
Jenkins975dfe12019-09-02 11:47:54 +0100610 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Jenkins4ba87db2019-05-23 17:11:51 +0100611 break;
612 }
613 case GEMMType::RESHAPED_V2:
614 {
Jenkins975dfe12019-09-02 11:47:54 +0100615 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v2(a, b, c_to_use, output, alpha, beta, gemm_info));
Jenkins4ba87db2019-05-23 17:11:51 +0100616 break;
617 }
618 case GEMMType::RESHAPED_ONLY_RHS:
619 {
Jenkins975dfe12019-09-02 11:47:54 +0100620 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
Jenkins4ba87db2019-05-23 17:11:51 +0100621 break;
622 }
623 default:
624 {
625 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
626 }
627 }
628
629 return Status{};
630}
631
Anthony Barbier871448e2017-03-24 14:54:29 +0000632void CLGEMM::run()
633{
Jenkinsb3a371b2018-05-23 11:36:53 +0100634 prepare();
635
Jenkins4ba87db2019-05-23 17:11:51 +0100636 MemoryGroupResourceScope scope_mg(_memory_group);
Anthony Barbier871448e2017-03-24 14:54:29 +0000637
638 // Run matrix multiply kernel
Jenkins4ba87db2019-05-23 17:11:51 +0100639 switch(_gemm_type)
Jenkins514be652019-02-28 12:25:18 +0000640 {
Jenkins4ba87db2019-05-23 17:11:51 +0100641 case GEMMType::NATIVE:
642 {
Jenkins975dfe12019-09-02 11:47:54 +0100643 CLScheduler::get().enqueue(_mm_kernel, true);
Jenkins4ba87db2019-05-23 17:11:51 +0100644 break;
645 }
646 case GEMMType::RESHAPED_V1:
647 {
648 // Run interleave kernel
649 CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
650
651 if(!_reshape_b_only_on_first_run)
652 {
653 // Run transpose kernel
Jenkins0e205f72019-11-28 16:53:35 +0000654 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
655 {
656 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
657 }
658 else
659 {
660 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
661 }
Jenkins4ba87db2019-05-23 17:11:51 +0100662 }
663
Jenkins975dfe12019-09-02 11:47:54 +0100664 CLScheduler::get().enqueue(_mm_kernel, true);
Jenkins4ba87db2019-05-23 17:11:51 +0100665 break;
666 }
667 case GEMMType::RESHAPED_V2:
668 {
669 // Run interleave kernel
670 CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
671
672 if(!_reshape_b_only_on_first_run)
673 {
674 // Run transpose kernel
Jenkins0e205f72019-11-28 16:53:35 +0000675 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
676 {
677 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
678 }
679 else
680 {
681 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
682 }
Jenkins4ba87db2019-05-23 17:11:51 +0100683 }
684
Jenkins975dfe12019-09-02 11:47:54 +0100685 CLScheduler::get().enqueue(_mm_reshaped_kernel, true);
Jenkins4ba87db2019-05-23 17:11:51 +0100686 break;
687 }
688 case GEMMType::RESHAPED_ONLY_RHS:
689 {
690 if(!_reshape_b_only_on_first_run)
691 {
692 // Run transpose kernel
Jenkins0e205f72019-11-28 16:53:35 +0000693 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
694 {
695 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
696 }
697 else
698 {
699 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
700 }
Jenkins4ba87db2019-05-23 17:11:51 +0100701 }
702
Jenkins975dfe12019-09-02 11:47:54 +0100703 CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true);
Jenkins4ba87db2019-05-23 17:11:51 +0100704 break;
705 }
706 default:
707 {
708 ARM_COMPUTE_ERROR("GEMMType not supported");
709 }
Jenkins514be652019-02-28 12:25:18 +0000710 }
Anthony Barbier871448e2017-03-24 14:54:29 +0000711}
Jenkinsb3a371b2018-05-23 11:36:53 +0100712
713void CLGEMM::prepare()
714{
715 if(!_is_prepared)
716 {
Jenkins4ba87db2019-05-23 17:11:51 +0100717 if(_gemm_type != GEMMType::NATIVE && _reshape_b_only_on_first_run)
Jenkinsb3a371b2018-05-23 11:36:53 +0100718 {
Jenkins0e205f72019-11-28 16:53:35 +0000719 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
720 {
721 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
722 }
723 else
724 {
725 // Run transpose kernel and mark original weights tensor as unused
726 _tmp_b.allocator()->allocate();
727 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
728 _original_b->mark_as_unused();
729 }
Jenkinsb3a371b2018-05-23 11:36:53 +0100730 }
731 CLScheduler::get().queue().finish();
732 _is_prepared = true;
733 }
734}
Jenkins514be652019-02-28 12:25:18 +0000735} // namespace arm_compute