arm_compute v18.11
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index f16d1c0..baa0cf4 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -44,8 +44,9 @@
{
bool flag = true;
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
+ if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
{
+ // COMPMID-852
if(k > 256 && m > 4 && is_data_type_float(data_type) && reshape_b_only_on_first_run)
{
constexpr float alpha = 3.2f;
@@ -71,8 +72,18 @@
} // namespace
CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _ma_kernel(), _tmp_a(), _tmp_b(), _original_b(nullptr), _is_interleaved_transposed(false),
- _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
+ : _memory_group(std::move(memory_manager)),
+ _interleave_kernel(),
+ _transpose_kernel(),
+ _mm_kernel(),
+ _ma_kernel(),
+ _tmp_a(),
+ _tmp_b(),
+ _original_b(nullptr),
+ _is_interleaved_transposed(false),
+ _run_addition(false),
+ _reshape_b_only_on_first_run(false),
+ _is_prepared(false)
{
}
@@ -122,10 +133,7 @@
if(_is_interleaved_transposed)
{
reinterpret_input_as_3d = false;
- }
- if(_is_interleaved_transposed)
- {
matrix_a = &_tmp_a;
matrix_b = &_tmp_b;
@@ -145,8 +153,10 @@
}
// Configure and tune matrix multiply kernel
- _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d,
- reinterpret_input_as_3d));
+ _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
+ mult_transpose1xW_width, mult_interleave4x4_height,
+ depth_output_gemm3d, reinterpret_input_as_3d),
+ gemm_info.fp_mixed_precision());
CLScheduler::get().tune_kernel_static(_mm_kernel);
if(_is_interleaved_transposed)
@@ -227,7 +237,7 @@
}
// Validate matrix multiply
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, alpha, run_interleave_transpose, reshape_info, gpu_target));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, alpha, run_interleave_transpose, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
if(beta != 0 && c != nullptr)
{