arm_compute v18.08
diff --git a/src/runtime/CL/functions/CLRNNLayer.cpp b/src/runtime/CL/functions/CLRNNLayer.cpp
index 4843ba6..1809e6e 100644
--- a/src/runtime/CL/functions/CLRNNLayer.cpp
+++ b/src/runtime/CL/functions/CLRNNLayer.cpp
@@ -36,7 +36,8 @@
using namespace arm_compute::misc::shape_calculator;
CLRNNLayer::CLRNNLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected_kernel(), _copy_kernel(), _fully_connected_out(), _gemm_output(), _add_output()
+ : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected_kernel(), _copy_kernel(), _fully_connected_out(), _gemm_output(), _add_output(),
+ _is_prepared(false)
{
}
@@ -57,7 +58,7 @@
auto shape_info = TensorInfo(compute_rnn_shape(recurrent_weights, hidden_state->dimension(idx_height)), 1, input->data_type());
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, weights, bias, &shape_info, true, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, weights, bias, &shape_info));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(hidden_state, recurrent_weights, nullptr, &shape_info, 1.f, 0.f));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(&shape_info, &shape_info, &shape_info, ConvertPolicy::SATURATE));
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&shape_info, &shape_info, info));
@@ -74,12 +75,14 @@
const int idx_height = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT);
TensorShape shape = compute_rnn_shape(recurrent_weights->info(), hidden_state->info()->dimension(idx_height));
+ _is_prepared = false;
+
_fully_connected_out.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
_gemm_output.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
// Manage intermediate buffers and configure
_memory_group.manage(&_fully_connected_out);
- _fully_connected_kernel.configure(input, weights, bias, &_fully_connected_out, true, false);
+ _fully_connected_kernel.configure(input, weights, bias, &_fully_connected_out);
_memory_group.manage(&_gemm_output);
_gemm_state_f.configure(hidden_state, recurrent_weights, nullptr, &_gemm_output, 1.f, 0.f);
@@ -100,7 +103,10 @@
void CLRNNLayer::run()
{
+ prepare();
+
_memory_group.acquire();
+
_fully_connected_kernel.run();
_gemm_state_f.run();
CLScheduler::get().enqueue(_add_kernel);
@@ -108,5 +114,17 @@
// copy hidden out to output
CLScheduler::get().enqueue(_copy_kernel);
+
_memory_group.release();
+}
+
+void CLRNNLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ _fully_connected_kernel.prepare();
+ _gemm_state_f.prepare();
+
+ _is_prepared = true;
+ }
}
\ No newline at end of file