blob: 058b6027c2259945f134a5a213e659ba91695ea5 [file] [log] [blame]
Jenkinsb3a371b2018-05-23 11:36:53 +01001/*
Jenkins18b685f2020-08-21 10:26:22 +01002 * Copyright (c) 2018-2020 Arm Limited.
Jenkinsb3a371b2018-05-23 11:36:53 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
25
Jenkinsb3a371b2018-05-23 11:36:53 +010026#include "arm_compute/core/Utils.h"
27#include "arm_compute/core/Validate.h"
Jenkins6a7771e2020-05-28 11:28:36 +010028#include "arm_compute/core/utils/misc/InfoHelpers.h"
Jenkinsb3a371b2018-05-23 11:36:53 +010029#include "arm_compute/core/utils/misc/ShapeCalculator.h"
30#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
31#include "arm_compute/runtime/CL/CLScheduler.h"
32
Jenkins6a7771e2020-05-28 11:28:36 +010033namespace arm_compute
34{
Jenkinsb3a371b2018-05-23 11:36:53 +010035using namespace arm_compute::misc::shape_calculator;
Jenkins6a7771e2020-05-28 11:28:36 +010036using namespace arm_compute::utils::info_helpers;
Jenkinsb3a371b2018-05-23 11:36:53 +010037
38CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Jenkins975dfe12019-09-02 11:47:54 +010039 : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _accum_input_gate1(), _subtract_input_gate(), _pixelwise_mul_input_gate(), _activation_input_gate(),
40 _fully_connected_forget_gate(), _accum_forget_gate1(), _pixelwise_mul_forget_gate(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(), _transpose_cell_state(),
41 _accum_cell_state1(), _accum_cell_state2(), _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(), _pixelwise_mul_cell_state2(), _fully_connected_output(),
42 _pixelwise_mul_output_state1(), _accum_output1(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state2(), _fully_connected_output_state(), _projection_clip(),
43 _copy_cell_state(), _copy_output(), _concat_scratch_buffer(), _concat_inputs_forget_gate(), _concat_weights_forget_gate(), _concat_weights_input_gate(), _concat_weights_output(),
44 _ones_memset_kernel(), _mean_std_norm_input_gate(), _pixelwise_mul_input_gate_coeff(), _accum_input_gate_bias(), _mean_std_norm_forget_gate(), _pixelwise_mul_forget_gate_coeff(),
45 _accum_forget_gate_bias(), _mean_std_norm_cell_gate(), _pixelwise_mul_cell_gate_coeff(), _accum_cell_gate_bias(), _mean_std_norm_output_gate(), _pixelwise_mul_output_gate_coeff(),
46 _accum_output_gate_bias(), _input_gate_out1(), _input_gate_out2(), _input_gate_out3(), _input_gate_out4(), _forget_gate_out1(), _forget_gate_out2(), _forget_gate_out3(), _forget_gate_out4(),
Jenkins4ba87db2019-05-23 17:11:51 +010047 _forget_gate_out5(), _forget_gate_out6(), _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(), _cell_state_out5(), _output1(), _output2(), _output3(), _output4(),
Jenkins975dfe12019-09-02 11:47:54 +010048 _cell_state_activation(), _output_state1(), _ones(), _input_layer_norm_out1(), _input_layer_norm_out2(), _forget_layer_norm_out1(), _forget_layer_norm_out2(), _cell_layer_norm_out1(),
49 _cell_layer_norm_out2(), _output_layer_norm_out1(), _output_layer_norm_out2(), _run_peephole_opt(false), _run_cifg_opt(false), _perform_cell_clipping(false), _has_projection_weights(false),
50 _perform_projection_clipping(false), _is_prepared(false), _is_layer_norm_lstm(false)
Jenkinsb3a371b2018-05-23 11:36:53 +010051{
52}
53
Jenkins52ba29e2018-08-29 15:32:11 +000054void CLLSTMLayer::configure(const ICLTensor *input,
55 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
Jenkinsb3a371b2018-05-23 11:36:53 +010056 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
57 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
Jenkins18b685f2020-08-21 10:26:22 +010058 const ICLTensor *output_state_in, ICLTensor *cell_state_in,
Jenkins52ba29e2018-08-29 15:32:11 +000059 ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
60 const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
Jenkinsb3a371b2018-05-23 11:36:53 +010061{
Jenkins6a7771e2020-05-28 11:28:36 +010062 configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
63 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
64 cell_threshold, projection_threshold);
65}
66
67void CLLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
68 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
69 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
70 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
Jenkins18b685f2020-08-21 10:26:22 +010071 const ICLTensor *output_state_in, ICLTensor *cell_state_in,
Jenkins6a7771e2020-05-28 11:28:36 +010072 ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
73 const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
74{
Jenkins52ba29e2018-08-29 15:32:11 +000075 ARM_COMPUTE_ERROR_ON_NULLPTR(input,
76 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
77 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
78 forget_gate_bias, cell_bias, output_gate_bias,
79 output_state_in, cell_state_in,
80 scratch_buffer, output_state_out, cell_state_out, output);
81
Jenkins975dfe12019-09-02 11:47:54 +010082 _is_layer_norm_lstm = lstm_params.use_layer_norm();
83
Jenkins52ba29e2018-08-29 15:32:11 +000084 // Set lstm parameters
Jenkins6a7771e2020-05-28 11:28:36 +010085 LSTMParams<ITensorInfo> lstm_params_info{};
86 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
Jenkins52ba29e2018-08-29 15:32:11 +000087
88 // Validate
Jenkinsb3a371b2018-05-23 11:36:53 +010089 ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(),
90 input_to_cell_weights->info(), input_to_output_weights->info(),
91 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
92 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Jenkins52ba29e2018-08-29 15:32:11 +000093 output_state_in->info(), cell_state_in->info(),
94 scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
95 lstm_params_info, activation_info, cell_threshold, projection_threshold));
Jenkinsb3a371b2018-05-23 11:36:53 +010096
Jenkins52ba29e2018-08-29 15:32:11 +000097 const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
Jenkins52ba29e2018-08-29 15:32:11 +000098 // Configure block that calculates the forget gate
99 // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
Jenkins4ba87db2019-05-23 17:11:51 +0100100 // We optimize this as follows:
101 // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias
Jenkinsb3a371b2018-05-23 11:36:53 +0100102 _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Jenkinsb3a371b2018-05-23 11:36:53 +0100103 _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Jenkins52ba29e2018-08-29 15:32:11 +0000104 _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Jenkinsb3a371b2018-05-23 11:36:53 +0100105
Jenkins4ba87db2019-05-23 17:11:51 +0100106 std::vector<const ICLTensor *> inputs_vector;
107 inputs_vector.emplace_back(input);
108 inputs_vector.emplace_back(output_state_in);
109 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
110 _forget_gate_out2.allocator()->init(TensorInfo(concat_shape, 1, input->info()->data_type()));
111
Jenkinsb3a371b2018-05-23 11:36:53 +0100112 _memory_group.manage(&_forget_gate_out2);
Jenkins18b685f2020-08-21 10:26:22 +0100113 _concat_inputs_forget_gate.configure(compile_context, inputs_vector, &_forget_gate_out2, Window::DimX);
Jenkins4ba87db2019-05-23 17:11:51 +0100114
115 std::vector<const ICLTensor *> weights_vector;
116
117 weights_vector.emplace_back(input_to_forget_weights);
118 weights_vector.emplace_back(recurrent_to_forget_weights);
119 const TensorShape weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(weights_vector, 0);
120 _forget_gate_out6.allocator()->init(TensorInfo(weights_concat_shape, 1, input->info()->data_type()));
121
Jenkins18b685f2020-08-21 10:26:22 +0100122 _concat_weights_forget_gate.configure(compile_context, weights_vector, &_forget_gate_out6, Window::DimX);
Jenkins4ba87db2019-05-23 17:11:51 +0100123
Jenkins52ba29e2018-08-29 15:32:11 +0000124 _memory_group.manage(&_forget_gate_out5);
Jenkins6a7771e2020-05-28 11:28:36 +0100125 _fully_connected_forget_gate.configure(compile_context, &_forget_gate_out2, &_forget_gate_out6, (_is_layer_norm_lstm) ? nullptr : forget_gate_bias, &_forget_gate_out5);
Jenkins4ba87db2019-05-23 17:11:51 +0100126 _memory_group.manage(&_forget_gate_out1);
127 _memory_group.manage(&_forget_gate_out3);
128 _forget_gate_out6.allocator()->allocate();
129
Jenkins52ba29e2018-08-29 15:32:11 +0000130 CLTensor *forget_gate_out = &_forget_gate_out5;
Jenkinsb3a371b2018-05-23 11:36:53 +0100131 if(lstm_params.has_peephole_opt())
132 {
Jenkins52ba29e2018-08-29 15:32:11 +0000133 _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Jenkinsb3a371b2018-05-23 11:36:53 +0100134
135 _run_peephole_opt = true;
136 _memory_group.manage(&_forget_gate_out4);
Jenkins6a7771e2020-05-28 11:28:36 +0100137 _pixelwise_mul_forget_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
138 _accum_forget_gate1.configure(compile_context, &_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3, ConvertPolicy::SATURATE);
Jenkinsb3a371b2018-05-23 11:36:53 +0100139 _forget_gate_out4.allocator()->allocate();
Jenkinsb3a371b2018-05-23 11:36:53 +0100140 _forget_gate_out5.allocator()->allocate();
Jenkinsb3a371b2018-05-23 11:36:53 +0100141 forget_gate_out = &_forget_gate_out3;
142 }
143 else
144 {
145 _forget_gate_out3.allocator()->allocate();
146 }
Jenkins975dfe12019-09-02 11:47:54 +0100147 if(_is_layer_norm_lstm)
148 {
149 _forget_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
150 _forget_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
151 _memory_group.manage(&_forget_layer_norm_out1);
152 _memory_group.manage(&_forget_layer_norm_out2);
Jenkins6a7771e2020-05-28 11:28:36 +0100153 _mean_std_norm_forget_gate.configure(compile_context, forget_gate_out);
154 _pixelwise_mul_forget_gate_coeff.configure(compile_context, forget_gate_out, lstm_params.forget_layer_norm_weights(), &_forget_layer_norm_out1, 1, ConvertPolicy::SATURATE,
155 RoundingPolicy::TO_NEAREST_EVEN);
Jenkins975dfe12019-09-02 11:47:54 +0100156 // forget_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
157 forget_gate_out->allocator()->allocate();
Jenkins18b685f2020-08-21 10:26:22 +0100158 _accum_forget_gate_bias.configure(compile_context, &_forget_layer_norm_out1, forget_gate_bias, &_forget_layer_norm_out2, ConvertPolicy::SATURATE);
Jenkins975dfe12019-09-02 11:47:54 +0100159 _forget_layer_norm_out1.allocator()->allocate();
160 forget_gate_out = &_forget_layer_norm_out2;
161 }
Jenkins6a7771e2020-05-28 11:28:36 +0100162 _activation_forget_gate.configure(compile_context, forget_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Jenkinsb3a371b2018-05-23 11:36:53 +0100163
Jenkinsb3a371b2018-05-23 11:36:53 +0100164 // Configure block that calculates the input gate
Jenkins52ba29e2018-08-29 15:32:11 +0000165 // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Jenkinsb3a371b2018-05-23 11:36:53 +0100166 // input_gate = 1 - forget_gate, with CIFG
Jenkins4ba87db2019-05-23 17:11:51 +0100167 // We optimize this as follows:
168 // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Jenkins52ba29e2018-08-29 15:32:11 +0000169 _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Jenkins514be652019-02-28 12:25:18 +0000170 CLTensor *input_gate_out = &_input_gate_out1;
Jenkinsb3a371b2018-05-23 11:36:53 +0100171 if(lstm_params.has_cifg_opt())
172 {
173 _memory_group.manage(&_input_gate_out1);
174 _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Jenkins6a7771e2020-05-28 11:28:36 +0100175 _ones_memset_kernel.configure(compile_context, &_ones, PixelValue(1, _ones.info()->data_type()));
Jenkins18b685f2020-08-21 10:26:22 +0100176 _subtract_input_gate.configure(compile_context, &_ones, forget_gate_out, &_input_gate_out1, ConvertPolicy::SATURATE);
Jenkinsb3a371b2018-05-23 11:36:53 +0100177 _ones.allocator()->allocate();
178 _run_cifg_opt = true;
179 }
180 else
181 {
Jenkinsb3a371b2018-05-23 11:36:53 +0100182 _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Jenkins52ba29e2018-08-29 15:32:11 +0000183 _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Jenkins4ba87db2019-05-23 17:11:51 +0100184
185 std::vector<const ICLTensor *> lstm_weights;
186 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
187 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
188 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
189 _input_gate_out2.allocator()->init(TensorInfo(lstm_weights_concat_shape, 1, input->info()->data_type()));
190
Jenkins18b685f2020-08-21 10:26:22 +0100191 _concat_weights_input_gate.configure(compile_context, lstm_weights, &_input_gate_out2, Window::DimX);
Jenkinsb3a371b2018-05-23 11:36:53 +0100192
193 _memory_group.manage(&_input_gate_out1);
Jenkins4ba87db2019-05-23 17:11:51 +0100194
Jenkinsb3a371b2018-05-23 11:36:53 +0100195 _memory_group.manage(&_input_gate_out3);
Jenkins6a7771e2020-05-28 11:28:36 +0100196 _fully_connected_input_gate.configure(compile_context, &_forget_gate_out2, &_input_gate_out2, (_is_layer_norm_lstm) ? nullptr : lstm_params.input_gate_bias(), &_input_gate_out3);
Jenkinsb3a371b2018-05-23 11:36:53 +0100197 _input_gate_out2.allocator()->allocate();
Jenkins4ba87db2019-05-23 17:11:51 +0100198
199 input_gate_out = &_input_gate_out3;
Jenkins52ba29e2018-08-29 15:32:11 +0000200 if(_run_peephole_opt)
201 {
Jenkins4ba87db2019-05-23 17:11:51 +0100202 _memory_group.manage(&_input_gate_out4);
Jenkins6a7771e2020-05-28 11:28:36 +0100203 _pixelwise_mul_input_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
204 _accum_input_gate1.configure(compile_context, &_input_gate_out3, &_input_gate_out4, &_input_gate_out1, ConvertPolicy::SATURATE);
Jenkins4ba87db2019-05-23 17:11:51 +0100205 _input_gate_out3.allocator()->allocate();
Jenkins514be652019-02-28 12:25:18 +0000206 _input_gate_out4.allocator()->allocate();
Jenkins514be652019-02-28 12:25:18 +0000207 input_gate_out = &_input_gate_out1;
Jenkins52ba29e2018-08-29 15:32:11 +0000208 }
Jenkins514be652019-02-28 12:25:18 +0000209 else
210 {
211 _input_gate_out1.allocator()->allocate();
212 }
Jenkins975dfe12019-09-02 11:47:54 +0100213
214 if(_is_layer_norm_lstm)
215 {
216 _input_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
217 _input_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
218 _memory_group.manage(&_input_layer_norm_out1);
219 _memory_group.manage(&_input_layer_norm_out2);
Jenkins6a7771e2020-05-28 11:28:36 +0100220 _mean_std_norm_input_gate.configure(compile_context, input_gate_out);
221 _pixelwise_mul_input_gate_coeff.configure(compile_context, input_gate_out, lstm_params.input_layer_norm_weights(), &_input_layer_norm_out1, 1, ConvertPolicy::SATURATE,
222 RoundingPolicy::TO_NEAREST_EVEN);
Jenkins975dfe12019-09-02 11:47:54 +0100223 // input_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
224 input_gate_out->allocator()->allocate();
Jenkins18b685f2020-08-21 10:26:22 +0100225 _accum_input_gate_bias.configure(compile_context, &_input_layer_norm_out1, lstm_params.input_gate_bias(), &_input_layer_norm_out2, ConvertPolicy::SATURATE);
Jenkins975dfe12019-09-02 11:47:54 +0100226 _input_layer_norm_out1.allocator()->allocate();
227 input_gate_out = &_input_layer_norm_out2;
228 }
Jenkins6a7771e2020-05-28 11:28:36 +0100229 _activation_input_gate.configure(compile_context, input_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Jenkinsb3a371b2018-05-23 11:36:53 +0100230 }
231
Jenkins52ba29e2018-08-29 15:32:11 +0000232 // Configure block that calculates the cell state
233 // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
Jenkinsb3a371b2018-05-23 11:36:53 +0100234 TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
235 _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
236 _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
237 _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
238 _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
239 _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
240
Jenkinsb3a371b2018-05-23 11:36:53 +0100241 _memory_group.manage(&_cell_state_out1);
Jenkins6a7771e2020-05-28 11:28:36 +0100242 _fully_connected_cell_state.configure(compile_context, input, input_to_cell_weights, (_is_layer_norm_lstm) ? nullptr : cell_bias, &_cell_state_out1);
Jenkinsb3a371b2018-05-23 11:36:53 +0100243 _memory_group.manage(&_cell_state_out2);
Jenkins6a7771e2020-05-28 11:28:36 +0100244 _transpose_cell_state.configure(compile_context, recurrent_to_cell_weights, &_cell_state_out2);
Jenkinsb3a371b2018-05-23 11:36:53 +0100245 _memory_group.manage(&_cell_state_out3);
Jenkins6a7771e2020-05-28 11:28:36 +0100246 _gemm_cell_state1.configure(compile_context, output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
Jenkinsb3a371b2018-05-23 11:36:53 +0100247 _cell_state_out2.allocator()->allocate();
248 _memory_group.manage(&_cell_state_out4);
Jenkins18b685f2020-08-21 10:26:22 +0100249 _accum_cell_state1.configure(compile_context, &_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
Jenkins975dfe12019-09-02 11:47:54 +0100250 CLTensor *cell_state_out_ptr = &_cell_state_out4;
251 if(_is_layer_norm_lstm)
252 {
253 _cell_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
254 _cell_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
255 _memory_group.manage(&_cell_layer_norm_out1);
256 _memory_group.manage(&_cell_layer_norm_out2);
Jenkins6a7771e2020-05-28 11:28:36 +0100257 _mean_std_norm_cell_gate.configure(compile_context, cell_state_out_ptr);
258 _pixelwise_mul_cell_gate_coeff.configure(compile_context, cell_state_out_ptr, lstm_params.cell_layer_norm_weights(), &_cell_layer_norm_out1, 1, ConvertPolicy::SATURATE,
259 RoundingPolicy::TO_NEAREST_EVEN);
Jenkins975dfe12019-09-02 11:47:54 +0100260 // cell_state_out_ptr is going to be reassigned, so allocate the tensor that it was assigned to before
261 cell_state_out_ptr->allocator()->allocate();
Jenkins18b685f2020-08-21 10:26:22 +0100262 _accum_cell_gate_bias.configure(compile_context, &_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2, ConvertPolicy::SATURATE);
Jenkins975dfe12019-09-02 11:47:54 +0100263 _cell_layer_norm_out1.allocator()->allocate();
264 cell_state_out_ptr = &_cell_layer_norm_out2;
265 }
Jenkins6a7771e2020-05-28 11:28:36 +0100266 _activation_cell_state.configure(compile_context, cell_state_out_ptr, nullptr, activation_info);
Jenkinsb3a371b2018-05-23 11:36:53 +0100267 _memory_group.manage(&_cell_state_out5);
Jenkins6a7771e2020-05-28 11:28:36 +0100268 _pixelwise_mul_cell_state1.configure(compile_context, cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Jenkins975dfe12019-09-02 11:47:54 +0100269 cell_state_out_ptr->allocator()->allocate();
Jenkins6a7771e2020-05-28 11:28:36 +0100270 _pixelwise_mul_cell_state2.configure(compile_context, forget_gate_out, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Jenkins18b685f2020-08-21 10:26:22 +0100271 _accum_cell_state2.configure(compile_context, &_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
Jenkinsb3a371b2018-05-23 11:36:53 +0100272 _cell_state_out3.allocator()->allocate();
273 _cell_state_out5.allocator()->allocate();
Jenkinsb3a371b2018-05-23 11:36:53 +0100274 // Perform clipping
275 if(cell_threshold != 0.f)
276 {
277 _perform_cell_clipping = true;
Jenkins6a7771e2020-05-28 11:28:36 +0100278 _cell_clip.configure(compile_context, &_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
Jenkinsb3a371b2018-05-23 11:36:53 +0100279 }
280
Jenkins52ba29e2018-08-29 15:32:11 +0000281 // Configure block that calculates the output
282 // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
Jenkins4ba87db2019-05-23 17:11:51 +0100283 // We optimize this as follows:
284 // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
Jenkinsb3a371b2018-05-23 11:36:53 +0100285 _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Jenkins4ba87db2019-05-23 17:11:51 +0100286 _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
287 std::vector<const ICLTensor *> in_out_weights;
288 in_out_weights.emplace_back(input_to_output_weights);
289 in_out_weights.emplace_back(recurrent_to_output_weights);
290 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
291 _output2.allocator()->init(TensorInfo(in_out_weights_concat_shape, 1, input->info()->data_type()));
292
Jenkins18b685f2020-08-21 10:26:22 +0100293 _concat_weights_output.configure(compile_context, in_out_weights, &_output2, Window::DimX);
Jenkinsb3a371b2018-05-23 11:36:53 +0100294
Jenkinsb3a371b2018-05-23 11:36:53 +0100295 _memory_group.manage(&_output1);
Jenkins4ba87db2019-05-23 17:11:51 +0100296 _memory_group.manage(&_output4);
297
Jenkins6a7771e2020-05-28 11:28:36 +0100298 _fully_connected_output.configure(compile_context, &_forget_gate_out2, &_output2, (_is_layer_norm_lstm) ? nullptr : output_gate_bias, &_output4);
Jenkins4ba87db2019-05-23 17:11:51 +0100299
Jenkinsb3a371b2018-05-23 11:36:53 +0100300 _output2.allocator()->allocate();
Jenkins4ba87db2019-05-23 17:11:51 +0100301 _forget_gate_out2.allocator()->allocate();
302
303 CLTensor *output_gate_out = &_output4;
Jenkinsb3a371b2018-05-23 11:36:53 +0100304 if(lstm_params.has_peephole_opt())
305 {
Jenkins4ba87db2019-05-23 17:11:51 +0100306 _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
Jenkinsb3a371b2018-05-23 11:36:53 +0100307
Jenkins4ba87db2019-05-23 17:11:51 +0100308 _memory_group.manage(&_output3);
Jenkins6a7771e2020-05-28 11:28:36 +0100309 _pixelwise_mul_output_state1.configure(compile_context, &_cell_state_out1, lstm_params.cell_to_output_weights(), &_output3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
310 _accum_output1.configure(compile_context, &_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
Jenkins4ba87db2019-05-23 17:11:51 +0100311 _output4.allocator()->allocate();
Jenkinsb3a371b2018-05-23 11:36:53 +0100312 output_gate_out = &_output1;
313
314 // Allocate intermediate buffers
Jenkins4ba87db2019-05-23 17:11:51 +0100315 _output3.allocator()->allocate();
Jenkinsb3a371b2018-05-23 11:36:53 +0100316 }
317 else
318 {
319 _output1.allocator()->allocate();
320 }
Jenkins975dfe12019-09-02 11:47:54 +0100321 if(_is_layer_norm_lstm)
322 {
323 _output_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
324 _output_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
325 _memory_group.manage(&_output_layer_norm_out1);
326 _memory_group.manage(&_output_layer_norm_out2);
Jenkins6a7771e2020-05-28 11:28:36 +0100327 _mean_std_norm_output_gate.configure(compile_context, output_gate_out);
328 _pixelwise_mul_output_gate_coeff.configure(compile_context, output_gate_out, lstm_params.output_layer_norm_weights(), &_output_layer_norm_out1, 1, ConvertPolicy::SATURATE,
329 RoundingPolicy::TO_NEAREST_EVEN);
Jenkins975dfe12019-09-02 11:47:54 +0100330 // output_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
331 output_gate_out->allocator()->allocate();
Jenkins18b685f2020-08-21 10:26:22 +0100332 _accum_output_gate_bias.configure(compile_context, &_output_layer_norm_out1, output_gate_bias, &_output_layer_norm_out2, ConvertPolicy::SATURATE);
Jenkins975dfe12019-09-02 11:47:54 +0100333 _output_layer_norm_out1.allocator()->allocate();
334 output_gate_out = &_output_layer_norm_out2;
335 }
Jenkins6a7771e2020-05-28 11:28:36 +0100336 _activation_output.configure(compile_context, output_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Jenkinsb3a371b2018-05-23 11:36:53 +0100337
Jenkinsb3a371b2018-05-23 11:36:53 +0100338 // Configure block that calculates the output state
339 /** lstm_res = PixelwiseMul(output, Activation(cell_state))
340 *
341 * -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
342 * /
343 * output_state = --
344 * \
345 * -- lstm_res , otherwise
346 */
Jenkins52ba29e2018-08-29 15:32:11 +0000347 ICLTensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
348 _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
349 _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
350
Jenkinsb3a371b2018-05-23 11:36:53 +0100351 _memory_group.manage(&_cell_state_activation);
Jenkins6a7771e2020-05-28 11:28:36 +0100352 _activation_output_state.configure(compile_context, &_cell_state_out1, &_cell_state_activation, activation_info);
353 _pixelwise_mul_output_state2.configure(compile_context, &_cell_state_activation, output_gate_out, output_state_out_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Jenkinsb3a371b2018-05-23 11:36:53 +0100354 _cell_state_activation.allocator()->allocate();
355
356 if(lstm_params.has_projection())
357 {
358 _has_projection_weights = true;
Jenkins6a7771e2020-05-28 11:28:36 +0100359 _fully_connected_output_state.configure(compile_context, output_state_out_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out);
Jenkins52ba29e2018-08-29 15:32:11 +0000360 _output_state1.allocator()->allocate();
Jenkinsb3a371b2018-05-23 11:36:53 +0100361 // Perform clipping
362 if(projection_threshold != 0.f)
363 {
364 _perform_projection_clipping = true;
Jenkins6a7771e2020-05-28 11:28:36 +0100365 _projection_clip.configure(compile_context, output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
Jenkinsb3a371b2018-05-23 11:36:53 +0100366 }
Jenkinsb3a371b2018-05-23 11:36:53 +0100367 }
368
369 // Copy cell state and output
Jenkins6a7771e2020-05-28 11:28:36 +0100370 _copy_cell_state.configure(compile_context, &_cell_state_out1, cell_state_out);
371 _copy_output.configure(compile_context, output_state_out, output);
Jenkinsb3a371b2018-05-23 11:36:53 +0100372
373 // Vector for holding the tensors to store in scratch buffer
Jenkins18b685f2020-08-21 10:26:22 +0100374 std::vector<const ICLTensor *> scratch_inputs;
Jenkinsb9abeae2018-11-22 11:58:08 +0000375 if(!lstm_params.has_cifg_opt())
Jenkinsb3a371b2018-05-23 11:36:53 +0100376 {
Jenkins514be652019-02-28 12:25:18 +0000377 scratch_inputs.emplace_back(input_gate_out);
Jenkinsb3a371b2018-05-23 11:36:53 +0100378 }
379 scratch_inputs.emplace_back(&_cell_state_out1);
380 scratch_inputs.emplace_back(forget_gate_out);
381 scratch_inputs.emplace_back(output_gate_out);
Jenkins6a7771e2020-05-28 11:28:36 +0100382 _concat_scratch_buffer.configure(compile_context, scratch_inputs, scratch_buffer, Window::DimX);
Jenkins514be652019-02-28 12:25:18 +0000383 input_gate_out->allocator()->allocate();
Jenkinsb9abeae2018-11-22 11:58:08 +0000384 _cell_state_out1.allocator()->allocate();
385 forget_gate_out->allocator()->allocate();
386 output_gate_out->allocator()->allocate();
Jenkinsb3a371b2018-05-23 11:36:53 +0100387}
388
Jenkins52ba29e2018-08-29 15:32:11 +0000389Status CLLSTMLayer::validate(const ITensorInfo *input,
390 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
Jenkinsb3a371b2018-05-23 11:36:53 +0100391 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
392 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
Jenkins52ba29e2018-08-29 15:32:11 +0000393 const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
394 const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
Jenkinsb3a371b2018-05-23 11:36:53 +0100395 const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
396{
Jenkins52ba29e2018-08-29 15:32:11 +0000397 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input,
398 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
399 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
400 forget_gate_bias, cell_bias, output_gate_bias,
401 output_state_in, cell_state_in,
402 scratch_buffer, output_state_out, cell_state_out, output);
Jenkinsb3a371b2018-05-23 11:36:53 +0100403
Jenkins52ba29e2018-08-29 15:32:11 +0000404 // Check data types
405 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
406 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input,
407 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
408 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
409 forget_gate_bias, cell_bias, output_gate_bias,
410 output_state_in, cell_state_in,
411 scratch_buffer, output_state_out, cell_state_out, output);
412
413 // Check dimensions
414 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
415 ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
416 ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
417 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
418 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
419 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
420 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
421 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
422 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
423 ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
424 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
425 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
426 ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
427 ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
428 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
429 ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
430 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0)
431 && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
432
433 const unsigned int num_batches = input->dimension(1);
434 const unsigned int num_cells = input_to_output_weights->dimension(1);
435
Jenkins975dfe12019-09-02 11:47:54 +0100436 if(lstm_params.use_layer_norm())
437 {
438 // If CIFG is used, input layer normalization weights tensor is omitted
439 if(lstm_params.has_cifg_opt())
440 {
441 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights() != nullptr);
442 }
443 else
444 {
445 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
446 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
Jenkins6a7771e2020-05-28 11:28:36 +0100447 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_cells);
Jenkins975dfe12019-09-02 11:47:54 +0100448 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
449 }
450
451 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
452 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
453 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
454 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
455 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
Jenkins6a7771e2020-05-28 11:28:36 +0100456 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_cells);
457 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_cells);
458 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_cells);
Jenkins975dfe12019-09-02 11:47:54 +0100459 }
460
Jenkins52ba29e2018-08-29 15:32:11 +0000461 // Check peephole optimization
Jenkinsb3a371b2018-05-23 11:36:53 +0100462 if(lstm_params.has_peephole_opt())
463 {
Jenkins52ba29e2018-08-29 15:32:11 +0000464 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
465 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
466 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
Jenkinsb3a371b2018-05-23 11:36:53 +0100467 }
468
469 TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
Jenkinsb3a371b2018-05-23 11:36:53 +0100470 TensorShape num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
471 const TensorInfo units_out_transposed_info = TensorInfo(units_out_transposed_shape, 1, input->data_type());
Jenkinsb3a371b2018-05-23 11:36:53 +0100472 const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type());
473
Jenkins52ba29e2018-08-29 15:32:11 +0000474 TensorInfo input_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
475 TensorInfo forget_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
476 TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
477 TensorInfo cell_state_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
478
Jenkinsb3a371b2018-05-23 11:36:53 +0100479 // Validate forget gate
Jenkins975dfe12019-09-02 11:47:54 +0100480 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, (lstm_params.use_layer_norm()) ? nullptr : forget_gate_bias, &forget_gate));
Jenkins4ba87db2019-05-23 17:11:51 +0100481
482 std::vector<const ITensorInfo *> inputs_vector;
483 inputs_vector.emplace_back(input);
484 inputs_vector.emplace_back(output_state_in);
485 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
486 TensorInfo forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
487
Jenkins18b685f2020-08-21 10:26:22 +0100488 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector, &forget_gate_concat, Window::DimX));
Jenkins4ba87db2019-05-23 17:11:51 +0100489
Jenkinsb3a371b2018-05-23 11:36:53 +0100490 if(lstm_params.has_peephole_opt())
491 {
Jenkins18b685f2020-08-21 10:26:22 +0100492 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Jenkins52ba29e2018-08-29 15:32:11 +0000493 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Jenkinsb3a371b2018-05-23 11:36:53 +0100494 }
Jenkins975dfe12019-09-02 11:47:54 +0100495 if(lstm_params.use_layer_norm())
496 {
497 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&forget_gate));
Jenkins18b685f2020-08-21 10:26:22 +0100498 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
499 RoundingPolicy::TO_NEAREST_EVEN));
Jenkins975dfe12019-09-02 11:47:54 +0100500 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
501 }
Jenkins18b685f2020-08-21 10:26:22 +0100502 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Jenkinsb3a371b2018-05-23 11:36:53 +0100503
504 // Validate input gate
505 if(!lstm_params.has_cifg_opt())
506 {
Jenkins52ba29e2018-08-29 15:32:11 +0000507 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
508 lstm_params.recurrent_to_input_weights(),
509 lstm_params.input_gate_bias());
510 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
511 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
512 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
513
Jenkins4ba87db2019-05-23 17:11:51 +0100514 std::vector<const ITensorInfo *> lstm_weights;
515 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
516 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
517 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
518 TensorInfo lstm_gate_concat = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
Jenkins18b685f2020-08-21 10:26:22 +0100519 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(lstm_weights, &lstm_gate_concat, Window::DimX));
Jenkins4ba87db2019-05-23 17:11:51 +0100520
Jenkins975dfe12019-09-02 11:47:54 +0100521 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), (lstm_params.use_layer_norm()) ? nullptr : lstm_params.input_gate_bias(), &input_gate));
Jenkins4ba87db2019-05-23 17:11:51 +0100522
Jenkins52ba29e2018-08-29 15:32:11 +0000523 if(lstm_params.has_peephole_opt())
524 {
525 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
526 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
Jenkins18b685f2020-08-21 10:26:22 +0100527 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Jenkins52ba29e2018-08-29 15:32:11 +0000528 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
529 }
Jenkins975dfe12019-09-02 11:47:54 +0100530
531 if(lstm_params.use_layer_norm())
532 {
533 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&input_gate));
Jenkins18b685f2020-08-21 10:26:22 +0100534 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Jenkins975dfe12019-09-02 11:47:54 +0100535 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
536 }
Jenkins18b685f2020-08-21 10:26:22 +0100537 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Jenkinsb3a371b2018-05-23 11:36:53 +0100538 }
539 else
540 {
Jenkins18b685f2020-08-21 10:26:22 +0100541 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtraction::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Jenkinsb3a371b2018-05-23 11:36:53 +0100542 }
543
544 // Validate cell state
Jenkins975dfe12019-09-02 11:47:54 +0100545 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, (lstm_params.use_layer_norm()) ? nullptr : cell_bias, &cell_state_tmp));
Jenkins52ba29e2018-08-29 15:32:11 +0000546 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
547 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
Jenkins975dfe12019-09-02 11:47:54 +0100548 if(lstm_params.use_layer_norm())
549 {
550 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
Jenkins18b685f2020-08-21 10:26:22 +0100551 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
552 RoundingPolicy::TO_NEAREST_EVEN));
Jenkins975dfe12019-09-02 11:47:54 +0100553 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
554 }
Jenkins18b685f2020-08-21 10:26:22 +0100555 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
556 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
557 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Jenkins52ba29e2018-08-29 15:32:11 +0000558 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
Jenkinsb3a371b2018-05-23 11:36:53 +0100559 if(cell_threshold != 0.f)
560 {
Jenkins18b685f2020-08-21 10:26:22 +0100561 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
562 cell_threshold)));
Jenkinsb3a371b2018-05-23 11:36:53 +0100563 }
564
Jenkins4ba87db2019-05-23 17:11:51 +0100565 std::vector<const ITensorInfo *> in_out_weights;
566 in_out_weights.emplace_back(input_to_output_weights);
567 in_out_weights.emplace_back(recurrent_to_output_weights);
568 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
569 TensorInfo in_out_gate_concat = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
Jenkins18b685f2020-08-21 10:26:22 +0100570 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(in_out_weights, &in_out_gate_concat, Window::DimX));
Jenkins52ba29e2018-08-29 15:32:11 +0000571 // Validate output gate tmp
Jenkins975dfe12019-09-02 11:47:54 +0100572 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, (lstm_params.use_layer_norm()) ? nullptr : output_gate_bias, &output_gate_tmp));
Jenkins4ba87db2019-05-23 17:11:51 +0100573
Jenkinsb3a371b2018-05-23 11:36:53 +0100574 if(lstm_params.has_peephole_opt())
575 {
Jenkins18b685f2020-08-21 10:26:22 +0100576 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
577 RoundingPolicy::TO_NEAREST_EVEN));
Jenkins52ba29e2018-08-29 15:32:11 +0000578 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
Jenkinsb3a371b2018-05-23 11:36:53 +0100579 }
Jenkins975dfe12019-09-02 11:47:54 +0100580 if(lstm_params.use_layer_norm())
581 {
582 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
Jenkins18b685f2020-08-21 10:26:22 +0100583 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
584 RoundingPolicy::TO_NEAREST_EVEN));
Jenkins975dfe12019-09-02 11:47:54 +0100585 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
586 }
Jenkins18b685f2020-08-21 10:26:22 +0100587 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Jenkinsb3a371b2018-05-23 11:36:53 +0100588
589 // Validate output state
Jenkins18b685f2020-08-21 10:26:22 +0100590 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
591 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Jenkinsb3a371b2018-05-23 11:36:53 +0100592 if(lstm_params.has_projection())
593 {
Jenkins52ba29e2018-08-29 15:32:11 +0000594 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
Jenkinsb3a371b2018-05-23 11:36:53 +0100595 if(projection_threshold != 0.f)
596 {
Jenkins18b685f2020-08-21 10:26:22 +0100597 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output_state_out, output_state_out,
598 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
Jenkinsb3a371b2018-05-23 11:36:53 +0100599 }
600 }
601
Jenkins52ba29e2018-08-29 15:32:11 +0000602 // Validate copy kernel
603 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(&cell_state_tmp, cell_state_out));
604 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
605
606 // Validate scratch concatenation
Jenkins18b685f2020-08-21 10:26:22 +0100607 std::vector<const ITensorInfo *> inputs_vector_info_raw;
Jenkinsb9abeae2018-11-22 11:58:08 +0000608 if(!lstm_params.has_cifg_opt())
Jenkinsb3a371b2018-05-23 11:36:53 +0100609 {
Jenkins52ba29e2018-08-29 15:32:11 +0000610 inputs_vector_info_raw.push_back(&input_gate);
Jenkinsb3a371b2018-05-23 11:36:53 +0100611 }
Jenkins52ba29e2018-08-29 15:32:11 +0000612 inputs_vector_info_raw.push_back(&cell_state_tmp);
613 inputs_vector_info_raw.push_back(&forget_gate);
614 inputs_vector_info_raw.push_back(&output_gate_tmp);
Jenkinsb3a371b2018-05-23 11:36:53 +0100615
Jenkins975dfe12019-09-02 11:47:54 +0100616 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
Jenkinsb3a371b2018-05-23 11:36:53 +0100617 return Status{};
618}
619
620void CLLSTMLayer::run()
621{
Jenkins4ba87db2019-05-23 17:11:51 +0100622 prepare();
623
624 MemoryGroupResourceScope scope_mg(_memory_group);
625
Jenkins18b685f2020-08-21 10:26:22 +0100626 _concat_inputs_forget_gate.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100627
628 _fully_connected_forget_gate.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100629
630 if(_run_peephole_opt)
631 {
Jenkins18b685f2020-08-21 10:26:22 +0100632 _pixelwise_mul_forget_gate.run();
Jenkins975dfe12019-09-02 11:47:54 +0100633 _accum_forget_gate1.run();
634 }
635 if(_is_layer_norm_lstm)
636 {
637 _mean_std_norm_forget_gate.run();
Jenkins18b685f2020-08-21 10:26:22 +0100638 _pixelwise_mul_forget_gate_coeff.run();
639 _accum_forget_gate_bias.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100640 }
Jenkins18b685f2020-08-21 10:26:22 +0100641 _activation_forget_gate.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100642
643 if(_run_cifg_opt)
644 {
Jenkins4ba87db2019-05-23 17:11:51 +0100645 CLScheduler::get().enqueue(_ones_memset_kernel);
Jenkins18b685f2020-08-21 10:26:22 +0100646 _subtract_input_gate.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100647 }
648 else
649 {
650 _fully_connected_input_gate.run();
Jenkins4ba87db2019-05-23 17:11:51 +0100651
Jenkins52ba29e2018-08-29 15:32:11 +0000652 if(_run_peephole_opt)
653 {
Jenkins18b685f2020-08-21 10:26:22 +0100654 _pixelwise_mul_input_gate.run();
Jenkins975dfe12019-09-02 11:47:54 +0100655 _accum_input_gate1.run();
656 }
657
658 if(_is_layer_norm_lstm)
659 {
660 _mean_std_norm_input_gate.run();
Jenkins18b685f2020-08-21 10:26:22 +0100661 _pixelwise_mul_input_gate_coeff.run();
662 _accum_input_gate_bias.run();
Jenkins52ba29e2018-08-29 15:32:11 +0000663 }
Jenkins18b685f2020-08-21 10:26:22 +0100664 _activation_input_gate.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100665 }
666
667 _fully_connected_cell_state.run();
Jenkins52ba29e2018-08-29 15:32:11 +0000668 CLScheduler::get().enqueue(_transpose_cell_state);
Jenkinsb3a371b2018-05-23 11:36:53 +0100669 _gemm_cell_state1.run();
Jenkins18b685f2020-08-21 10:26:22 +0100670 _accum_cell_state1.run();
Jenkins975dfe12019-09-02 11:47:54 +0100671 if(_is_layer_norm_lstm)
672 {
673 _mean_std_norm_cell_gate.run();
Jenkins18b685f2020-08-21 10:26:22 +0100674 _pixelwise_mul_cell_gate_coeff.run();
675 _accum_cell_gate_bias.run();
Jenkins975dfe12019-09-02 11:47:54 +0100676 }
Jenkins18b685f2020-08-21 10:26:22 +0100677 _activation_cell_state.run();
678 _pixelwise_mul_cell_state1.run();
679 _pixelwise_mul_cell_state2.run();
680 _accum_cell_state2.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100681
682 if(_perform_cell_clipping)
683 {
Jenkins18b685f2020-08-21 10:26:22 +0100684 _cell_clip.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100685 }
686
687 _fully_connected_output.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100688
689 if(_run_peephole_opt)
690 {
Jenkins18b685f2020-08-21 10:26:22 +0100691 _pixelwise_mul_output_state1.run();
Jenkins975dfe12019-09-02 11:47:54 +0100692 _accum_output1.run();
693 }
694 if(_is_layer_norm_lstm)
695 {
696 _mean_std_norm_output_gate.run();
Jenkins18b685f2020-08-21 10:26:22 +0100697 _pixelwise_mul_output_gate_coeff.run();
698 _accum_output_gate_bias.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100699 }
Jenkins18b685f2020-08-21 10:26:22 +0100700 _activation_output.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100701
Jenkins18b685f2020-08-21 10:26:22 +0100702 _activation_output_state.run();
703 _pixelwise_mul_output_state2.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100704
705 if(_has_projection_weights)
706 {
707 _fully_connected_output_state.run();
708 if(_perform_projection_clipping)
709 {
Jenkins18b685f2020-08-21 10:26:22 +0100710 _projection_clip.run();
Jenkinsb3a371b2018-05-23 11:36:53 +0100711 }
712 }
713
714 CLScheduler::get().enqueue(_copy_cell_state);
715 CLScheduler::get().enqueue(_copy_output);
716
717 _concat_scratch_buffer.run();
Jenkins4ba87db2019-05-23 17:11:51 +0100718}
Jenkinsb3a371b2018-05-23 11:36:53 +0100719
Jenkins4ba87db2019-05-23 17:11:51 +0100720void CLLSTMLayer::prepare()
721{
722 if(!_is_prepared)
723 {
Jenkins18b685f2020-08-21 10:26:22 +0100724 _concat_weights_forget_gate.run();
Jenkins4ba87db2019-05-23 17:11:51 +0100725 if(!_run_cifg_opt)
726 {
Jenkins18b685f2020-08-21 10:26:22 +0100727 _concat_weights_input_gate.run();
Jenkins4ba87db2019-05-23 17:11:51 +0100728 }
Jenkins18b685f2020-08-21 10:26:22 +0100729 _concat_weights_output.run();
Jenkins4ba87db2019-05-23 17:11:51 +0100730 _is_prepared = true;
731 }
Jenkins514be652019-02-28 12:25:18 +0000732}
Jenkins6a7771e2020-05-28 11:28:36 +0100733} // namespace arm_compute