arm_compute v17.09
Change-Id: I4bf8f4e6e5f84ce0d5b6f5ba570d276879f42a81
diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
index 2a78c58..7505a2c 100644
--- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp
+++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
@@ -25,29 +25,34 @@
#include "arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/runtime/CL/CLMemoryGroup.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
using namespace arm_compute;
-CLSoftmaxLayer::CLSoftmaxLayer()
- : _max_kernel(), _shift_exp_sum_kernel(), _norm_kernel(), _max(), _sum(), _tmp()
+CLSoftmaxLayer::CLSoftmaxLayer(std::shared_ptr<IMemoryManager> memory_manager)
+ : _memory_group(std::move(memory_manager)), _max_kernel(), _shift_exp_sum_kernel(), _norm_kernel(), _max(), _sum(), _tmp()
{
}
void CLSoftmaxLayer::configure(const ICLTensor *input, ICLTensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
// Create intermediate tensors shapes
- _tmp.allocator()->init(TensorInfo(input->info()->tensor_shape(), input->info()->num_channels(), input->info()->data_type()));
+ _tmp.allocator()->init(TensorInfo(input->info()->tensor_shape(), input->info()->num_channels(), input->info()->data_type(), input->info()->fixed_point_position()));
TensorShape shape = input->info()->tensor_shape();
shape.set(0, 1);
- TensorInfo tensor_info_max_sum(shape, input->info()->num_channels(), input->info()->data_type());
+ TensorInfo tensor_info_max_sum(shape, input->info()->num_channels(), input->info()->data_type(), input->info()->fixed_point_position());
_max.allocator()->init(tensor_info_max_sum);
_sum.allocator()->init(tensor_info_max_sum);
+ // Manage intermediate buffers
+ _memory_group.manage(&_tmp);
+ _memory_group.manage(&_max);
+ _memory_group.manage(&_sum);
+
// Configure Kernels
_max_kernel.configure(input, &_max);
_shift_exp_sum_kernel.configure(input, &_max, &_tmp, &_sum);
@@ -61,7 +66,11 @@
void CLSoftmaxLayer::run()
{
+ _memory_group.acquire();
+
CLScheduler::get().enqueue(_max_kernel, false);
CLScheduler::get().enqueue(_shift_exp_sum_kernel, false);
CLScheduler::get().enqueue(_norm_kernel);
+
+ _memory_group.release();
}