arm_compute v18.08
diff --git a/src/runtime/NEON/functions/NELocallyConnectedLayer.cpp b/src/runtime/NEON/functions/NELocallyConnectedLayer.cpp
index 913acf8..80a2541 100644
--- a/src/runtime/NEON/functions/NELocallyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NELocallyConnectedLayer.cpp
@@ -73,7 +73,7 @@
NELocallyConnectedLayer::NELocallyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(std::move(memory_manager)), _input_im2col_kernel(), _weights_reshape_kernel(), _mm_kernel(), _output_col2im_kernel(), _input_im2col_reshaped(), _weights_reshaped(), _gemm_output(),
- _is_first_run(false), _original_weights(nullptr)
+ _is_prepared(false), _original_weights(nullptr)
{
}
@@ -113,7 +113,7 @@
TensorInfo input_im2col_reshaped_info(shape_im2col, 1, input->data_type());
TensorInfo gemm_output_info(shape_gemm, 1, input->data_type());
- ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &input_im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, has_bias, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &input_im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, has_bias));
ARM_COMPUTE_RETURN_ON_ERROR(NEWeightsReshapeKernel::validate(weights, biases, &weights_reshaped_info));
ARM_COMPUTE_RETURN_ON_ERROR(NELocallyConnectedMatrixMultiplyKernel::validate(&input_im2col_reshaped_info, &weights_reshaped_info, &gemm_output_info));
ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h)));
@@ -127,7 +127,7 @@
ARM_COMPUTE_ERROR_THROW_ON(NELocallyConnectedLayer::validate(input->info(), weights->info(), biases == nullptr ? nullptr : biases->info(), output->info(), conv_info));
bool _has_bias = (biases != nullptr);
- _is_first_run = true;
+ _is_prepared = false;
_original_weights = weights;
const unsigned int kernel_width = weights->info()->dimension(0);
@@ -160,24 +160,13 @@
_output_col2im_kernel.configure(&_gemm_output, output, Size2D(conv_w, conv_h));
// Allocate intermediate tensors
- _weights_reshaped.allocator()->allocate();
_input_im2col_reshaped.allocator()->allocate();
_gemm_output.allocator()->allocate();
}
void NELocallyConnectedLayer::run()
{
- // Run weights reshaping (Runs once for every configure)
- if(_is_first_run)
- {
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
- _is_first_run = false;
- NEScheduler::get().schedule(&_weights_reshape_kernel, 3);
-
- // Mark original weights tensor as unused
- _original_weights->mark_as_unused();
- }
+ prepare();
_memory_group.acquire();
@@ -192,3 +181,18 @@
_memory_group.release();
}
+
+void NELocallyConnectedLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ // Run weights reshaping and mark original weights tensor as unused
+ _weights_reshaped.allocator()->allocate();
+ NEScheduler::get().schedule(&_weights_reshape_kernel, 3);
+ _original_weights->mark_as_unused();
+
+ _is_prepared = true;
+ }
+}