arm_compute v18.08
diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp
index cda29d6..9703b0f 100644
--- a/src/core/CL/CLHelpers.cpp
+++ b/src/core/CL/CLHelpers.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/core/Log.h"
#include "arm_compute/core/Types.h"
+#include <utility>
#include <vector>
namespace arm_compute
@@ -37,8 +38,6 @@
{
case DataType::U8:
return "uchar";
- case DataType::QS8:
- return "qs8";
case DataType::S8:
return "char";
case DataType::QASYMM8:
@@ -47,14 +46,10 @@
return "ushort";
case DataType::S16:
return "short";
- case DataType::QS16:
- return "qs16";
case DataType::U32:
return "uint";
case DataType::S32:
return "int";
- case DataType::QS32:
- return "qs32";
case DataType::U64:
return "ulong";
case DataType::S64:
@@ -74,13 +69,11 @@
switch(dt)
{
case DataType::U8:
- case DataType::QS8:
case DataType::S8:
case DataType::QASYMM8:
return "8";
case DataType::U16:
case DataType::S16:
- case DataType::QS16:
case DataType::F16:
return "16";
case DataType::U32:
@@ -98,20 +91,10 @@
std::string get_underlying_cl_type_from_data_type(const DataType &dt)
{
- switch(dt)
- {
- case DataType::QS8:
- return "char";
- case DataType::QS16:
- return "short";
- case DataType::QS32:
- return "int";
- default:
- return get_cl_type_from_data_type(dt);
- }
+ return get_cl_type_from_data_type(dt);
}
-GPUTarget get_target_from_device(cl::Device &device)
+GPUTarget get_target_from_device(const cl::Device &device)
{
// Query device name size
std::string device_name = device.getInfo<CL_DEVICE_NAME>();
@@ -129,6 +112,16 @@
return device_supports_extension(device, "cl_khr_fp16");
}
+bool dot8_supported(const cl::Device &device)
+{
+ return device_supports_extension(device, "cl_arm_integer_dot_product_int8");
+}
+
+bool dot8_acc_supported(const cl::Device &device)
+{
+ return device_supports_extension(device, "cl_arm_integer_dot_product_accumulate_int8");
+}
+
CLVersion get_cl_version(const cl::Device &device)
{
std::string version_str = device.getInfo<CL_DEVICE_VERSION>();
@@ -159,4 +152,47 @@
return (pos != std::string::npos);
}
+bool cl_winograd_convolution_layer_supported(const Size2D &output_tile, const Size2D &kernel_size, DataLayout data_layout)
+{
+ ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN);
+
+ using WinogradConfiguration = std::pair<std::pair<int, int>, std::pair<int, int>>;
+
+ std::vector<WinogradConfiguration> winograd_configs_nchw =
+ {
+ WinogradConfiguration(std::pair<int, int>(1, 2), std::pair<int, int>(1, 3)),
+ WinogradConfiguration(std::pair<int, int>(1, 4), std::pair<int, int>(1, 3)),
+ WinogradConfiguration(std::pair<int, int>(2, 1), std::pair<int, int>(3, 1)),
+ WinogradConfiguration(std::pair<int, int>(4, 1), std::pair<int, int>(3, 1)),
+ WinogradConfiguration(std::pair<int, int>(2, 2), std::pair<int, int>(3, 3)),
+ WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(3, 3)),
+ WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(5, 5)),
+ WinogradConfiguration(std::pair<int, int>(4, 1), std::pair<int, int>(5, 1)),
+ WinogradConfiguration(std::pair<int, int>(1, 4), std::pair<int, int>(1, 5))
+ };
+
+ std::vector<WinogradConfiguration> winograd_configs_nhwc =
+ {
+ WinogradConfiguration(std::pair<int, int>(2, 2), std::pair<int, int>(3, 3)),
+ WinogradConfiguration(std::pair<int, int>(1, 4), std::pair<int, int>(1, 3)),
+ WinogradConfiguration(std::pair<int, int>(4, 1), std::pair<int, int>(3, 1)),
+ WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(3, 3)),
+ WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(5, 5)),
+ WinogradConfiguration(std::pair<int, int>(4, 1), std::pair<int, int>(5, 1)),
+ WinogradConfiguration(std::pair<int, int>(1, 4), std::pair<int, int>(1, 5))
+ };
+
+ auto p = std::make_pair(std::pair<int, int>(output_tile.width, output_tile.height),
+ std::pair<int, int>(kernel_size.width, kernel_size.height));
+
+ // Return true if supported
+ if(data_layout == DataLayout::NCHW)
+ {
+ return (std::find(winograd_configs_nchw.begin(), winograd_configs_nchw.end(), p) != winograd_configs_nchw.end());
+ }
+ else
+ {
+ return (std::find(winograd_configs_nhwc.begin(), winograd_configs_nhwc.end(), p) != winograd_configs_nhwc.end());
+ }
+}
} // namespace arm_compute
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index bdb26f8..3c92257 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -149,8 +149,10 @@
{ "accumulate_weighted", "accumulate.cl" },
{ "activation_layer", "activation_layer.cl" },
{ "activation_layer_qa8", "activation_layer_qa8.cl" },
+ { "arithmetic_add_quantized", "arithmetic_op_quantized.cl" },
{ "arithmetic_add", "arithmetic_op.cl" },
{ "arithmetic_sub", "arithmetic_op.cl" },
+ { "arithmetic_div", "arithmetic_op.cl" },
{ "batchnormalization_layer_nchw", "batchnormalization_layer.cl" },
{ "batchnormalization_layer_nhwc", "batchnormalization_layer.cl" },
{ "bitwise_or", "bitwise_op.cl" },
@@ -195,9 +197,13 @@
{ "deconvolution_upsample", "deconvolution_layer.cl" },
{ "depthwise_convolution_3x3", "depthwise_convolution.cl" },
{ "depthwise_convolution_3x3_f16", "depthwise_convolution.cl" },
+ { "depthwise_convolution_3x3_nhwc", "depthwise_convolution.cl" },
+ { "depthwise_convolution_3x3_nhwc_stride1", "depthwise_convolution.cl" },
{ "depthwise_convolution_3x3_quantized_nchw", "depthwise_convolution_quantized.cl" },
+ { "depthwise_convolution_3x3_quantized_nhwc", "depthwise_convolution_quantized.cl" },
{ "depthwise_convolution_3x3_quantized_nhwc_stride1", "depthwise_convolution_quantized.cl" },
- { "depthwise_convolution_3x3_quantized_nhwc_stride2", "depthwise_convolution_quantized.cl" },
+ { "depthwise_convolution_3x3_quantized_dot8_nchw", "depthwise_convolution_quantized.cl" },
+ { "depthwise_convolution_3x3_quantized_dot8_nhwc_stride1", "depthwise_convolution_quantized.cl" },
{ "depthwise_convolution_3x3_stridex1_stridey1_bifrost_f16", "depthwise_convolution.cl" },
{ "depthwise_convolution_3x3_stridex2_stridey2_bifrost_f16", "depthwise_convolution.cl" },
{ "depthwise_convolution_3x3_stridex1_stridey1_bifrost_f32", "depthwise_convolution.cl" },
@@ -209,14 +215,18 @@
{ "derivative", "derivative.cl" },
{ "dilate", "dilate.cl" },
{ "direct_convolution1x1", "direct_convolution1x1.cl" },
+ { "direct_convolution1x1_nhwc", "direct_convolution1x1.cl" },
{ "direct_convolution1x1_f32_bifrost", "direct_convolution1x1.cl" },
{ "direct_convolution3x3", "direct_convolution3x3.cl" },
+ { "direct_convolution3x3_nhwc", "direct_convolution3x3.cl" },
{ "direct_convolution3x3_f32_bifrost", "direct_convolution3x3.cl" },
{ "direct_convolution5x5", "direct_convolution5x5.cl" },
+ { "direct_convolution5x5_nhwc", "direct_convolution5x5.cl" },
{ "direct_convolution5x5_f32_bifrost", "direct_convolution5x5.cl" },
{ "direct_convolution_1x1_3x3_5x5_quantized", "direct_convolution_1x1_3x3_5x5_quantized.cl" },
{ "erode", "erode.cl" },
{ "fast_corners", "fast_corners.cl" },
+ { "flatten", "flatten.cl" },
{ "fill_image_borders_constant", "fill_border.cl" },
{ "fill_image_borders_replicate", "fill_border.cl" },
{ "finalize", "optical_flow_pyramid_lk.cl" },
@@ -227,29 +237,25 @@
{ "gemm_interleave4x4", "gemm.cl" },
{ "gemm_ma_f16", "gemm.cl" },
{ "gemm_ma_f32", "gemm.cl" },
- { "gemm_ma_qs8", "gemm.cl" },
- { "gemm_ma_qs16", "gemm.cl" },
{ "gemm_mv", "gemv.cl" },
{ "gemm_mv_quantized", "gemv.cl" },
{ "gemm_mm_interleaved_transposed_f16", "gemm.cl" },
{ "gemm_mm_interleaved_transposed_f16_bifrost", "gemm.cl" },
{ "gemm_mm_interleaved_transposed_f32", "gemm.cl" },
{ "gemm_mm_interleaved_transposed_f32_bifrost", "gemm.cl" },
- { "gemm_mm_interleaved_transposed_qs8", "gemm.cl" },
- { "gemm_mm_interleaved_transposed_qs16", "gemm.cl" },
{ "gemm_mm_floating_point", "gemm.cl" },
{ "gemm_mm_floating_point_f16_bifrost", "gemm.cl" },
{ "gemm_mm_floating_point_f32_bifrost", "gemm.cl" },
{ "gemm_mm_floating_point_f32_bifrost_1000", "gemm.cl" },
- { "gemm_mm_qs8", "gemm.cl" },
- { "gemm_mm_qs16", "gemm.cl" },
{ "gemm_lc_vm_f32", "gemm.cl" },
{ "gemm_transpose1xW", "gemm.cl" },
{ "gemmlowp_matrix_a_reduction", "gemmlowp.cl" },
{ "gemmlowp_matrix_b_reduction", "gemmlowp.cl" },
{ "gemmlowp_mm_bifrost", "gemmlowp.cl" },
+ { "gemmlowp_mm_bifrost_dot8", "gemmlowp.cl" },
{ "gemmlowp_mm_midgard", "gemmlowp.cl" },
{ "gemmlowp_mm_interleaved_transposed_bifrost", "gemmlowp.cl" },
+ { "gemmlowp_mm_interleaved_transposed_bifrost_dot8", "gemmlowp.cl" },
{ "gemmlowp_mm_interleaved_transposed_midgard", "gemmlowp.cl" },
{ "gemmlowp_offset_contribution", "gemmlowp.cl" },
{ "gemmlowp_output_stage_quantize_down", "gemmlowp.cl" },
@@ -265,13 +271,14 @@
{ "hog_detector", "hog.cl" },
{ "hog_orientation_binning", "hog.cl" },
{ "hysteresis", "canny.cl" },
- { "im2col1x1_stridex1_dchw", "im2col.cl" },
- { "im2col3x3_dchw", "im2col.cl" },
- { "im2col5x5_dchw", "im2col.cl" },
- { "im2col11x11_padx0_pady0_dchw", "im2col.cl" },
- { "im2col_generic_dchw", "im2col.cl" },
- { "im2col_generic_padx0_pady0_dchw", "im2col.cl" },
- { "im2col_reduced_dchw", "im2col.cl" },
+ { "im2col1x1_stridex1_nchw", "im2col.cl" },
+ { "im2col3x3_nchw", "im2col.cl" },
+ { "im2col5x5_nchw", "im2col.cl" },
+ { "im2col11x11_padx0_pady0_nchw", "im2col.cl" },
+ { "im2col_generic_nchw", "im2col.cl" },
+ { "im2col_generic_padx0_pady0_nchw", "im2col.cl" },
+ { "im2col3x3_nhwc", "im2col.cl" },
+ { "im2col_generic_nhwc", "im2col.cl" },
{ "init_level", "optical_flow_pyramid_lk.cl" },
{ "init_level_max", "optical_flow_pyramid_lk.cl" },
{ "init_level_max_initial_estimate", "optical_flow_pyramid_lk.cl" },
@@ -336,8 +343,10 @@
{ "RGBA8888_to_RGB888_bt709", "color_convert.cl" },
{ "RGBA8888_to_YUV444_bt709", "color_convert.cl" },
{ "roi_pooling_layer", "roi_pooling_layer.cl" },
- { "scale_nearest_neighbour", "scale.cl" },
- { "scale_bilinear", "scale.cl" },
+ { "scale_nearest_neighbour_nchw", "scale.cl" },
+ { "scale_nearest_neighbour_nhwc", "scale.cl" },
+ { "scale_bilinear_nchw", "scale.cl" },
+ { "scale_bilinear_nhwc", "scale.cl" },
{ "scharr3x3", "scharr_filter.cl" },
{ "sobel3x3", "sobel_filter.cl" },
{ "sobel_separable5x1", "sobel_filter.cl" },
@@ -364,16 +373,54 @@
{ "warp_affine_bilinear", "warp_affine.cl" },
{ "warp_perspective_nearest_neighbour", "warp_perspective.cl" },
{ "warp_perspective_bilinear", "warp_perspective.cl" },
- { "winograd_filter_transform_2x2_3x3_nchw", "winograd.cl" },
- { "winograd_filter_transform_4x4_3x3_nchw", "winograd.cl" },
- { "winograd_filter_transform_4x4_5x5_nchw", "winograd.cl" },
- { "winograd_input_transform_4x4_5x5_stepz1_nchw", "winograd.cl" },
- { "winograd_input_transform_2x2_3x3_stepz1_nchw", "winograd.cl" },
- { "winograd_input_transform_2x2_3x3_stepz2_nchw", "winograd.cl" },
- { "winograd_input_transform_4x4_3x3_stepz1_nchw", "winograd.cl" },
- { "winograd_output_transform_2x2_3x3_nchw", "winograd.cl" },
- { "winograd_output_transform_4x4_3x3_nchw", "winograd.cl" },
- { "winograd_output_transform_4x4_5x5_nchw", "winograd.cl" },
+ { "winograd_filter_transform_2x2_3x3_nchw", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_2x1_3x1_nchw", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_1x2_1x3_nchw", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_4x4_3x3_nchw", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_4x1_3x1_nchw", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_1x4_1x3_nchw", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_4x4_5x5_nchw", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_4x1_5x1_nchw", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_1x4_1x5_nchw", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_4x1_3x1_nhwc", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_1x4_1x3_nhwc", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_4x4_3x3_nhwc", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_4x4_5x5_nhwc", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_4x1_5x1_nhwc", "winograd_filter_transform.cl" },
+ { "winograd_filter_transform_1x4_1x5_nhwc", "winograd_filter_transform.cl" },
+ { "winograd_input_transform_2x2_3x3_stepz1_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_2x2_3x3_stepz2_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_2x1_3x1_stepz1_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_2x1_3x1_stepz2_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_1x2_1x3_stepz1_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_1x2_1x3_stepz2_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_4x4_3x3_stepz1_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_4x1_3x1_stepz1_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_1x4_1x3_stepz1_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_4x4_5x5_stepz1_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_4x1_5x1_stepz1_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_1x4_1x5_stepz1_nchw", "winograd_input_transform.cl" },
+ { "winograd_input_transform_4x1_3x1_stepz1_nhwc", "winograd_input_transform.cl" },
+ { "winograd_input_transform_1x4_1x3_stepz1_nhwc", "winograd_input_transform.cl" },
+ { "winograd_input_transform_4x4_3x3_stepz1_nhwc", "winograd_input_transform.cl" },
+ { "winograd_input_transform_4x4_5x5_stepz1_nhwc", "winograd_input_transform.cl" },
+ { "winograd_input_transform_4x1_5x1_stepz1_nhwc", "winograd_input_transform.cl" },
+ { "winograd_input_transform_1x4_1x5_stepz1_nhwc", "winograd_input_transform.cl" },
+ { "winograd_output_transform_2x2_3x3_nchw", "winograd_output_transform.cl" },
+ { "winograd_output_transform_2x1_3x1_nchw", "winograd_output_transform.cl" },
+ { "winograd_output_transform_1x2_1x3_nchw", "winograd_output_transform.cl" },
+ { "winograd_output_transform_4x4_3x3_nchw", "winograd_output_transform.cl" },
+ { "winograd_output_transform_4x1_3x1_nchw", "winograd_output_transform.cl" },
+ { "winograd_output_transform_1x4_1x3_nchw", "winograd_output_transform.cl" },
+ { "winograd_output_transform_4x4_5x5_nchw", "winograd_output_transform.cl" },
+ { "winograd_output_transform_4x1_5x1_nchw", "winograd_output_transform.cl" },
+ { "winograd_output_transform_1x4_1x5_nchw", "winograd_output_transform.cl" },
+ { "winograd_output_transform_4x1_3x1_nhwc", "winograd_output_transform.cl" },
+ { "winograd_output_transform_1x4_1x3_nhwc", "winograd_output_transform.cl" },
+ { "winograd_output_transform_4x4_3x3_nhwc", "winograd_output_transform.cl" },
+ { "winograd_output_transform_4x4_5x5_nhwc", "winograd_output_transform.cl" },
+ { "winograd_output_transform_4x1_5x1_nhwc", "winograd_output_transform.cl" },
+ { "winograd_output_transform_1x4_1x5_nhwc", "winograd_output_transform.cl" },
{ "YUYV422_to_IYUV_bt709", "color_convert.cl" },
{ "YUYV422_to_NV12_bt709", "color_convert.cl" },
{ "YUYV422_to_RGB888_bt709", "color_convert.cl" },
@@ -404,6 +451,10 @@
#include "./cl_kernels/arithmetic_op.clembed"
},
{
+ "arithmetic_op_quantized.cl",
+#include "./cl_kernels/arithmetic_op_quantized.clembed"
+ },
+ {
"bitwise_op.cl",
#include "./cl_kernels/bitwise_op.clembed"
},
@@ -520,12 +571,12 @@
#include "./cl_kernels/fast_corners.clembed"
},
{
- "fill_border.cl",
-#include "./cl_kernels/fill_border.clembed"
+ "flatten.cl",
+#include "./cl_kernels/flatten.clembed"
},
{
- "fixed_point.h",
-#include "./cl_kernels/fixed_point.hembed"
+ "fill_border.cl",
+#include "./cl_kernels/fill_border.clembed"
},
{
"floor.cl",
@@ -712,8 +763,16 @@
#include "./cl_kernels/warp_perspective.clembed"
},
{
- "winograd.cl",
-#include "./cl_kernels/winograd.clembed"
+ "winograd_filter_transform.cl",
+#include "./cl_kernels/winograd_filter_transform.clembed"
+ },
+ {
+ "winograd_input_transform.cl",
+#include "./cl_kernels/winograd_input_transform.clembed"
+ },
+ {
+ "winograd_output_transform.cl",
+#include "./cl_kernels/winograd_output_transform.clembed"
},
#endif /* EMBEDDED_KERNELS */
};
@@ -741,11 +800,26 @@
}
std::string concat_str;
- if(fp16_supported(_device))
+#if defined(ARM_COMPUTE_DEBUG_ENABLED)
+ // Enable debug properties in CL kernels
+ concat_str += " -DARM_COMPUTE_DEBUG_ENABLED";
+#endif // defined(ARM_COMPUTE_DEBUG_ENABLED)
+
+ if(fp16_supported())
{
concat_str += " -DARM_COMPUTE_OPENCL_FP16_ENABLED=1 ";
}
+ if(dot8_supported(_device))
+ {
+ concat_str += " -DARM_COMPUTE_OPENCL_DOT8_ENABLED=1 ";
+ }
+
+ if(dot8_acc_supported(_device))
+ {
+ concat_str += " -DARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED=1 ";
+ }
+
if(get_cl_version(_device) == CLVersion::CL20)
{
concat_str += " -cl-std=CL2.0 ";
@@ -794,6 +868,16 @@
_built_programs_map.emplace(built_program_name, program);
}
+bool CLKernelLibrary::fp16_supported() const
+{
+ return ::fp16_supported(_device);
+}
+
+bool CLKernelLibrary::int64_base_atomics_supported() const
+{
+ return device_supports_extension(_device, "cl_khr_int64_base_atomics");
+}
+
const Program &CLKernelLibrary::load_program(const std::string &program_name) const
{
const auto program_it = _programs_map.find(program_name);
@@ -882,8 +966,7 @@
cl::NDRange CLKernelLibrary::default_ndrange() const
{
- cl::Device device = cl::Device::getDefault();
- GPUTarget _target = get_target_from_device(device);
+ GPUTarget _target = get_target_from_device(_device);
cl::NDRange default_range;
switch(_target)
diff --git a/src/core/CL/ICLSimple2DKernel.cpp b/src/core/CL/ICLSimple2DKernel.cpp
index 5dc3e6c..cf6c9c8 100644
--- a/src/core/CL/ICLSimple2DKernel.cpp
+++ b/src/core/CL/ICLSimple2DKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,7 +42,7 @@
unsigned int idx = 0;
add_2D_tensor_argument(idx, _input, slice);
add_2D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_2D(slice));
}
diff --git a/src/core/CL/ICLSimple3DKernel.cpp b/src/core/CL/ICLSimple3DKernel.cpp
index 0bd9d15..4197307 100644
--- a/src/core/CL/ICLSimple3DKernel.cpp
+++ b/src/core/CL/ICLSimple3DKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,7 +41,7 @@
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice);
add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/ICLSimpleKernel.cpp b/src/core/CL/ICLSimpleKernel.cpp
index fec9d92..48e5a88 100644
--- a/src/core/CL/ICLSimpleKernel.cpp
+++ b/src/core/CL/ICLSimpleKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,5 +50,5 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size);
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/OpenCL.cpp b/src/core/CL/OpenCL.cpp
index a8ed973..486bb6a 100644
--- a/src/core/CL/OpenCL.cpp
+++ b/src/core/CL/OpenCL.cpp
@@ -74,6 +74,7 @@
#define LOAD_FUNCTION_PTR(func_name, handle) \
func_name##_ptr = reinterpret_cast<decltype(func_name) *>(dlsym(handle, #func_name));
+ LOAD_FUNCTION_PTR(clCreateContext, handle);
LOAD_FUNCTION_PTR(clCreateContextFromType, handle);
LOAD_FUNCTION_PTR(clCreateCommandQueue, handle);
LOAD_FUNCTION_PTR(clGetContextInfo, handle);
@@ -254,6 +255,26 @@
}
}
+cl_context clCreateContext(
+ const cl_context_properties *properties,
+ cl_uint num_devices,
+ const cl_device_id *devices,
+ void (*pfn_notify)(const char *, const void *, size_t, void *),
+ void *user_data,
+ cl_int *errcode_ret)
+{
+ arm_compute::CLSymbols::get().load_default();
+ auto func = arm_compute::CLSymbols::get().clCreateContext_ptr;
+ if(func != nullptr)
+ {
+ return func(properties, num_devices, devices, pfn_notify, user_data, errcode_ret);
+ }
+ else
+ {
+ return nullptr;
+ }
+}
+
cl_context clCreateContextFromType(const cl_context_properties *properties,
cl_device_type device_type,
void (*pfn_notify)(const char *, const void *, size_t, void *),
diff --git a/src/core/CL/cl_kernels/activation_layer.cl b/src/core/CL/cl_kernels/activation_layer.cl
index a8ea738..373406a 100644
--- a/src/core/CL/cl_kernels/activation_layer.cl
+++ b/src/core/CL/cl_kernels/activation_layer.cl
@@ -25,23 +25,6 @@
#define TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
-#if defined(FIXED_POINT_POSITION)
-#include "fixed_point.h"
-
-#define CONST_ONE (1 << FIXED_POINT_POSITION)
-#define ABS_OP(a) ABS_SAT_OP_EXPAND((a), DATA_TYPE, VEC_SIZE)
-#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE)
-#define SUB_OP(a, b) SUB_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE)
-#define MUL_OP(a, b) MUL_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define MLA_OP(a, b, c) MLA_SAT_OP_EXPAND((a), (b), (c), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define DIV_OP(a, b) DIV_SAT_OP_VEC_EXPAND((a), (b), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define EXP_OP(a) EXP_OP_EXPAND((a), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define LOG_OP(a) LOG_OP_EXPAND((a), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define SQRT_OP(a) DIV_OP(CONST_ONE, INVSQRT_OP_EXPAND((a), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION))
-#define TANH_OP(a) TANH_OP_EXPAND((a), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-
-#else /* FIXED_POINT_POSITION */
-
#define CONST_ONE 1.f
#define ABS_OP(a) fabs((a))
#define ADD_OP(a, b) ((a) + (b))
@@ -54,8 +37,6 @@
#define SQRT_OP(a) sqrt((a))
#define TANH_OP(a) tanh((a))
-#endif /* FIXED_POINT_POSITION */
-
// Logistic Activation
inline TYPE logistic_op(TYPE x)
{
@@ -125,9 +106,8 @@
* @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size. e.g. -DVEC_SIZE=16
* @note Activation function should be given as a preprocessor argument using -DACT=name. e.g. -DACT=TANH
* @note A, B variables required by some activation functions are set using -DA_VAL= and -DB_VAL= respectively.
- * @note In case of fixed point calculations the fixed point position is passed using -DFIXED_POINT_POSITION=position. e.g. -DFIXED_POINT_POSITION=3.
*
- * @param[in] input_ptr Pointer to the source image. Supported data types: QS8/QS16/F16/F32
+ * @param[in] input_ptr Pointer to the source image. Supported data types: F16/F32
* @param[in] input_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the source image in Y dimension (in bytes)
diff --git a/src/core/CL/cl_kernels/activation_layer_qa8.cl b/src/core/CL/cl_kernels/activation_layer_qa8.cl
index 66e54ed..8f6a807 100644
--- a/src/core/CL/cl_kernels/activation_layer_qa8.cl
+++ b/src/core/CL/cl_kernels/activation_layer_qa8.cl
@@ -24,7 +24,18 @@
#include "helpers.h"
#define TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+#define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE)
+// Logistic Activation
+inline TYPE logistic_op(TYPE x)
+{
+ VEC_FLOAT x_flt = CONVERT(x, VEC_FLOAT);
+ x_flt = round(x_flt - (float)O1_VAL) * ((float)S1_VAL);
+ x_flt = 1.f / (1.f + exp(-x_flt));
+
+ const TYPE x_u8 = CONVERT_SAT(round(x_flt / ((float)S1_VAL)) + (float)O1_VAL, TYPE);
+ return x_u8;
+}
// RELU Activation
inline TYPE relu_op(TYPE x)
{
@@ -119,4 +130,4 @@
(data, 0, (__global DATA_TYPE *)output.ptr);
}
-#endif /* defined(ACT) */
\ No newline at end of file
+#endif /* defined(ACT) */
diff --git a/src/core/CL/cl_kernels/arithmetic_op.cl b/src/core/CL/cl_kernels/arithmetic_op.cl
index 1296347..9efb71b 100644
--- a/src/core/CL/cl_kernels/arithmetic_op.cl
+++ b/src/core/CL/cl_kernels/arithmetic_op.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,10 +23,6 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-#include "fixed_point.h"
-#endif /* FIXED_POINT_POSITION */
-
#ifdef SATURATE
#define ADD(x, y) add_sat((x), (y))
#define SUB(x, y) sub_sat((x), (y))
@@ -35,13 +31,15 @@
#define SUB(x, y) (x) - (y)
#endif /* SATURATE */
+#define DIV(x, y) (x) / (y)
+
/** This function adds two tensors.
*
* @attention The input and output data_types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT:
* e.g. -DDATA_TYPE_IN1=uchar -DDATA_TYPE_IN2=uchar -DDATA_TYPE_OUT=short
* @attention To perform saturating operation -DSATURATE has to be passed to the compiler otherwise wrapping policy will be used.
*
- * @param[in] in1_ptr Pointer to the source tensor. Supported data types: U8/QS8/QS16/S16/F16/F32
+ * @param[in] in1_ptr Pointer to the source tensor. Supported data types: U8/S16/F16/F32
* @param[in] in1_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] in1_step_x in1_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] in1_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -49,7 +47,7 @@
* @param[in] in1_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] in1_step_z in1_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] in1_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[in] in2_ptr Pointer to the source tensor. Supported data types: U8/QS8 (only if @p in1_ptr is QS8), QS16 (only if @p in1_ptr is QS16), S16/F16/F32
+ * @param[in] in2_ptr Pointer to the source tensor. Supported data types: U8/S16/F16/F32
* @param[in] in2_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] in2_step_x in2_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] in2_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -57,7 +55,7 @@
* @param[in] in2_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] in2_step_z in2_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] in2_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] out_ptr Pointer to the destination tensor. Supported data types: U8 (only if both inputs are U8), QS8 (only if both inputs are QS8), QS16 (only if both inputs are QS16), S16/F16/F32
+ * @param[out] out_ptr Pointer to the destination tensor. Supported data types: U8 (only if both inputs are U8), S16/F16/F32
* @param[in] out_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] out_step_x out_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] out_stride_y Stride of the destination tensor in Y dimension (in bytes)
@@ -86,7 +84,7 @@
vstore16(ADD(in_a, in_b), 0, (__global DATA_TYPE_OUT *)out.ptr);
}
-/** This function subtracts one tensors from another.
+/** This function subtracts one tensor from another.
*
* @attention The input and output data_types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT:
* e.g. -DDATA_TYPE_IN1=uchar -DDATA_TYPE_IN2=uchar -DDATA_TYPE_OUT=short
@@ -136,3 +134,53 @@
// Calculate and store result
vstore16(SUB(in_a, in_b), 0, (__global DATA_TYPE_OUT *)out.ptr);
}
+
+/** This function divides one tensor from another.
+ *
+ * @attention The input and output data_types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT:
+ * e.g. -DDATA_TYPE_IN1=float -DDATA_TYPE_IN2=float -DDATA_TYPE_OUT=float
+ *
+ * @param[in] in1_ptr Pointer to the source tensor. Supported data types: F16/F32
+ * @param[in] in1_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] in1_step_x in1_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] in1_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] in1_step_y in1_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] in1_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] in1_step_z in1_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] in1_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[in] in2_ptr Pointer to the source tensor. Supported data types: Same as @p in1_ptr
+ * @param[in] in2_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] in2_step_x in2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] in2_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] in2_step_y in2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] in2_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] in2_step_z in2_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] in2_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] out_ptr Pointer to the destination tensor. Supported data types: Same as @p in1_ptr
+ * @param[in] out_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] out_step_x out_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] out_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] out_step_y out_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] out_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] out_step_z out_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] out_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void arithmetic_div(
+ TENSOR3D_DECLARATION(in1),
+ TENSOR3D_DECLARATION(in2),
+ TENSOR3D_DECLARATION(out))
+{
+ // Get pixels pointer
+ Tensor3D in1 = CONVERT_TO_TENSOR3D_STRUCT(in1);
+ Tensor3D in2 = CONVERT_TO_TENSOR3D_STRUCT(in2);
+ Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
+
+ // Load values
+ VEC_DATA_TYPE(DATA_TYPE_OUT, 16)
+ in_a = CONVERT(vload16(0, (__global DATA_TYPE_IN1 *)in1.ptr), VEC_DATA_TYPE(DATA_TYPE_OUT, 16));
+ VEC_DATA_TYPE(DATA_TYPE_OUT, 16)
+ in_b = CONVERT(vload16(0, (__global DATA_TYPE_IN2 *)in2.ptr), VEC_DATA_TYPE(DATA_TYPE_OUT, 16));
+
+ // Calculate and store result
+ vstore16(DIV(in_a, in_b), 0, (__global DATA_TYPE_OUT *)out.ptr);
+}
diff --git a/src/core/CL/cl_kernels/arithmetic_op_quantized.cl b/src/core/CL/cl_kernels/arithmetic_op_quantized.cl
new file mode 100644
index 0000000..082317b
--- /dev/null
+++ b/src/core/CL/cl_kernels/arithmetic_op_quantized.cl
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2016-2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+
+#ifdef SATURATE
+#define ADD(x, y) add_sat((x), (y))
+#define SUB(x, y) sub_sat((x), (y))
+#else /* SATURATE */
+#define ADD(x, y) (x) + (y)
+#define SUB(x, y) (x) - (y)
+#endif /* SATURATE */
+
+#if defined(OFFSET_IN1)
+
+/** This function adds two tensors.
+ *
+ * @attention The quantization offset must be passed at compile time using -DOFFSET_IN1, i.e. -DOFFSET_IN1=10
+ * @attention To perform saturating operation -DSATURATE has to be passed to the compiler otherwise wrapping policy will be used.
+ *
+ * @param[in] in1_ptr Pointer to the source tensor. Supported data types: QASYMM8
+ * @param[in] in1_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] in1_step_x in1_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] in1_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] in1_step_y in1_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] in1_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] in1_step_z in1_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] in1_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[in] in2_ptr Pointer to the source tensor. Supported data types: same as @p in1_ptr
+ * @param[in] in2_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] in2_step_x in2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] in2_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] in2_step_y in2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] in2_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] in2_step_z in2_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] in2_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] out_ptr Pointer to the destination tensor. Supported data types: same as @p in1_ptr
+ * @param[in] out_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] out_step_x out_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] out_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] out_step_y out_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] out_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] out_step_z out_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] out_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void arithmetic_add_quantized(
+ TENSOR3D_DECLARATION(in1),
+ TENSOR3D_DECLARATION(in2),
+ TENSOR3D_DECLARATION(out))
+{
+ // Get pixels pointer
+ Tensor3D in1 = CONVERT_TO_TENSOR3D_STRUCT(in1);
+ Tensor3D in2 = CONVERT_TO_TENSOR3D_STRUCT(in2);
+ Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
+
+ int16 in_a = CONVERT(vload16(0, (__global uchar *)in1.ptr), int16);
+ int16 in_b = CONVERT(vload16(0, (__global uchar *)in2.ptr), int16);
+
+ in_a = SUB(in_a, (int16)((int)OFFSET_IN1));
+ in_b = SUB(in_b, (int16)((int)OFFSET_IN2));
+
+ const float16 in1f32 = convert_float16(in_a) * (float16)((float)SCALE_IN1);
+ const float16 in2f32 = convert_float16(in_b) * (float16)((float)SCALE_IN2);
+ const float16 qresf32 = (in1f32 + in2f32) / ((float16)(float)SCALE_OUT) + ((float16)((float16)OFFSET_OUT));
+ const uchar16 res = convert_uchar16_sat(convert_int16_rte(qresf32));
+
+ // Store result
+ vstore16(res, 0, (__global uchar *)out.ptr);
+}
+#endif /* defined(OFFSET) */
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index 9c980da..5352af3 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -25,25 +25,12 @@
#if defined(VEC_SIZE) && defined(DATA_TYPE)
-#if defined(FIXED_POINT_POSITION)
-#include "fixed_point.h"
-
-#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE)
-#define SUB_OP(a, b) SUB_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE)
-#define MUL_OP(a, b) MUL_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define INVSQRT_OP(a) INVSQRT_OP_EXPAND((a), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define SQCVT_SAT(a) SQCVT_SAT_OP_EXPAND((a), DATA_TYPE, FIXED_POINT_POSITION)
-
-#else /* FIXED_POINT_POSITION */
-
#define ADD_OP(a, b) ((a) + (b))
#define SUB_OP(a, b) ((a) - (b))
#define MUL_OP(a, b) ((a) * (b))
#define INVSQRT_OP(a) rsqrt((a))
#define SQCVT_SAT(a) (a)
-#endif /* FIXED_POINT_POSITION */
-
#if defined(FUSED_ACTIVATION)
#include "activation_layer.cl"
#define ACTIVATION_FUNC(x) ACTIVATION_OP(FUSED_ACTIVATION, x)
@@ -53,7 +40,7 @@
/** Apply batch normalization.
*
- * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F16/F32
+ * @param[in] input_ptr Pointer to the first source tensor. Supported data types: F16/F32
* @param[in] input_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the first source tensor in Y dimension (in bytes)
@@ -163,7 +150,7 @@
/** Apply batch normalization on tensors with NHWC format.
*
- * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F16/F32
+ * @param[in] input_ptr Pointer to the first source tensor. Supported data types: F16/F32
* @param[in] input_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the first source tensor in Y dimension (in bytes)
diff --git a/src/core/CL/cl_kernels/canny.cl b/src/core/CL/cl_kernels/canny.cl
index 166d681..9bfa2f4 100644
--- a/src/core/CL/cl_kernels/canny.cl
+++ b/src/core/CL/cl_kernels/canny.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -77,10 +77,10 @@
m = CONVERT_SAT((abs(h) + abs(v)), VEC_DATA_TYPE(DATA_TYPE_OUT, 4));
/* Calculate the angle */
- float4 p = atan2pi(convert_float4(v), convert_float4(h));
+ float4 p = 180.0f * atan2pi(convert_float4(v), convert_float4(h));
/* Remap angle to range [0, 256) */
- p = select(p, p + 2, p < 0.0f) * 128.0f;
+ p = select(p, p + 180.0f, p < 0.0f);
/* Store results */
vstore4(m, 0, (__global DATA_TYPE_OUT *)grad.ptr);
@@ -138,29 +138,27 @@
float4 m = sqrt(h * h + v * v);
/* Calculate the angle */
- float4 p = atan2pi(v, h);
+ float4 p = 180.0f * atan2pi(v, h);
/* Remap angle to range [0, 256) */
- p = select(p, p + 2, p < 0.0f) * 128.0f;
+ p = select(p, p + 180.0f, p < 0.0f);
/* Store results */
vstore4(CONVERT_SAT_ROUND(m, VEC_DATA_TYPE(DATA_TYPE_OUT, 4), rte), 0, (__global DATA_TYPE_OUT *)grad.ptr);
vstore4(convert_uchar4_sat_rte(p), 0, angle.ptr);
}
+#define EDGE 255
+#define NO_EDGE 0
+
/** Array that holds the relative coordinates offset for the neighbouring pixels.
*/
__constant short4 neighbours_coords[] =
{
{ -1, 0, 1, 0 }, // 0
- { -1, 1, 1, -1 }, // 45
- { 0, 1, 0, -1 }, // 90
- { 1, 1, -1, -1 }, // 135
- { 1, 0, -1, 0 }, // 180
- { 1, -1, -1, 1 }, // 225
- { 0, 1, 0, -1 }, // 270
- { -1, -1, 1, 1 }, // 315
- { -1, 0, 1, 0 }, // 360
+ { -1, -1, 1, 1 }, // 45
+ { 0, -1, 0, 1 }, // 90
+ { 1, -1, -1, 1 }, // 135
};
/** Perform non maximum suppression.
@@ -199,18 +197,39 @@
Image angle = CONVERT_TO_IMAGE_STRUCT(angle);
Image non_max = CONVERT_TO_IMAGE_STRUCT(non_max);
+ // Index
+ const int x = get_global_id(0);
+ const int y = get_global_id(1);
+
// Get gradient and angle
DATA_TYPE_IN gradient = *((__global DATA_TYPE_IN *)grad.ptr);
- uchar an = convert_ushort(*angle.ptr);
+ uchar an = *((__global uchar *)angle.ptr);
+ // Early return if not greater than lower threshold
if(gradient <= lower_thr)
{
return;
}
- // Divide the whole round into 8 directions
- uchar ang = 127 - an;
- DATA_TYPE_OUT q_an = (ang + 16) >> 5;
+ // Divide the whole round into 4 directions
+ DATA_TYPE_OUT q_an;
+
+ if(an < 22.5f || an >= 157.5f)
+ {
+ q_an = 0;
+ }
+ else if(an < 67.5f)
+ {
+ q_an = 1;
+ }
+ else if(an < 112.5f)
+ {
+ q_an = 2;
+ }
+ else
+ {
+ q_an = 3;
+ }
// Find the two pixels in the perpendicular direction
short2 x_p = neighbours_coords[q_an].s02;
@@ -220,11 +239,11 @@
if((gradient > g1) && (gradient > g2))
{
- *((global DATA_TYPE_OUT *)non_max.ptr) = gradient;
+ __global uchar *non_max_addr = non_max_ptr + non_max_offset_first_element_in_bytes + x * non_max_stride_x + y * non_max_stride_y;
+ *((global DATA_TYPE_OUT *)non_max_addr) = gradient;
}
}
-#define EDGE 255
#define hysteresis_local_stack_L1 8 // The size of level 1 stack. This has to agree with the host side
#define hysteresis_local_stack_L2 16 // The size of level 2 stack, adjust this can impact the match rate with VX implementation
@@ -330,10 +349,16 @@
// Load value
DATA_TYPE_IN val = *((__global DATA_TYPE_IN *)offset(&src, x, y));
- // If less than upper threshold set to NO_EDGE and return
+ // If the pixel has already been marked as NO_EDGE, store that value in the output and return
+ if(val == NO_EDGE)
+ {
+ *offset(&out, x, y) = NO_EDGE;
+ return;
+ }
+
+ // Return if it is a MAYBE pixel. Such pixels will become edges if near a strong edge
if(val <= up_thr)
{
- *offset(&out, x, y) = 0;
return;
}
@@ -372,7 +397,7 @@
// Get direction pixel indices
int N = max(y - 1, 0), S = min(y + 1, height - 2), W = max(x - 1, 0), E = min(x + 1, width - 2);
- // Check 8 pixels around for week edges where low_thr < val <= up_thr
+ // Check 8 pixels around for weak edges where low_thr < val <= up_thr
x_tmp = vload4(0, (__global DATA_TYPE_IN *)offset(&src, W, N));
v_tmp = vload4(0, (__global uint *)offset(&visited, W, N));
check_pixel(((x_tmp.s0 <= low_thr) || v_tmp.s0 || (x_tmp.s0 > up_thr)), W, N, x, y); // NW
diff --git a/src/core/CL/cl_kernels/channel_shuffle.cl b/src/core/CL/cl_kernels/channel_shuffle.cl
index 26cee9c..23962e1 100644
--- a/src/core/CL/cl_kernels/channel_shuffle.cl
+++ b/src/core/CL/cl_kernels/channel_shuffle.cl
@@ -38,7 +38,7 @@
* @note The number of channels in each group should be given as a preprocessor argument using -DK=num. e.g. -DK=1
* K is equal to num_channels / num_groups.
*
- * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
+ * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[in] src_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the first source tensor in Y dimension (in bytes)
diff --git a/src/core/CL/cl_kernels/col2im.cl b/src/core/CL/cl_kernels/col2im.cl
index 9b5a7b5..5e52127 100644
--- a/src/core/CL/cl_kernels/col2im.cl
+++ b/src/core/CL/cl_kernels/col2im.cl
@@ -23,12 +23,7 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-#include "fixed_point.h"
-#endif // FIXED_POINT_POSITION
-
#if defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT)
-#if !defined(FIXED_POINT_POSITION)
#if ELEMENT_SIZE == 1
#define COND_DATA_TYPE char
@@ -46,6 +41,7 @@
* @note The width of the input tensor must be passed at compile time using -DWIDTH_INPUT: e.g. -DWIDTH_INPUT=320
* @note The width of the output tensor must be passed at compile time using -DWIDTH_OUTPUT: e.g. -DWIDTH_OUTPUT=600
* @note The element size must be passed at compile time using -DELEMENT_SIZE: e.g. -DELEMENT_SIZE=4
+ * @note In case of grouping the GROUPING flag must be passed at compile time using -DGROUPING
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
@@ -72,6 +68,9 @@
{
Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ const uint xd = get_global_id(1) % WIDTH_OUTPUT; // x coordinate of the destination tensor
+ const uint yd = get_global_id(1) / WIDTH_OUTPUT; // y coordinate of the destination tensor
+
VEC_DATA_TYPE(DATA_TYPE, 8)
data = vload8(0, (__global DATA_TYPE *)src.ptr);
@@ -89,8 +88,16 @@
__global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes;
- // Compute output offset
- int idx = (get_global_id(1) / WIDTH_OUTPUT) * dst_stride_y + (get_global_id(1) % WIDTH_OUTPUT) * dst_stride_x + get_global_id(2) * dst_stride_w;
+#if defined(GROUPING)
+ // Compute output offset (batches on 4th dimension, no need to compute manually)
+ int idx = yd * dst_stride_y + xd * dst_stride_x;
+
+ const uint group = get_global_id(2); // group ID
+ x_clamped += group * WIDTH_INPUT;
+#else /* defined(GROUPING) */
+ // Compute output offset (batches on 3rd dimension)
+ int idx = yd * dst_stride_y + xd * dst_stride_x + get_global_id(2) * dst_stride_w;
+#endif /* GROUPING */
// Store value
*((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s0 * dst_stride_z)) = data.s0;
@@ -102,43 +109,4 @@
*((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s6 * dst_stride_z)) = data.s6;
*((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s7 * dst_stride_z)) = data.s7;
}
-#else // !defined(FIXED_POINT_POSITION)
-/** This kernel performs a reshaping of the output of the convolution layer.
- *
- * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=qs8
- * @note The width of the output tensor must be passed at compile time using -DWIDTH_OUTPUT: e.g. -DWIDTH_OUTPUT=320
- *
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QS16
- * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
- */
-__kernel void col2im(
- TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst),
- uint dst_stride_w)
-{
- Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
- Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(dst);
-
- // Compute output offset
- int idx = get_global_id(0) * dst.stride_z + (get_global_id(1) / WIDTH_OUTPUT) * dst_stride_y + (get_global_id(1) % WIDTH_OUTPUT) * dst_stride_x + get_global_id(2) * dst_stride_w;
-
- // Store value
- *((__global DATA_TYPE *)(dst.ptr + idx)) = *((__global DATA_TYPE *)(src.ptr));
-}
-#endif // !defined(FIXED_POINT_POSITION)
-#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT)
\ No newline at end of file
+#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT)
diff --git a/src/core/CL/cl_kernels/color_convert.cl b/src/core/CL/cl_kernels/color_convert.cl
index 01d8b90..02a0c8e 100644
--- a/src/core/CL/cl_kernels/color_convert.cl
+++ b/src/core/CL/cl_kernels/color_convert.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -135,13 +135,23 @@
char8 cb = (char8)(uyvy.s0, uyvy.s0, uyvy.s4, uyvy.s4, uyvy.s8, uyvy.s8, uyvy.sc, uyvy.sc) - (char8)(128);
char8 cr = (char8)(uyvy.s2, uyvy.s2, uyvy.s6, uyvy.s6, uyvy.sa, uyvy.sa, uyvy.se, uyvy.se) - (char8)(128);
- float8 f_r = convert_float8(luma) + (float8)(0.0000f) * convert_float8(cb) + (float8)(1.5748f) * convert_float8(cr);
- float8 f_g = convert_float8(luma) - (float8)(0.1873f) * convert_float8(cb) - (float8)(0.4681f) * convert_float8(cr);
- float8 f_b = convert_float8(luma) + (float8)(1.8556f) * convert_float8(cb) + (float8)(0.0000f) * convert_float8(cr);
+ float8 red_coef_bt709 = (float8)(1.5748f);
+ float8 green_coef_bt709 = (float8)(-0.1873f);
+ float8 green_coef2_bt709 = (float8)(-0.4681f);
+ float8 blue_coef_bt709 = (float8)(1.8556f);
+ float8 lumav = convert_float8(luma);
- uchar8 r_0 = convert_uchar8_rtz(f_r);
- uchar8 g_0 = convert_uchar8_rtz(f_g);
- uchar8 b_0 = convert_uchar8_rtz(f_b);
+ float8 f_r = red_coef_bt709 * convert_float8(cr);
+ float8 f_g = green_coef_bt709 * convert_float8(cb) + green_coef2_bt709 * convert_float8(cr);
+ float8 f_b = blue_coef_bt709 * convert_float8(cb);
+
+ f_r += lumav;
+ f_g += lumav;
+ f_b += lumav;
+
+ uchar8 r_0 = convert_uchar8_sat_rtz(f_r);
+ uchar8 g_0 = convert_uchar8_sat_rtz(f_g);
+ uchar8 b_0 = convert_uchar8_sat_rtz(f_b);
uchar16 rgb_0 = (uchar16)(r_0.s0, g_0.s0, b_0.s0, r_0.s1, g_0.s1, b_0.s1, r_0.s2, g_0.s2, b_0.s2,
r_0.s3, g_0.s3, b_0.s3, r_0.s4, g_0.s4, b_0.s4, r_0.s5);
@@ -183,13 +193,23 @@
char8 cb = (char8)(uyvy.s0, uyvy.s0, uyvy.s4, uyvy.s4, uyvy.s8, uyvy.s8, uyvy.sc, uyvy.sc) - (char8)(128);
char8 cr = (char8)(uyvy.s2, uyvy.s2, uyvy.s6, uyvy.s6, uyvy.sa, uyvy.sa, uyvy.se, uyvy.se) - (char8)(128);
- float8 f_r = convert_float8(luma) + (float8)(0.0000f) * convert_float8(cb) + (float8)(1.5748f) * convert_float8(cr);
- float8 f_g = convert_float8(luma) - (float8)(0.1873f) * convert_float8(cb) - (float8)(0.4681f) * convert_float8(cr);
- float8 f_b = convert_float8(luma) + (float8)(1.8556f) * convert_float8(cb) + (float8)(0.0000f) * convert_float8(cr);
+ float8 red_coef_bt709 = (float8)(1.5748f);
+ float8 green_coef_bt709 = (float8)(-0.1873f);
+ float8 green_coef2_bt709 = (float8)(-0.4681f);
+ float8 blue_coef_bt709 = (float8)(1.8556f);
+ float8 lumav = convert_float8(luma);
- uchar8 r_0 = convert_uchar8_rtz(f_r);
- uchar8 g_0 = convert_uchar8_rtz(f_g);
- uchar8 b_0 = convert_uchar8_rtz(f_b);
+ float8 f_r = red_coef_bt709 * convert_float8(cr);
+ float8 f_g = green_coef_bt709 * convert_float8(cb) + green_coef2_bt709 * convert_float8(cr);
+ float8 f_b = blue_coef_bt709 * convert_float8(cb);
+
+ f_r += lumav;
+ f_g += lumav;
+ f_b += lumav;
+
+ uchar8 r_0 = convert_uchar8_sat_rtz(f_r);
+ uchar8 g_0 = convert_uchar8_sat_rtz(f_g);
+ uchar8 b_0 = convert_uchar8_sat_rtz(f_b);
uchar16 rgba_0 = (uchar16)(r_0.s0, g_0.s0, b_0.s0, 255, r_0.s1, g_0.s1, b_0.s1, 255,
r_0.s2, g_0.s2, b_0.s2, 255, r_0.s3, g_0.s3, b_0.s3, 255);
@@ -232,13 +252,23 @@
char8 cb = (char8)(uyvy.s1, uyvy.s1, uyvy.s5, uyvy.s5, uyvy.s9, uyvy.s9, uyvy.sd, uyvy.sd) - (char8)(128);
char8 cr = (char8)(uyvy.s3, uyvy.s3, uyvy.s7, uyvy.s7, uyvy.sb, uyvy.sb, uyvy.sf, uyvy.sf) - (char8)(128);
- float8 f_r = convert_float8(luma) + (float8)(0.0000f) * convert_float8(cb) + (float8)(1.5748f) * convert_float8(cr);
- float8 f_g = convert_float8(luma) - (float8)(0.1873f) * convert_float8(cb) - (float8)(0.4681f) * convert_float8(cr);
- float8 f_b = convert_float8(luma) + (float8)(1.8556f) * convert_float8(cb) + (float8)(0.0000f) * convert_float8(cr);
+ float8 red_coef_bt709 = (float8)(1.5748f);
+ float8 green_coef_bt709 = (float8)(-0.1873f);
+ float8 green_coef2_bt709 = (float8)(-0.4681f);
+ float8 blue_coef_bt709 = (float8)(1.8556f);
+ float8 lumav = convert_float8(luma);
- uchar8 r_0 = convert_uchar8_rtz(f_r);
- uchar8 g_0 = convert_uchar8_rtz(f_g);
- uchar8 b_0 = convert_uchar8_rtz(f_b);
+ float8 f_r = red_coef_bt709 * convert_float8(cr);
+ float8 f_g = green_coef_bt709 * convert_float8(cb) + green_coef2_bt709 * convert_float8(cr);
+ float8 f_b = blue_coef_bt709 * convert_float8(cb);
+
+ f_r += lumav;
+ f_g += lumav;
+ f_b += lumav;
+
+ uchar8 r_0 = convert_uchar8_sat_rtz(f_r);
+ uchar8 g_0 = convert_uchar8_sat_rtz(f_g);
+ uchar8 b_0 = convert_uchar8_sat_rtz(f_b);
uchar16 rgb_0 = (uchar16)(r_0.s0, g_0.s0, b_0.s0, r_0.s1, g_0.s1, b_0.s1, r_0.s2, g_0.s2, b_0.s2,
r_0.s3, g_0.s3, b_0.s3, r_0.s4, g_0.s4, b_0.s4, r_0.s5);
@@ -280,13 +310,23 @@
char8 cb = (char8)(uyvy.s1, uyvy.s1, uyvy.s5, uyvy.s5, uyvy.s9, uyvy.s9, uyvy.sd, uyvy.sd) - (char8)(128);
char8 cr = (char8)(uyvy.s3, uyvy.s3, uyvy.s7, uyvy.s7, uyvy.sb, uyvy.sb, uyvy.sf, uyvy.sf) - (char8)(128);
- float8 f_r = convert_float8(luma) + (float8)(0.0000f) * convert_float8(cb) + (float8)(1.5748f) * convert_float8(cr);
- float8 f_g = convert_float8(luma) - (float8)(0.1873f) * convert_float8(cb) - (float8)(0.4681f) * convert_float8(cr);
- float8 f_b = convert_float8(luma) + (float8)(1.8556f) * convert_float8(cb) + (float8)(0.0000f) * convert_float8(cr);
+ float8 red_coef_bt709 = (float8)(1.5748f);
+ float8 green_coef_bt709 = (float8)(-0.1873f);
+ float8 green_coef2_bt709 = (float8)(-0.4681f);
+ float8 blue_coef_bt709 = (float8)(1.8556f);
+ float8 lumav = convert_float8(luma);
- uchar8 r_0 = convert_uchar8_rtz(f_r);
- uchar8 g_0 = convert_uchar8_rtz(f_g);
- uchar8 b_0 = convert_uchar8_rtz(f_b);
+ float8 f_r = red_coef_bt709 * convert_float8(cr);
+ float8 f_g = green_coef_bt709 * convert_float8(cb) + green_coef2_bt709 * convert_float8(cr);
+ float8 f_b = blue_coef_bt709 * convert_float8(cb);
+
+ f_r += lumav;
+ f_g += lumav;
+ f_b += lumav;
+
+ uchar8 r_0 = convert_uchar8_sat_rtz(f_r);
+ uchar8 g_0 = convert_uchar8_sat_rtz(f_g);
+ uchar8 b_0 = convert_uchar8_sat_rtz(f_b);
uchar16 rgba_0 = (uchar16)(r_0.s0, g_0.s0, b_0.s0, 255, r_0.s1, g_0.s1, b_0.s1, 255,
r_0.s2, g_0.s2, b_0.s2, 255, r_0.s3, g_0.s3, b_0.s3, 255);
@@ -431,9 +471,9 @@
float4 f_g = convert_float4(luma_0) + temp1;
float4 f_b = convert_float4(luma_0) + temp2;
- uchar4 r_0 = convert_uchar4_rtz(f_r);
- uchar4 g_0 = convert_uchar4_rtz(f_g);
- uchar4 b_0 = convert_uchar4_rtz(f_b);
+ uchar4 r_0 = convert_uchar4_sat_rtz(f_r);
+ uchar4 g_0 = convert_uchar4_sat_rtz(f_g);
+ uchar4 b_0 = convert_uchar4_sat_rtz(f_b);
uchar8 rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, r_0.s1, g_0.s1, b_0.s1, r_0.s2, g_0.s2);
uchar4 rgb_1 = (uchar4)(b_0.s2, r_0.s3, g_0.s3, b_0.s3);
@@ -444,9 +484,9 @@
f_g = convert_float4(luma_1) + temp1;
f_b = convert_float4(luma_1) + temp2;
- r_0 = convert_uchar4_rtz(f_r);
- g_0 = convert_uchar4_rtz(f_g);
- b_0 = convert_uchar4_rtz(f_b);
+ r_0 = convert_uchar4_sat_rtz(f_r);
+ g_0 = convert_uchar4_sat_rtz(f_g);
+ b_0 = convert_uchar4_sat_rtz(f_b);
rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, r_0.s1, g_0.s1, b_0.s1, r_0.s2, g_0.s2);
rgb_1 = (uchar4)(b_0.s2, r_0.s3, g_0.s3, b_0.s3);
@@ -887,9 +927,9 @@
float4 f_g = convert_float4(luma_0) + temp1;
float4 f_b = convert_float4(luma_0) + temp2;
- uchar4 r_0 = convert_uchar4_rtz(f_r);
- uchar4 g_0 = convert_uchar4_rtz(f_g);
- uchar4 b_0 = convert_uchar4_rtz(f_b);
+ uchar4 r_0 = convert_uchar4_sat_rtz(f_r);
+ uchar4 g_0 = convert_uchar4_sat_rtz(f_g);
+ uchar4 b_0 = convert_uchar4_sat_rtz(f_b);
uchar8 rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, 255, r_0.s1, g_0.s1, b_0.s1, 255);
uchar8 rgb_1 = (uchar8)(r_0.s2, g_0.s2, b_0.s2, 255, r_0.s3, g_0.s3, b_0.s3, 255);
@@ -900,9 +940,9 @@
f_g = convert_float4(luma_1) + temp1;
f_b = convert_float4(luma_1) + temp2;
- r_0 = convert_uchar4_rtz(f_r);
- g_0 = convert_uchar4_rtz(f_g);
- b_0 = convert_uchar4_rtz(f_b);
+ r_0 = convert_uchar4_sat_rtz(f_r);
+ g_0 = convert_uchar4_sat_rtz(f_g);
+ b_0 = convert_uchar4_sat_rtz(f_b);
rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, 255, r_0.s1, g_0.s1, b_0.s1, 255);
rgb_1 = (uchar8)(r_0.s2, g_0.s2, b_0.s2, 255, r_0.s3, g_0.s3, b_0.s3, 255);
@@ -1086,9 +1126,9 @@
float4 f_g = convert_float4(luma_0) + temp1;
float4 f_b = convert_float4(luma_0) + temp2;
- uchar4 r_0 = convert_uchar4_rtz(f_r);
- uchar4 g_0 = convert_uchar4_rtz(f_g);
- uchar4 b_0 = convert_uchar4_rtz(f_b);
+ uchar4 r_0 = convert_uchar4_sat_rtz(f_r);
+ uchar4 g_0 = convert_uchar4_sat_rtz(f_g);
+ uchar4 b_0 = convert_uchar4_sat_rtz(f_b);
uchar8 rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, r_0.s1, g_0.s1, b_0.s1, r_0.s2, g_0.s2);
uchar4 rgb_1 = (uchar4)(b_0.s2, r_0.s3, g_0.s3, b_0.s3);
@@ -1099,9 +1139,9 @@
f_g = convert_float4(luma_1) + temp1;
f_b = convert_float4(luma_1) + temp2;
- r_0 = convert_uchar4_rtz(f_r);
- g_0 = convert_uchar4_rtz(f_g);
- b_0 = convert_uchar4_rtz(f_b);
+ r_0 = convert_uchar4_sat_rtz(f_r);
+ g_0 = convert_uchar4_sat_rtz(f_g);
+ b_0 = convert_uchar4_sat_rtz(f_b);
rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, r_0.s1, g_0.s1, b_0.s1, r_0.s2, g_0.s2);
rgb_1 = (uchar4)(b_0.s2, r_0.s3, g_0.s3, b_0.s3);
@@ -1157,9 +1197,9 @@
float4 f_g = convert_float4(luma_0) + temp1;
float4 f_b = convert_float4(luma_0) + temp2;
- uchar4 r_0 = convert_uchar4_rtz(f_r);
- uchar4 g_0 = convert_uchar4_rtz(f_g);
- uchar4 b_0 = convert_uchar4_rtz(f_b);
+ uchar4 r_0 = convert_uchar4_sat_rtz(f_r);
+ uchar4 g_0 = convert_uchar4_sat_rtz(f_g);
+ uchar4 b_0 = convert_uchar4_sat_rtz(f_b);
uchar8 rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, 255, r_0.s1, g_0.s1, b_0.s1, 255);
uchar8 rgb_1 = (uchar8)(r_0.s2, g_0.s2, b_0.s2, 255, r_0.s3, g_0.s3, b_0.s3, 255);
@@ -1170,9 +1210,9 @@
f_g = convert_float4(luma_1) + temp1;
f_b = convert_float4(luma_1) + temp2;
- r_0 = convert_uchar4_rtz(f_r);
- g_0 = convert_uchar4_rtz(f_g);
- b_0 = convert_uchar4_rtz(f_b);
+ r_0 = convert_uchar4_sat_rtz(f_r);
+ g_0 = convert_uchar4_sat_rtz(f_g);
+ b_0 = convert_uchar4_sat_rtz(f_b);
rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, 255, r_0.s1, g_0.s1, b_0.s1, 255);
rgb_1 = (uchar8)(r_0.s2, g_0.s2, b_0.s2, 255, r_0.s3, g_0.s3, b_0.s3, 255);
@@ -1485,9 +1525,9 @@
float4 f_g = convert_float4(luma_0) + temp1;
float4 f_b = convert_float4(luma_0) + temp2;
- uchar4 r_0 = convert_uchar4_rtz(f_r);
- uchar4 g_0 = convert_uchar4_rtz(f_g);
- uchar4 b_0 = convert_uchar4_rtz(f_b);
+ uchar4 r_0 = convert_uchar4_sat_rtz(f_r);
+ uchar4 g_0 = convert_uchar4_sat_rtz(f_g);
+ uchar4 b_0 = convert_uchar4_sat_rtz(f_b);
uchar8 rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, r_0.s1, g_0.s1, b_0.s1, r_0.s2, g_0.s2);
uchar4 rgb_1 = (uchar4)(b_0.s2, r_0.s3, g_0.s3, b_0.s3);
@@ -1498,9 +1538,9 @@
f_g = convert_float4(luma_1) + temp1;
f_b = convert_float4(luma_1) + temp2;
- r_0 = convert_uchar4_rtz(f_r);
- g_0 = convert_uchar4_rtz(f_g);
- b_0 = convert_uchar4_rtz(f_b);
+ r_0 = convert_uchar4_sat_rtz(f_r);
+ g_0 = convert_uchar4_sat_rtz(f_g);
+ b_0 = convert_uchar4_sat_rtz(f_b);
rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, r_0.s1, g_0.s1, b_0.s1, r_0.s2, g_0.s2);
rgb_1 = (uchar4)(b_0.s2, r_0.s3, g_0.s3, b_0.s3);
@@ -1564,9 +1604,9 @@
float4 f_g = convert_float4(luma_0) + temp1;
float4 f_b = convert_float4(luma_0) + temp2;
- uchar4 r_0 = convert_uchar4_rtz(f_r);
- uchar4 g_0 = convert_uchar4_rtz(f_g);
- uchar4 b_0 = convert_uchar4_rtz(f_b);
+ uchar4 r_0 = convert_uchar4_sat_rtz(f_r);
+ uchar4 g_0 = convert_uchar4_sat_rtz(f_g);
+ uchar4 b_0 = convert_uchar4_sat_rtz(f_b);
uchar8 rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, 255, r_0.s1, g_0.s1, b_0.s1, 255);
uchar8 rgb_1 = (uchar8)(r_0.s2, g_0.s2, b_0.s2, 255, r_0.s3, g_0.s3, b_0.s3, 255);
@@ -1577,9 +1617,9 @@
f_g = convert_float4(luma_1) + temp1;
f_b = convert_float4(luma_1) + temp2;
- r_0 = convert_uchar4_rtz(f_r);
- g_0 = convert_uchar4_rtz(f_g);
- b_0 = convert_uchar4_rtz(f_b);
+ r_0 = convert_uchar4_sat_rtz(f_r);
+ g_0 = convert_uchar4_sat_rtz(f_g);
+ b_0 = convert_uchar4_sat_rtz(f_b);
rgb_0 = (uchar8)(r_0.s0, g_0.s0, b_0.s0, 255, r_0.s1, g_0.s1, b_0.s1, 255);
rgb_1 = (uchar8)(r_0.s2, g_0.s2, b_0.s2, 255, r_0.s3, g_0.s3, b_0.s3, 255);
diff --git a/src/core/CL/cl_kernels/concatenate.cl b/src/core/CL/cl_kernels/concatenate.cl
index f97ae13..16c4363 100644
--- a/src/core/CL/cl_kernels/concatenate.cl
+++ b/src/core/CL/cl_kernels/concatenate.cl
@@ -23,9 +23,14 @@
*/
#include "helpers.h"
+#if defined(DATA_TYPE)
+#if defined(WIDTH_OFFSET)
/** This kernel concatenates the input tensor into the output tensor along the first dimension
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8, QASYMM8, QS16, F16, F32
+ * @note The data type has to be passed at compile time using -DDATA_TYPE. i.e. -DDATA_TYPE=float
+ * @note The offset for the first spatial dimension has to be passed at compile time using -DWIDTH_OFFSET. i.e. -DWIDTH_OFFSET=128
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -45,8 +50,7 @@
*/
__kernel void concatenate_width(
TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst),
- int offset)
+ TENSOR3D_DECLARATION(dst))
{
Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
@@ -55,12 +59,13 @@
source_values = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)src.ptr);
VSTORE(VEC_SIZE)
- (source_values, 0, (__global DATA_TYPE *)(dst.ptr + offset));
+ (source_values, 0, (__global DATA_TYPE *)(dst.ptr) + WIDTH_OFFSET);
}
+#endif // defined(WIDTH_OFFSET)
/** This kernel concatenates the input tensor into the output tensor along the third dimension
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8, QS16, F16, F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16, F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -92,3 +97,4 @@
VSTORE(VEC_SIZE)
(source_values, 0, (__global DATA_TYPE *)(dst.ptr + offsets.z));
}
+#endif // defined(DATA_TYPE)
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/convert_fc_weights.cl b/src/core/CL/cl_kernels/convert_fc_weights.cl
index 3c3e8b0..d47b733 100644
--- a/src/core/CL/cl_kernels/convert_fc_weights.cl
+++ b/src/core/CL/cl_kernels/convert_fc_weights.cl
@@ -32,7 +32,7 @@
* @attention Data type can be passed using the -DDATA_TYPE compile flag, e.g. -DDATA_TYPE=float
* @attention Original input tensor width*height and depth should be given as a preprocessor argument using -DFACTOR_1=size and -DFACTOR_2=size for NCHW and vice versa for NHWC. e.g. -DFACTOR_1=256 and -DFACTOR_2=128
*
- * @param[in] src_ptr Pointer to the source image. Supported data types: U8, S8, QS8, QASYMM8, U16, S16, QS16, U32, S32, QS32, F16, F32
+ * @param[in] src_ptr Pointer to the source image. Supported data types: U8, S8, QASYMM8, U16, S16, U32, S32, F16, F32
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
diff --git a/src/core/CL/cl_kernels/convolution_layer.cl b/src/core/CL/cl_kernels/convolution_layer.cl
index f8e0c27..2b75b45 100644
--- a/src/core/CL/cl_kernels/convolution_layer.cl
+++ b/src/core/CL/cl_kernels/convolution_layer.cl
@@ -23,14 +23,11 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-#include "fixed_point.h"
-#endif // FIXED_POINT_POSITION
-
-#if defined(DATA_TYPE)
+#if defined(DATA_TYPE) && defined(NUM_GROUPS)
/** This kernel reshapes the tensor's low three dimensions to single column
*
* @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
+ * @note The number of groups should be given as a preprocessor argument using -DNUM_GROUPS=number. e.g. -DNUM_GROUPS=2
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
@@ -54,6 +51,7 @@
* @param[in] height The height of the input tensor
* @param[in] depth The depth of the input tensor
* @param[in] total_filters Total number of filters. 4th dimension of the weights matrix
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
*/
__kernel void reshape_to_columns(
TENSOR3D_DECLARATION(src),
@@ -61,7 +59,7 @@
#ifdef HAS_BIAS
VECTOR_DECLARATION(bias),
#endif /* HAS_BIAS */
- uint width, uint height, uint depth, uint total_filters)
+ uint width, uint height, uint depth, uint total_filters, uint dst_stride_z)
{
Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
bool is_last_thread = (get_global_id(0) == (get_global_size(0) - 1) && get_global_id(1) == (get_global_size(1) - 1) && get_global_id(2) == (get_global_size(2) - 1));
@@ -75,26 +73,40 @@
if(is_last_thread)
{
- for(uint i = 0; i < total_filters; ++i)
+ for(uint g = 0; g < NUM_GROUPS; ++g)
{
- *((__global DATA_TYPE *)tmp_dst_ptr) = *((__global DATA_TYPE *)tmp_src_ptr);
+ __global uchar *curr_group_dst = tmp_dst_ptr;
+
+ for(uint i = 0; i < total_filters / NUM_GROUPS; ++i)
+ {
+ *((__global DATA_TYPE *)curr_group_dst) = *((__global DATA_TYPE *)tmp_src_ptr);
#ifdef HAS_BIAS
- *((__global DATA_TYPE *)(tmp_dst_ptr + dst_stride_y)) = *((__global DATA_TYPE *)(tmp_bias_ptr));
- tmp_bias_ptr += bias_stride_x;
+ *((__global DATA_TYPE *)(curr_group_dst + dst_stride_y)) = *((__global DATA_TYPE *)(tmp_bias_ptr));
+ tmp_bias_ptr += bias_stride_x;
#endif /* HAS_BIAS */
- tmp_src_ptr += depth * src_stride_z;
- tmp_dst_ptr += dst_stride_x;
+ tmp_src_ptr += depth * src_stride_z;
+ curr_group_dst += dst_stride_x;
+ }
+
+ tmp_dst_ptr += dst_stride_z;
}
}
else
{
- for(uint i = 0; i < total_filters; ++i)
+ for(uint g = 0; g < NUM_GROUPS; ++g)
{
- *((__global DATA_TYPE *)tmp_dst_ptr) = *((__global DATA_TYPE *)tmp_src_ptr);
- tmp_src_ptr += depth * src_stride_z;
- tmp_dst_ptr += dst_stride_x;
+ __global uchar *curr_group_dst = tmp_dst_ptr;
+
+ for(uint i = 0; i < total_filters / NUM_GROUPS; ++i)
+ {
+ *((__global DATA_TYPE *)curr_group_dst) = *((__global DATA_TYPE *)tmp_src_ptr);
+ tmp_src_ptr += depth * src_stride_z;
+ curr_group_dst += dst_stride_x;
+ }
+
+ tmp_dst_ptr += dst_stride_z;
}
}
}
-#endif // defined(DATA_TYPE)
\ No newline at end of file
+#endif // defined(DATA_TYPE)
diff --git a/src/core/CL/cl_kernels/copy_tensor.cl b/src/core/CL/cl_kernels/copy_tensor.cl
index 4b37dec..930a676 100644
--- a/src/core/CL/cl_kernels/copy_tensor.cl
+++ b/src/core/CL/cl_kernels/copy_tensor.cl
@@ -25,24 +25,35 @@
/** Performs a copy of input tensor to the output tensor.
*
- * @param[in] in_ptr Pointer to the source image. Supported data types: U8.
- * @param[in] in_stride_x Stride of the source image in X dimension (in bytes)
- * @param[in] in_step_x in_stride_x * number of elements along X processed per work item (in bytes)
- * @param[in] in_offset_first_element_in_bytes Offset of the first element in the source image
- * @param[out] out_ptr Pointer to the destination image. Supported data types: U8.
- * @param[in] out_stride_x Stride of the destination image in X dimension (in bytes)
- * @param[in] out_step_x out_stride_x * number of elements along X processed per work item (in bytes)
- * @param[in] out_offset_first_element_in_bytes Offset of the first element in the destination image
+ * @param[in] in_ptr Pointer to the source tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
+ * @param[in] in_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] in_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] in_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] in_step_y input_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] in_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] in_step_z input_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] in_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] out_ptr Pointer to the destination tensor. Supported data types: same as @p in_ptr
+ * @param[in] out_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] out_step_x output_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] out_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] out_step_y output_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] out_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] out_step_z output_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] out_offset_first_element_in_bytes The offset of the first element in the destination tensor
*/
__kernel void copy_tensor(
- VECTOR_DECLARATION(in),
- VECTOR_DECLARATION(out))
+ TENSOR3D_DECLARATION(in),
+ TENSOR3D_DECLARATION(out))
{
- Vector in = CONVERT_TO_VECTOR_STRUCT(in);
- Vector out = CONVERT_TO_VECTOR_STRUCT(out);
+ Tensor3D in = CONVERT_TO_TENSOR3D_STRUCT(in);
+ Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
- VEC_DATA_TYPE(DATA_TYPE, 16)
- data = vload16(0, (__global DATA_TYPE *)in.ptr);
+ // Load data
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+ data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
- vstore16(data, 0, (__global DATA_TYPE *)out.ptr);
+ // Store result
+ VSTORE(VEC_SIZE)
+ (data, 0, (__global DATA_TYPE *)out.ptr);
}
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/deconvolution_layer.cl b/src/core/CL/cl_kernels/deconvolution_layer.cl
index 2514ddc..e15482c 100644
--- a/src/core/CL/cl_kernels/deconvolution_layer.cl
+++ b/src/core/CL/cl_kernels/deconvolution_layer.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,26 +25,30 @@
/** This function applies upsample on an input image.
*
- * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
- * @param[out] dst_ptr Pointer to the destination image. Supported data types: F32
+ * @param[out] dst_ptr Pointer to the destination image. Supported data types: F16/F32
* @param[in] dst_stride_x Stride of the destination image in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination image in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination image
*/
__kernel void deconvolution_upsample(
- IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Store result
- *((__global float *)dst.ptr) = *((__global float *)src.ptr);
+ *((__global DATA_TYPE *)dst.ptr) = *((__global DATA_TYPE *)src.ptr);
}
diff --git a/src/core/CL/cl_kernels/depth_convert.cl b/src/core/CL/cl_kernels/depth_convert.cl
index a9b7284..611449e 100644
--- a/src/core/CL/cl_kernels/depth_convert.cl
+++ b/src/core/CL/cl_kernels/depth_convert.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,47 +23,31 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-
-#include "fixed_point.h"
-
#ifdef SATURATE
-#define CONVERT_DOWN(x, in_type, out_type, fixed_point_position) CONVERT_DOWN1_SAT(x, in_type, out_type, fixed_point_position)
-#define CONVERT_DOWN1_SAT(x, in_type, out_type, fixed_point_position) convert_##out_type##_##in_type##_sat(x, fixed_point_position)
-#else /* SATURATE */
-#define CONVERT_DOWN(x, in_type, out_type, fixed_point_position) CONVERT_DOWN1(x, in_type, out_type, fixed_point_position)
-#define CONVERT_DOWN1(x, in_type, out_type, fixed_point_position) convert_##out_type##_##in_type(x, fixed_point_position)
-#endif /* SATURATE */
-
-#define CONVERT_UP(x, in_type, out_type, fixed_point_position) CONVERT_UP1(x, in_type, out_type, fixed_point_position)
-#define CONVERT_UP1(x, in_type, out_type, fixed_point_position) convert_##out_type##_##in_type(x, fixed_point_position)
-
-#else /* FIXED_POINT_POSITION */
-
-#ifdef SATURATE
+#if defined(IS_DATA_TYPE_FLOAT)
+#define CONVERT_RTE(x, type) (convert_##type##_rte((x)))
+#define CONVERT_DOWN(x, type) CONVERT_RTE(x, type)
+#else /* defined(IS_DATA_TYPE_FLOAT) */
#define CONVERT_DOWN(x, type) CONVERT_SAT(x, type)
-#else /* SATURATE */
+#endif /* defined(IS_DATA_TYPE_FLOAT) */
+#else /* SATURATE */
#define CONVERT_DOWN(x, type) CONVERT(x, type)
#endif /* SATURATE */
#define CONVERT_UP(x, type) CONVERT(x, type)
-#endif /* FIXED_POINT_POSITION */
-
/** This function performs a down-scaling depth conversion.
*
* @attention The input and output data_types need to be passed at compile time using -DDATA_TYPE_IN and -DDATA_TYPE_OUT:
* e.g. -DDATA_TYPE_IN=uchar -DDATA_TYPE_OUT=short
*
- * @note In case of fixed-point operation -DFIXED_POINT_POSITION=fixed_point_position must be provided: e.g. -DFIXED_POINT_POSITION=3
- *
- * @param[in] in_ptr Pointer to the source image. Supported data types: U8, U16, S16, U32, S32, F16, F32
+ * @param[in] in_ptr Pointer to the source image. Supported data types: U8/U16/S16/U32/S32/F16/F32
* @param[in] in_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] in_step_x in_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] in_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] in_step_y in_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] in_offset_first_element_in_bytes The offset of the first element in the source image
- * @param[out] out_ptr Pointer to the destination image. Supported data types: QS8, U8, QS16, U16, S16, U32, S32
+ * @param[out] out_ptr Pointer to the destination image. Supported data types: U8/U16/S16/U32/S32/F16/F32
* @param[in] out_stride_x Stride of the destination image in X dimension (in bytes)
* @param[in] out_step_x out_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] out_stride_y Stride of the destination image in Y dimension (in bytes)
@@ -84,11 +68,12 @@
VEC_DATA_TYPE(DATA_TYPE_IN, 16)
in_data = vload16(0, (__global DATA_TYPE_IN *)in.ptr);
-#if defined(FIXED_POINT_POSITION)
- vstore16(CONVERT_DOWN(in_data, VEC_DATA_TYPE(DATA_TYPE_IN, 16), VEC_DATA_TYPE(DATA_TYPE_OUT, 16), FIXED_POINT_POSITION), 0, (__global DATA_TYPE_OUT *)out.ptr);
-#else /* FIXED_POINT_POSITION */
+#if defined(IS_DATA_TYPE_FLOAT)
+ const DATA_TYPE_IN scale = (DATA_TYPE_IN)(1 << shift);
+ vstore16(CONVERT_DOWN(in_data / scale, VEC_DATA_TYPE(DATA_TYPE_OUT, 16)), 0, (__global DATA_TYPE_OUT *)out.ptr);
+#else /* defined(IS_DATA_TYPE_FLOAT) */
vstore16(CONVERT_DOWN(in_data >> shift, VEC_DATA_TYPE(DATA_TYPE_OUT, 16)), 0, (__global DATA_TYPE_OUT *)out.ptr);
-#endif /* FIXED_POINT_POSITION */
+#endif /* defined(IS_DATA_TYPE_FLOAT) */
}
/** This function performs a up-scaling depth conversion.
@@ -96,15 +81,13 @@
* @attention The input and output data_types need to be passed at compile time using -DDATA_TYPE_IN and -DDATA_TYPE_OUT:
* e.g. -DDATA_TYPE_IN=uchar -DDATA_TYPE_OUT=short
*
- * @note In case of fixed-point operation -DFIXED_POINT_POSITION=fixed_point_position must be provided: e.g. -DFIXED_POINT_POSITION=3
- *
- * @param[in] in_ptr Pointer to the source image. Supported data types: U8, QS8, U16, S16, QS16, U32 or S32
+ * @param[in] in_ptr Pointer to the source image. Supported data types: U8/U16/S16/U32/S32/F16/F32
* @param[in] in_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] in_step_x in_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] in_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] in_step_y in_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] in_offset_first_element_in_bytes The offset of the first element in the source image
- * @param[out] out_ptr Pointer to the destination image. Supported data types: U8, U16, S16, U32, S32, F16 or F32
+ * @param[out] out_ptr Pointer to the destination image. Supported data types: U8/U16/S16/U32/S32/F16/F32
* @param[in] out_stride_x Stride of the destination image in X dimension (in bytes)
* @param[in] out_step_x out_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] out_stride_y Stride of the destination image in Y dimension (in bytes)
@@ -125,9 +108,10 @@
VEC_DATA_TYPE(DATA_TYPE_IN, 16)
in_data = vload16(0, (__global DATA_TYPE_IN *)in.ptr);
-#if defined(FIXED_POINT_POSITION)
- vstore16(CONVERT_UP(in_data, VEC_DATA_TYPE(DATA_TYPE_IN, 16), VEC_DATA_TYPE(DATA_TYPE_OUT, 16), FIXED_POINT_POSITION), 0, (__global DATA_TYPE_OUT *)out.ptr);
-#else /* FIXED_POINT_POSITION */
+#if defined(IS_DATA_TYPE_FLOAT)
+ const DATA_TYPE_OUT scale = (DATA_TYPE_OUT)(1 << shift);
+ vstore16(CONVERT_UP(in_data, VEC_DATA_TYPE(DATA_TYPE_OUT, 16)) * scale, 0, (__global DATA_TYPE_OUT *)out.ptr);
+#else /* defined(IS_DATA_TYPE_FLOAT) */
vstore16(CONVERT_UP(in_data, VEC_DATA_TYPE(DATA_TYPE_OUT, 16)) << shift, 0, (__global DATA_TYPE_OUT *)out.ptr);
-#endif /* FIXED_POINT_POSITION */
+#endif /* defined(IS_DATA_TYPE_FLOAT) */
}
diff --git a/src/core/CL/cl_kernels/depthwise_convolution.cl b/src/core/CL/cl_kernels/depthwise_convolution.cl
index 5f4247e..77a76b6 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution.cl
@@ -451,6 +451,22 @@
#endif // defined(DEPTH_MULTIPLIER)
+#if defined(NCHW)
+#define in_stride_x src_stride_x
+#define in_stride_y src_stride_y
+#define in_stride_z src_stride_z
+#define out_stride_x dst_stride_x
+#define out_stride_y dst_stride_y
+#define out_stride_z dst_stride_z
+#else //defined(NCHW)
+#define in_stride_x src_stride_y
+#define in_stride_y src_stride_z
+#define in_stride_z src_stride_x
+#define out_stride_x dst_stride_y
+#define out_stride_y dst_stride_z
+#define out_stride_z dst_stride_x
+#endif //defined(NCHW)
+
#if defined(SRC_WIDTH) && defined(DATA_TYPE)
/** This kernel reshapes each of the tensor's low three dimensions to single rows.
*
@@ -484,17 +500,16 @@
#endif /* HAS_BIAS */
)
{
- Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
#endif /* HAS_BIAS */
- __global DATA_TYPE *input_ptr = (__global DATA_TYPE *)src.ptr;
- __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + get_global_id(1) * SRC_WIDTH * dst_stride_x + get_global_id(2) * dst_stride_y;
+ __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + get_global_id(1) * in_stride_y + get_global_id(2) * in_stride_z;
+ __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + get_global_id(1) * SRC_WIDTH * dst_stride_x + get_global_id(2) * dst_stride_y;
- for(int i = 0; i < SRC_WIDTH; ++i, ++input_ptr)
+ for(int i = 0; i < SRC_WIDTH; ++i, input_ptr += in_stride_x)
{
- *((__global DATA_TYPE *)(output_ptr + i * dst_stride_x)) = *input_ptr;
+ *((__global DATA_TYPE *)(output_ptr + i * dst_stride_x)) = *((__global DATA_TYPE *)input_ptr);
}
#if defined(HAS_BIAS)
@@ -512,7 +527,7 @@
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
* @note The convolution information must be passed at compile time using -DSTRIDE_X, -DSTRIDE_Y, -DPAD_LEFT, -DPAD_TOP, -DPAD_RIGHT, -DPAD_BOTTOM, -DKERNEL_WIDHT, -DKERNEL_HEIGHT, -DSRC_WIDTH, -DSRC_HEIGHT, -DDEPTH_MULTIPLIER
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -541,7 +556,7 @@
const int src_y = -PAD_TOP + src_pixel_linear / max_initial_x * STRIDE_Y;
const int src_z = get_global_id(2) / DEPTH_MULTIPLIER;
- __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + src_z * src_stride_z;
+ __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + src_z * in_stride_z;
__global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst.ptr));
for(int y = src_y; y < src_y + KERNEL_HEIGHT; ++y)
@@ -554,7 +569,7 @@
}
else
{
- *output_ptr = *((__global DATA_TYPE *)(input_ptr + x * src_stride_x + y * src_stride_y));
+ *output_ptr = *((__global DATA_TYPE *)(input_ptr + x * in_stride_x + y * in_stride_y));
}
}
}
@@ -572,7 +587,7 @@
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
* @note The convolution information must be passed at compile time using -DCONV_WIDTH, -DCONV_HEIGHT, e.g -DCONV_WIDTH=32, -DCONV_HEIGHT=42
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
@@ -596,7 +611,7 @@
const int z = id0 / patch_size;
const int index2D = id0 - z * patch_size;
- __global uchar *out_ptr = dst_ptr + dst_offset_first_element_in_bytes + index2D % CONV_WIDTH * dst_stride_x + index2D / CONV_WIDTH * dst_stride_y + z * dst_stride_z;
+ __global uchar *out_ptr = dst_ptr + dst_offset_first_element_in_bytes + index2D % CONV_WIDTH * out_stride_x + index2D / CONV_WIDTH * out_stride_y + z * out_stride_z;
*((__global DATA_TYPE *)out_ptr) = *((__global DATA_TYPE *)src.ptr);
}
@@ -980,3 +995,335 @@
vstore4(pixels1, 0, (__global half *)(dst.ptr + 1 * dst_stride_y));
}
#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(DEPTH_MULTIPLIER)
+
+#if defined(VEC_SIZE) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT)
+
+#define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE)
+
+#if defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y)
+/** This function computes the depthwise convolution for NHWC data layout when the stride along the width or height is not 1.
+ *
+ * @note The number of elements read per thread must be passed at compile time using -DVEC_SIZE (e.g. -DVEC_SIZE=2)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_TOP (e.g. -DCONV_PAD_TOP=1)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_LEFT (e.g. -DCONV_PAD_LEFT=1)
+ * @note The convolution stride along the width must be passed at compile time using -DCONV_STRIDE_X (e.g. -DCONV_STRIDE_Y=X)
+ * @note The convolution stride along the height must be passed at compile time using -DCONV_STRIDE_Y (e.g. -DCONV_STRIDE_Y=1)
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: FP32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: same as src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] weights_ptr Pointer to the weights tensor. Supported data types: QASYMM8
+ * @param[in] weights_stride_x Stride of the weights tensor in X dimension (in bytes)
+ * @param[in] weights_step_x weights_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] weights_stride_y Stride of the weights tensor in Y dimension (in bytes)
+ * @param[in] weights_step_y weights_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] weights_stride_z Stride of the weights tensor in Z dimension (in bytes)
+ * @param[in] weights_step_z weights_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
+ * @param[in] max_offset Max offset for the input tensor
+ * @param[in] biases_ptr (Optional) Pointer to the biases vector. Supported data types: same as src_ptr
+ * @param[in] biases_stride_x (Optional) Stride of the biases vector in X dimension (in bytes)
+ * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases vector
+ */
+__kernel void depthwise_convolution_3x3_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+ TENSOR3D_DECLARATION(weights),
+#if defined(HAS_BIAS)
+ VECTOR_DECLARATION(biases),
+#endif /* defined(HAS_BIAS) */
+ int max_offset)
+{
+ int x = get_global_id(0); // channels
+ int y = get_global_id(1); // spatial coordinate x
+ int z = get_global_id(2); // spatial coordinate y
+
+ Vector weights = CONVERT_TO_VECTOR_STRUCT(weights);
+
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float) * VEC_SIZE;
+
+ int z_coord = 0;
+ int4 offset = 0;
+ int4 y_offset = ((int4)(y * CONV_STRIDE_X) + (int4)(0, 1, 2, 3) - CONV_PAD_LEFT) * (int4)src_stride_y;
+
+ // We compute 2x1x1 [C,W,H] elements
+ VEC_FLOAT acc = 0;
+
+ // Load weights
+ VEC_FLOAT w0 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 0 * weights_stride_z));
+ VEC_FLOAT w1 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 0 * weights_stride_z));
+ VEC_FLOAT w2 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 0 * weights_stride_z));
+ VEC_FLOAT w3 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 1 * weights_stride_z));
+ VEC_FLOAT w4 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 1 * weights_stride_z));
+ VEC_FLOAT w5 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 1 * weights_stride_z));
+ VEC_FLOAT w6 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 2 * weights_stride_z));
+ VEC_FLOAT w7 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 2 * weights_stride_z));
+ VEC_FLOAT w8 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 2 * weights_stride_z));
+
+ // Load input values
+ // z == 0
+ // Clamp z_coord as for z = 0, it can be negative
+ // z_coord is casted to unsigned int in order to use just a min() operation
+ // A "-1" 32 bit signed variable converted to unsigned gives 4294967295
+ z_coord = z * CONV_STRIDE_Y - (int)CONV_PAD_TOP;
+ z_coord = min((uint)z_coord, (uint)SRC_DIM_2);
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ offset = min(offset, (int4)max_offset);
+
+ VEC_FLOAT values0 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
+ VEC_FLOAT values1 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
+ VEC_FLOAT values2 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+
+ // z == 1
+ // z_coord can be only negative for z = 0 so we do not need to clamp it
+ // Moreover z_coord cannot be out-of-bound for z = 1 so we do not need to clamp the offset
+ z_coord = z * CONV_STRIDE_Y - (int)CONV_PAD_TOP + 1;
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ VEC_FLOAT values3 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
+ VEC_FLOAT values4 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
+ VEC_FLOAT values5 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+
+ // z == 2
+ // After z = 1 we can simply add src_stride_z to offset without updating z_coord
+ // However offset can be out-of-bound so we need to check if it is greater than max_offset
+ offset += (int4)src_stride_z;
+ offset = min(offset, (int4)max_offset);
+ VEC_FLOAT values6 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
+ VEC_FLOAT values7 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
+ VEC_FLOAT values8 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+
+ acc = fma(values0, w0, acc);
+ acc = fma(values1, w1, acc);
+ acc = fma(values2, w2, acc);
+
+ acc = fma(values3, w3, acc);
+ acc = fma(values4, w4, acc);
+ acc = fma(values5, w5, acc);
+
+ acc = fma(values6, w6, acc);
+ acc = fma(values7, w7, acc);
+ acc = fma(values8, w8, acc);
+
+#if defined(HAS_BIAS)
+ Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
+ VEC_FLOAT bias_values = VLOAD(VEC_SIZE)(0, (__global float *)biases.ptr);
+ acc += bias_values;
+#endif // defined(HAS_BIAS)
+
+ Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
+ VSTORE(VEC_SIZE)
+ (acc, 0, (__global float *)(dst.ptr));
+}
+#endif // defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y)
+
+#if defined(NUM_ROWS_PROCESSED) && defined(NUM_PLANES_PROCESSED)
+/** This function computes the depthwise convolution for NHWC data layout when the stride along the width and height is 1.
+ *
+ * @note The number of elements read per thread must be passed at compile time using -DVEC_SIZE (e.g. -DVEC_SIZE=2)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The number of rows processed per thread must be passed at compile time using -DNUM_ROWS_PROCESSED (i.e. -DNUM_ROWS_PROCESSED=2)
+ * @note The number of planes processed per thread must be passed at compile time using -DNUM_PLANES_PROCESSED (i.e. -DNUM_PLANES_PROCESSED=2)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_TOP (e.g. -DCONV_PAD_TOP=1)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_LEFT (e.g. -DCONV_PAD_LEFT=1)
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: FP32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: same as src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] weights_ptr Pointer to the weights tensor. Supported data types: QASYMM8
+ * @param[in] weights_stride_x Stride of the weights tensor in X dimension (in bytes)
+ * @param[in] weights_step_x weights_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] weights_stride_y Stride of the weights tensor in Y dimension (in bytes)
+ * @param[in] weights_step_y weights_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] weights_stride_z Stride of the weights tensor in Z dimension (in bytes)
+ * @param[in] weights_step_z weights_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
+ * @param[in] max_offset Max offset for the input tensor
+ * @param[in] biases_ptr (Optional) Pointer to the biases vector. Supported data types: same as src_ptr
+ * @param[in] biases_stride_x (Optional) Stride of the biases vector in X dimension (in bytes)
+ * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases vector
+ */
+__kernel void depthwise_convolution_3x3_nhwc_stride1(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+ TENSOR3D_DECLARATION(weights),
+#if defined(HAS_BIAS)
+ VECTOR_DECLARATION(biases),
+#endif /* defined(HAS_BIAS) */
+ int max_offset)
+{
+ int x = get_global_id(0); // channels
+ int y = get_global_id(1); // spatial coordinate x
+ int z = get_global_id(2); // spatial coordinate y
+
+ Vector weights = CONVERT_TO_VECTOR_STRUCT(weights);
+
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float) * VEC_SIZE;
+
+ int z_coord = 0;
+ int4 offset = 0;
+ int4 y_offset = ((int4)(y * NUM_ROWS_PROCESSED) + (int4)(0, 1, 2, 3) - (int)CONV_PAD_LEFT) * (int4)src_stride_y;
+
+ // We compute 2x2x2 [C,W,H] elements
+ VEC_FLOAT acc0 = 0;
+ VEC_FLOAT acc1 = 0;
+ VEC_FLOAT acc2 = 0;
+ VEC_FLOAT acc3 = 0;
+
+ // Load weights
+ VEC_FLOAT w0 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 0 * weights_stride_z));
+ VEC_FLOAT w1 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 0 * weights_stride_z));
+ VEC_FLOAT w2 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 0 * weights_stride_z));
+ VEC_FLOAT w3 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 1 * weights_stride_z));
+ VEC_FLOAT w4 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 1 * weights_stride_z));
+ VEC_FLOAT w5 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 1 * weights_stride_z));
+ VEC_FLOAT w6 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 0 * weights_stride_y + 2 * weights_stride_z));
+ VEC_FLOAT w7 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 1 * weights_stride_y + 2 * weights_stride_z));
+ VEC_FLOAT w8 = VLOAD(VEC_SIZE)(0, (__global float *)(weights.ptr + 2 * weights_stride_y + 2 * weights_stride_z));
+
+ // Load input values
+ // z == 0
+ // Clamp z_coord as for z = 0, it can be negative
+ // z_coord is casted to unsigned int in order to use just a min() operation
+ // A "-1" 32 bit signed variable converted to unsigned gives 4294967295
+ z_coord = z * (int)NUM_PLANES_PROCESSED - (int)CONV_PAD_TOP;
+ z_coord = min((uint)z_coord, (uint)SRC_DIM_2);
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ offset = min(offset, (int4)max_offset);
+
+ VEC_FLOAT values0 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
+ VEC_FLOAT values1 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
+ VEC_FLOAT values2 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+ VEC_FLOAT values3 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s3));
+
+ // z == 1
+ // z_coord can be only negative for z = 0 so we do not need to clamp it
+ // Moreover z_coord cannot be out-of-bound for z = 1 so we do not need to clamp the offset
+ z_coord = z * (int)NUM_PLANES_PROCESSED - (int)CONV_PAD_TOP + 1;
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ VEC_FLOAT values4 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
+ VEC_FLOAT values5 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
+ VEC_FLOAT values6 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+ VEC_FLOAT values7 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s3));
+
+ // z == 2
+ // After z = 1 we can simply add src_stride_z to offset without updating z_coord
+ // However offset can be out-of-bound so we need to check if it is greater than max_offset
+ offset += (int4)src_stride_z;
+ offset = min(offset, (int4)max_offset);
+ VEC_FLOAT values8 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
+ VEC_FLOAT values9 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
+ VEC_FLOAT values10 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+ VEC_FLOAT values11 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s3));
+
+ // z == 3
+ // After z = 1 we can simply add src_stride_z to offset without updating z_coord
+ // However offset can be out-of-bound so we need to check if it is greater than max_offset
+ offset += (int4)src_stride_z;
+ offset = min(offset, (int4)max_offset);
+ VEC_FLOAT values12 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s0));
+ VEC_FLOAT values13 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s1));
+ VEC_FLOAT values14 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s2));
+ VEC_FLOAT values15 = VLOAD(VEC_SIZE)(0, (__global float *)(src_addr + offset.s3));
+
+ acc0 = fma(values0, w0, acc0);
+ acc0 = fma(values1, w1, acc0);
+ acc0 = fma(values2, w2, acc0);
+ acc1 = fma(values1, w0, acc1);
+ acc1 = fma(values2, w1, acc1);
+ acc1 = fma(values3, w2, acc1);
+
+ acc0 = fma(values4, w3, acc0);
+ acc0 = fma(values5, w4, acc0);
+ acc0 = fma(values6, w5, acc0);
+ acc1 = fma(values5, w3, acc1);
+ acc1 = fma(values6, w4, acc1);
+ acc1 = fma(values7, w5, acc1);
+
+ acc0 = fma(values8, w6, acc0);
+ acc0 = fma(values9, w7, acc0);
+ acc0 = fma(values10, w8, acc0);
+ acc1 = fma(values9, w6, acc1);
+ acc1 = fma(values10, w7, acc1);
+ acc1 = fma(values11, w8, acc1);
+
+ acc2 = fma(values4, w0, acc2);
+ acc2 = fma(values5, w1, acc2);
+ acc2 = fma(values6, w2, acc2);
+ acc3 = fma(values5, w0, acc3);
+ acc3 = fma(values6, w1, acc3);
+ acc3 = fma(values7, w2, acc3);
+
+ acc2 = fma(values8, w3, acc2);
+ acc2 = fma(values9, w4, acc2);
+ acc2 = fma(values10, w5, acc2);
+ acc3 = fma(values9, w3, acc3);
+ acc3 = fma(values10, w4, acc3);
+ acc3 = fma(values11, w5, acc3);
+
+ acc2 = fma(values12, w6, acc2);
+ acc2 = fma(values13, w7, acc2);
+ acc2 = fma(values14, w8, acc2);
+ acc3 = fma(values13, w6, acc3);
+ acc3 = fma(values14, w7, acc3);
+ acc3 = fma(values15, w8, acc3);
+
+#if defined(HAS_BIAS)
+ Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
+
+ VEC_FLOAT bias_values = VLOAD(VEC_SIZE)(0, (__global float *)biases.ptr);
+
+ acc0 += bias_values;
+ acc1 += bias_values;
+ acc2 += bias_values;
+ acc3 += bias_values;
+#endif // defined(HAS_BIAS)
+
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * dst_step_x + y * dst_step_y + (z * NUM_PLANES_PROCESSED) * dst_step_z;
+
+ VSTORE(VEC_SIZE)
+ (acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+ VSTORE(VEC_SIZE)
+ (acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+
+#if((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0)
+ if((z * NUM_PLANES_PROCESSED + 1) < DST_DIM_2)
+#endif // ((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0)
+ {
+ VSTORE(VEC_SIZE)
+ (acc2, 0, (__global float *)(dst_addr + 0 * dst_stride_y + 1 * dst_stride_z));
+ VSTORE(VEC_SIZE)
+ (acc3, 0, (__global float *)(dst_addr + 1 * dst_stride_y + 1 * dst_stride_z));
+ }
+}
+
+#endif // defined(NUM_ROWS_PROCESSED) && defined(NUM_PLANES_PROCESSED)
+#endif // defined(VEC_SIZE) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT)
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
index ccb3a1f..fe902ed 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
@@ -37,12 +37,22 @@
#define ACTIVATION_FUNC(x) (x)
#endif /* defined(FUSED_ACTIVATION) */
-#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X)
+#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val = arm_dot_acc((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3), val);
+#else // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val += arm_dot((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3));
+#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+
+#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER)
#if CONV_STRIDE_X > 3
#error "Stride X not supported"
#endif /* CONV_STRIDE_X > 3 */
+#if !defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+
#if CONV_STRIDE_X == 1
#define GET_VALUES(first_value, left, middle, right) \
({ \
@@ -250,34 +260,40 @@
#endif /* CONV_STRIDE_Y == 1 */
}
-#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) */
+#else // !defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
-#if defined(VEC_SIZE) && defined(SRC_DEPTH) && defined(CONV_PAD_TOP) && defined(ROWS_READ)
-
-#define asymm_mult_by_quant_multiplier_less_than_one(x, y, z) ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(x, y, z, VEC_SIZE)
-
-#define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE)
-#define VEC_UCHAR VEC_DATA_TYPE(uchar, VEC_SIZE)
-
-#define BIFROST_MAD_4(acc, x, y) \
- ({ \
- acc.s0 += (ushort)x.s0 * (ushort)y.s0; \
- acc.s1 += (ushort)x.s1 * (ushort)y.s1; \
- acc.s2 += (ushort)x.s2 * (ushort)y.s2; \
- acc.s3 += (ushort)x.s3 * (ushort)y.s3; \
+#if CONV_STRIDE_X == 1
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ uchar8 temp0 = vload8(0, first_value); \
+ uchar2 temp1 = vload2(0, (first_value + 8 * sizeof(uchar))); \
+ \
+ left = temp0.s01234567; \
+ middle = (uchar8)(temp0.s1234, temp0.s567, temp1.s0); \
+ right = (uchar8)(temp0.s2345, temp0.s67, temp1.s01); \
})
-
-#if WEIGHTS_OFFSET != 0
-#define BIFROST_MAD_ACC_4(acc, sum, x, y) \
- ({ \
- sum += CONVERT(x, VEC_INT); \
- BIFROST_MAD_4(acc, x, y); \
+#elif CONV_STRIDE_X == 2
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ uchar16 temp0 = vload16(0, first_value); \
+ uchar temp1 = *(first_value + 16 * sizeof(uchar)); \
+ \
+ left = temp0.s02468ace; \
+ middle = temp0.s13579bdf; \
+ right = (uchar8)(temp0.s2468, temp0.sace, temp1); \
})
-#else /* WEIGHTS_OFFSET != 0 */
-#define BIFROST_MAD_ACC_4(acc, sum, x, y) BIFROST_MAD_4(acc, x, y)
-#endif /* WEIGHTS_OFFSET != 0 */
-
-/** This function computes the depthwise convolution quantized.
+#else /* CONV_STRIDE_X */
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ uchar16 temp0 = vload16(0, first_value); \
+ uchar8 temp1 = vload8(0, (first_value + 16 * sizeof(uchar))); \
+ \
+ left = (uchar8)(temp0.s0369, temp0.scf, temp1.s25); \
+ middle = (uchar8)(temp0.s147a, temp0.sd, temp1.s036); \
+ right = (uchar8)(temp0.s258b, temp0.se, temp1.s147); \
+ })
+#endif /* CONV_STRIDE_X */
+/** This function computes the depthwise convolution quantized using dot product when the data layout is NCHW.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: QASYMM8
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
@@ -309,162 +325,568 @@
* @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases vector
*/
+__kernel void depthwise_convolution_3x3_quantized_dot8_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+ TENSOR3D_DECLARATION(weights)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(biases)
+#endif //defined(HAS_BIAS)
+)
+{
+ Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
+ Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+#if defined(HAS_BIAS)
+ Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
+
+ const int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
+#endif //defined(HAS_BIAS)
+
+ src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+
+ uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
+ uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
+ uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
+
+ uchar8 left0, middle0, right0;
+ uchar8 left1, middle1, right1;
+ uchar8 left2, middle2, right2;
+
+ int8 values0 = 0;
+ int8 sum0 = 0;
+
+ GET_VALUES(src.ptr + 0 * src_stride_y, left0, middle0, right0);
+ GET_VALUES(src.ptr + 1 * src_stride_y, left1, middle1, right1);
+ GET_VALUES(src.ptr + 2 * src_stride_y, left2, middle2, right2);
+
+#if WEIGHTS_OFFSET != 0
+ sum0 += convert_int8(left0) + convert_int8(middle0) + convert_int8(right0);
+ sum0 += convert_int8(left1) + convert_int8(middle1) + convert_int8(right1);
+ sum0 += convert_int8(left2) + convert_int8(middle2) + convert_int8(right2);
+#endif /* WEIGHTS_OFFSET != 0 */
+
+#if CONV_STRIDE_Y == 1
+ // If conv_stride_y is equals to 1, we compute two output rows
+
+ uchar8 left3, middle3, right3;
+ int8 values1 = 0;
+ int8 sum1 = 0;
+
+ GET_VALUES(src.ptr + 3 * src_stride_y, left3, middle3, right3);
+
+#if WEIGHTS_OFFSET != 0
+ sum1 += convert_int8(left1) + convert_int8(middle1) + convert_int8(right1);
+ sum1 += convert_int8(left2) + convert_int8(middle2) + convert_int8(right2);
+ sum1 += convert_int8(left3) + convert_int8(middle3) + convert_int8(right3);
+#endif /* WEIGHTS_OFFSET != 0 */
+#endif // CONV_STRIDE_Y == 1
+
+ ARM_DOT(left0.s0, middle0.s0, right0.s0, left1.s0, w0.s0, w0.s1, w0.s2, w1.s0, values0.s0);
+ ARM_DOT(middle1.s0, right1.s0, left2.s0, middle2.s0, w1.s1, w1.s2, w2.s0, w2.s1, values0.s0);
+ values0.s0 += right2.s0 * w2.s2;
+
+ ARM_DOT(left0.s1, middle0.s1, right0.s1, left1.s1, w0.s0, w0.s1, w0.s2, w1.s0, values0.s1);
+ ARM_DOT(middle1.s1, right1.s1, left2.s1, middle2.s1, w1.s1, w1.s2, w2.s0, w2.s1, values0.s1);
+ values0.s1 += right2.s1 * w2.s2;
+
+ ARM_DOT(left0.s2, middle0.s2, right0.s2, left1.s2, w0.s0, w0.s1, w0.s2, w1.s0, values0.s2);
+ ARM_DOT(middle1.s2, right1.s2, left2.s2, middle2.s2, w1.s1, w1.s2, w2.s0, w2.s1, values0.s2);
+ values0.s2 += right2.s2 * w2.s2;
+
+ ARM_DOT(left0.s3, middle0.s3, right0.s3, left1.s3, w0.s0, w0.s1, w0.s2, w1.s0, values0.s3);
+ ARM_DOT(middle1.s3, right1.s3, left2.s3, middle2.s3, w1.s1, w1.s2, w2.s0, w2.s1, values0.s3);
+ values0.s3 += right2.s3 * w2.s2;
+
+ ARM_DOT(left0.s4, middle0.s4, right0.s4, left1.s4, w0.s0, w0.s1, w0.s2, w1.s0, values0.s4);
+ ARM_DOT(middle1.s4, right1.s4, left2.s4, middle2.s4, w1.s1, w1.s2, w2.s0, w2.s1, values0.s4);
+ values0.s4 += right2.s4 * w2.s2;
+
+ ARM_DOT(left0.s5, middle0.s5, right0.s5, left1.s5, w0.s0, w0.s1, w0.s2, w1.s0, values0.s5);
+ ARM_DOT(middle1.s5, right1.s5, left2.s5, middle2.s5, w1.s1, w1.s2, w2.s0, w2.s1, values0.s5);
+ values0.s5 += right2.s5 * w2.s2;
+
+ ARM_DOT(left0.s6, middle0.s6, right0.s6, left1.s6, w0.s0, w0.s1, w0.s2, w1.s0, values0.s6);
+ ARM_DOT(middle1.s6, right1.s6, left2.s6, middle2.s6, w1.s1, w1.s2, w2.s0, w2.s1, values0.s6);
+ values0.s6 += right2.s6 * w2.s2;
+
+ ARM_DOT(left0.s7, middle0.s7, right0.s7, left1.s7, w0.s0, w0.s1, w0.s2, w1.s0, values0.s7);
+ ARM_DOT(middle1.s7, right1.s7, left2.s7, middle2.s7, w1.s1, w1.s2, w2.s0, w2.s1, values0.s7);
+ values0.s7 += right2.s7 * w2.s2;
+
+#if CONV_STRIDE_Y == 1
+ ARM_DOT(left1.s0, middle1.s0, right1.s0, left2.s0, w0.s0, w0.s1, w0.s2, w1.s0, values1.s0);
+ ARM_DOT(middle2.s0, right2.s0, left3.s0, middle3.s0, w1.s1, w1.s2, w2.s0, w2.s1, values1.s0);
+ values1.s0 += right3.s0 * w2.s2;
+
+ ARM_DOT(left1.s1, middle1.s1, right1.s1, left2.s1, w0.s0, w0.s1, w0.s2, w1.s0, values1.s1);
+ ARM_DOT(middle2.s1, right2.s1, left3.s1, middle3.s1, w1.s1, w1.s2, w2.s0, w2.s1, values1.s1);
+ values1.s1 += right3.s1 * w2.s2;
+
+ ARM_DOT(left1.s2, middle1.s2, right1.s2, left2.s2, w0.s0, w0.s1, w0.s2, w1.s0, values1.s2);
+ ARM_DOT(middle2.s2, right2.s2, left3.s2, middle3.s2, w1.s1, w1.s2, w2.s0, w2.s1, values1.s2);
+ values1.s2 += right3.s2 * w2.s2;
+
+ ARM_DOT(left1.s3, middle1.s3, right1.s3, left2.s3, w0.s0, w0.s1, w0.s2, w1.s0, values1.s3);
+ ARM_DOT(middle2.s3, right2.s3, left3.s3, middle3.s3, w1.s1, w1.s2, w2.s0, w2.s1, values1.s3);
+ values1.s3 += right3.s3 * w2.s2;
+
+ ARM_DOT(left1.s4, middle1.s4, right1.s4, left2.s4, w0.s0, w0.s1, w0.s2, w1.s0, values1.s4);
+ ARM_DOT(middle2.s4, right2.s4, left3.s4, middle3.s4, w1.s1, w1.s2, w2.s0, w2.s1, values1.s4);
+ values1.s4 += right3.s4 * w2.s2;
+
+ ARM_DOT(left1.s5, middle1.s5, right1.s5, left2.s5, w0.s0, w0.s1, w0.s2, w1.s0, values1.s5);
+ ARM_DOT(middle2.s5, right2.s5, left3.s5, middle3.s5, w1.s1, w1.s2, w2.s0, w2.s1, values1.s5);
+ values1.s5 += right3.s5 * w2.s2;
+
+ ARM_DOT(left1.s6, middle1.s6, right1.s6, left2.s6, w0.s0, w0.s1, w0.s2, w1.s0, values1.s6);
+ ARM_DOT(middle2.s6, right2.s6, left3.s6, middle3.s6, w1.s1, w1.s2, w2.s0, w2.s1, values1.s6);
+ values1.s6 += right3.s6 * w2.s2;
+
+ ARM_DOT(left1.s7, middle1.s7, right1.s7, left2.s7, w0.s0, w0.s1, w0.s2, w1.s0, values1.s7);
+ ARM_DOT(middle2.s7, right2.s7, left3.s7, middle3.s7, w1.s1, w1.s2, w2.s0, w2.s1, values1.s7);
+ values1.s7 += right3.s7 * w2.s2;
+#endif // CONV_STRIDE_Y == 1
+
+#if defined(HAS_BIAS)
+ values0 += (int8)(bias_value);
+#if CONV_STRIDE_Y == 1
+ values1 += (int8)(bias_value);
+#endif /* CONV_STRIDE_Y == 1 */
+#endif //defined(HAS_BIAS)
+
+#if WEIGHTS_OFFSET != 0
+ values0 += sum0 * (int8)(WEIGHTS_OFFSET);
+#if CONV_STRIDE_Y == 1
+ values1 += sum1 * (int8)(WEIGHTS_OFFSET);
+#endif /* CONV_STRIDE_Y == 1 */
+#endif /* WEIGHTS_OFFSET != 0 */
+
+#if INPUT_OFFSET != 0
+ ushort sum_weights = 0;
+ ushort3 tmp_we = convert_ushort3(w0) + convert_ushort3(w1) + convert_ushort3(w2);
+ sum_weights += tmp_we.s0 + tmp_we.s1 + tmp_we.s2;
+ values0 += sum_weights * (int8)(INPUT_OFFSET);
+#if CONV_STRIDE_Y == 1
+ values1 += sum_weights * (int8)(INPUT_OFFSET);
+#endif /* CONV_STRIDE_Y == 1 */
+#endif /* INPUT_OFFSET != 0 */
+
+#if K_OFFSET != 0
+ values0 += (int8)(K_OFFSET);
+#if CONV_STRIDE_Y == 1
+ values1 += (int8)(K_OFFSET);
+#endif /* CONV_STRIDE_Y == 1 */
+#endif /* K_OFFSET != 0 */
+
+ values0 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(values0, OUTPUT_MULTIPLIER, OUTPUT_SHIFT, 8);
+ values0 += (int8)OUTPUT_OFFSET;
+ uchar8 res0 = convert_uchar8_sat(values0);
+ res0 = max(res0, (uchar8)0);
+ res0 = min(res0, (uchar8)255);
+
+ vstore8(ACTIVATION_FUNC(res0), 0, dst.ptr);
+#if CONV_STRIDE_Y == 1
+
+ values1 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(values1, OUTPUT_MULTIPLIER, OUTPUT_SHIFT, 8);
+ values1 += (int8)OUTPUT_OFFSET;
+ uchar8 res1 = convert_uchar8_sat(values1);
+ res1 = max(res1, (uchar8)0);
+ res1 = min(res1, (uchar8)255);
+
+ vstore8(ACTIVATION_FUNC(res1), 0, dst.ptr + dst_stride_y);
+#endif /* CONV_STRIDE_Y == 1 */
+}
+
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
+
+#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) */
+
+#if defined(VEC_SIZE) && defined(SRC_DIM_1) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT)
+
+#define asymm_mult_by_quant_multiplier_less_than_one(x, y, z) ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(x, y, z, VEC_SIZE)
+
+#define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE)
+#define VEC_UCHAR VEC_DATA_TYPE(uchar, VEC_SIZE)
+#define VEC_USHORT VEC_DATA_TYPE(ushort, VEC_SIZE)
+
+#define MULTIPLY_ADD(x, y, acc) acc += CONVERT(CONVERT(x, VEC_USHORT) * CONVERT(y, VEC_USHORT), VEC_INT)
+
+#if WEIGHTS_OFFSET != 0
+#define MULTIPLY_ADD_ACCUMULATE(x, y, acc, sum) \
+ ({ \
+ sum += CONVERT(x, VEC_INT); \
+ MULTIPLY_ADD(x, y, acc); \
+ })
+#else /* WEIGHTS_OFFSET != 0 */
+#define MULTIPLY_ADD_ACCUMULATE(x, y, acc, sum) MULTIPLY_ADD(x, y, acc)
+#endif /* WEIGHTS_OFFSET != 0 */
+
+#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+#define DOT_PRODUCT(acc, val0, val1, val2, val3, val4, val5, val6, val7, val8, w0, w1, w2, w3, w4, w5, w6, w7, w8) \
+ ({ \
+ ARM_DOT(val0.s0, val1.s0, val2.s0, val3.s0, w0.s0, w1.s0, w2.s0, w3.s0, acc.s0); \
+ ARM_DOT(val4.s0, val5.s0, val6.s0, val7.s0, w4.s0, w5.s0, w6.s0, w7.s0, acc.s0); \
+ acc.s0 += val8.s0 * w8.s0; \
+ \
+ ARM_DOT(val0.s1, val1.s1, val2.s1, val3.s1, w0.s1, w1.s1, w2.s1, w3.s1, acc.s1); \
+ ARM_DOT(val4.s1, val5.s1, val6.s1, val7.s1, w4.s1, w5.s1, w6.s1, w7.s1, acc.s1); \
+ acc.s1 += val8.s1 * w8.s1; \
+ \
+ ARM_DOT(val0.s2, val1.s2, val2.s2, val3.s2, w0.s2, w1.s2, w2.s2, w3.s2, acc.s2); \
+ ARM_DOT(val4.s2, val5.s2, val6.s2, val7.s2, w4.s2, w5.s2, w6.s2, w7.s2, acc.s2); \
+ acc.s2 += val8.s2 * w8.s2; \
+ \
+ ARM_DOT(val0.s3, val1.s3, val2.s3, val3.s3, w0.s3, w1.s3, w2.s3, w3.s3, acc.s3); \
+ ARM_DOT(val4.s3, val5.s3, val6.s3, val7.s3, w4.s3, w5.s3, w6.s3, w7.s3, acc.s3); \
+ acc.s3 += val8.s3 * w8.s3; \
+ })
+
+#if WEIGHTS_OFFSET != 0
+#define DOT_PRODUCT_ACCUMULATE(acc, sum, val0, val1, val2, val3, val4, val5, val6, val7, val8, w0, w1, w2, w3, w4, w5, w6, w7, w8) \
+ ({ \
+ sum += CONVERT(val0, VEC_INT) + CONVERT(val1, VEC_INT) + CONVERT(val2, VEC_INT) + CONVERT(val3, VEC_INT) + CONVERT(val4, VEC_INT) + CONVERT(val5, VEC_INT) + CONVERT(val6, VEC_INT) + CONVERT(val7, VEC_INT) + CONVERT(val8, VEC_INT); \
+ DOT_PRODUCT(acc, val0, val1, val2, val3, val4, val5, val6, val7, val8, w0, w1, w2, w3, w4, w5, w6, w7, w8); \
+ })
+#else /* WEIGHTS_OFFSET != 0 */
+#define DOT_PRODUCT_ACCUMULATE(acc, sum, val0, val1, val2, val3, val4, val5, val6, val7, val8, w0, w1, w2, w3, w4, w5, w6, w7, w8) DOT_PRODUCT(acc, val0, val1, val2, val3, val4, val5, val6, val7, val8, w0, w1, w2, w3, w4, w5, w6, w7, w8)
+#endif /* WEIGHTS_OFFSET != 0 */
+
+#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+
+#if defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y)
+/** This function computes the depthwise convolution quantized for NHWC data layout when the stride along the width or height is not 1.
+ *
+ * @note The number of elements read per thread must be passed at compile time using -DVEC_SIZE (e.g. -DVEC_SIZE=2)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_TOP (e.g. -DCONV_PAD_TOP=1)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_LEFT (e.g. -DCONV_PAD_LEFT=1)
+ * @note The convolution stride along the width must be passed at compile time using -DCONV_STRIDE_X (e.g. -DCONV_STRIDE_Y=X)
+ * @note The convolution stride along the height must be passed at compile time using -DCONV_STRIDE_Y (e.g. -DCONV_STRIDE_Y=1)
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: QASYMM8
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] weights_ptr Pointer to the weights tensor. Supported data types: same as @p src_ptr
+ * @param[in] weights_stride_x Stride of the weights tensor in X dimension (in bytes)
+ * @param[in] weights_step_x weights_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] weights_stride_y Stride of the weights tensor in Y dimension (in bytes)
+ * @param[in] weights_step_y weights_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] weights_stride_z Stride of the weights tensor in Z dimension (in bytes)
+ * @param[in] weights_step_z weights_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
+ * @param[in] biases_ptr (Optional) Pointer to the biases vector. Supported data types: same as @p src_ptr
+ * @param[in] biases_stride_x (Optional) Stride of the biases vector in X dimension (in bytes)
+ * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases vector
+ * @param[in] max_offset Max offset for the input tensor
+ */
+__kernel void depthwise_convolution_3x3_quantized_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+ TENSOR3D_DECLARATION(weights),
+#if defined(HAS_BIAS)
+ VECTOR_DECLARATION(biases),
+#endif /* defined(HAS_BIAS) */
+ int max_offset)
+{
+ const int x = get_global_id(0); // channels
+ const int y = get_global_id(1); // spatial coordinate x
+ const int z = get_global_id(2); // spatial coordinate y
+
+ Vector weights = CONVERT_TO_VECTOR_STRUCT(weights);
+
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * VEC_SIZE;
+
+ int z_coord = 0;
+ int4 offset = 0;
+ const int4 y_offset = ((int4)(y * CONV_STRIDE_X) + (int4)(0, 1, 2, 3) - (int)CONV_PAD_LEFT) * (int4)src_stride_y;
+
+ // We compute 2x1x1 [C,W,H] elements
+ VEC_INT acc = 0, sum = 0;
+
+ // Load weights
+ VEC_UCHAR w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y + 0 * weights_stride_z);
+ VEC_UCHAR w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y + 0 * weights_stride_z);
+ VEC_UCHAR w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y + 0 * weights_stride_z);
+ VEC_UCHAR w3 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y + 1 * weights_stride_z);
+ VEC_UCHAR w4 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y + 1 * weights_stride_z);
+ VEC_UCHAR w5 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y + 1 * weights_stride_z);
+ VEC_UCHAR w6 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y + 2 * weights_stride_z);
+ VEC_UCHAR w7 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y + 2 * weights_stride_z);
+ VEC_UCHAR w8 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y + 2 * weights_stride_z);
+
+#if INPUT_OFFSET != 0
+ VEC_INT sum_we = CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT)
+ + CONVERT(w3, VEC_INT) + CONVERT(w4, VEC_INT) + CONVERT(w5, VEC_INT)
+ + CONVERT(w6, VEC_INT) + CONVERT(w7, VEC_INT) + CONVERT(w8, VEC_INT);
+#endif /* INPUT_OFFSET != 0 */
+
+ // Load input values
+ // z == 0
+ // Clamp z_coord as for z = 0, it can be negative
+ // z_coord is casted to unsigned int in order to use just a min() operation
+ // A "-1" 32 bit signed variable converted to unsigned gives 4294967295
+ z_coord = z * (int)CONV_STRIDE_Y - (int)CONV_PAD_TOP;
+ z_coord = min((uint)z_coord, (uint)SRC_DIM_2);
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ offset = min(offset, (int4)max_offset);
+
+ VEC_UCHAR values0 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values1 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values2 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+
+ // z == 1
+ // z_coord can be only negative for z = 0 so we do not need to clamp it
+ // Moreover z_coord cannot be out-of-bound for z = 1 so we do not need to clamp the offset
+ z_coord = z * (int)CONV_STRIDE_Y - (int)CONV_PAD_TOP + 1;
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ VEC_UCHAR values3 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values4 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values5 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+
+ // z == 2
+ // After z = 1 we can simply add src_stride_z to offset without updating z_coord
+ // However offset can be out-of-bound so we need to check if it is greater than max_offset
+ offset += (int4)src_stride_z;
+ offset = min(offset, (int4)max_offset);
+ VEC_UCHAR values6 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values7 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values8 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+
+ MULTIPLY_ADD_ACCUMULATE(values0, w0, acc, sum);
+ MULTIPLY_ADD_ACCUMULATE(values1, w1, acc, sum);
+ MULTIPLY_ADD_ACCUMULATE(values2, w2, acc, sum);
+
+ MULTIPLY_ADD_ACCUMULATE(values3, w3, acc, sum);
+ MULTIPLY_ADD_ACCUMULATE(values4, w4, acc, sum);
+ MULTIPLY_ADD_ACCUMULATE(values5, w5, acc, sum);
+
+ MULTIPLY_ADD_ACCUMULATE(values6, w6, acc, sum);
+ MULTIPLY_ADD_ACCUMULATE(values7, w7, acc, sum);
+ MULTIPLY_ADD_ACCUMULATE(values8, w8, acc, sum);
+
+#if defined(HAS_BIAS)
+ Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
+ VEC_INT bias_values = VLOAD(VEC_SIZE)(0, (__global int *)biases.ptr);
+ acc += bias_values;
+#endif // defined(HAS_BIAS)
+
+#if WEIGHTS_OFFSET != 0
+ acc += WEIGHTS_OFFSET * sum;
+#endif /* WEIGHTS_OFFSET != 0 */
+
+#if INPUT_OFFSET != 0
+ acc += INPUT_OFFSET * sum_we;
+#endif /* INPUT_OFFSET != 0 */
+
+#if K_OFFSET != 0
+ acc += (VEC_INT)K_OFFSET;
+#endif /* K_OFFSET != 0 */
+
+ acc = asymm_mult_by_quant_multiplier_less_than_one(acc, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
+ acc += (VEC_INT)OUTPUT_OFFSET;
+
+ VEC_UCHAR res = CONVERT_SAT(acc, VEC_UCHAR);
+ res = CLAMP(res, (VEC_UCHAR)0, (VEC_UCHAR)255);
+
+ Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
+ VSTORE(VEC_SIZE)
+ (res, 0, dst.ptr);
+}
+#endif // defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y)
+
+#if defined(NUM_ROWS_PROCESSED) && defined(NUM_PLANES_PROCESSED)
+/** This function computes the depthwise convolution quantized for NHWC data layout when the stride along the width and height is 1
+ *
+ * @note The number of elements read per thread must be passed at compile time using -DVEC_SIZE (e.g. -DVEC_SIZE=2)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The number of rows processed per thread must be passed at compile time using -DNUM_ROWS_PROCESSED (i.e. -DNUM_ROWS_PROCESSED=2)
+ * @note The number of planes processed per thread must be passed at compile time using -DNUM_PLANES_PROCESSED (i.e. -DNUM_PLANES_PROCESSED=2)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_TOP (e.g. -DCONV_PAD_TOP=1)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_LEFT (e.g. -DCONV_PAD_LEFT=1).
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: QASYMM8
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] weights_ptr Pointer to the weights tensor. Supported data types: same as @p src_ptr
+ * @param[in] weights_stride_x Stride of the weights tensor in X dimension (in bytes)
+ * @param[in] weights_step_x weights_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] weights_stride_y Stride of the weights tensor in Y dimension (in bytes)
+ * @param[in] weights_step_y weights_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] weights_stride_z Stride of the weights tensor in Z dimension (in bytes)
+ * @param[in] weights_step_z weights_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
+ * @param[in] biases_ptr (Optional) Pointer to the biases vector. Supported data types: same as @p src_ptr
+ * @param[in] biases_stride_x (Optional) Stride of the biases vector in X dimension (in bytes)
+ * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases vector
+ * @param[in] max_offset Max offset for the input tensor
+ */
+
__kernel void depthwise_convolution_3x3_quantized_nhwc_stride1(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
TENSOR3D_DECLARATION(weights),
#if defined(HAS_BIAS)
- VECTOR_DECLARATION(biases)
+ VECTOR_DECLARATION(biases),
#endif /* defined(HAS_BIAS) */
-)
+ int max_offset)
{
- Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2);
+
Vector weights = CONVERT_TO_VECTOR_STRUCT(weights);
+
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * VEC_SIZE;
+
+ int z_coord = 0;
+ int4 offset = 0;
+ int4 y_offset = ((int4)(y * NUM_ROWS_PROCESSED) + (int4)(0, 1, 2, 3) - (int)CONV_PAD_LEFT) * (int4)src_stride_y;
+
+ // We compute 2x2x2 [C,W,H] elements
+ VEC_INT acc0 = 0, sum0 = 0;
+ VEC_INT acc1 = 0, sum1 = 0;
+ VEC_INT acc2 = 0, sum2 = 0;
+ VEC_INT acc3 = 0, sum3 = 0;
+
+ // Load weights
+ VEC_UCHAR w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y + 0 * weights_stride_z);
+ VEC_UCHAR w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y + 0 * weights_stride_z);
+ VEC_UCHAR w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y + 0 * weights_stride_z);
+ VEC_UCHAR w3 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y + 1 * weights_stride_z);
+ VEC_UCHAR w4 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y + 1 * weights_stride_z);
+ VEC_UCHAR w5 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y + 1 * weights_stride_z);
+ VEC_UCHAR w6 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y + 2 * weights_stride_z);
+ VEC_UCHAR w7 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y + 2 * weights_stride_z);
+ VEC_UCHAR w8 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y + 2 * weights_stride_z);
+
+#if INPUT_OFFSET != 0
+ VEC_INT sum_we = CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT)
+ + CONVERT(w3, VEC_INT) + CONVERT(w4, VEC_INT) + CONVERT(w5, VEC_INT)
+ + CONVERT(w6, VEC_INT) + CONVERT(w7, VEC_INT) + CONVERT(w8, VEC_INT);
+#endif /* INPUT_OFFSET != 0 */
+
+ // Load input values
+ // z == 0
+ // Clamp z_coord as for z = 0, it can be negative
+ // z_coord is casted to unsigned int in order to use just a min() operation
+ // A "-1" 32 bit signed variable converted to unsigned gives 4294967295
+ z_coord = z * (int)NUM_PLANES_PROCESSED - (int)CONV_PAD_TOP;
+ z_coord = min((uint)z_coord, (uint)SRC_DIM_2);
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ offset = min(offset, (int4)max_offset);
+
+ VEC_UCHAR values0 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values1 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values2 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+ VEC_UCHAR values3 = VLOAD(VEC_SIZE)(0, src_addr + offset.s3);
+
+ // z == 1
+ // z_coord can be only negative for z = 0 so we do not need to clamp it
+ // Moreover z_coord cannot be out-of-bound for z = 1 so we do not need to clamp the offset
+ z_coord = z * (int)NUM_PLANES_PROCESSED - (int)CONV_PAD_TOP + 1;
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ VEC_UCHAR values4 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values5 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values6 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+ VEC_UCHAR values7 = VLOAD(VEC_SIZE)(0, src_addr + offset.s3);
+
+ // z == 2
+ // After z = 1 we can simply add src_stride_z to offset without updating z_coord
+ // However offset can be out-of-bound so we need to check if it is greater than max_offset
+ offset += (int4)src_stride_z;
+ offset = min(offset, (int4)max_offset);
+ VEC_UCHAR values8 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values9 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values10 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+ VEC_UCHAR values11 = VLOAD(VEC_SIZE)(0, src_addr + offset.s3);
+
+ // z == 3
+ // After z = 1 we can simply add src_stride_z to offset without updating z_coord
+ // However offset can be out-of-bound so we need to check if it is greater than max_offset
+ offset += (int4)(src_stride_z);
+ offset = min(offset, (int4)max_offset);
+ VEC_UCHAR values12 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values13 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values14 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+ VEC_UCHAR values15 = VLOAD(VEC_SIZE)(0, src_addr + offset.s3);
+
+ MULTIPLY_ADD_ACCUMULATE(values0, w0, acc0, sum0);
+ MULTIPLY_ADD_ACCUMULATE(values1, w1, acc0, sum0);
+ MULTIPLY_ADD_ACCUMULATE(values2, w2, acc0, sum0);
+ MULTIPLY_ADD_ACCUMULATE(values1, w0, acc1, sum1);
+ MULTIPLY_ADD_ACCUMULATE(values2, w1, acc1, sum1);
+ MULTIPLY_ADD_ACCUMULATE(values3, w2, acc1, sum1);
+
+ MULTIPLY_ADD_ACCUMULATE(values4, w3, acc0, sum0);
+ MULTIPLY_ADD_ACCUMULATE(values5, w4, acc0, sum0);
+ MULTIPLY_ADD_ACCUMULATE(values6, w5, acc0, sum0);
+ MULTIPLY_ADD_ACCUMULATE(values5, w3, acc1, sum1);
+ MULTIPLY_ADD_ACCUMULATE(values6, w4, acc1, sum1);
+ MULTIPLY_ADD_ACCUMULATE(values7, w5, acc1, sum1);
+
+ MULTIPLY_ADD_ACCUMULATE(values8, w6, acc0, sum0);
+ MULTIPLY_ADD_ACCUMULATE(values9, w7, acc0, sum0);
+ MULTIPLY_ADD_ACCUMULATE(values10, w8, acc0, sum0);
+ MULTIPLY_ADD_ACCUMULATE(values9, w6, acc1, sum1);
+ MULTIPLY_ADD_ACCUMULATE(values10, w7, acc1, sum1);
+ MULTIPLY_ADD_ACCUMULATE(values11, w8, acc1, sum1);
+
+ MULTIPLY_ADD_ACCUMULATE(values4, w0, acc2, sum2);
+ MULTIPLY_ADD_ACCUMULATE(values5, w1, acc2, sum2);
+ MULTIPLY_ADD_ACCUMULATE(values6, w2, acc2, sum2);
+ MULTIPLY_ADD_ACCUMULATE(values5, w0, acc3, sum3);
+ MULTIPLY_ADD_ACCUMULATE(values6, w1, acc3, sum3);
+ MULTIPLY_ADD_ACCUMULATE(values7, w2, acc3, sum3);
+
+ MULTIPLY_ADD_ACCUMULATE(values8, w3, acc2, sum2);
+ MULTIPLY_ADD_ACCUMULATE(values9, w4, acc2, sum2);
+ MULTIPLY_ADD_ACCUMULATE(values10, w5, acc2, sum2);
+ MULTIPLY_ADD_ACCUMULATE(values9, w3, acc3, sum3);
+ MULTIPLY_ADD_ACCUMULATE(values10, w4, acc3, sum3);
+ MULTIPLY_ADD_ACCUMULATE(values11, w5, acc3, sum3);
+
+ MULTIPLY_ADD_ACCUMULATE(values12, w6, acc2, sum2);
+ MULTIPLY_ADD_ACCUMULATE(values13, w7, acc2, sum2);
+ MULTIPLY_ADD_ACCUMULATE(values14, w8, acc2, sum2);
+ MULTIPLY_ADD_ACCUMULATE(values13, w6, acc3, sum3);
+ MULTIPLY_ADD_ACCUMULATE(values14, w7, acc3, sum3);
+ MULTIPLY_ADD_ACCUMULATE(values15, w8, acc3, sum3);
+
#if defined(HAS_BIAS)
Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
VEC_INT bias_values = VLOAD(VEC_SIZE)(0, (__global int *)biases.ptr);
-#endif /* defined(HAS_BIAS) */
- __global uchar *first_elem = src_ptr + src_offset_first_element_in_bytes;
-
- const int z = get_global_id(2);
- const int pad_offs = -ROWS_READ * src_stride_y;
- const int src_offs0 = get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + z * src_step_z - CONV_PAD_TOP * src_stride_z;
- const int src_offs1 = src_offs0 + src_stride_z;
- const int src_offs2 = src_offs1 + src_stride_z;
-
- const int cond_top = z - CONV_PAD_TOP < 0;
- const int cond_bottom = z * (src_step_z / src_stride_z) + 2 > SRC_DEPTH;
-
- __global uchar *src_addr0 = first_elem + select(src_offs0, pad_offs, cond_top);
- __global uchar *src_addr1 = first_elem + src_offs1;
- __global uchar *src_addr2 = first_elem + select(src_offs2, pad_offs, cond_bottom);
-
- VEC_INT sum_we = 0;
- VEC_INT acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0;
- VEC_INT sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0;
-
- // z == 0
- VEC_UCHAR w0, w1, w2;
- w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
- w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
- w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
-
-#if INPUT_OFFSET != 0
- sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
-#endif /* INPUT_OFFSET != 0 */
-
- VEC_UCHAR values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
-
- src_addr0 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
- BIFROST_MAD_ACC_4(acc1, sum1, values, w0);
-
- src_addr0 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
- BIFROST_MAD_ACC_4(acc1, sum1, values, w1);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
-
- src_addr0 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc1, sum1, values, w2);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
- BIFROST_MAD_ACC_4(acc3, sum3, values, w0);
-
- src_addr0 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
- BIFROST_MAD_ACC_4(acc3, sum3, values, w1);
-
- src_addr0 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc3, sum3, values, w2);
-
- weights.ptr += weights_stride_z;
-
- // z == 1
- w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
- w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
- w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
-
-#if INPUT_OFFSET != 0
- sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
-#endif /* INPUT_OFFSET != 0 */
-
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
-
- src_addr1 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
- BIFROST_MAD_ACC_4(acc1, sum1, values, w0);
-
- src_addr1 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
- BIFROST_MAD_ACC_4(acc1, sum1, values, w1);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
-
- src_addr1 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc1, sum1, values, w2);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
- BIFROST_MAD_ACC_4(acc3, sum3, values, w0);
-
- src_addr1 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
- BIFROST_MAD_ACC_4(acc3, sum3, values, w1);
-
- src_addr1 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc3, sum3, values, w2);
-
- weights.ptr += weights_stride_z;
-
- // z == 2
- w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
- w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
- w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
-
-#if INPUT_OFFSET != 0
- sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
-#endif /* INPUT_OFFSET != 0 */
-
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
-
- src_addr2 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
- BIFROST_MAD_ACC_4(acc1, sum1, values, w0);
-
- src_addr2 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
- BIFROST_MAD_ACC_4(acc1, sum1, values, w1);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
-
- src_addr2 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc1, sum1, values, w2);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
- BIFROST_MAD_ACC_4(acc3, sum3, values, w0);
-
- src_addr2 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
- BIFROST_MAD_ACC_4(acc3, sum3, values, w1);
-
- src_addr2 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc3, sum3, values, w2);
-
-#if defined(HAS_BIAS)
acc0 += bias_values;
acc1 += bias_values;
acc2 += bias_values;
@@ -514,17 +936,33 @@
res2 = CLAMP(res2, (VEC_UCHAR)0, (VEC_UCHAR)255);
res3 = CLAMP(res3, (VEC_UCHAR)0, (VEC_UCHAR)255);
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * dst_step_x + y * dst_step_y + (z * NUM_PLANES_PROCESSED) * dst_step_z;
+
VSTORE(VEC_SIZE)
- (res0, 0, dst.ptr + 0 * dst_stride_y);
+ (res0, 0, dst_addr + 0 * dst_stride_y);
VSTORE(VEC_SIZE)
- (res1, 0, dst.ptr + 1 * dst_stride_y);
- VSTORE(VEC_SIZE)
- (res2, 0, dst.ptr + 2 * dst_stride_y);
- VSTORE(VEC_SIZE)
- (res3, 0, dst.ptr + 3 * dst_stride_y);
+ (res1, 0, dst_addr + 1 * dst_stride_y);
+
+#if((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0)
+ if((z * NUM_PLANES_PROCESSED + 1) < DST_DIM_2)
+#endif // ((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0)
+ {
+ VSTORE(VEC_SIZE)
+ (res2, 0, dst_addr + 0 * dst_stride_y + 1 * dst_stride_z);
+ VSTORE(VEC_SIZE)
+ (res3, 0, dst_addr + 1 * dst_stride_y + 1 * dst_stride_z);
+ }
}
-/** This function computes the depthwise convolution quantized.
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+/** This function computes the depthwise convolution quantized for NHWC data layout when the stride along the width and height is 1 using dot product
+ *
+ * @note The number of elements read per thread must be passed at compile time using -DVEC_SIZE (e.g. -DVEC_SIZE=2)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The number of rows processed per thread must be passed at compile time using -DNUM_ROWS_PROCESSED (i.e. -DNUM_ROWS_PROCESSED=2)
+ * @note The number of planes processed per thread must be passed at compile time using -DNUM_PLANES_PROCESSED (i.e. -DNUM_PLANES_PROCESSED=2)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_TOP (e.g. -DCONV_PAD_TOP=1)
+ * @note The convolution pad top must be passed at compile time using -DCONV_PAD_LEFT (e.g. -DCONV_PAD_LEFT=1).
*
* @param[in] src_ptr Pointer to the source image. Supported data types: QASYMM8
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
@@ -556,171 +994,175 @@
* @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases vector
*/
-__kernel void depthwise_convolution_3x3_quantized_nhwc_stride2(
+__kernel void depthwise_convolution_3x3_quantized_dot8_nhwc_stride1(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
TENSOR3D_DECLARATION(weights),
#if defined(HAS_BIAS)
- VECTOR_DECLARATION(biases)
+ VECTOR_DECLARATION(biases),
#endif /* defined(HAS_BIAS) */
-)
+ int max_offset)
{
- Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2);
+
Vector weights = CONVERT_TO_VECTOR_STRUCT(weights);
+
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * VEC_SIZE;
+
+ int z_coord = 0;
+ int4 offset = 0;
+ int4 y_offset = ((int4)(y * NUM_ROWS_PROCESSED) + (int4)(0, 1, 2, 3) - (int)CONV_PAD_LEFT) * (int4)src_stride_y;
+
+ // We compute 2x2x2 [C,W,H] elements
+ VEC_INT acc0 = 0, sum0 = 0;
+ VEC_INT acc1 = 0, sum1 = 0;
+ VEC_INT acc2 = 0, sum2 = 0;
+ VEC_INT acc3 = 0, sum3 = 0;
+
+ // Load weights
+ VEC_UCHAR w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y + 0 * weights_stride_z);
+ VEC_UCHAR w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y + 0 * weights_stride_z);
+ VEC_UCHAR w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y + 0 * weights_stride_z);
+ VEC_UCHAR w3 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y + 1 * weights_stride_z);
+ VEC_UCHAR w4 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y + 1 * weights_stride_z);
+ VEC_UCHAR w5 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y + 1 * weights_stride_z);
+ VEC_UCHAR w6 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y + 2 * weights_stride_z);
+ VEC_UCHAR w7 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y + 2 * weights_stride_z);
+ VEC_UCHAR w8 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y + 2 * weights_stride_z);
+
+#if INPUT_OFFSET != 0
+ VEC_INT sum_we = CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT)
+ + CONVERT(w3, VEC_INT) + CONVERT(w4, VEC_INT) + CONVERT(w5, VEC_INT)
+ + CONVERT(w6, VEC_INT) + CONVERT(w7, VEC_INT) + CONVERT(w8, VEC_INT);
+#endif /* INPUT_OFFSET != 0 */
+
+ // Load input values
+ // z == 0
+ // Clamp z_coord as for z = 0, it can be negative
+ // z_coord is casted to unsigned int in order to use just a min() operation
+ // A "-1" 32 bit signed variable converted to unsigned gives 4294967295
+ z_coord = z * (int)NUM_PLANES_PROCESSED - (int)CONV_PAD_TOP;
+ z_coord = min((uint)z_coord, (uint)SRC_DIM_2);
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ offset = min(offset, (int4)max_offset);
+
+ VEC_UCHAR values0 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values1 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values2 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+ VEC_UCHAR values3 = VLOAD(VEC_SIZE)(0, src_addr + offset.s3);
+
+ // z == 1
+ // z_coord can be only negative for z = 0 so we do not need to clamp it
+ // Moreover z_coord cannot be out-of-bound for z = 1 so we do not need to clamp the offset
+ z_coord = z * (int)NUM_PLANES_PROCESSED - (int)CONV_PAD_TOP + 1;
+ offset = y_offset + (int4)(z_coord * src_stride_z);
+ VEC_UCHAR values4 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values5 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values6 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+ VEC_UCHAR values7 = VLOAD(VEC_SIZE)(0, src_addr + offset.s3);
+
+ // z == 2
+ // After z = 1 we can simply add src_stride_z to offset without updating z_coord
+ // However offset can be out-of-bound so we need to check if it is greater than max_offset
+ offset += (int4)src_stride_z;
+ offset = min(offset, (int4)max_offset);
+ VEC_UCHAR values8 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values9 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values10 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+ VEC_UCHAR values11 = VLOAD(VEC_SIZE)(0, src_addr + offset.s3);
+
+ // z == 3
+ // After z = 1 we can simply add src_stride_z to offset without updating z_coord
+ // However offset can be out-of-bound so we need to check if it is greater than max_offset
+ offset += (int4)(src_stride_z);
+ offset = min(offset, (int4)max_offset);
+ VEC_UCHAR values12 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
+ VEC_UCHAR values13 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
+ VEC_UCHAR values14 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
+ VEC_UCHAR values15 = VLOAD(VEC_SIZE)(0, src_addr + offset.s3);
+
+ DOT_PRODUCT_ACCUMULATE(acc0, sum0, values0, values1, values2, values4, values5, values6, values8, values9, values10, w0, w1, w2, w3, w4, w5, w6, w7, w8);
+ DOT_PRODUCT_ACCUMULATE(acc1, sum1, values1, values2, values3, values5, values6, values7, values9, values10, values11, w0, w1, w2, w3, w4, w5, w6, w7, w8);
+ DOT_PRODUCT_ACCUMULATE(acc2, sum2, values4, values5, values6, values8, values9, values10, values12, values13, values14, w0, w1, w2, w3, w4, w5, w6, w7, w8);
+ DOT_PRODUCT_ACCUMULATE(acc3, sum3, values5, values6, values7, values9, values10, values11, values13, values14, values15, w0, w1, w2, w3, w4, w5, w6, w7, w8);
+
#if defined(HAS_BIAS)
Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
VEC_INT bias_values = VLOAD(VEC_SIZE)(0, (__global int *)biases.ptr);
-#endif /* defined(HAS_BIAS) */
- __global uchar *first_elem = src_ptr + src_offset_first_element_in_bytes;
-
- const int z = get_global_id(2);
- const int pad_offs = -ROWS_READ * src_stride_y;
- const int src_offs0 = get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + z * src_step_z - CONV_PAD_TOP * src_stride_z;
- const int src_offs1 = src_offs0 + src_stride_z;
- const int src_offs2 = src_offs1 + src_stride_z;
-
- const int cond_top = z - CONV_PAD_TOP < 0;
- const int cond_bottom = z * (src_step_z / src_stride_z) + 2 > SRC_DEPTH;
-
- __global uchar *src_addr0 = first_elem + select(src_offs0, pad_offs, cond_top);
- __global uchar *src_addr1 = first_elem + src_offs1;
- __global uchar *src_addr2 = first_elem + select(src_offs2, pad_offs, cond_bottom);
-
- VEC_INT sum_we = 0;
- VEC_INT acc0 = 0, acc2 = 0;
- VEC_INT sum0 = 0, sum2 = 0;
-
- // z == 0
- VEC_UCHAR w0, w1, w2;
- w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
- w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
- w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
-
-#if INPUT_OFFSET != 0
- sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
-#endif /* INPUT_OFFSET != 0 */
-
- VEC_UCHAR values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
-
- src_addr0 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
-
- src_addr0 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
-
- src_addr0 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
-
- src_addr0 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr0);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
-
- weights.ptr += weights_stride_z;
-
- // z == 1
- w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
- w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
- w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
-
-#if INPUT_OFFSET != 0
- sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
-#endif /* INPUT_OFFSET != 0 */
-
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
-
- src_addr1 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
-
- src_addr1 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
-
- src_addr1 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
-
- src_addr1 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr1);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
-
- weights.ptr += weights_stride_z;
-
- // z == 2
- w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
- w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
- w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
-
-#if INPUT_OFFSET != 0
- sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
-#endif /* INPUT_OFFSET != 0 */
-
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
-
- src_addr2 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
-
- src_addr2 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
-
- src_addr2 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
-
- src_addr2 += src_stride_y;
- values = VLOAD(VEC_SIZE)(0, src_addr2);
- BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
-
-#if defined(HAS_BIAS)
acc0 += bias_values;
+ acc1 += bias_values;
acc2 += bias_values;
+ acc3 += bias_values;
#endif /* defined(HAS_BIAS) */
#if WEIGHTS_OFFSET != 0
acc0 += WEIGHTS_OFFSET * sum0;
+ acc1 += WEIGHTS_OFFSET * sum1;
acc2 += WEIGHTS_OFFSET * sum2;
+ acc3 += WEIGHTS_OFFSET * sum3;
#endif /* WEIGHTS_OFFSET != 0 */
#if INPUT_OFFSET != 0
VEC_INT offs = INPUT_OFFSET * sum_we;
acc0 += offs;
+ acc1 += offs;
acc2 += offs;
+ acc3 += offs;
#endif /* INPUT_OFFSET != 0 */
#if K_OFFSET != 0
acc0 += (VEC_INT)K_OFFSET;
+ acc1 += (VEC_INT)K_OFFSET;
acc2 += (VEC_INT)K_OFFSET;
+ acc3 += (VEC_INT)K_OFFSET;
#endif /* K_OFFSET != 0 */
acc0 = asymm_mult_by_quant_multiplier_less_than_one(acc0, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
+ acc1 = asymm_mult_by_quant_multiplier_less_than_one(acc1, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
acc2 = asymm_mult_by_quant_multiplier_less_than_one(acc2, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
+ acc3 = asymm_mult_by_quant_multiplier_less_than_one(acc3, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
+
acc0 += (VEC_INT)OUTPUT_OFFSET;
+ acc1 += (VEC_INT)OUTPUT_OFFSET;
acc2 += (VEC_INT)OUTPUT_OFFSET;
+ acc3 += (VEC_INT)OUTPUT_OFFSET;
+
VEC_UCHAR res0 = CONVERT_SAT(acc0, VEC_UCHAR);
+ VEC_UCHAR res1 = CONVERT_SAT(acc1, VEC_UCHAR);
VEC_UCHAR res2 = CONVERT_SAT(acc2, VEC_UCHAR);
- res0 = CLAMP(res0, (VEC_UCHAR)0, (VEC_UCHAR)255);
- res2 = CLAMP(res2, (VEC_UCHAR)0, (VEC_UCHAR)255);
+ VEC_UCHAR res3 = CONVERT_SAT(acc3, VEC_UCHAR);
+
+ res0 = CLAMP(res0, (VEC_UCHAR)0, (VEC_UCHAR)255);
+ res1 = CLAMP(res1, (VEC_UCHAR)0, (VEC_UCHAR)255);
+ res2 = CLAMP(res2, (VEC_UCHAR)0, (VEC_UCHAR)255);
+ res3 = CLAMP(res3, (VEC_UCHAR)0, (VEC_UCHAR)255);
+
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * dst_step_x + y * dst_step_y + (z * NUM_PLANES_PROCESSED) * dst_step_z;
VSTORE(VEC_SIZE)
- (res0, 0, dst.ptr + 0 * dst_stride_y);
+ (res0, 0, dst_addr + 0 * dst_stride_y);
VSTORE(VEC_SIZE)
- (res2, 0, dst.ptr + 1 * dst_stride_y);
+ (res1, 0, dst_addr + 1 * dst_stride_y);
+
+#if((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0)
+ if((z * NUM_PLANES_PROCESSED + 1) < DST_DIM_2)
+#endif // ((DST_DIM_2 % NUM_PLANES_PROCESSED) != 0)
+ {
+ VSTORE(VEC_SIZE)
+ (res2, 0, dst_addr + 0 * dst_stride_y + 1 * dst_stride_z);
+ VSTORE(VEC_SIZE)
+ (res3, 0, dst_addr + 1 * dst_stride_y + 1 * dst_stride_z);
+ }
}
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
-#endif /* defined(VEC_SIZE) && defined(SRC_DEPTH) && defined(CONV_PAD_TOP) && defined(ROWS_READ) */
+#endif // defined(NUM_ROWS_PROCESSED) && defined(NUM_PLANES_PROCESSED)
-#endif /* defined(WEIGHTS_OFFSET) && defined(INPUT_OFFSET) && defined(K_OFFSET) && defined(OUTPUT_OFFSET) && defined(OUTPUT_MULTIPLIER) && defined(OUTPUT_SHIFT) */
+#endif // defined(VEC_SIZE) && defined(SRC_DIM_1) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT)
+
+#endif // defined(WEIGHTS_OFFSET) && defined(INPUT_OFFSET) && defined(K_OFFSET) && defined(OUTPUT_OFFSET) && defined(OUTPUT_MULTIPLIER) && defined(OUTPUT_SHIFT)
diff --git a/src/core/CL/cl_kernels/dequantization_layer.cl b/src/core/CL/cl_kernels/dequantization_layer.cl
index 21e9c87..4908bb0 100644
--- a/src/core/CL/cl_kernels/dequantization_layer.cl
+++ b/src/core/CL/cl_kernels/dequantization_layer.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,7 +25,7 @@
/** This performs the dequantization of 8-bit unsigned integers to floating point.
*
- * @param[in] input_ptr Pointer to the source image. Supported data types: QS8/QS16/F16/F32
+ * @param[in] input_ptr Pointer to the source image. Supported data types: F16/F32
* @param[in] input_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the source image in Y dimension (in bytes)
diff --git a/src/core/CL/cl_kernels/direct_convolution1x1.cl b/src/core/CL/cl_kernels/direct_convolution1x1.cl
index 817c261..cceeb0f 100644
--- a/src/core/CL/cl_kernels/direct_convolution1x1.cl
+++ b/src/core/CL/cl_kernels/direct_convolution1x1.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,26 +23,130 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-#include "fixed_point.h"
-
-#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE_PROMOTED, 8)
-#define MUL_OP(a, b) MUL_SAT_OP_EXPAND(CONVERT((a), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), CONVERT((b), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), DATA_TYPE_PROMOTED, 8, FIXED_POINT_POSITION)
-
-// There is no need to have a larger intermediate type for qs32 because all the arguments are already promoted
-MULQ_SAT_IMPL(qs32x8, qs32x8)
-
-#else /* FIXED_POINT_POSITION */
#undef CONVERT_SAT
#define ADD_OP(a, b) ((a) + (b))
#define MUL_OP(a, b) ((a) * (b))
#define CONVERT_SAT(a, b) ((a))
-#endif /* FIXED_POINT_POSITION */
-
#if defined(DATA_TYPE) && defined(DATA_SIZE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
+#if defined(DATA_LAYOUT_NHWC)
+
+#define PTR_TO_VALUE(PTR, DATA_TYPE) *((__global DATA_TYPE *)(PTR))
+
+/** This kernel performs a direct convolution to convolve the low three dimensions of a tensor with data layout NHWC
+ *
+ * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
+ * @note The data size must be passed at compile time using -DDATA_SIZE e.g. -DDATA_SIZE=32
+ * @note The convolution stride x must be passed at compile time using -DSTRIDE_X e.g. -DSTRIDE_X=1
+ * @note The third dimensions of the weights tensors must be passed at compile time using -DWEIGHTS_DEPTH
+ * @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] weights_ptr Pointer to the weights tensor. Supported data types: same as @p src_ptr
+ * @param[in] weights_stride_x Stride of the weights tensor in X dimension (in bytes)
+ * @param[in] weights_step_x weights_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] weights_stride_y Stride of the weights tensor in Y dimension (in bytes)
+ * @param[in] weights_step_y weights_stride_y * number of elements along y processed per workitem(in bytes)
+ * @param[in] weights_stride_z Stride of the weights tensor in Z dimension (in bytes)
+ * @param[in] weights_step_z weights_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
+ * @param[in] biases_ptr Pointer to the biases tensor. Same as @p src_ptr
+ * @param[in] biases_stride_x Stride of the biases tensor in X dimension (in bytes)
+ * @param[in] biases_step_x biases_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the biases tensor
+ * @param[in] weights_stride_w Stride of the weights tensor in the 4th dimension
+ */
+__kernel void direct_convolution1x1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+ TENSOR3D_DECLARATION(weights),
+#ifdef HAS_BIAS
+ VECTOR_DECLARATION(biases),
+#endif /* defined(HAS_BIAS) */
+ unsigned int weights_stride_w)
+{
+ Image src = CONVERT_TO_IMAGE_STRUCT(src);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
+
+#ifdef HAS_BIAS
+ Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
+#endif /* defined(HAS_BIAS) */
+
+ VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
+ values = 0;
+ const int id0 = get_global_id(0);
+ const int id1 = get_global_id(1);
+ const int id2 = get_global_id(2);
+ weights.ptr += id0 * weights_stride_w;
+ __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - src_stride_x * id0 + id2 * STRIDE_Y * (int)src_stride_z;
+
+ for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
+ {
+ DATA_TYPE weight = *(__global DATA_TYPE *)weights.ptr;
+#if STRIDE_X == 1
+ VEC_DATA_TYPE(DATA_TYPE, 8)
+ col0 = (VEC_DATA_TYPE(DATA_TYPE, 8))(
+ PTR_TO_VALUE(src_addr + 0 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 1 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 2 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 3 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 4 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 5 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 6 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 7 * src_stride_y, DATA_TYPE));
+#elif STRIDE_X == 2 /* STRIDE_X == 1 */
+ VEC_DATA_TYPE(DATA_TYPE, 8)
+ col0 = (VEC_DATA_TYPE(DATA_TYPE, 8))(
+ PTR_TO_VALUE(src_addr + 0 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 2 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 4 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 6 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 8 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 10 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 12 * src_stride_y, DATA_TYPE),
+ PTR_TO_VALUE(src_addr + 14 * src_stride_y, DATA_TYPE));
+#else /* STRIDE_X not equals 1 or 2 */
+#error "STRIDE_X larger than 2 is not supported"
+#endif /* STRIDE_X == 2 */
+ values = ADD_OP(values, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))weight, col0));
+
+ src_addr += src_stride_x;
+ weights.ptr += weights_stride_x;
+ }
+
+#ifdef HAS_BIAS
+ values = ADD_OP(values, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, id0))));
+#endif /* defined(HAS_BIAS) */
+
+ *((__global DATA_TYPE *)dst.ptr) = values.s0;
+ *((__global DATA_TYPE *)(dst.ptr + 1 * dst_stride_y)) = values.s1;
+ *((__global DATA_TYPE *)(dst.ptr + 2 * dst_stride_y)) = values.s2;
+ *((__global DATA_TYPE *)(dst.ptr + 3 * dst_stride_y)) = values.s3;
+ *((__global DATA_TYPE *)(dst.ptr + 4 * dst_stride_y)) = values.s4;
+ *((__global DATA_TYPE *)(dst.ptr + 5 * dst_stride_y)) = values.s5;
+ *((__global DATA_TYPE *)(dst.ptr + 6 * dst_stride_y)) = values.s6;
+ *((__global DATA_TYPE *)(dst.ptr + 7 * dst_stride_y)) = values.s7;
+}
+#endif // defined(DATA_LAYOUT_NHWC)
+
#if STRIDE_X == 3
#define INPUT_PIXEL_STR(data_size) extract_input_stride3_##data_size
#define INPUT_PIXEL(data_size) INPUT_PIXEL_STR(data_size)
@@ -58,7 +162,7 @@
*
* @param[in] input_pixel Pointer to the first pixel.
*
- * @return extracted input pixels.
+ * @return extracted input values.
*/
inline VEC_DATA_TYPE(DATA_TYPE, 8) extract_input_stride1(__global const DATA_TYPE *input_pixel)
{
@@ -69,7 +173,7 @@
*
* @param[in] input_pixel Pointer to the first pixel.
*
- * @return extracted input pixels.
+ * @return extracted input values.
*/
inline VEC_DATA_TYPE(DATA_TYPE, 8) extract_input_stride2(__global const DATA_TYPE *input_pixel)
{
@@ -82,7 +186,7 @@
*
* @param[in] input_pixel Pointer to the first pixel.
*
- * @return extracted input pixels.
+ * @return extracted input values.
*/
inline VEC_DATA_TYPE(DATA_TYPE, 8) extract_input_stride3_32(__global const DATA_TYPE *input_pixel)
{
@@ -101,7 +205,7 @@
*
* @param[in] input_pixel Pointer to the first pixel.
*
- * @return extracted input pixels.
+ * @return extracted input values.
*/
inline VEC_DATA_TYPE(DATA_TYPE, 8) extract_input_stride3_16(__global const DATA_TYPE *input_pixel)
{
@@ -118,7 +222,7 @@
*
* @param[in] input_pixel Pointer to the first pixel.
*
- * @return extracted input pixels.
+ * @return extracted input values.
*/
inline VEC_DATA_TYPE(DATA_TYPE, 8) extract_input_stride3_8(__global const DATA_TYPE *input_pixel)
{
@@ -185,27 +289,26 @@
#endif /* defined(HAS_BIAS) */
VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
- pixels = 0;
+ values = 0;
const uint z_index = get_global_id(2);
weights.ptr += z_index * weights_stride_w;
-
for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
{
DATA_TYPE weight = *(__global DATA_TYPE *)weights.ptr;
VEC_DATA_TYPE(DATA_TYPE, 8)
input_pixel = INPUT_PIXEL(DATA_SIZE)((__global DATA_TYPE *)src.ptr);
- pixels = ADD_OP(pixels, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))weight, input_pixel));
+ values = ADD_OP(values, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))weight, input_pixel));
src.ptr += src_stride_z;
weights.ptr += weights_stride_z;
}
#ifdef HAS_BIAS
- pixels = ADD_OP(pixels, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, z_index))));
+ values = ADD_OP(values, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, z_index))));
#endif /* defined(HAS_BIAS) */
- vstore8(CONVERT_SAT(pixels, VEC_DATA_TYPE(DATA_TYPE, 8)), 0, (__global DATA_TYPE *)dst.ptr);
+ vstore8(CONVERT_SAT(values, VEC_DATA_TYPE(DATA_TYPE, 8)), 0, (__global DATA_TYPE *)dst.ptr);
}
#endif // defined(DATA_TYPE) && defined(DATA_SIZE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
@@ -326,4 +429,4 @@
vstore4(acc2, 0, (__global float *)(dst.ptr + 2 * dst_stride_y));
vstore4(acc3, 0, (__global float *)(dst.ptr + 3 * dst_stride_y));
}
-#endif // defined(WEIGHTS_DEPTH)
\ No newline at end of file
+#endif // defined(WEIGHTS_DEPTH)
diff --git a/src/core/CL/cl_kernels/direct_convolution3x3.cl b/src/core/CL/cl_kernels/direct_convolution3x3.cl
index a7abc9f..08d25f6 100644
--- a/src/core/CL/cl_kernels/direct_convolution3x3.cl
+++ b/src/core/CL/cl_kernels/direct_convolution3x3.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,25 +23,12 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-#include "fixed_point.h"
-
-#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE_PROMOTED, 8)
-#define MUL_OP(a, b) MUL_SAT_OP_EXPAND(CONVERT((a), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), CONVERT((b), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)), DATA_TYPE_PROMOTED, 8, FIXED_POINT_POSITION)
-
-// There is no need to have a larger intermediate type for qs32 because all the arguments are already promoted
-MULQ_SAT_IMPL(qs32x8, qs32x8)
-
-#else /* FIXED_POINT_POSITION */
-
#undef CONVERT_SAT
#define ADD_OP(a, b) ((a) + (b))
#define MUL_OP(a, b) ((a) * (b))
#define CONVERT_SAT(a, b) ((a))
-#endif /* FIXED_POINT_POSITION */
-
#if defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
#if STRIDE_X == 1
@@ -79,6 +66,76 @@
acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2)); \
})
+#if defined(DATA_LAYOUT_NHWC)
+
+#define PTR_TO_VALUE(PTR, DATA_TYPE) *((__global DATA_TYPE *)(PTR))
+
+#if STRIDE_X == 1
+#define CONVOLUTION1x3_NHWC(acc, row_ptr, weights_ptr) CONVOLUTION1x3_STRIDE_NHWC_STRIDE1(acc, row_ptr, weights_ptr)
+#elif STRIDE_X == 2 /* STRIDE_X == 1 */
+#define CONVOLUTION1x3_NHWC(acc, row_ptr, weights_ptr) CONVOLUTION1x3_STRIDE_NHWC_STRIDE2(acc, row_ptr, weights_ptr)
+#else /* STRIDE_X not equals 1 or 2 */
+#error "STRIDE_X larger than 2 is not supported"
+#endif /* STRIDE_X == 2 */
+
+#define CONVOLUTION1x3_STRIDE_NHWC_STRIDE1(acc, row_ptr, weights_ptr) \
+ { \
+ VEC_DATA_TYPE(DATA_TYPE, 8) \
+ src0 = (VEC_DATA_TYPE(DATA_TYPE, 8))( \
+ PTR_TO_VALUE(row_ptr + 0 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 1 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 2 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 3 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 4 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 5 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 6 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 7 * src_stride_y, DATA_TYPE)); \
+ VEC_DATA_TYPE(DATA_TYPE, 2) \
+ src1 = (VEC_DATA_TYPE(DATA_TYPE, 2))( \
+ PTR_TO_VALUE(row_ptr + 8 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 9 * src_stride_y, DATA_TYPE)); \
+ VEC_DATA_TYPE(DATA_TYPE, 3) \
+ weights = (VEC_DATA_TYPE(DATA_TYPE, 3))( \
+ PTR_TO_VALUE((weights_ptr) + 0 * weights_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE((weights_ptr) + 1 * weights_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE((weights_ptr) + 2 * weights_stride_y, DATA_TYPE)); \
+ acc = ADD_OP(acc, MUL_OP(src0, (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s0)); \
+ acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1234, src0.s567, src1.s0), (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s1)); \
+ acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s234, src0.s567, src1.s01), (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s2)); \
+ }
+
+#define CONVOLUTION1x3_STRIDE_NHWC_STRIDE2(acc, row_ptr, weights_ptr) \
+ { \
+ VEC_DATA_TYPE(DATA_TYPE, 16) \
+ src0 = (VEC_DATA_TYPE(DATA_TYPE, 16))( \
+ PTR_TO_VALUE(row_ptr + 0 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 1 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 2 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 3 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 4 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 5 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 6 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 7 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 8 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 9 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 10 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 11 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 12 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 13 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 14 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 15 * src_stride_y, DATA_TYPE)); \
+ DATA_TYPE src1 = PTR_TO_VALUE(row_ptr + 16 * src_stride_y, DATA_TYPE); \
+ VEC_DATA_TYPE(DATA_TYPE, 3) \
+ weights = (VEC_DATA_TYPE(DATA_TYPE, 3))( \
+ PTR_TO_VALUE((weights_ptr) + 0 * weights_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE((weights_ptr) + 1 * weights_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE((weights_ptr) + 2 * weights_stride_y, DATA_TYPE)); \
+ \
+ acc = ADD_OP(acc, MUL_OP(src0.s02468ACE, (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s0)); \
+ acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1357, src0.s9BDF), (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s1)); \
+ acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1), (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s2)); \
+ }
+
/** This kernel performs a direct convolution to convolve the low three dimensions.
*
* @note This OpenCL kernel works with stride_x = 1 and 2
@@ -116,6 +173,115 @@
* @param[in] biases_offset_first_element_in_bytes The offset of the first element in the biases tensor
* @param[in] weights_stride_w Stride of the weights tensor in the 4th dimension
*/
+__kernel void direct_convolution3x3_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+ TENSOR3D_DECLARATION(weights),
+#ifdef HAS_BIAS
+ VECTOR_DECLARATION(biases),
+#endif /* defined(HAS_BIAS) */
+ unsigned int weights_stride_w)
+{
+ Image src = CONVERT_TO_IMAGE_STRUCT(src);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
+
+ VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
+ values0 = 0;
+ const int id0 = get_global_id(0);
+ const int id1 = get_global_id(1);
+ const int id2 = get_global_id(2);
+
+ __global uchar *weights_addr = (__global uchar *)tensor3D_offset(&weights, 0, 0, 0);
+ __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - src_stride_x * id0 + ((id2 * STRIDE_Y) - PAD_TOP) * (int)src_stride_z;
+
+ weights_addr += id0 * weights_stride_w;
+
+ const int coordy = ((id2 * STRIDE_Y) - PAD_TOP);
+ for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
+ {
+#if PAD_TOP > 0
+ if(coordy < 0) // special case Z = -1 doesn't exists
+ {
+ //skip first row and load the two next ones
+ CONVOLUTION1x3_NHWC(values0, src_addr + 1 * (int)src_stride_z, (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x3_NHWC(values0, src_addr + 2 * (int)src_stride_z, (weights_addr + 2 * (int)weights_stride_z));
+ }
+ else if(coordy == (SRC_HEIGHT - PAD_TOP - 1))
+ {
+ // special case when computing the last row of the output we must read the last three rows from the input buffer (including padding) but the
+ // Z axis has no padding at all.
+ CONVOLUTION1x3_NHWC(values0, src_addr, (weights_addr + 0 * (int)weights_stride_z));
+ CONVOLUTION1x3_NHWC(values0, src_addr + 1 * (int)src_stride_z, (weights_addr + 1 * (int)weights_stride_z));
+ }
+ else
+ {
+ CONVOLUTION1x3_NHWC(values0, src_addr, (weights_addr + 0 * (int)weights_stride_z));
+ CONVOLUTION1x3_NHWC(values0, src_addr + 1 * (int)src_stride_z, (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x3_NHWC(values0, src_addr + 2 * (int)src_stride_z, (weights_addr + 2 * (int)weights_stride_z));
+ }
+#else // PAD_TOP > 0
+ CONVOLUTION1x3_NHWC(values0, src_addr, (weights_addr + 0 * (int)weights_stride_z));
+ CONVOLUTION1x3_NHWC(values0, src_addr + 1 * (int)src_stride_z, (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x3_NHWC(values0, src_addr + 2 * (int)src_stride_z, (weights_addr + 2 * (int)weights_stride_z));
+#endif // PAD_TOP > 0
+ src_addr += src_stride_x;
+ weights_addr += weights_stride_x;
+ }
+
+#ifdef HAS_BIAS
+ Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
+ values0 = ADD_OP(values0, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, id0))));
+#endif /* defined(HAS_BIAS) */
+
+ *((__global DATA_TYPE *)(dst.ptr + 0 * dst_stride_y)) = values0.s0;
+ *((__global DATA_TYPE *)(dst.ptr + 1 * dst_stride_y)) = values0.s1;
+ *((__global DATA_TYPE *)(dst.ptr + 2 * dst_stride_y)) = values0.s2;
+ *((__global DATA_TYPE *)(dst.ptr + 3 * dst_stride_y)) = values0.s3;
+ *((__global DATA_TYPE *)(dst.ptr + 4 * dst_stride_y)) = values0.s4;
+ *((__global DATA_TYPE *)(dst.ptr + 5 * dst_stride_y)) = values0.s5;
+ *((__global DATA_TYPE *)(dst.ptr + 6 * dst_stride_y)) = values0.s6;
+ *((__global DATA_TYPE *)(dst.ptr + 7 * dst_stride_y)) = values0.s7;
+}
+#endif // defined(DATA_LAYOUT_NHWC)
+
+/** This kernel performs a direct convolution to convolve the low three dimensions.
+ *
+ * @note This OpenCL kernel works with stride_x = 1 and 2
+ * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
+ * @note The third dimensions of the weights tensors must be passed at compile time using -DWEIGHTS_DEPTH
+ * @note If biases are used then -DHAS_BIAS has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] weights_ptr Pointer to the weights tensor. Supported data types: same as @p src_ptr
+ * @param[in] weights_stride_x Stride of the weights tensor in X dimension (in bytes)
+ * @param[in] weights_step_x weights_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] weights_stride_y Stride of the weights tensor in Y dimension (in bytes)
+ * @param[in] weights_step_y weights_stride_y * number of elements along y processed per workitem(in bytes)
+ * @param[in] weights_stride_z Stride of the weights tensor in Z dimension (in bytes)
+ * @param[in] weights_step_z weights_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
+ * @param[in] biases_ptr Pointer to the biases tensor. Same as @p src_ptr
+ * @param[in] biases_stride_x Stride of the biases tensor in X dimension (in bytes)
+ * @param[in] biases_step_x biases_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the biases tensor
+ * @param[in] weights_stride_w Stride of the weights tensor in the 4th dimension
+ */
__kernel void direct_convolution3x3(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
@@ -130,7 +296,7 @@
Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
- pixels0 = 0;
+ values0 = 0;
__global uchar *weights_addr = (__global uchar *)tensor3D_offset(&weights, 0, 0, 0);
__global uchar *src_addr = (__global uchar *)offset(&src, 0, 0);
@@ -140,9 +306,9 @@
for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
{
- CONVOLUTION1x3(pixels0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 0 * weights_stride_y));
- CONVOLUTION1x3(pixels0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 1 * weights_stride_y));
- CONVOLUTION1x3(pixels0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 2 * weights_stride_y));
+ CONVOLUTION1x3(values0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 0 * weights_stride_y));
+ CONVOLUTION1x3(values0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 1 * weights_stride_y));
+ CONVOLUTION1x3(values0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 2 * weights_stride_y));
src_addr += src_stride_z;
weights_addr += weights_stride_z;
@@ -151,10 +317,10 @@
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- pixels0 = ADD_OP(pixels0, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index))));
+ values0 = ADD_OP(values0, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index))));
#endif /* defined(HAS_BIAS) */
- vstore8(CONVERT_SAT(pixels0, VEC_DATA_TYPE(DATA_TYPE, 8)), 0, (__global DATA_TYPE *)dst.ptr);
+ vstore8(CONVERT_SAT(values0, VEC_DATA_TYPE(DATA_TYPE, 8)), 0, (__global DATA_TYPE *)dst.ptr);
}
#endif //defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
@@ -227,9 +393,9 @@
Image src = CONVERT_TO_IMAGE_STRUCT(src);
Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
- float4 pixels0 = 0;
- float4 pixels1 = 0;
- float4 pixels2 = 0;
+ float4 values0 = 0;
+ float4 values1 = 0;
+ float4 values2 = 0;
__global uchar *weights_addr = (__global uchar *)(weights_ptr + weights_offset_first_element_in_bytes + kernel_index * weights_stride_w);
__global uchar *src_addr = (__global uchar *)offset(&src, 0, 0);
@@ -249,39 +415,39 @@
src0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
src1 = vload2(0, (__global float *)(src_addr + 0 * src_stride_y) + 4);
- CONVOLUTION1x3_BIFROST(pixels0, src0, src1, weights_row0);
+ CONVOLUTION1x3_BIFROST(values0, src0, src1, weights_row0);
// Load values from row1 of input tensor
src0 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
src1 = vload2(0, (__global float *)(src_addr + 1 * src_stride_y) + 4);
// Accumulate
- CONVOLUTION1x3_BIFROST(pixels0, src0, src1, weights_row1);
- CONVOLUTION1x3_BIFROST(pixels1, src0, src1, weights_row0);
+ CONVOLUTION1x3_BIFROST(values0, src0, src1, weights_row1);
+ CONVOLUTION1x3_BIFROST(values1, src0, src1, weights_row0);
// Load values from row2 of input tensor
src0 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
src1 = vload2(0, (__global float *)(src_addr + 2 * src_stride_y) + 4);
// Accumulate
- CONVOLUTION1x3_BIFROST(pixels0, src0, src1, weights_row2);
- CONVOLUTION1x3_BIFROST(pixels1, src0, src1, weights_row1);
- CONVOLUTION1x3_BIFROST(pixels2, src0, src1, weights_row0);
+ CONVOLUTION1x3_BIFROST(values0, src0, src1, weights_row2);
+ CONVOLUTION1x3_BIFROST(values1, src0, src1, weights_row1);
+ CONVOLUTION1x3_BIFROST(values2, src0, src1, weights_row0);
// Load values from row3 of input tensor
src0 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
src1 = vload2(0, (__global float *)(src_addr + 3 * src_stride_y) + 4);
// Accumulate
- CONVOLUTION1x3_BIFROST(pixels1, src0, src1, weights_row2);
- CONVOLUTION1x3_BIFROST(pixels2, src0, src1, weights_row1);
+ CONVOLUTION1x3_BIFROST(values1, src0, src1, weights_row2);
+ CONVOLUTION1x3_BIFROST(values2, src0, src1, weights_row1);
// Row4
src0 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
src1 = vload2(0, (__global float *)(src_addr + 4 * src_stride_y) + 4);
// Accumulate
- CONVOLUTION1x3_BIFROST(pixels2, src0, src1, weights_row2);
+ CONVOLUTION1x3_BIFROST(values2, src0, src1, weights_row2);
src_addr += src_stride_z;
weights_addr += weights_stride_z;
@@ -292,13 +458,13 @@
float bias = (float) * ((__global float *)(vector_offset(&biases, kernel_index)));
- pixels0 += (float4)bias;
- pixels1 += (float4)bias;
- pixels2 += (float4)bias;
+ values0 += (float4)bias;
+ values1 += (float4)bias;
+ values2 += (float4)bias;
#endif /* defined(HAS_BIAS) */
- vstore4(pixels0, 0, (__global float *)(dst.ptr + 0 * dst_stride_y));
- vstore4(pixels1, 0, (__global float *)(dst.ptr + 1 * dst_stride_y));
- vstore4(pixels2, 0, (__global float *)(dst.ptr + 2 * dst_stride_y));
+ vstore4(values0, 0, (__global float *)(dst.ptr + 0 * dst_stride_y));
+ vstore4(values1, 0, (__global float *)(dst.ptr + 1 * dst_stride_y));
+ vstore4(values2, 0, (__global float *)(dst.ptr + 2 * dst_stride_y));
}
#endif // defined(WEIGHTS_DEPTH)
diff --git a/src/core/CL/cl_kernels/direct_convolution5x5.cl b/src/core/CL/cl_kernels/direct_convolution5x5.cl
index e678f6f..70be058 100644
--- a/src/core/CL/cl_kernels/direct_convolution5x5.cl
+++ b/src/core/CL/cl_kernels/direct_convolution5x5.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -69,6 +69,190 @@
acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s468a, src0.sCE, src1.s02) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_value1; \
})
+#if defined(DATA_LAYOUT_NHWC)
+
+#define PTR_TO_VALUE(PTR, DATA_TYPE) *((__global DATA_TYPE *)(PTR))
+
+#if STRIDE_X == 1
+#define CONVOLUTION1x5_NHWC(acc, row_ptr, weights_ptr) CONVOLUTION1x5_STRIDE1_NHWC(acc, row_ptr, weights_ptr)
+#elif STRIDE_X == 2 /* STRIDE_X == 1 */
+#define CONVOLUTION1x5_NHWC(acc, row_ptr, weights_ptr) CONVOLUTION1x5_STRIDE2_NHWC(acc, row_ptr, weights_ptr)
+#else /* STRIDE_X not equals 1 or 2 */
+#error "STRIDE_X larger than 2 is not supported"
+#endif /* STRIDE_X == 2 */
+
+#define CONVOLUTION1x5_STRIDE1_NHWC(acc, row_ptr, weights_ptr) \
+ ({ \
+ VEC_DATA_TYPE(DATA_TYPE, 8) \
+ src0 = (VEC_DATA_TYPE(DATA_TYPE, 8))( \
+ PTR_TO_VALUE(row_ptr + 0 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 1 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 2 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 3 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 4 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 5 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 6 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 7 * src_stride_y, DATA_TYPE)); \
+ VEC_DATA_TYPE(DATA_TYPE, 4) \
+ src1 = (VEC_DATA_TYPE(DATA_TYPE, 4))( \
+ PTR_TO_VALUE(row_ptr + 8 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 9 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 10 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 11 * src_stride_y, DATA_TYPE)); \
+ VEC_DATA_TYPE(DATA_TYPE, 4) \
+ weights_values0 = (VEC_DATA_TYPE(DATA_TYPE, 4))( \
+ PTR_TO_VALUE(weights_ptr + 0 * weights_stride_y, DATA_TYPE), PTR_TO_VALUE(weights_ptr + 1 * weights_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(weights_ptr + 2 * weights_stride_y, DATA_TYPE), PTR_TO_VALUE(weights_ptr + 3 * weights_stride_y, DATA_TYPE)); \
+ DATA_TYPE weights_value1 = PTR_TO_VALUE(weights_ptr + 4 * weights_stride_y, DATA_TYPE); \
+ acc += src0 * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0; \
+ acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1234, src0.s567, src1.s0) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1; \
+ acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s234, src0.s567, src1.s01) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2; \
+ acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s345, src0.s67, src1.s012) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s3; \
+ acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s45, src0.s67, src1.s0123) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_value1; \
+ })
+
+#define CONVOLUTION1x5_STRIDE2_NHWC(acc, row_ptr, weights_ptr) \
+ ({ \
+ VEC_DATA_TYPE(DATA_TYPE, 16) \
+ src0 = (VEC_DATA_TYPE(DATA_TYPE, 16))( \
+ PTR_TO_VALUE(row_ptr + 0 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 1 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 2 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 3 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 4 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 5 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 6 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 7 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 8 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 9 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 10 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 11 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 12 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 13 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 14 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 15 * src_stride_y, DATA_TYPE)); \
+ VEC_DATA_TYPE(DATA_TYPE, 4) \
+ src1 = (VEC_DATA_TYPE(DATA_TYPE, 4))( \
+ PTR_TO_VALUE(row_ptr + 16 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 17 * src_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(row_ptr + 18 * src_stride_y, DATA_TYPE), PTR_TO_VALUE(row_ptr + 19 * src_stride_y, DATA_TYPE)); \
+ VEC_DATA_TYPE(DATA_TYPE, 4) \
+ weights_values0 = (VEC_DATA_TYPE(DATA_TYPE, 4))( \
+ PTR_TO_VALUE(weights_ptr + 0 * weights_stride_y, DATA_TYPE), PTR_TO_VALUE(weights_ptr + 1 * weights_stride_y, DATA_TYPE), \
+ PTR_TO_VALUE(weights_ptr + 2 * weights_stride_y, DATA_TYPE), PTR_TO_VALUE(weights_ptr + 3 * weights_stride_y, DATA_TYPE)); \
+ DATA_TYPE weights_value1 = PTR_TO_VALUE(weights_ptr + 4 * weights_stride_y, DATA_TYPE); \
+ acc += src0.s02468ACE * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0; \
+ acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1357, src0.s9BDF) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1; \
+ acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1.s0) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2; \
+ \
+ acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s3579, src0.sBDF, src1.s1) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s3; \
+ acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s468a, src0.sCE, src1.s02) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_value1; \
+ })
+
+/** This kernel performs a direct convolution to convolve the low three dimensions in a tensor with the NHWC data layout
+ *
+ * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
+ * @note The third dimensions of the weights tensors must be passed at compile time using -DWEIGHTS_DEPTH
+ * @note If biases are used then -DHAS_BIAS has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] weights_ptr Pointer to the weights tensor. Supported data types: same as @p src_ptr
+ * @param[in] weights_stride_x Stride of the weights tensor in X dimension (in bytes)
+ * @param[in] weights_step_x weights_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] weights_stride_y Stride of the weights tensor in Y dimension (in bytes)
+ * @param[in] weights_step_y weights_stride_y * number of elements along y processed per workitem(in bytes)
+ * @param[in] weights_stride_z Stride of the weights tensor in Z dimension (in bytes)
+ * @param[in] weights_step_z weights_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
+ * @param[in] biases_ptr Pointer to the biases tensor. Same as @p src_ptr
+ * @param[in] biases_stride_x Stride of the biases tensor in X dimension (in bytes)
+ * @param[in] biases_step_x biases_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the biases tensor
+ * @param[in] weights_stride_w Stride of the weights tensor in the 4th dimension
+ */
+__kernel void direct_convolution5x5_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+ TENSOR3D_DECLARATION(weights),
+#ifdef HAS_BIAS
+ VECTOR_DECLARATION(biases),
+#endif /* defined(HAS_BIAS) */
+ unsigned int weights_stride_w)
+{
+ Image src = CONVERT_TO_IMAGE_STRUCT(src);
+ Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
+
+ VEC_DATA_TYPE(DATA_TYPE, 8)
+ values0 = 0;
+
+ const int id0 = get_global_id(0);
+ const int id1 = get_global_id(1);
+ const int id2 = get_global_id(2);
+
+ __global uchar *weights_addr = (__global uchar *)tensor3D_offset(&weights, 0, 0, 0);
+ __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - src_stride_x * id0 + ((id2 * STRIDE_Y) - PAD_TOP) * (int)src_stride_z;
+
+ weights_addr += id0 * weights_stride_w;
+ const int coordy = id2 - PAD_TOP;
+
+ for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
+ {
+#if(PAD_TOP)
+ if(coordy < 0) // special case Z = -1 doesn't exists
+ {
+ //skip first row and load the two next ones
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
+ }
+ else if(coordy == (DST_HEIGHT - PAD_TOP - 1))
+ {
+ // special case when computing the last row of the output we must read the last three rows from the input buffer (including padding) but the
+ // Z axis has no padding at all.
+ CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr);
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ }
+ else
+ {
+ CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr);
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
+ }
+#else //PAD_TOP > 0
+ CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr);
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
+#endif // PAD_TOP > 0
+
+ src_addr += src_stride_x;
+ weights_addr += weights_stride_x;
+ }
+
+#ifdef HAS_BIAS
+ Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
+ values0 += (VEC_DATA_TYPE(DATA_TYPE, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, id0)));
+#endif /* defined(HAS_BIAS) */
+
+ *((__global DATA_TYPE *)(dst.ptr + 0 * dst_stride_y)) = values0.s0;
+ *((__global DATA_TYPE *)(dst.ptr + 1 * dst_stride_y)) = values0.s1;
+ *((__global DATA_TYPE *)(dst.ptr + 2 * dst_stride_y)) = values0.s2;
+ *((__global DATA_TYPE *)(dst.ptr + 3 * dst_stride_y)) = values0.s3;
+ *((__global DATA_TYPE *)(dst.ptr + 4 * dst_stride_y)) = values0.s4;
+ *((__global DATA_TYPE *)(dst.ptr + 5 * dst_stride_y)) = values0.s5;
+ *((__global DATA_TYPE *)(dst.ptr + 6 * dst_stride_y)) = values0.s6;
+ *((__global DATA_TYPE *)(dst.ptr + 7 * dst_stride_y)) = values0.s7;
+}
+
+#endif // defined(DATA_LAYOUT_NHWC)
+
/** This kernel performs a direct convolution to convolve the low three dimensions.
*
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
@@ -119,7 +303,7 @@
Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
VEC_DATA_TYPE(DATA_TYPE, 8)
- pixels0 = 0;
+ values0 = 0;
__global uchar *weights_addr = (__global uchar *)tensor3D_offset(&weights, 0, 0, 0);
__global uchar *src_addr = (__global uchar *)offset(&src, 0, 0);
@@ -129,11 +313,11 @@
for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
{
- CONVOLUTION1x5(pixels0, (__global DATA_TYPE *)src_addr, (__global DATA_TYPE *)weights_addr);
- CONVOLUTION1x5(pixels0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 1 * weights_stride_y));
- CONVOLUTION1x5(pixels0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 2 * weights_stride_y));
- CONVOLUTION1x5(pixels0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 3 * weights_stride_y));
- CONVOLUTION1x5(pixels0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 4 * weights_stride_y));
+ CONVOLUTION1x5(values0, (__global DATA_TYPE *)src_addr, (__global DATA_TYPE *)weights_addr);
+ CONVOLUTION1x5(values0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 1 * weights_stride_y));
+ CONVOLUTION1x5(values0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 2 * weights_stride_y));
+ CONVOLUTION1x5(values0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 3 * weights_stride_y));
+ CONVOLUTION1x5(values0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 4 * weights_stride_y));
src_addr += src_stride_z;
weights_addr += weights_stride_z;
@@ -142,10 +326,10 @@
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
- pixels0 += (VEC_DATA_TYPE(DATA_TYPE, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index)));
+ values0 += (VEC_DATA_TYPE(DATA_TYPE, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index)));
#endif /* defined(HAS_BIAS) */
- vstore8(pixels0, 0, (__global DATA_TYPE *)dst.ptr);
+ vstore8(values0, 0, (__global DATA_TYPE *)dst.ptr);
}
#endif // defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
@@ -226,8 +410,8 @@
Image src = CONVERT_TO_IMAGE_STRUCT(src);
Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
- float4 pixels0 = 0.0f;
- float4 pixels1 = 0.0f;
+ float4 values0 = 0.0f;
+ float4 values1 = 0.0f;
__global uchar *weights_addr = (__global uchar *)(weights_ptr + weights_offset_first_element_in_bytes + kernel_index * weights_stride_w);
__global uchar *src_addr = (__global uchar *)offset(&src, 0, 0);
@@ -247,14 +431,14 @@
src0 = vload8(0, (__global float *)(src_addr + 0 * src_stride_y));
// Accumulate
- CONVOLUTION1x5_BIFROST(pixels0, src0, weights_row00, weights_row01);
+ CONVOLUTION1x5_BIFROST(values0, src0, weights_row00, weights_row01);
// Load values from row1 of input tensor
src0 = vload8(0, (__global float *)(src_addr + 1 * src_stride_y));
// Accumulate
- CONVOLUTION1x5_BIFROST(pixels0, src0, weights_row10, weights_row11);
- CONVOLUTION1x5_BIFROST(pixels1, src0, weights_row00, weights_row01);
+ CONVOLUTION1x5_BIFROST(values0, src0, weights_row10, weights_row11);
+ CONVOLUTION1x5_BIFROST(values1, src0, weights_row00, weights_row01);
// Load values from row2 of input tensor
src0 = vload8(0, (__global float *)(src_addr + 2 * src_stride_y));
@@ -264,8 +448,8 @@
weights_row01 = *((__global float *)(weights_addr + 2 * weights_stride_y) + 4);
// Accumulate
- CONVOLUTION1x5_BIFROST(pixels0, src0, weights_row00, weights_row01);
- CONVOLUTION1x5_BIFROST(pixels1, src0, weights_row10, weights_row11);
+ CONVOLUTION1x5_BIFROST(values0, src0, weights_row00, weights_row01);
+ CONVOLUTION1x5_BIFROST(values1, src0, weights_row10, weights_row11);
// Load values from row3 of input tensor
src0 = vload8(0, (__global float *)(src_addr + 3 * src_stride_y));
@@ -275,8 +459,8 @@
weights_row11 = *((__global float *)(weights_addr + 3 * weights_stride_y) + 4);
// Accumulate
- CONVOLUTION1x5_BIFROST(pixels0, src0, weights_row10, weights_row11);
- CONVOLUTION1x5_BIFROST(pixels1, src0, weights_row00, weights_row01);
+ CONVOLUTION1x5_BIFROST(values0, src0, weights_row10, weights_row11);
+ CONVOLUTION1x5_BIFROST(values1, src0, weights_row00, weights_row01);
// Load values from row4 of input tensor
src0 = vload8(0, (__global float *)(src_addr + 4 * src_stride_y));
@@ -285,14 +469,14 @@
weights_row00 = vload4(0, (__global float *)(weights_addr + 4 * weights_stride_y));
weights_row01 = *((__global float *)(weights_addr + 4 * weights_stride_y) + 4);
- CONVOLUTION1x5_BIFROST(pixels0, src0, weights_row00, weights_row01);
- CONVOLUTION1x5_BIFROST(pixels1, src0, weights_row10, weights_row11);
+ CONVOLUTION1x5_BIFROST(values0, src0, weights_row00, weights_row01);
+ CONVOLUTION1x5_BIFROST(values1, src0, weights_row10, weights_row11);
// Load values from row5 of input tensor
src0 = vload8(0, (__global float *)(src_addr + 5 * src_stride_y));
// Accumulate
- CONVOLUTION1x5_BIFROST(pixels1, src0, weights_row00, weights_row01);
+ CONVOLUTION1x5_BIFROST(values1, src0, weights_row00, weights_row01);
src_addr += src_stride_z;
weights_addr += weights_stride_z;
@@ -303,11 +487,11 @@
float4 bias = (float4) * ((__global float *)(vector_offset(&biases, kernel_index)));
- pixels0 += bias;
- pixels1 += bias;
+ values0 += bias;
+ values1 += bias;
#endif /* defined(HAS_BIAS) */
- vstore4(pixels0, 0, (__global float *)(dst.ptr + 0 * dst_stride_y));
- vstore4(pixels1, 0, (__global float *)(dst.ptr + 1 * dst_stride_y));
+ vstore4(values0, 0, (__global float *)(dst.ptr + 0 * dst_stride_y));
+ vstore4(values1, 0, (__global float *)(dst.ptr + 1 * dst_stride_y));
}
#endif // defined(WEIGHTS_DEPTH)
diff --git a/src/core/CL/cl_kernels/direct_convolution_1x1_3x3_5x5_quantized.cl b/src/core/CL/cl_kernels/direct_convolution_1x1_3x3_5x5_quantized.cl
index b58dc7a..83da767 100644
--- a/src/core/CL/cl_kernels/direct_convolution_1x1_3x3_5x5_quantized.cl
+++ b/src/core/CL/cl_kernels/direct_convolution_1x1_3x3_5x5_quantized.cl
@@ -248,6 +248,12 @@
}
#endif // defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
+#if defined(VEC_SIZE)
+
+#define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE)
+#define CONVERT_SAT_UCHAR_STR(x, size) (convert_uchar##size##_sat((x)))
+#define CONVERT_SAT_UCHAR(x, size) CONVERT_SAT_UCHAR_STR(x, size)
+
/** This function computes the output stage of a depthwise convolution.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: QASYMM8
@@ -274,7 +280,6 @@
* @param[in] output_multiplier Output scale multiplier
* @param[in] output_shift Output scale divisor exponent
*/
-
__kernel void output_stage_quantized(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
@@ -292,17 +297,29 @@
#endif //defined(HAS_BIAS)
// Load input
- int16 vals = vload16(0, (__global int *)(src.ptr));
+ VEC_INT vals = VLOAD(VEC_SIZE)(0, (__global int *)(src.ptr));
#if defined(HAS_BIAS)
// Load and add bias
+#if defined(NCHW)
int bias_value = *((__global int *)(vector_offset(&bias, get_global_id(2))));
- vals += (int16)(bias_value);
+#else // defined(NCHW)
+ VEC_INT bias_value = VLOAD(VEC_SIZE)(0, ((__global int *)(vector_offset(&bias, get_global_id(0) * VEC_SIZE))));
+#endif // defined(NCHW)
+
+ vals += (VEC_INT)(bias_value);
#endif //defined(HAS_BIAS)
- vals = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(vals, output_multiplier, output_shift, 16);
+ vals = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(vals, output_multiplier, output_shift, VEC_SIZE);
vals = vals + output_offset;
// Store result in dst
- vstore16(convert_uchar16_sat(vals), 0, (__global uchar *)dst.ptr);
+ VSTORE(VEC_SIZE)
+ (CONVERT_SAT_UCHAR(vals, VEC_SIZE), 0, (__global uchar *)dst.ptr);
}
+
+#undef VEC_INT
+#undef CONVERT_SAT_UCHAR_STR
+#undef CONVERT_SAT_UCHAR
+
+#endif // defined(VEC_SIZE)
diff --git a/src/core/CL/cl_kernels/fill_border.cl b/src/core/CL/cl_kernels/fill_border.cl
index 33a9495..9d6a2b8 100644
--- a/src/core/CL/cl_kernels/fill_border.cl
+++ b/src/core/CL/cl_kernels/fill_border.cl
@@ -23,10 +23,6 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-#include "fixed_point.h"
-#endif /* FIXED_POINT_POSITION */
-
/** Fill N pixel of the padding edge of a single channel image by replicating the closest valid pixel.
*
* @attention The DATA_TYPE needs to be passed at the compile time.
diff --git a/src/core/CL/cl_kernels/fixed_point.h b/src/core/CL/cl_kernels/fixed_point.h
deleted file mode 100644
index 46fa645..0000000
--- a/src/core/CL/cl_kernels/fixed_point.h
+++ /dev/null
@@ -1,518 +0,0 @@
-/*
- * Copyright (c) 2017-2018 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef ARM_COMPUTE_FIXED_POINT_H
-#define ARM_COMPUTE_FIXED_POINT_H
-
-#define TYPE_ALIAS(type, alias) \
- typedef type alias; \
- typedef type alias##x##1; \
- typedef type##2 alias##x##2; \
- typedef type##3 alias##x##3; \
- typedef type##4 alias##x##4; \
- typedef type##8 alias##x##8; \
- typedef type##16 alias##x##16;
-
-TYPE_ALIAS(char, qs8)
-TYPE_ALIAS(short, qs16)
-TYPE_ALIAS(int, qs32)
-
-#define qs8_MIN ((char)CHAR_MIN)
-#define qs8_MAX ((char)CHAR_MAX)
-#define qs16_MIN ((short)SHRT_MIN)
-#define qs16_MAX ((short)SHRT_MAX)
-#define qs32_MIN ((int)INT_MIN)
-#define qs32_MAX ((int)INT_MAX)
-
-#define qu8_MIN ((uchar)0)
-#define qu8_MAX ((uchar)UCHAR_MAX)
-#define qu16_MIN ((ushort)0)
-#define qu16_MAX ((ushort)USHRT_MAX)
-#define qu32_MIN ((uint)0)
-#define qu32_MAX ((uint)UINT_MAX)
-
-#define qs8_TYPE char
-#define qs8x1_TYPE char
-#define qs8x2_TYPE char2
-#define qs8x3_TYPE char3
-#define qs8x4_TYPE char4
-#define qs8x8_TYPE char8
-#define qs8x16_TYPE char16
-
-#define qs16_TYPE short
-#define qs16x1_TYPE short
-#define qs16x2_TYPE short2
-#define qs16x3_TYPE short3
-#define qs16x4_TYPE short4
-#define qs16x8_TYPE short8
-#define qs16x16_TYPE short16
-
-#define qs32_TYPE int
-#define qs32x1_TYPE int
-#define qs32x2_TYPE int2
-#define qs32x3_TYPE int3
-#define qs32x4_TYPE int4
-#define qs32x8_TYPE int8
-#define qs32x16_TYPE int16
-
-/* All internal constants are represented in the maximum supported fixed point format (QS16),
- * thus we define an additional shift parameter required to convert the constant
- * from the maximum supported format to the require one.
- */
-#define qs8_SHIFT 8
-#define qs16_SHIFT 0
-
-#undef VEC_DATA_TYPE_STR
-#undef VEC_DATA_TYPE
-#undef CONVERT_STR
-#undef CONVERT
-#undef CONVERT_SAT_STR
-#undef CONVERT_SAT
-
-#define VEC_DATA_TYPE_STR(type, size) type##x##size
-#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
-
-#define CONVERT_STR3(x, type, rtype) (convert_##rtype((x)))
-#define CONVERT_STR2(x, type, rtype) CONVERT_STR3(x, type, rtype)
-#define CONVERT_STR(x, type) CONVERT_STR2(x, type, type##_TYPE)
-#define CONVERT(x, type) CONVERT_STR(x, type)
-
-#define CONVERT_SAT_STR3(x, type, rtype) (convert_##rtype##_sat((x)))
-#define CONVERT_SAT_STR2(x, type, rtype) CONVERT_SAT_STR3(x, type, rtype)
-#define CONVERT_SAT_STR(x, type) CONVERT_SAT_STR2(x, type, type##_TYPE)
-#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type)
-
-/** Computes saturating absolute value of fixed point vector.
- *
- * @param[in] type the actual data type.
- *
- * @return The result of the fixed point absolute value.
- */
-#define ABSQ_SAT_IMPL(type) \
- inline type abs_##type##_sat(type VopA) \
- { \
- return CONVERT_SAT(abs(VopA), type); \
- }
-
-ABSQ_SAT_IMPL(qs8x16)
-ABSQ_SAT_IMPL(qs16x8)
-
-#define ABS_SAT_OP_EXPAND_STR(a, type, size) abs_##type##x##size##_sat((a))
-#define ABS_SAT_OP_EXPAND(a, type, size) ABS_SAT_OP_EXPAND_STR(a, type, size)
-
-/** Computes max of fixed point types.
- *
- * @param[in] type the actual data type.
- *
- * @return The result of the fixed point maximum.
- */
-#define MAXQ_IMPL(type) \
- inline type max_##type(type VopA, type VopB) \
- { \
- return max(VopA, VopB); \
- }
-
-MAXQ_IMPL(qs8x1)
-MAXQ_IMPL(qs8x2)
-MAXQ_IMPL(qs8x4)
-MAXQ_IMPL(qs8x8)
-MAXQ_IMPL(qs8x16)
-MAXQ_IMPL(qs16x1)
-MAXQ_IMPL(qs16x2)
-MAXQ_IMPL(qs16x4)
-MAXQ_IMPL(qs16x8)
-MAXQ_IMPL(qs16x16)
-
-#define MAX_OP_EXPAND_STR(a, b, type, size) max_##type##x##size((a), (b))
-#define MAX_OP_EXPAND(a, b, type, size) MAX_OP_EXPAND_STR(a, b, type, size)
-
-/** Computes saturated addition of fixed point types.
- *
- * @param[in] type the actual data type.
- *
- * @return The result of the fixed point addition. The result is saturated in case of overflow
- */
-#define ADDQ_SAT_IMPL(type) \
- inline type add_sat_##type(type VopA, type VopB) \
- { \
- return add_sat(VopA, VopB); \
- }
-
-ADDQ_SAT_IMPL(qs8x1)
-ADDQ_SAT_IMPL(qs8x2)
-ADDQ_SAT_IMPL(qs8x4)
-ADDQ_SAT_IMPL(qs8x8)
-ADDQ_SAT_IMPL(qs8x16)
-ADDQ_SAT_IMPL(qs16x1)
-ADDQ_SAT_IMPL(qs16x2)
-ADDQ_SAT_IMPL(qs16x4)
-ADDQ_SAT_IMPL(qs16x8)
-ADDQ_SAT_IMPL(qs16x16)
-ADDQ_SAT_IMPL(qs32x1)
-ADDQ_SAT_IMPL(qs32x2)
-ADDQ_SAT_IMPL(qs32x4)
-ADDQ_SAT_IMPL(qs32x8)
-ADDQ_SAT_IMPL(qs32x16)
-
-#define ADD_SAT_OP_EXPAND_STR(a, b, type, size) add_sat_##type##x##size((a), (b))
-#define ADD_SAT_OP_EXPAND(a, b, type, size) ADD_SAT_OP_EXPAND_STR(a, b, type, size)
-
-/** Computes saturated subtraction of fixed point types.
- *
- * @param[in] type the actual data type.
- *
- * @return The result of the fixed point subtraction. The result is saturated in case of overflow
- */
-#define SUBQ_SAT_IMPL(type) \
- inline type sub_sat_##type(type VopA, type VopB) \
- { \
- return sub_sat(VopA, VopB); \
- }
-
-SUBQ_SAT_IMPL(qs8x1)
-SUBQ_SAT_IMPL(qs8x2)
-SUBQ_SAT_IMPL(qs8x4)
-SUBQ_SAT_IMPL(qs8x8)
-SUBQ_SAT_IMPL(qs8x16)
-SUBQ_SAT_IMPL(qs16x1)
-SUBQ_SAT_IMPL(qs16x2)
-SUBQ_SAT_IMPL(qs16x4)
-SUBQ_SAT_IMPL(qs16x8)
-SUBQ_SAT_IMPL(qs16x16)
-
-#define SUB_SAT_OP_EXPAND_STR(a, b, type, size) sub_sat_##type##x##size((a), (b))
-#define SUB_SAT_OP_EXPAND(a, b, type, size) SUB_SAT_OP_EXPAND_STR(a, b, type, size)
-
-/* Multiply of two fixed point numbers
- *
- * @param[in] type the actual data type.
- * @param[in] itype the intermediate data type.
- *
- * @return The result of the fixed point multiplication.
- */
-#define MULQ_IMPL(type, itype) \
- inline type mul_##type(type VopA, type VopB, int fixed_point_position) \
- { \
- itype round_val = (itype)(1 << (fixed_point_position - 1)); \
- itype res = CONVERT((VopA), itype) * CONVERT((VopB), itype) + round_val; \
- return CONVERT((res >> (itype)fixed_point_position), type); \
- }
-
-MULQ_IMPL(qs8x8, qs16x8)
-MULQ_IMPL(qs16x8, qs32x8)
-MULQ_IMPL(qs8x16, qs16x16)
-MULQ_IMPL(qs16x16, qs32x16)
-
-#define MUL_OP_EXPAND_STR(a, b, type, size, position) mul_##type##x##size((a), (b), (position))
-#define MUL_OP_EXPAND(a, b, type, size, position) MUL_OP_EXPAND_STR(a, b, type, size, position)
-
-/* Saturate multiply of two fixed point numbers
- *
- * @param[in] type the actual data type.
- * @param[in] itype the intermediate data type.
- *
- * @return The result of the fixed point multiplication. The result is saturated in case of overflow
- */
-#define MULQ_SAT_IMPL(type, itype) \
- inline type mul_sat_##type(type VopA, type VopB, int fixed_point_position) \
- { \
- itype round_val = (itype)(1 << (fixed_point_position - 1)); \
- itype res = mad_sat(CONVERT((VopA), itype), CONVERT((VopB), itype), round_val); \
- return CONVERT_SAT((res >> (itype)fixed_point_position), type); \
- }
-
-MULQ_SAT_IMPL(qs8x1, qs16x1)
-MULQ_SAT_IMPL(qs8x2, qs16x2)
-MULQ_SAT_IMPL(qs8x3, qs16x3)
-MULQ_SAT_IMPL(qs8x4, qs16x4)
-MULQ_SAT_IMPL(qs8x8, qs16x8)
-MULQ_SAT_IMPL(qs8x16, qs16x16)
-MULQ_SAT_IMPL(qs16x1, qs32x1)
-MULQ_SAT_IMPL(qs16x2, qs32x2)
-MULQ_SAT_IMPL(qs16x3, qs32x3)
-MULQ_SAT_IMPL(qs16x4, qs32x4)
-MULQ_SAT_IMPL(qs16x8, qs32x8)
-MULQ_SAT_IMPL(qs16x16, qs32x16)
-
-#define MUL_SAT_OP_EXPAND_STR(a, b, type, size, position) mul_sat_##type##x##size((a), (b), (position))
-#define MUL_SAT_OP_EXPAND(a, b, type, size, position) MUL_SAT_OP_EXPAND_STR(a, b, type, size, position)
-
-/** Saturate multiply-accumulate
- *
- * @param[in] type the actual data type.
- * @param[in] itype the intermediate data type.
- *
- * @return The result of the fixed point multiply-accumulate. The result is saturated in case of overflow
- */
-#define MLAQ_SAT_IMPL(type, itype) \
- type mla_sat_##type(type VopA, type VopB, type VopC, int fixed_point_position) \
- { \
- itype res = mad_sat(CONVERT(VopB, itype), CONVERT(VopC, itype), (itype)(1 << (fixed_point_position - 1))); \
- return add_sat(VopA, CONVERT_SAT(res >> (itype)fixed_point_position, type)); \
- }
-
-MLAQ_SAT_IMPL(qs8x8, qs16x8)
-MLAQ_SAT_IMPL(qs8x16, qs16x16)
-MLAQ_SAT_IMPL(qs16x8, qs32x8)
-
-#define MLA_SAT_OP_EXPAND_STR(a, b, c, type, size, position) mla_sat_##type##x##size((a), (b), (c), (position))
-#define MLA_SAT_OP_EXPAND(a, b, c, type, size, position) MLA_SAT_OP_EXPAND_STR(a, b, c, type, size, position)
-
-/** Saturate multiply-accumulate long
- *
- * @param[in] type the actual data type.
- * @param[in] itype the intermediate data type.
- *
- * @return The result of the fixed point multiply-accumulate long. The result is saturated in case of overflow
- */
-#define MLALQ_SAT_IMPL(type, itype) \
- itype mlal_sat_##type(itype VopA, type VopB, type VopC, int fixed_point_position) \
- { \
- itype res = mad_sat(CONVERT(VopB, itype), CONVERT(VopC, itype), (itype)(1 << (fixed_point_position - 1))); \
- return add_sat(VopA, res >> (itype)fixed_point_position); \
- }
-
-MLALQ_SAT_IMPL(qs8x8, qs16x8)
-MLALQ_SAT_IMPL(qs16x8, qs32x8)
-
-#define MLAL_SAT_OP_EXPAND_STR(a, b, c, type, size, position) mlal_sat_##type##x##size((a), (b), (c), (position))
-#define MLAL_SAT_OP_EXPAND(a, b, c, type, size, position) MLAL_SAT_OP_EXPAND_STR(a, b, c, type, size, position)
-
-/** Saturate division of two fixed point vectors
- *
- * @param[in] stype the actual scalar data type.
- * @param[in] type the actual data type.
- * @param[in] itype the intermediate data type.
- *
- * @return The result of the fixed point division. The result is saturated in case of overflow
- */
-#define DIVQ_SAT_IMPL(stype, type, itype) \
- inline type div_sat_##type(type VopA, type VopB, int fixed_point_position) \
- { \
- itype conv_a = CONVERT((VopA), itype); \
- itype denominator = CONVERT((VopB), itype); \
- itype numerator = conv_a << (itype)(fixed_point_position); \
- itype res = select((itype)(numerator / denominator), select((itype)stype##_MAX, (itype)stype##_MIN, (itype)(conv_a < (itype)0)), (itype)(denominator == (itype)0)); \
- return CONVERT_SAT((res), type); \
- }
-
-DIVQ_SAT_IMPL(qs8, qs8x16, qs16x16)
-DIVQ_SAT_IMPL(qs16, qs16x8, qs32x8)
-DIVQ_SAT_IMPL(qs16, qs16x16, qs32x16)
-DIVQ_SAT_IMPL(qs8, qs8, qs16)
-DIVQ_SAT_IMPL(qs16, qs16, qs32)
-
-#define DIV_SAT_OP_EXPAND_STR(a, b, type, position) div_sat_##type((a), (b), (position))
-#define DIV_SAT_OP_EXPAND(a, b, type, position) DIV_SAT_OP_EXPAND_STR(a, b, type, position)
-
-#define DIV_SAT_OP_VEC_EXPAND_STR(a, b, type, size, position) div_sat_##type##x##size((a), (b), (position))
-#define DIV_SAT_OP_VEC_EXPAND(a, b, type, size, position) DIV_SAT_OP_VEC_EXPAND_STR(a, b, type, size, position)
-
-/** Saturate exponential of a fixed point vector
- *
- * @note Implemented approach uses taylor polynomial to approximate the exponential function.
- *
- * @param[in] stype the actual scalar data type.
- * @param[in] type the actual data type.
- * @param[in] size the number of the calculated elements.
- *
- * @return The result of the fixed point exponential. The result is saturated in case of overflow
- */
-#define EXPQ_IMPL(stype, type, size) \
- inline type exp_sat_##type(type VopA, int fixed_point_position) \
- { \
- type const_one = (type)(1 << (fixed_point_position)); \
- type ln2 = (type)((((0x58B9 >> (14 - fixed_point_position))) + 1) >> 1); \
- type inv_ln2 = (type)((((0x38AA >> (14 - fixed_point_position)) + 1) >> 1)) | const_one; \
- type A = (type)(((0x7FBA >> (14 - fixed_point_position)) + 1) >> 1); \
- type B = (type)(((0x3FE9 >> (14 - fixed_point_position)) + 1) >> 1); \
- type C = (type)(((0x1693 >> (14 - fixed_point_position)) + 1) >> 1); \
- type D = (type)(((0x0592 >> (14 - fixed_point_position)) + 1) >> 1); \
- type m = MUL_SAT_OP_EXPAND(VopA, inv_ln2, stype, size, fixed_point_position); \
- type dec_m = m >> (type)fixed_point_position; \
- type alpha = MUL_SAT_OP_EXPAND(dec_m << (type)fixed_point_position, ln2, stype, size, fixed_point_position); \
- alpha = CONVERT(abs_diff(VopA, alpha), type); \
- type sum = add_sat(MUL_SAT_OP_EXPAND(alpha, D, stype, size, fixed_point_position), C); \
- sum = add_sat(MUL_SAT_OP_EXPAND(alpha, sum, stype, size, fixed_point_position), B); \
- sum = add_sat(MUL_SAT_OP_EXPAND(alpha, sum, stype, size, fixed_point_position), A); \
- sum = add_sat(MUL_SAT_OP_EXPAND(alpha, sum, stype, size, fixed_point_position), const_one); \
- return select((type)stype##_MAX, select(sum << dec_m, sum >> -dec_m, dec_m < (type)0), clz(sum) > dec_m); /* Saturate result if needed */ \
- }
-
-EXPQ_IMPL(qs8, qs8x2, 2)
-EXPQ_IMPL(qs8, qs8x4, 4)
-EXPQ_IMPL(qs8, qs8x8, 8)
-EXPQ_IMPL(qs8, qs8x16, 16)
-EXPQ_IMPL(qs16, qs16x2, 2)
-EXPQ_IMPL(qs16, qs16x4, 4)
-EXPQ_IMPL(qs16, qs16x8, 8)
-EXPQ_IMPL(qs16, qs16x16, 16)
-
-#define EXP_OP_EXPAND_STR(a, type, size, position) exp_sat_##type##x##size((a), (position))
-#define EXP_OP_EXPAND(a, type, size, position) EXP_OP_EXPAND_STR(a, type, size, position)
-
-/** Saturate logarithm of a fixed point vector
- *
- * @note Implemented approach uses taylor polynomial to approximate the logarithm function.
- *
- * @param[in] stype the actual scalar data type.
- * @param[in] type the actual data type.
- * @param[in] size the number of the calculated elements.
- *
- * @return The result of the fixed point logarithm. The result is saturated in case of overflow
- */
-#define LOGQ_IMPL(stype, type, size) \
- inline type log_sat_##type(type VopA, int fixed_point_position) \
- { \
- type const_one = (type)(1 << (fixed_point_position)); \
- type ln2 = (type)(0x58B9 >> (15 - fixed_point_position)); /* 1.4384189 */ \
- type A = (type)(0x5C0F >> (14 - fixed_point_position)); /* 1.4384189 */ \
- type B = -(type)(0x56AE >> (15 - fixed_point_position)); /* -0.6771900 */ \
- type C = (type)(0x2933 >> (15 - fixed_point_position)); /* 0.3218538 */ \
- type D = -(type)(0x0AA7 >> (15 - fixed_point_position)); /* -0.0832229 */ \
- type inter_a = select(VopA, DIV_SAT_OP_VEC_EXPAND(const_one, VopA, stype, size, fixed_point_position), VopA < const_one); \
- type shift_val = (type)(15 - stype##_SHIFT) - clz(inter_a >> (type)fixed_point_position); \
- inter_a = inter_a >> shift_val; \
- inter_a = sub_sat(inter_a, const_one); \
- type sum = add_sat(MUL_SAT_OP_EXPAND(inter_a, D, stype, size, fixed_point_position), C); \
- sum = add_sat(MUL_SAT_OP_EXPAND(inter_a, sum, stype, size, fixed_point_position), B); \
- sum = add_sat(MUL_SAT_OP_EXPAND(inter_a, sum, stype, size, fixed_point_position), A); \
- sum = MUL_SAT_OP_EXPAND(inter_a, sum, stype, size, fixed_point_position); \
- sum = MUL_SAT_OP_EXPAND(add_sat(sum, shift_val << (type)fixed_point_position), ln2, stype, size, fixed_point_position); \
- return select(select(sum, -sum, VopA < const_one), (type)0, VopA < (type)0); /* Saturate result if needed */ \
- }
-
-LOGQ_IMPL(qs8, qs8x16, 16)
-LOGQ_IMPL(qs16, qs16x8, 8)
-LOGQ_IMPL(qs16, qs16x16, 16)
-
-#define LOG_OP_EXPAND_STR(a, type, size, position) log_sat_##type##x##size((a), (position))
-#define LOG_OP_EXPAND(a, type, size, position) LOG_OP_EXPAND_STR(a, type, size, position)
-
-/** Saturate inverse square root of a fixed point vector
- *
- * @note Implemented approach uses Newton's method to approximate the inverse square root function.
- *
- * @param[in] stype the actual scalar data type.
- * @param[in] type the actual data type.
- * @param[in] size the number of the calculated elements.
- *
- * @return The result of the fixed point inverse square root. The result is saturated in case of overflow
- */
-#define INVSQRTQ_IMPL(stype, type, size) \
- inline type invsqrt_sat_##type(type VopA, int fixed_point_position) \
- { \
- type const_three = (type)(3 << (fixed_point_position)); \
- type shift_value = (type)(16 - stype##_SHIFT) - (clz(VopA) + (type)fixed_point_position); \
- type temp = select((type)(VopA >> shift_value), select((type)stype##_MAX, (type)(VopA << (-shift_value)), (type)(clz(VopA) > (-shift_value))), (type)(shift_value < (type)0)); \
- type x = temp; \
- x = MUL_SAT_OP_EXPAND(x, sub_sat(const_three, MUL_SAT_OP_EXPAND(MUL_SAT_OP_EXPAND(x, x, stype, size, fixed_point_position), temp, stype, size, fixed_point_position)), stype, size, fixed_point_position) >> 1; \
- x = MUL_SAT_OP_EXPAND(x, sub_sat(const_three, MUL_SAT_OP_EXPAND(MUL_SAT_OP_EXPAND(x, x, stype, size, fixed_point_position), temp, stype, size, fixed_point_position)), stype, size, fixed_point_position) >> 1; \
- x = MUL_SAT_OP_EXPAND(x, sub_sat(const_three, MUL_SAT_OP_EXPAND(MUL_SAT_OP_EXPAND(x, x, stype, size, fixed_point_position), temp, stype, size, fixed_point_position)), stype, size, fixed_point_position) >> 1; \
- if(sizeof((stype)(1)) > 1) /* Perform more iterations if datatype is QS16 */ \
- { \
- x = MUL_SAT_OP_EXPAND(x, sub_sat(const_three, MUL_SAT_OP_EXPAND(MUL_SAT_OP_EXPAND(x, x, stype, size, fixed_point_position), temp, stype, size, fixed_point_position)), stype, size, fixed_point_position) >> 1; \
- x = MUL_SAT_OP_EXPAND(x, sub_sat(const_three, MUL_SAT_OP_EXPAND(MUL_SAT_OP_EXPAND(x, x, stype, size, fixed_point_position), temp, stype, size, fixed_point_position)), stype, size, fixed_point_position) >> 1; \
- } \
- type shift_value2 = select(shift_value >> 1, (-shift_value) >> 1, shift_value < (type)0); \
- return select((type)(x >> shift_value2), select((type)stype##_MAX, (type)(x << shift_value2), (type)(clz(x) > shift_value2)), (type)(shift_value < (type)0)); /* Saturate result if needed */ \
- }
-
-INVSQRTQ_IMPL(qs8, qs8x1, 1)
-INVSQRTQ_IMPL(qs16, qs16x1, 1)
-INVSQRTQ_IMPL(qs8, qs8x16, 16)
-INVSQRTQ_IMPL(qs16, qs16x8, 8)
-
-#define INVSQRT_OP_EXPAND_STR(a, type, size, position) invsqrt_sat_##type##x##size((a), (position))
-#define INVSQRT_OP_EXPAND(a, type, size, position) INVSQRT_OP_EXPAND_STR(a, type, size, position)
-
-/** Saturate hyperbolic tangent of a fixed point vector
- *
- * tanh(x) = (e^2x - 1)/(e^2x + 1)
- *
- * @param[in] stype the actual scalar data type.
- * @param[in] type the actual data type.
- * @param[in] size the number of the calculated elements.
- *
- * @return The result of the fixed point hyperbolic tangent. The result is saturated in case of overflow
- */
-#define TANHQ_IMPL(stype, type, size) \
- inline type tanh_sat_##type(type VopA, int fixed_point_position) \
- { \
- type const_one = (type)(1 << (fixed_point_position)); \
- type const_two = (type)(2 << (fixed_point_position)); \
- type exp2x = EXP_OP_EXPAND(MUL_SAT_OP_EXPAND(const_two, VopA, stype, size, fixed_point_position), stype, size, fixed_point_position); \
- type num = SUB_SAT_OP_EXPAND(exp2x, const_one, stype, size); \
- type den = ADD_SAT_OP_EXPAND(exp2x, const_one, stype, size); \
- return DIV_SAT_OP_VEC_EXPAND(num, den, stype, size, fixed_point_position); \
- }
-
-TANHQ_IMPL(qs8, qs8x16, 16)
-TANHQ_IMPL(qs16, qs16x8, 8)
-
-#define TANH_OP_EXPAND_STR(a, type, size, position) tanh_sat_##type##x##size((a), (position))
-#define TANH_OP_EXPAND(a, type, size, position) TANH_OP_EXPAND_STR(a, type, size, position)
-
-#define floatx16 float16
-#define float16_TYPE float16
-
-#define CONVERTQ_DOWN_IMPL(in_type, out_type) \
- inline out_type convert_##out_type##_##in_type(in_type a, int fixed_point_position) \
- { \
- return CONVERT(a * (1 << fixed_point_position) + select((in_type)-0.5f, (in_type)0.5f, isgreater(a, (in_type)0)), out_type); \
- }
-
-CONVERTQ_DOWN_IMPL(float16, qs8x16)
-CONVERTQ_DOWN_IMPL(float16, qs16x16)
-
-#define CONVERTQ_DOWN_SAT_IMPL(in_type, out_type) \
- inline out_type convert_##out_type##_##in_type##_sat(in_type a, int fixed_point_position) \
- { \
- return CONVERT_SAT(a * (1 << fixed_point_position) + select((in_type)-0.5f, (in_type)0.5f, isgreater(a, (in_type)0)), out_type); \
- }
-
-CONVERTQ_DOWN_SAT_IMPL(float16, qs8x16)
-CONVERTQ_DOWN_SAT_IMPL(float16, qs16x16)
-
-#define CONVERTQ_UP_IMPL(in_type, out_type) \
- inline out_type convert_##out_type##_##in_type(in_type a, int fixed_point_position) \
- { \
- return CONVERT(a, out_type) / (1 << fixed_point_position); \
- }
-
-CONVERTQ_UP_IMPL(qs8x16, float16)
-CONVERTQ_UP_IMPL(qs16x16, float16)
-
-#define SQCVT_SAT_IMPL(type) \
- inline type sqcvt_##type##_sat(float a, int fixed_point_position) \
- { \
- return CONVERT_SAT((a * (1 << fixed_point_position) + ((a < 0) ? -0.5f : 0.5f)), type); \
- }
-
-SQCVT_SAT_IMPL(qs8)
-SQCVT_SAT_IMPL(qs16)
-
-#define SQCVT_SAT_OP_EXPAND_STR(a, type, position) sqcvt_##type##_sat((a), (position))
-#define SQCVT_SAT_OP_EXPAND(a, type, position) SQCVT_SAT_OP_EXPAND_STR((a), type, position)
-
-#endif // ARM_COMPUTE_FIXED_POINT_H
diff --git a/src/core/CL/cl_kernels/flatten.cl b/src/core/CL/cl_kernels/flatten.cl
new file mode 100644
index 0000000..df0f9c4
--- /dev/null
+++ b/src/core/CL/cl_kernels/flatten.cl
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+
+#if defined(DATA_TYPE) && defined(SRC_WIDTH) && defined(SRC_HEIGHT)
+
+/** This opencl kernel flattens the first 3 dimensions of the input tensor
+ *
+ * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=float
+ * @note The width and height of the input tensor must be passed at compile time using -DSRC_WIDTH and -DSRC_HEIGHT. e.g. -DSRC_WIDTH=24, -DSRC_HEIGHT=24
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void flatten(
+ TENSOR3D_DECLARATION(src),
+ VECTOR_DECLARATION(dst))
+{
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+
+ __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) + get_global_id(1) * (int)SRC_WIDTH + get_global_id(2) * (int)(SRC_WIDTH * SRC_HEIGHT)) * sizeof(
+ DATA_TYPE);
+
+ *((__global DATA_TYPE *)output_ptr) = *((__global DATA_TYPE *)src.ptr);
+}
+#endif // defined(DATA_TYPE) && defined(SRC_WIDTH) && defined(SRC_HEIGHT)
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 9ed3af8..932e0d6 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -23,10 +23,6 @@
*/
#include "helpers.h"
-#ifdef FIXED_POINT_POSITION
-#include "fixed_point.h"
-#endif // FIXED_POINT_POSITION
-
#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
#if ELEMENT_SIZE == 1
@@ -44,7 +40,7 @@
* @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
* @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
*
- * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
+ * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
@@ -92,8 +88,13 @@
*
* @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
* @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
+ * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
*
- * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
+ * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
@@ -109,9 +110,15 @@
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
*/
__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
+ TENSOR3D_DECLARATION(dst)
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
+ )
{
// Compute source and destination addresses
uint x = get_global_id(0);
@@ -128,6 +135,45 @@
// Add offset for batched GEMM
dst_addr_in_bytes += z * dst_stride_z;
+#if defined(REINTERPRET_INPUT_AS_3D)
+ __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
+
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (cross_plane_pad * src_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src_stride_z by DEPTH_GEMM3D
+ input_ptr += z * src_stride_z * DEPTH_GEMM3D;
+
+ // Load values from Matrix A
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
+#else // defined(REINTERPRET_INPUT_AS_3D)
__global uchar *input_ptr = src.ptr;
// Load values from Matrix A
@@ -139,6 +185,7 @@
a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
VEC_DATA_TYPE(DATA_TYPE, 4)
val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
@@ -165,6 +212,12 @@
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -183,13 +236,22 @@
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint cross_plane_pad
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -273,6 +335,40 @@
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the cross plane paddings
+ zout *= (cross_plane_pad * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x4 block
+ vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -281,6 +377,7 @@
vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
@@ -293,6 +390,12 @@
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -311,13 +414,22 @@
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint cross_plane_pad
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -533,6 +645,40 @@
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the cross plane paddings
+ zout *= (cross_plane_pad * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x4 block
+ vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -541,6 +687,7 @@
vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
// Undefine local defines
@@ -556,6 +703,12 @@
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -574,13 +727,22 @@
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint cross_plane_pad
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -664,6 +826,40 @@
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the cross plane paddings
+ zout *= (cross_plane_pad * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x8 block
+ vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -672,6 +868,7 @@
vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
@@ -683,6 +880,12 @@
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -701,13 +904,19 @@
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint cross_plane_pad
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -873,6 +1082,40 @@
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the cross plane paddings
+ zout *= (cross_plane_pad * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x8 block
+ vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -881,6 +1124,7 @@
vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
// Undefine local defines
@@ -888,242 +1132,6 @@
#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
-#if defined(FIXED_POINT_POSITION)
-/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
- * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_8bit and @ref gemm_transpose1x16 before running the matrix multiplication
- *
- * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
- * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
- * @note:ALPHA must be passed in 8 bit fixed point format
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- */
-__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z)
-{
- int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
- int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
- int z = get_global_id(2);
-
- // Offset
- const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
- const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 16;
-
- // src_addr_a = address of matrix A
- // src_addr_b = address of matrix B
- int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
- int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src1_addr_in_bytes += z * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- __global char *src_addr_a = (__global char *)(src0_ptr + src0_addr_in_bytes);
- __global char *src_addr_b = (__global char *)(src1_ptr + src1_addr_in_bytes);
-
- // Compute end row address for matrix B
- __global char *src_end_addr_b = src_addr_b + COLS_B;
-
- src_addr_a += offset_row_a;
- src_addr_b += offset_row_b;
-
- // Reset accumulators
- short8 c00 = 0.0f;
- short8 c10 = 0.0f;
- short8 c20 = 0.0f;
- short8 c30 = 0.0f;
- short8 c01 = 0.0f;
- short8 c11 = 0.0f;
- short8 c21 = 0.0f;
- short8 c31 = 0.0f;
-
- // This for loop performs 1 accumulation for each iteration
- for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
- {
- // Load values from matrix A (interleaved) and matrix B (transposed)
- char4 a0 = vload4(0, src_addr_a);
- char16 b0 = vload16(0, src_addr_b);
-
- c00 = mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
- c10 = mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
- c20 = mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
- c30 = mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
-
- c01 = mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
- c11 = mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
- c21 = mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
- c31 = mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
- }
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Multiply by the weight of matrix product
- char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
- char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
- char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
- char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
-
-#if defined(ALPHA)
- c00_qs8 = mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
- c10_qs8 = mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
- c20_qs8 = mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
- c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
-
- // Compute dst address
- __global uchar *dst_addr = offset(&dst, 0, 0);
-
- // Add offset for batched GEMM
- dst_addr += z * dst_stride_z;
-
- // Store 16x4 block
- vstore16(c00_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
- vstore16(c10_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
- vstore16(c20_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
- vstore16(c30_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
-}
-
-/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
- * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
- *
- * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
- * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
- * @note:ALPHA must be passed in 16 bit fixed point format
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS16
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- */
-__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z)
-{
- int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
- int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
- int z = get_global_id(2);
-
- // Offset
- const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
- const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
-
- // src_addr_a = address of matrix A
- // src_addr_b = address of matrix B
- int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
- int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src1_addr_in_bytes += z * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- __global short *src_addr_a = (__global short *)(src0_ptr + src0_addr_in_bytes);
- __global short *src_addr_b = (__global short *)(src1_ptr + src1_addr_in_bytes);
-
- // Compute end row address for matrix B
- __global short *src_end_addr_b = src_addr_b + COLS_B;
-
- src_addr_a += offset_row_a;
- src_addr_b += offset_row_b;
-
- // Reset accumulators
- int8 c00 = 0.0f;
- int8 c10 = 0.0f;
- int8 c20 = 0.0f;
- int8 c30 = 0.0f;
-
- // This for loop performs 1 accumulation for each iteration
- for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
- {
- /* Load values from matrix A (interleaved) and matrix B (transposed) */
- short4 a0 = vload4(0, src_addr_a);
- short8 b0 = vload8(0, src_addr_b);
-
- c00 = mlal_sat_qs16x8(c00, (short8)a0.s0, b0, FIXED_POINT_POSITION);
- c10 = mlal_sat_qs16x8(c10, (short8)a0.s1, b0, FIXED_POINT_POSITION);
- c20 = mlal_sat_qs16x8(c20, (short8)a0.s2, b0, FIXED_POINT_POSITION);
- c30 = mlal_sat_qs16x8(c30, (short8)a0.s3, b0, FIXED_POINT_POSITION);
- }
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Multiply by the weight of matrix product
- short8 c00_qs16 = convert_short8_sat(c00);
- short8 c10_qs16 = convert_short8_sat(c10);
- short8 c20_qs16 = convert_short8_sat(c20);
- short8 c30_qs16 = convert_short8_sat(c30);
-
-#if defined(ALPHA)
- c00_qs16 = mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
- c10_qs16 = mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
- c20_qs16 = mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
- c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
-
- // Compute dst address
- __global uchar *dst_addr = offset(&dst, 0, 0);
-
- // Add offset for batched GEMM
- dst_addr += z * dst_stride_z;
-
- // Store 8x4 block
- vstore8(c00_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
- vstore8(c10_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
- vstore8(c20_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
- vstore8(c30_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
-}
-#endif // defined(FIXED_POINT_POSITION)
#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
@@ -1138,6 +1146,13 @@
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1156,13 +1171,27 @@
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint dst_cross_plane_pad
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1175,9 +1204,40 @@
// Update address for the matrix B
src_addr.s1 += idx * sizeof(DATA_TYPE);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1200,6 +1260,23 @@
for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
VEC_DATA_TYPE(DATA_TYPE, 2)
a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
@@ -1215,6 +1292,8 @@
VEC_DATA_TYPE(DATA_TYPE, 2)
a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
@@ -1238,6 +1317,19 @@
for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1249,6 +1341,8 @@
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
@@ -1271,36 +1365,85 @@
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
-
// Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
acc0 = acc0 * (VECTOR_TYPE)ALPHA;
#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+ acc1 = acc1 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+ acc2 = acc2 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+ acc3 = acc3 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ int z = get_global_id(2);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the cross plane paddings
+ zout *= (dst_cross_plane_pad * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store output block
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
+ // Store output block
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
- acc1 = acc1 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
- acc2 = acc2 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
- acc3 = acc3 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
#endif // defined(DATA_TYPE)
@@ -1314,6 +1457,13 @@
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1332,13 +1482,27 @@
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint dst_cross_plane_pad
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1351,9 +1515,40 @@
// Update address for matrix B
src_addr.s1 += idx * sizeof(float);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1392,6 +1587,19 @@
int i = 0;
for(; i <= ((int)COLS_A - 4); i += 4)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A and matrix B
+ float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A and matrix B
float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1403,6 +1611,8 @@
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -1543,8 +1753,21 @@
for(; i < (int)COLS_A; ++i)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
- float a0 = *((__global float *)(src0_ptr + src_addr.s0));
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1554,6 +1777,8 @@
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -1585,6 +1810,8 @@
src_addr.s0 += sizeof(float);
}
+ int z = get_global_id(2);
+
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
@@ -1595,46 +1822,83 @@
acc02 = acc02 * ALPHA;
acc03 = acc03 * ALPHA;
#endif // defined(ALPHA)
-
- // Compute dst address
- __global uchar *dst_addr = offset(&dst, 0, 0);
-
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
-
- float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
- vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
acc10 = acc10 * ALPHA;
acc11 = acc11 * ALPHA;
acc12 = acc12 * ALPHA;
acc13 = acc13 * ALPHA;
-#endif // defined(ALPHA)
- float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
- vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
acc20 = acc20 * ALPHA;
acc21 = acc21 * ALPHA;
acc22 = acc22 * ALPHA;
acc23 = acc23 * ALPHA;
-#endif // defined(ALPHA)
- float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
- vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
acc30 = acc30 * ALPHA;
acc31 = acc31 * ALPHA;
acc32 = acc32 * ALPHA;
acc33 = acc33 * ALPHA;
-#endif // defined(ALPHA)
- float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
- vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the cross plane paddings
+ zout *= (dst_cross_plane_pad * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
+ // Store the output block
+ vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
@@ -1648,6 +1912,13 @@
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1666,13 +1937,27 @@
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint dst_cross_plane_pad
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
// Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1686,9 +1971,40 @@
// Update address for the matrix B
src_addr.s1 += idx * sizeof(float);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1717,8 +2033,13 @@
int i = 0;
for(; i <= ((int)COLS_A - 8); i += 8)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix B
float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
@@ -1758,7 +2079,11 @@
acc01 = fma(a0.s7, b7.s1, acc01);
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
acc10 = fma(a0.s0, b0.s0, acc10);
acc10 = fma(a0.s1, b1.s0, acc10);
acc10 = fma(a0.s2, b2.s0, acc10);
@@ -1778,7 +2103,11 @@
acc11 = fma(a0.s7, b7.s1, acc11);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
acc20 = fma(a0.s0, b0.s0, acc20);
acc20 = fma(a0.s1, b1.s0, acc20);
acc20 = fma(a0.s2, b2.s0, acc20);
@@ -1798,7 +2127,11 @@
acc21 = fma(a0.s7, b7.s1, acc21);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
acc30 = fma(a0.s0, b0.s0, acc30);
acc30 = fma(a0.s1, b1.s0, acc30);
acc30 = fma(a0.s2, b2.s0, acc30);
@@ -1823,6 +2156,19 @@
// float size increment
for(; i < (int)COLS_A; ++i)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1834,6 +2180,8 @@
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -1857,48 +2205,90 @@
src_addr.s0 += sizeof(float);
}
+ // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+ acc00 = acc00 * ALPHA;
+ acc01 = acc01 * ALPHA;
+#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+ acc10 = acc10 * ALPHA;
+ acc11 = acc11 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+ acc20 = acc20 * ALPHA;
+ acc21 = acc21 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+ acc30 = acc30 * ALPHA;
+ acc31 = acc31 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ int z = get_global_id(2);
+
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
- // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
- acc00 = acc00 * ALPHA;
- acc01 = acc01 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc0 = ((float2)(acc00, acc01));
- vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the cross plane paddings
+ zout *= (dst_cross_plane_pad * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
- acc10 = acc10 * ALPHA;
- acc11 = acc11 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc1 = ((float2)(acc10, acc11));
- vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+ vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
- acc20 = acc20 * ALPHA;
- acc21 = acc21 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc2 = ((float2)(acc20, acc21));
- vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+ vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
- acc30 = acc30 * ALPHA;
- acc31 = acc31 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc3 = (float2)(acc30, acc31);
- vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+ vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
+ // Store the output block
+ vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
+#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
*
* @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
@@ -1909,6 +2299,13 @@
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1927,13 +2324,27 @@
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint dst_cross_plane_pad
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1946,9 +2357,40 @@
// Update address for the matrix B
src_addr.s1 += idx * sizeof(half);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1970,6 +2412,19 @@
int i = 0;
for(; i <= ((int)COLS_A - 4); i += 4)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1981,6 +2436,8 @@
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -2041,6 +2498,19 @@
for(; i < (int)COLS_A; ++i)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -2052,6 +2522,8 @@
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
@@ -2070,393 +2542,86 @@
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
}
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Compute dst address
- __global uchar *dst_addr = offset(&dst, 0, 0);
-
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
-
// Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
acc0 = acc0 * (half8)ALPHA;
#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+ acc1 = acc1 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+ acc2 = acc2 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+ acc3 = acc3 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ int z = get_global_id(2);
+
+ // Compute destination address
+ Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the cross plane paddings
+ zout *= (dst_cross_plane_pad * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
+ // Store the output block
vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
- acc1 = acc1 * (half8)ALPHA;
-#endif // defined(ALPHA)
vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
- acc2 = acc2 * (half8)ALPHA;
-#endif // defined(ALPHA)
vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
- acc3 = acc3 * (half8)ALPHA;
-#endif // defined(ALPHA)
vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // REINTERPRET_OUTPUT_AS_3D
}
+#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
-#if defined(FIXED_POINT_POSITION)
-/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
- *
- * @note This OpenCL kernel works with fixed point data types QS8
- * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
- * @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
- * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
- * @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- */
-__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z)
-{
- int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
-
- // Compute starting address for matrix A and Matrix B
- int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
-
- // Update address for the matrix A
- src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
-
- // Update address for the matrix B
- src_addr.s1 += idx * sizeof(char);
-
- // Add offset for batched GEMM
- src_addr.s0 += get_global_id(2) * src0_stride_z;
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src_addr.s1 += get_global_id(2) * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
-
- short8 acc00 = 0;
- short8 acc01 = 0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- short8 acc10 = 0;
- short8 acc11 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- short8 acc20 = 0;
- short8 acc21 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- short8 acc30 = 0;
- short8 acc31 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- // This for loop performs 4 accumulations per iteration
- for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
- {
- char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
- char16 b1 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
-
- acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
- acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
- acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
- acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
- acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
- acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
- acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
- acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
- acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
- acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
- acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
- acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
- acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- }
-
- // Left-over accumulations
- for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
- {
- char a0 = *((__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- char a1 = *((__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- char a2 = *((__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- char a3 = *((__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1));
-
- acc00 = mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
- acc01 = mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc10 = mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
- acc11 = mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc20 = mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
- acc21 = mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc30 = mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
- acc31 = mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- }
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Compute dst address
- __global uchar *dst_addr = offset(&dst, 0, 0);
-
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
-
- // Multiply by the weight of matrix product and store the result
- char16 acc_qs8;
- acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
-#if defined(ALPHA)
- acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
- vstore16(acc_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
-#if defined(ALPHA)
- acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
- vstore16(acc_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
-#if defined(ALPHA)
- acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
- vstore16(acc_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
-#if defined(ALPHA)
- acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
- vstore16(acc_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-}
-
-/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
- *
- * @note This OpenCL kernel works with fixed point data types QS16
- * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
- * @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
- * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
- * @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- */
-__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z)
-{
- int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
-
- // Compute starting address for matrix A and Matrix B
- int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
-
- // Update address for the matrix A
- src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
-
- // Update address for the matrix B
- src_addr.s1 += idx * sizeof(short);
-
- // Add offset for batched GEMM
- src_addr.s0 += get_global_id(2) * src0_stride_z;
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src_addr.s1 += get_global_id(2) * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
-
- int8 acc0 = 0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- int8 acc1 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- int8 acc2 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- int8 acc3 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- // This for loop performs 4 accumulations per iteration
- for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(short)); src_addr += (int2)(2 * sizeof(short), 2 * src1_stride_y))
- {
- short2 a0 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- short2 a1 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- short2 a2 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- short2 a3 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
- short8 b1 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
-
- acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
- acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
- acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
- acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
- acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- }
-
- // Left-over accumulations
- for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(short), src1_stride_y))
- {
- short a0 = *((__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- short a1 = *((__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- short a2 = *((__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- short a3 = *((__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1));
-
- acc0 = mlal_sat_qs16x8(acc0, (short8)a0, b0, FIXED_POINT_POSITION);
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc1 = mlal_sat_qs16x8(acc1, (short8)a1, b0, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc2 = mlal_sat_qs16x8(acc2, (short8)a2, b0, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc3 = mlal_sat_qs16x8(acc3, (short8)a3, b0, FIXED_POINT_POSITION);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- }
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Compute dst address
- __global uchar *dst_addr = offset(&dst, 0, 0);
-
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
-
- // Multiply by the weight of matrix product and store the result
- short8 acc_qs16;
- acc_qs16 = convert_short8_sat(acc0);
-#if defined(ALPHA)
- acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
- vstore8(acc_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc_qs16 = convert_short8_sat(acc1);
-#if defined(ALPHA)
- acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
- vstore8(acc_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc_qs16 = convert_short8_sat(acc2);
-#if defined(ALPHA)
- acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
- vstore8(acc_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc_qs16 = convert_short8_sat(acc3);
-#if defined(ALPHA)
- acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
-#endif // defined(ALPHA)
- vstore8(acc_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-}
-#endif // defined(FIXED_POINT_POSITION)
#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
#if defined(BETA)
@@ -2469,20 +2634,24 @@
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
@@ -2497,6 +2666,7 @@
vstore4(out, 0, (__global float *)dst.ptr);
}
+#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
*
* @note The beta's value need to be passed at compile time using -DBETA
@@ -2506,20 +2676,24 @@
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
@@ -2533,86 +2707,7 @@
// Store final result in axb matrix
vstore8(out, 0, (__global half *)dst.ptr);
}
-
-#if defined(FIXED_POINT_POSITION)
-/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 8 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
- *
- * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
- *
- * @note: BETA must be passed in 8 bit fixed point format
- *
- * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS8
- * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- */
-__kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
-{
- // Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Load values from A x B
- char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
-
- // Load values from Matrix C
- char16 c = vload16(0, (__global char *)src.ptr);
-
- // Computes alpha * axb + beta * c
- char16 out = mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
-
- // Store final result in axb matrix
- vstore16(out, 0, (__global char *)dst.ptr);
-}
-
-/** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 16 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
- *
- * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
- *
- * @note: BETA must be passed in 16 bit fixed point format
- *
- * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS16
- * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- */
-__kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
-{
- // Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Load values from A x B
- short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
-
- // Load values from Matrix C
- short8 c = vload8(0, (__global short *)src.ptr);
-
- // Computes alpha * axb + beta * c
- short8 out = mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
-
- // Store final result in axb matrix
- vstore8(out, 0, (__global short *)dst.ptr);
-}
-#endif // defined(FIXED_POINT_POSITION)
+#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
#endif // defined(BETA)
#if defined(WIDTH_VECTOR_A)
@@ -2688,7 +2783,7 @@
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
* @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
*
- * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/QS8/U16/S16/F16/U32/S32/F32
+ * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
* @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
* @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
@@ -2712,11 +2807,7 @@
accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
-#ifdef FIXED_POINT_POSITION
- accum_value = ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
-#else // FIXED_POINT_POSITION
- accum_value = biases_value + accum_value;
-#endif // FIXED_POINT_POSITION
+ accum_value = biases_value + accum_value;
// Store result in the accumulate buffer
VSTORE(VECTOR_SIZE)
(accum_value, 0, (__global DATA_TYPE *)accum.ptr);
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index 5e144d7..cd8b269 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -24,6 +24,14 @@
#include "helpers.h"
#include "helpers_asymm.h"
+#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val = arm_dot_acc((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3), val);
+#else // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val += arm_dot((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3));
+#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+
#if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
* Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
@@ -414,6 +422,173 @@
vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(offset(&dst, 0, 2)));
vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(offset(&dst, 0, 3)));
}
+
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+/** This OpenCL kernel is optimized for Bifrost and computes the matrix multiplication between matrix A (src0) and matrix B (src1)
+ * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
+ *
+ * @attention The number of matrix B columns needs to be passed at compile time using -DCOLS_B
+ * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ *
+ * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
+ * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
+ * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ */
+__kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION(src0),
+ IMAGE_DECLARATION(src1),
+ IMAGE_DECLARATION(dst))
+{
+ int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP;
+ int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
+
+ // Offset
+ const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
+ const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
+
+ // src_addr_a = address of matrix A
+ // src_addr_b = address of matrix B
+ __global uchar *src_addr_a = (__global uchar *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
+ __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
+
+ // Compute end row address for matrix B
+ __global uchar *src_end_addr_b = src_addr_b + COLS_B;
+
+ src_addr_a += offset_row_a;
+ src_addr_b += offset_row_b;
+
+ // Reset accumulators
+ uint c00 = 0;
+ uint c01 = 0;
+ uint c02 = 0;
+ uint c03 = 0;
+ uint c10 = 0;
+ uint c11 = 0;
+ uint c12 = 0;
+ uint c13 = 0;
+ uint c20 = 0;
+ uint c21 = 0;
+ uint c22 = 0;
+ uint c23 = 0;
+ uint c30 = 0;
+ uint c31 = 0;
+ uint c32 = 0;
+ uint c33 = 0;
+
+#if MULT_INTERLEAVE4X4_HEIGHT == 1
+ for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP))
+ {
+ // Load values from matrix A (interleaved) and matrix B (transposed)
+ uchar16 a0 = vload16(0, src_addr_a);
+ uchar4 b0 = vload4(0, src_addr_b);
+ uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP);
+ uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP);
+ uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
+
+ // Accumulate
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s0, b1.s0, b2.s0, b3.s0, c00);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s1, b1.s1, b2.s1, b3.s1, c01);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s2, b1.s2, b2.s2, b3.s2, c02);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s3, b1.s3, b2.s3, b3.s3, c03);
+
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s0, b1.s0, b2.s0, b3.s0, c10);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s1, b1.s1, b2.s1, b3.s1, c11);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s2, b1.s2, b2.s2, b3.s2, c12);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s3, b1.s3, b2.s3, b3.s3, c13);
+
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s0, b1.s0, b2.s0, b3.s0, c20);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s1, b1.s1, b2.s1, b3.s1, c21);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s2, b1.s2, b2.s2, b3.s2, c22);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s3, b1.s3, b2.s3, b3.s3, c23);
+
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s0, b1.s0, b2.s0, b3.s0, c30);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s1, b1.s1, b2.s1, b3.s1, c31);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s2, b1.s2, b2.s2, b3.s2, c32);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s3, b1.s3, b2.s3, b3.s3, c33);
+
+ // Load values from matrix A (interleaved) and matrix B (transposed)
+ a0 = vload16(0, src_addr_a + 16);
+ b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP);
+ b1 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP);
+ b2 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP);
+ b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
+
+ // Accumulate
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s0, b1.s0, b2.s0, b3.s0, c00);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s1, b1.s1, b2.s1, b3.s1, c01);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s2, b1.s2, b2.s2, b3.s2, c02);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s3, b1.s3, b2.s3, b3.s3, c03);
+
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s0, b1.s0, b2.s0, b3.s0, c10);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s1, b1.s1, b2.s1, b3.s1, c11);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s2, b1.s2, b2.s2, b3.s2, c12);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s3, b1.s3, b2.s3, b3.s3, c13);
+
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s0, b1.s0, b2.s0, b3.s0, c20);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s1, b1.s1, b2.s1, b3.s1, c21);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s2, b1.s2, b2.s2, b3.s2, c22);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s3, b1.s3, b2.s3, b3.s3, c23);
+
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s0, b1.s0, b2.s0, b3.s0, c30);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s1, b1.s1, b2.s1, b3.s1, c31);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s2, b1.s2, b2.s2, b3.s2, c32);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s3, b1.s3, b2.s3, b3.s3, c33);
+ }
+#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
+
+ for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP))
+ {
+ // Load values from matrix A (interleaved) and matrix B (transposed)
+ uchar4 a0 = vload4(0, src_addr_a);
+ uchar4 b0 = vload4(0, src_addr_b);
+
+ c00 += (ushort)a0.s0 * b0.s0;
+ c01 += (ushort)a0.s0 * b0.s1;
+ c02 += (ushort)a0.s0 * b0.s2;
+ c03 += (ushort)a0.s0 * b0.s3;
+
+ c10 += (ushort)a0.s1 * b0.s0;
+ c11 += (ushort)a0.s1 * b0.s1;
+ c12 += (ushort)a0.s1 * b0.s2;
+ c13 += (ushort)a0.s1 * b0.s3;
+
+ c20 += (ushort)a0.s2 * b0.s0;
+ c21 += (ushort)a0.s2 * b0.s1;
+ c22 += (ushort)a0.s2 * b0.s2;
+ c23 += (ushort)a0.s2 * b0.s3;
+
+ c30 += (ushort)a0.s3 * b0.s0;
+ c31 += (ushort)a0.s3 * b0.s1;
+ c32 += (ushort)a0.s3 * b0.s2;
+ c33 += (ushort)a0.s3 * b0.s3;
+ }
+
+ // Compute destination address
+ Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+
+ // Store 4x4 block
+ vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(offset(&dst, 0, 0)));
+ vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(offset(&dst, 0, 1)));
+ vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(offset(&dst, 0, 2)));
+ vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(offset(&dst, 0, 3)));
+}
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
+
#endif // defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
#if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
@@ -918,6 +1093,254 @@
vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(offset(&dst, 0, 4)));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}
+
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+/** OpenCL kernel optimized to use dot product that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
+ *
+ * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
+ *
+ * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
+ * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
+ * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ */
+__kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0),
+ IMAGE_DECLARATION(src1),
+ IMAGE_DECLARATION(dst))
+{
+ int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
+
+ // Compute starting address for matrix A and Matrix B
+ int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
+
+ // Update address for the matrix A
+ src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
+
+ // Update address for the matrix B
+ src_addr.s1 += idx;
+
+ int end_row_vec_a = src_addr.s0 + COLS_A;
+
+ uint acc00 = 0;
+ uint acc01 = 0;
+ uint acc02 = 0;
+ uint acc03 = 0;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ uint acc10 = 0;
+ uint acc11 = 0;
+ uint acc12 = 0;
+ uint acc13 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ uint acc20 = 0;
+ uint acc21 = 0;
+ uint acc22 = 0;
+ uint acc23 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ uint acc30 = 0;
+ uint acc31 = 0;
+ uint acc32 = 0;
+ uint acc33 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uint acc40 = 0;
+ uint acc41 = 0;
+ uint acc42 = 0;
+ uint acc43 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+
+ for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y))
+ {
+ // Load values from matrix A
+ uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ // Load values from matrix B
+ uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
+ uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
+ uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
+ uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
+
+ {
+ // Accumulate
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a0.s0, a0.s1, a0.s2, a0.s3, acc00);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a0.s0, a0.s1, a0.s2, a0.s3, acc01);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a0.s0, a0.s1, a0.s2, a0.s3, acc02);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a0.s0, a0.s1, a0.s2, a0.s3, acc03);
+ }
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ {
+ // Accumulate
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a1.s0, a1.s1, a1.s2, a1.s3, acc10);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a1.s0, a1.s1, a1.s2, a1.s3, acc11);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a1.s0, a1.s1, a1.s2, a1.s3, acc12);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a1.s0, a1.s1, a1.s2, a1.s3, acc13);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ {
+ // Accumulate
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a2.s0, a2.s1, a2.s2, a2.s3, acc20);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a2.s0, a2.s1, a2.s2, a2.s3, acc21);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a2.s0, a2.s1, a2.s2, a2.s3, acc22);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a2.s0, a2.s1, a2.s2, a2.s3, acc23);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ {
+ // Accumulate
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a3.s0, a3.s1, a3.s2, a3.s3, acc30);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a3.s0, a3.s1, a3.s2, a3.s3, acc31);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a3.s0, a3.s1, a3.s2, a3.s3, acc32);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a3.s0, a3.s1, a3.s2, a3.s3, acc33);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ {
+ // Accumulate
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a4.s0, a4.s1, a4.s2, a4.s3, acc40);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a4.s0, a4.s1, a4.s2, a4.s3, acc41);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a4.s0, a4.s1, a4.s2, a4.s3, acc42);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a4.s0, a4.s1, a4.s2, a4.s3, acc43);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ }
+
+ for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
+ {
+ // Load values from matrix A
+ uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ // Load values from matrix B
+ uchar4 b0 = vload4(0, src1_ptr + src_addr.s1);
+
+ // Accumulate
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a0;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a0;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a0;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a0;
+
+ acc00 += ((uint)tmp0);
+ acc01 += ((uint)tmp1);
+ acc02 += ((uint)tmp2);
+ acc03 += ((uint)tmp3);
+ }
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a1;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a1;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a1;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a1;
+
+ acc10 += ((uint)tmp0);
+ acc11 += ((uint)tmp1);
+ acc12 += ((uint)tmp2);
+ acc13 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a2;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a2;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a2;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a2;
+
+ acc20 += ((uint)tmp0);
+ acc21 += ((uint)tmp1);
+ acc22 += ((uint)tmp2);
+ acc23 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a3;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a3;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a3;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a3;
+
+ acc30 += ((uint)tmp0);
+ acc31 += ((uint)tmp1);
+ acc32 += ((uint)tmp2);
+ acc33 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a4;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a4;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a4;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a4;
+
+ acc40 += ((uint)tmp0);
+ acc41 += ((uint)tmp1);
+ acc42 += ((uint)tmp2);
+ acc43 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ }
+
+ // Compute destination address
+ Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+
+ // Store the result
+ vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(offset(&dst, 0, 0)));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(offset(&dst, 0, 1)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(offset(&dst, 0, 2)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(offset(&dst, 0, 3)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(offset(&dst, 0, 4)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+}
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
+
#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
#if defined(COLS_A)
diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h
index 615c518..3f7a2a5 100644
--- a/src/core/CL/cl_kernels/helpers.h
+++ b/src/core/CL/cl_kernels/helpers.h
@@ -28,6 +28,14 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
+#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
+#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+
+#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
+#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+
#if defined(ARM_COMPUTE_DEBUG_ENABLED)
#if defined(cl_arm_printf)
#pragma OPENCL EXTENSION cl_arm_printf : enable
diff --git a/src/core/CL/cl_kernels/helpers_asymm.h b/src/core/CL/cl_kernels/helpers_asymm.h
index c314d17..a69bcc1 100644
--- a/src/core/CL/cl_kernels/helpers_asymm.h
+++ b/src/core/CL/cl_kernels/helpers_asymm.h
@@ -62,7 +62,6 @@
b_64 = convert_long##size(b); \
VEC_DATA_TYPE(long, size) \
ab_64 = a_64 * b_64; \
- /* COMPMID-907 */ \
VEC_DATA_TYPE(int, size) \
ab_x2_high32 = convert_int##size(((ab_64 + (1 << 30)) >> 31)); \
return select(ab_x2_high32, INT_MAX, overflow); \
@@ -367,4 +366,4 @@
ASYMM_RESCALE_IMPL(8)
ASYMM_RESCALE_IMPL(16)
-#endif // ARM_COMPUTE_HELPERS_ASYMM_H
\ No newline at end of file
+#endif // ARM_COMPUTE_HELPERS_ASYMM_H
diff --git a/src/core/CL/cl_kernels/im2col.cl b/src/core/CL/cl_kernels/im2col.cl
index 1e85e1b..186d5a8 100644
--- a/src/core/CL/cl_kernels/im2col.cl
+++ b/src/core/CL/cl_kernels/im2col.cl
@@ -23,12 +23,7 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-#include "fixed_point.h"
-#endif // FIXED_POINT_POSITION
-
#if defined(DATA_TYPE) && defined(ELEMENT_SIZE)
-#if !defined(FIXED_POINT_POSITION)
#if ELEMENT_SIZE == 1
#define COND_DATA_TYPE char
@@ -40,17 +35,17 @@
#error "Element size not support"
#endif // ELEMENT_SIZE
-#if defined(CONVOLVED_WIDTH) && defined(STRIDE_Y) && defined(KERNEL_DEPTH)
-/** This kernel performs a reshaping of the input tensor to a tensor used to perform convolution using GEMM when the kernel size is 1x1 and the stride_x = 1
+#if defined(CONVOLVED_WIDTH) && defined(STRIDE_Y) && defined(SRC_DEPTH)
+/** This opencl kernel performs im2col when the kernel size is 1x1, the stride_x = 1 and the data layout is NCHW
*
- * @note This kernel computes 4 elements
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
* @note The width of output tensor after matrix multiplication must be passed at compile time using -DCONVOLVED_WIDTH: e.g. -DCONVOLVED_WIDTH=34
- * @note The kernel depth must be passed at compile time using -DKERNEL_DEPTH: e.g. -DKERNEL_DEPTH=3
+ * @note The number of input channels must be passed at compile time using -DSRC_DEPTH: e.g. -DSRC_DEPTH=3
* @note The stride along the Y direction must be passed at compile time using -DSTRIDE_Y: e.g. -DSTRIDE_Y=1
* @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
+ * @note In case grouping is performed, the number of groups must be passed at compile time using -DNUM_GROUPS: e.g. -DNUM_GROUPS=4
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QASYMM8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -63,20 +58,26 @@
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes).
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes).
*/
-__kernel void im2col1x1_stridex1_dchw(
+__kernel void im2col1x1_stridex1_nchw(
TENSOR3D_DECLARATION(src),
+#if defined(NUM_GROUPS)
+ TENSOR3D_DECLARATION(dst),
+#else // defined(NUM_GROUPS)
IMAGE_DECLARATION(dst),
+#endif // defined(NUM_GROUPS)
uint src_stride_w,
uint dst_stride_w)
{
- const uint xc = get_global_id(0) * 4; // x coordinate in the convolved tensor
- const uint yc = get_global_id(1); // y coordinate in the convolved tensor
- const uint ch = get_global_id(2) % KERNEL_DEPTH; // input feature map
- const uint batch = get_global_id(2) / KERNEL_DEPTH; // batch size
+ const uint xc = get_global_id(0) * 4; // x coordinate in the convolved tensor
+ const uint yc = get_global_id(1); // y coordinate in the convolved tensor
+ const uint ch = get_global_id(2) % SRC_DEPTH; // input feature map
+ const uint batch = get_global_id(2) / SRC_DEPTH; // batch size
// Clamp xc
// The strategy clamps at "xc" as it will be a valid value for sure
@@ -92,13 +93,22 @@
const uint yi = yc * STRIDE_Y;
// Calculate output indices
- const uint xo = ch;
+
+#if defined(NUM_GROUPS)
+ const uint xo = ch % (SRC_DEPTH / NUM_GROUPS);
+ const uint zo = ch / (SRC_DEPTH / NUM_GROUPS);
+#else // defined(NUM_GROUPS)
+ const uint xo = ch;
+#endif // defined(NUM_GROUPS)
const uint4 yo = xc_clamped + yc * CONVOLVED_WIDTH; // Index of the convolution
// Get input and output address
__global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + xi * src_stride_x + yi * src_stride_y + ch * src_stride_z + batch * src_stride_w;
-
+#if defined(NUM_GROUPS)
+ __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + xo * dst_stride_x + zo * dst_stride_z + batch * dst_stride_w;
+#else // defined(NUM_GROUPS)
__global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + xo * dst_stride_x + batch * dst_stride_w;
+#endif // defined(NUM_GROUPS)
VEC_DATA_TYPE(DATA_TYPE, 4)
data = vload4(0, (__global DATA_TYPE *)input_ptr);
@@ -112,7 +122,11 @@
*(__global DATA_TYPE *)(output_ptr + yo.s3 * dst_stride_y) = data.s3;
#ifdef HAS_BIAS
- if(ch == (KERNEL_DEPTH - 1))
+#if defined(NUM_GROUPS)
+ if(xo == (SRC_DEPTH / NUM_GROUPS - 1))
+#else // defined(NUM_GROUPS)
+ if(ch == (SRC_DEPTH - 1))
+#endif // defined(NUM_GROUPS)
{
*((__global DATA_TYPE *)(output_ptr + yo.s0 * dst_stride_y) + 1) = 1.0f;
*((__global DATA_TYPE *)(output_ptr + yo.s1 * dst_stride_y) + 1) = 1.0f;
@@ -121,21 +135,24 @@
}
#endif // HAS_BIAS
}
-#endif // defined(CONVOLVED_WIDTH) && defined(STRIDE_Y) && defined(KERNEL_DEPTH)
+#endif // defined(CONVOLVED_WIDTH) && defined(STRIDE_Y) && defined(SRC_DEPTH)
-#if defined(CONVOLVED_WIDTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_DEPTH) && defined(PAD_LEFT) && defined(PAD_RIGHT) && defined(PAD_TOP) && defined(PAD_BOTTOM) && defined(PAD_VALUE)
-/** This kernel performs a reshaping of the input tensor to a tensor used to perform convolution using GEMM when the kernel size is 3x3
+#if defined(CONVOLVED_WIDTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(SRC_DEPTH) && defined(PAD_LEFT) && defined(PAD_RIGHT) && defined(PAD_TOP) && defined(PAD_BOTTOM) && defined(PAD_VALUE)
+#if defined(DILATION_X) && defined(DILATION_Y)
+/** This opencl kernel performs a generic im2col implementation when the data layout is NCHW
*
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
* @note The width and height of the input tensor must be passed at compile time using -DSRC_WIDTH and -DSRC_HEIGHT: e.g. -DSRC_WIDTH=128 and -DSRC_HEIGHT=128
* @note The width of output tensor after matrix multiplication must be passed at compile time using -DCONVOLVED_WIDTH: e.g. -DCONVOLVED_WIDTH=34
- * @note The kernel depth must be passed at compile time using -DKERNEL_DEPTH: e.g. -DKERNEL_DEPTH=3
+ * @note The kernel width, height and depth must be passed at compile time using -DKERNEL_WIDTH, -DKERNEL_HEIGHT and -DSRC_DEPTH: e.g. -DKERNEL_WIDTH=3, -DKERNEL_HEIGHT=3 and -DSRC_DEPTH=64
* @note The pad_left, pad_right, pad_top and pad_bottom must be passed at compile time using -DPAD_LEFT, -DPAD_RIGHT, -DPAD_TOP and -DPAD_BOTTOM: e.g. -DPAD_LEFT=1, -DPAD_RIGHT=2, -DPAD_TOP=3 and -DPAD_BOTTOM=2
* @note The zero value to store in case we load values out-of-bounds must be passed at compile time using -DPAD_VALUE: e.g. -DPAD_VALUE=0.0
* @note The stride along the X and Y directions must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1 and -DSTRIDE_Y=1
+ * @note The dilation_x and dilation_y must be passed at compile time using -DDILATION_X and -DDILATION_Y: e.g. -DDILATION_X=1, -DDILATION_Y=1
* @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
+ * @note In case grouping is performed, the number of groups must be passed at compile time using -DNUM_GROUPS: e.g. -DNUM_GROUPS=4
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QASYMM8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -148,33 +165,147 @@
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes).
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes).
*/
-__kernel void im2col3x3_dchw(
+__kernel void im2col_generic_nchw(
TENSOR3D_DECLARATION(src),
+#if defined(NUM_GROUPS)
+ TENSOR3D_DECLARATION(dst),
+#else // defined(NUM_GROUPS)
IMAGE_DECLARATION(dst),
+#endif // defined(NUM_GROUPS)
uint src_stride_w,
uint dst_stride_w)
{
- const int xc = get_global_id(0); // x coordinate in the convolved tensor
- const int yc = get_global_id(1); // y coordinate in the convolved tensor
- const int ch = get_global_id(2) % KERNEL_DEPTH; // input feature map
- const int batch = get_global_id(2) / KERNEL_DEPTH; // batch size
+ const int xc = get_global_id(0); // x coordinate in the convolved tensor
+ const int yc = get_global_id(1); // y coordinate in the convolved tensor
+ const int ch = get_global_id(2) % SRC_DEPTH; // input feature map
+ const int batch = get_global_id(2) / SRC_DEPTH; // batch size
// Calculate input indices
const int xi = xc * STRIDE_X - PAD_LEFT;
const int yi = yc * STRIDE_Y - PAD_TOP;
// Calculate output indices
- const int xo = ch * 9; // 3x3
+#if defined(NUM_GROUPS)
+ const int xo = (ch % (SRC_DEPTH / NUM_GROUPS)) * KERNEL_WIDTH * KERNEL_HEIGHT;
+ const int zo = ch / (SRC_DEPTH / NUM_GROUPS);
+#else // defined(NUM_GROUPS)
+ const int xo = ch * KERNEL_WIDTH * KERNEL_HEIGHT;
+#endif // defined(NUM_GROUPS)
+ const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
+
+ __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + ch * src_stride_z + batch * src_stride_w;
+#if defined(NUM_GROUPS)
+ __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y + zo * dst_stride_z + batch * dst_stride_w)) + xo;
+#else // defined(NUM_GROUPS)
+ __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y + batch * dst_stride_w)) + xo;
+#endif // defined(NUM_GROUPS)
+
+ // Linearize convolution elements
+ for(int yk = 0; yk < KERNEL_HEIGHT; ++yk)
+ {
+ int y = yi + yk * DILATION_Y;
+ for(int xk = 0; xk < KERNEL_WIDTH; ++xk, ++output_ptr)
+ {
+ int x = xi + xk * DILATION_X;
+#if PAD_LEFT == 0 && PAD_TOP == 0 && PAD_RIGHT == 0 && PAD_BOTTOM == 0
+ *output_ptr = *((__global DATA_TYPE *)(input_ptr + x * src_stride_x + y * src_stride_y));
+#else // PAD_LEFT == 0 && PAD_TOP == 0 && PAD_RIGHT == 0 && PAD_BOTTOM == 0
+ if(x < 0 || x >= SRC_WIDTH || y < 0 || y >= SRC_HEIGHT)
+ {
+ *output_ptr = PAD_VALUE;
+ }
+ else
+ {
+ *output_ptr = *((__global DATA_TYPE *)(input_ptr + x * src_stride_x + y * src_stride_y));
+ }
+#endif // PAD_LEFT == 0 && PAD_TOP == 0 && PAD_RIGHT == 0 && PAD_BOTTOM == 0
+ }
+ }
+
+#ifdef HAS_BIAS
+#if defined(NUM_GROUPS)
+ if((xo / (KERNEL_WIDTH * KERNEL_HEIGHT)) == (SRC_DEPTH / NUM_GROUPS - 1))
+#else // defined(NUM_GROUPS)
+ if(ch == (SRC_DEPTH - 1))
+#endif // defined(NUM_GROUPS)
+ {
+ *output_ptr = 1.0f;
+ }
+#endif // HAS_BIAS
+}
+#endif // defined(DILATION_X) && defined(DILATION_Y)
+
+/** This opencl kernel performs im2col when the kernel size is 3x3 and the data layout is NCHW
+ *
+ * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
+ * @note The width and height of the input tensor must be passed at compile time using -DSRC_WIDTH and -DSRC_HEIGHT: e.g. -DSRC_WIDTH=128 and -DSRC_HEIGHT=128
+ * @note The width of output tensor after matrix multiplication must be passed at compile time using -DCONVOLVED_WIDTH: e.g. -DCONVOLVED_WIDTH=34
+ * @note The number of input channels must be passed at compile time using -DSRC_DEPTH: e.g. -DSRC_DEPTH=3
+ * @note The pad_left, pad_right, pad_top and pad_bottom must be passed at compile time using -DPAD_LEFT, -DPAD_RIGHT, -DPAD_TOP and -DPAD_BOTTOM: e.g. -DPAD_LEFT=1, -DPAD_RIGHT=2, -DPAD_TOP=3 and -DPAD_BOTTOM=2
+ * @note The zero value to store in case we load values out-of-bounds must be passed at compile time using -DPAD_VALUE: e.g. -DPAD_VALUE=0.0
+ * @note The stride along the X and Y directions must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1 and -DSTRIDE_Y=1
+ * @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes).
+ * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes).
+ */
+__kernel void im2col3x3_nchw(
+ TENSOR3D_DECLARATION(src),
+#if defined(NUM_GROUPS)
+ TENSOR3D_DECLARATION(dst),
+#else // defined(NUM_GROUPS)
+ IMAGE_DECLARATION(dst),
+#endif // defined(NUM_GROUPS)
+ uint src_stride_w,
+ uint dst_stride_w)
+{
+ const int xc = get_global_id(0); // x coordinate in the convolved tensor
+ const int yc = get_global_id(1); // y coordinate in the convolved tensor
+ const int ch = get_global_id(2) % SRC_DEPTH; // input feature map
+ const int batch = get_global_id(2) / SRC_DEPTH; // batch size
+
+ // Calculate input indices
+ const int xi = xc * STRIDE_X - PAD_LEFT;
+ const int yi = yc * STRIDE_Y - PAD_TOP;
+
+ // Calculate output indices
+#if defined(NUM_GROUPS)
+ const int xo = (ch % (SRC_DEPTH / NUM_GROUPS)) * 9; // 3x3
+ const int zo = ch / (SRC_DEPTH / NUM_GROUPS);
+#else // defined(NUM_GROUPS)
+ const int xo = ch * 9; // 3x3
+#endif // defined(NUM_GROUPS)
const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
// Get input and output address
__global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + xi * (int)src_stride_x + yi * (int)src_stride_y + ch * src_stride_z + batch * src_stride_w;
-
+#if defined(NUM_GROUPS)
+ __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + xo * dst_stride_x + yo * dst_stride_y + zo * dst_stride_z + batch * dst_stride_w;
+#else // defined(NUM_GROUPS)
__global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + xo * dst_stride_x + yo * dst_stride_y + batch * dst_stride_w;
+#endif // defined(NUM_GROUPS)
VEC_DATA_TYPE(DATA_TYPE, 3)
row0 = vload3(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
@@ -204,25 +335,30 @@
*((__global DATA_TYPE *)output_ptr + 8) = row2.s2;
#ifdef HAS_BIAS
- if(ch == (KERNEL_DEPTH - 1))
+#if defined(NUM_GROUPS)
+ if((xo / 9) == (SRC_DEPTH / NUM_GROUPS - 1))
+#else // defined(NUM_GROUPS)
+ if(ch == (SRC_DEPTH - 1))
+#endif // defined(NUM_GROUPS)
{
*((__global DATA_TYPE *)output_ptr + 9) = 1.0f;
}
#endif // HAS_BIAS
}
-/** This kernel performs a reshaping of the input tensor to a tensor used to perform convolution using GEMM when the kernel size is 5x5
+/** This opencl kernel performs im2col when the kernel size is 5x5 and the data layout is NCHW
*
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
* @note The width and height of the input tensor must be passed at compile time using -DSRC_WIDTH and -DSRC_HEIGHT: e.g. -DSRC_WIDTH=128 and -DSRC_HEIGHT=128
* @note The width of output tensor after matrix multiplication must be passed at compile time using -DCONVOLVED_WIDTH: e.g. -DCONVOLVED_WIDTH=34
- * @note The kernel depth must be passed at compile time using -DKERNEL_DEPTH: e.g. -DKERNEL_DEPTH=3
+ * @note The number of input channels must be passed at compile time using -DSRC_DEPTH: e.g. -DSRC_DEPTH=3
* @note The pad_left, pad_right, pad_top and pad_bottom must be passed at compile time using -DPAD_LEFT, -DPAD_RIGHT, -DPAD_TOP and -DPAD_BOTTOM: e.g. -DPAD_LEFT=1, -DPAD_RIGHT=2, -DPAD_TOP=3 and -DPAD_BOTTOM=2
* @note The zero value to store in case we load values out-of-bounds must be passed at compile time using -DPAD_VALUE: e.g. -DPAD_VALUE=0.0
* @note The stride along the X and Y directions must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1 and -DSTRIDE_Y=1
* @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
+ * @note In case grouping is performed, the number of groups must be passed at compile time using -DNUM_GROUPS: e.g. -DNUM_GROUPS=4
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QASYMM8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -235,27 +371,38 @@
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes).
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes).
*/
-__kernel void im2col5x5_dchw(
+__kernel void im2col5x5_nchw(
TENSOR3D_DECLARATION(src),
+#if defined(NUM_GROUPS)
+ TENSOR3D_DECLARATION(dst),
+#else // defined(NUM_GROUPS)
IMAGE_DECLARATION(dst),
+#endif // defined(NUM_GROUPS)
uint src_stride_w,
uint dst_stride_w)
{
- const int xc = get_global_id(0); // x coordinate in the convolved tensor
- const int yc = get_global_id(1); // y coordinate in the convolved tensor
- const int ch = get_global_id(2) % KERNEL_DEPTH; // input feature map
- const int batch = get_global_id(2) / KERNEL_DEPTH; // batch size
+ const int xc = get_global_id(0); // x coordinate in the convolved tensor
+ const int yc = get_global_id(1); // y coordinate in the convolved tensor
+ const int ch = get_global_id(2) % SRC_DEPTH; // input feature map
+ const int batch = get_global_id(2) / SRC_DEPTH; // batch size
// Calculate input indices
const int xi = xc * STRIDE_X - PAD_LEFT;
const int yi = yc * STRIDE_Y - PAD_TOP;
// Calculate output indices
- const int xo = ch * 25; // 5x5
+#if defined(NUM_GROUPS)
+ const int xo = (ch % (SRC_DEPTH / NUM_GROUPS)) * 25; // 5x5
+ const int zo = ch / (SRC_DEPTH / NUM_GROUPS);
+#else // defined(NUM_GROUPS)
+ const int xo = ch * 25; // 5x5
+#endif // defined(NUM_GROUPS)
const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
#if PAD_LEFT != 0 || PAD_TOP != 0 || PAD_RIGHT != 0 || PAD_BOTTOM != 0
@@ -276,8 +423,11 @@
// Get input and output address
__global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + xi * (int)src_stride_x + yi * (int)src_stride_y + ch * src_stride_z + batch * src_stride_w;
-
+#if defined(NUM_GROUPS)
+ __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + xo * dst_stride_x + yo * dst_stride_y + zo * dst_stride_z + batch * dst_stride_w;
+#else // defined(NUM_GROUPS)
__global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + xo * dst_stride_x + yo * dst_stride_y + batch * dst_stride_w;
+#endif // defined(NUM_GROUPS)
{
VEC_DATA_TYPE(DATA_TYPE, 4)
@@ -378,24 +528,29 @@
}
#ifdef HAS_BIAS
- if(ch == (KERNEL_DEPTH - 1))
+#if defined(NUM_GROUPS)
+ if((xo / 25) == (SRC_DEPTH / NUM_GROUPS - 1))
+#else // defined(NUM_GROUPS)
+ if(ch == (SRC_DEPTH - 1))
+#endif // defined(NUM_GROUPS)
{
*((__global DATA_TYPE *)output_ptr) = 1.0f;
}
#endif // HAS_BIAS
}
-#endif // defined(CONVOLVED_WIDTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_DEPTH) && defined(PAD_LEFT) && defined(PAD_RIGHT) && defined(PAD_TOP) && defined(PAD_BOTTOM) && defined(PAD_VALUE)
+#endif // defined(CONVOLVED_WIDTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(SRC_DEPTH) && defined(PAD_LEFT) && defined(PAD_RIGHT) && defined(PAD_TOP) && defined(PAD_BOTTOM) && defined(PAD_VALUE)
-#if defined(CONVOLVED_WIDTH) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_DEPTH)
-/** This kernel performs a reshaping of the input tensor to a tensor used to perform convolution using GEMM when the kernel size is 11x11
+#if defined(CONVOLVED_WIDTH) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(SRC_DEPTH)
+/** This opencl kernel performs im2col when the kernel size is 11x11, we do not have paddings and the data layout is NCHW
*
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
* @note The width of output tensor after matrix multiplication must be passed at compile time using -DCONVOLVED_WIDTH: e.g. -DCONVOLVED_WIDTH=34
- * @note The kernel depth must be passed at compile time using -DKERNEL_DEPTH: e.g. -DKERNEL_DEPTH=3
+ * @note The number of input channels must be passed at compile time using -DSRC_DEPTH: e.g. -DSRC_DEPTH=3
* @note The stride along the X and Y directions must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1 and -DSTRIDE_Y=1
* @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
+ * @note In case grouping is performed, the number of groups must be passed at compile time using -DNUM_GROUPS: e.g. -DNUM_GROUPS=4
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QASYMM8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -408,33 +563,48 @@
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes).
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes).
*/
-__kernel void im2col11x11_padx0_pady0_dchw(
+__kernel void im2col11x11_padx0_pady0_nchw(
TENSOR3D_DECLARATION(src),
+#if defined(NUM_GROUPS)
+ TENSOR3D_DECLARATION(dst),
+#else // defined(NUM_GROUPS)
IMAGE_DECLARATION(dst),
+#endif // defined(NUM_GROUPS)
uint src_stride_w,
uint dst_stride_w)
{
- const int xc = get_global_id(0); // x coordinate in the convolved tensor
- const int yc = get_global_id(1); // y coordinate in the convolved tensor
- const int ch = get_global_id(2) % KERNEL_DEPTH; // input feature map
- const int batch = get_global_id(2) / KERNEL_DEPTH; // batch size
+ const int xc = get_global_id(0); // x coordinate in the convolved tensor
+ const int yc = get_global_id(1); // y coordinate in the convolved tensor
+ const int ch = get_global_id(2) % SRC_DEPTH; // input feature map
+ const int batch = get_global_id(2) / SRC_DEPTH; // batch size
// Calculate input indices
const int xi = xc * STRIDE_X;
const int yi = yc * STRIDE_Y;
// Calculate output indices
- const int xo = ch * 121; // 11x11
+#if defined(NUM_GROUPS)
+ const int xo = (ch % (SRC_DEPTH / NUM_GROUPS)) * 121; // 11x11
+ const int zo = ch / (SRC_DEPTH / NUM_GROUPS);
+#else // defined(NUM_GROUPS)
+ const int xo = ch * 121; // 11x11
+#endif // defined(NUM_GROUPS)
const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
// Get input and output address
__global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + xi * src_stride_x + yi * src_stride_y + ch * src_stride_z + batch * src_stride_w;
-
+#if defined(NUM_GROUPS)
+ __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + xo * dst_stride_x + yo * dst_stride_y + zo * dst_stride_z + batch * dst_stride_w;
+#else // defined(NUM_GROUPS)
__global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + xo * dst_stride_x + yo * dst_stride_y + batch * dst_stride_w;
+#endif // defined(NUM_GROUPS)
+
{
VEC_DATA_TYPE(DATA_TYPE, 8)
row00 = vload8(0, (__global DATA_TYPE *)(input_ptr));
@@ -578,25 +748,29 @@
}
#ifdef HAS_BIAS
- if(ch == (KERNEL_DEPTH - 1))
+#if defined(NUM_GROUPS)
+ if((xo / 121) == (SRC_DEPTH / NUM_GROUPS - 1))
+#else // defined(NUM_GROUPS)
+ if(ch == (SRC_DEPTH - 1))
+#endif // defined(NUM_GROUPS)
{
*((__global DATA_TYPE *)output_ptr) = 1.0f;
}
#endif // HAS_BIAS
}
-#endif // defined(CONVOLVED_WIDTH) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_DEPTH)
-#endif // !defined(FIXED_POINT_POSITION)
+#endif // defined(CONVOLVED_WIDTH) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(SRC_DEPTH)
-#if defined(CONVOLVED_WIDTH) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(KERNEL_DEPTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(VECTOR_SIZE) && defined(WIDTH_MOD_VECTOR_SIZE)
-/** This kernel reshapes the input tensor to a tensor used to perform convolution using GEMM when
- * the kernel width is greater than 1 (except when the kernel size is 3x3) and pad_x == pad_y == 0.
+#if defined(CONVOLVED_WIDTH) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(SRC_DEPTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(VECTOR_SIZE) && defined(WIDTH_MOD_VECTOR_SIZE)
+/** This opencl kernel performs im2col when the kernel size is greater than 1x1, we do not have paddings and the data layout is NCHW
*
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float.
* @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=4.
* @note The width modulo vector size must be passed at compile time using -DWIDTH_MOD_VECTOR_SIZE e.g. -DWIDTH_MOD_VECTOR_SIZE=3.
+ * @note The stride along the X and Y directions must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1 and -DSTRIDE_Y=1
* @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
+ * @note In case grouping is performed, the number of groups must be passed at compile time using -DNUM_GROUPS: e.g. -DNUM_GROUPS=4
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -609,29 +783,47 @@
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes).
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes).
*/
-__kernel void im2col_generic_padx0_pady0_dchw(
+__kernel void im2col_generic_padx0_pady0_nchw(
TENSOR3D_DECLARATION(src),
+#if defined(NUM_GROUPS)
+ TENSOR3D_DECLARATION(dst),
+#else // defined(NUM_GROUPS)
IMAGE_DECLARATION(dst),
+#endif // defined(NUM_GROUPS)
uint src_stride_w,
uint dst_stride_w)
{
- const int xc = get_global_id(0); // x coordinate in the convolved tensor
- const int yc = get_global_id(1); // y coordinate in the convolved tensor
- const int ch = get_global_id(2) % KERNEL_DEPTH; // input feature map
- const int batch = get_global_id(2) / KERNEL_DEPTH; // batch size
+ const int xc = get_global_id(0); // x coordinate in the convolved tensor
+ const int yc = get_global_id(1); // y coordinate in the convolved tensor
+ const int ch = get_global_id(2) % SRC_DEPTH; // input feature map
+ const int batch = get_global_id(2) / SRC_DEPTH; // batch size
// Calculate input indices
const int xi = xc * STRIDE_X;
const int yi = yc * STRIDE_Y;
+
// Calculate output indices
+#if defined(NUM_GROUPS)
+ const int xo = (ch % (SRC_DEPTH / NUM_GROUPS)) * KERNEL_WIDTH * KERNEL_HEIGHT;
+ const int zo = ch / (SRC_DEPTH / NUM_GROUPS);
+#else // defined(NUM_GROUPS)
const int xo = ch * KERNEL_WIDTH * KERNEL_HEIGHT;
- const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
- __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + ch * src_stride_z + batch * src_stride_w;
+#endif // defined(NUM_GROUPS)
+ const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
+
+ __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + ch * src_stride_z + batch * src_stride_w;
+#if defined(NUM_GROUPS)
+ __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y + zo * dst_stride_z + batch * dst_stride_w)) + xo;
+#else // defined(NUM_GROUPS)
__global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y + batch * dst_stride_w)) + xo;
+#endif // defined(NUM_GROUPS)
+
// Linearize convolution elements
for(int y = yi, y_e = yi + KERNEL_HEIGHT; y < y_e; ++y)
{
@@ -658,32 +850,32 @@
} /* End of loop over KERNEL_HEIGHT */
#ifdef HAS_BIAS
- if(ch == (KERNEL_DEPTH - 1))
+#if defined(NUM_GROUPS)
+ if((xo / (KERNEL_WIDTH * KERNEL_HEIGHT)) == (SRC_DEPTH / NUM_GROUPS - 1))
+#else // defined(NUM_GROUPS)
+ if(ch == (SRC_DEPTH - 1))
+#endif // defined(NUM_GROUPS)
{
-#ifdef FIXED_POINT_POSITION
- *output_ptr = (DATA_TYPE)(1 << FIXED_POINT_POSITION);
-#else // FIXED_POINT_POSITION
*output_ptr = 1.0f;
-#endif // FIXED_POINT_POSITION
}
#endif // HAS_BIAS
}
-#endif //defined(CONVOLVED_WIDTH) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(PAD_RIGHT) && defined(PAD_BOTTOM) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(KERNEL_DEPTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(VECTOR_SIZE) && defined(WIDTH_MOD_VECTOR_SIZE)
+#endif //defined(CONVOLVED_WIDTH) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(PAD_RIGHT) && defined(PAD_BOTTOM) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(SRC_DEPTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(VECTOR_SIZE) && defined(WIDTH_MOD_VECTOR_SIZE)
-#if defined(CONVOLVED_WIDTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(KERNEL_DEPTH) && defined(PAD_LEFT) && defined(PAD_RIGHT) && defined(PAD_TOP) && defined(PAD_BOTTOM) && defined(PAD_VALUE)
-/** This kernel performs a reshaping of the input tensor to a tensor used to perform convolution using GEMM.
+#if defined(CONVOLVED_WIDTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(SRC_DEPTH) && defined(PAD_LEFT) && defined(PAD_RIGHT) && defined(PAD_TOP) && defined(PAD_BOTTOM) && defined(PAD_VALUE) && defined(VECTOR_SIZE) && defined(LAST_ACCESSED)
+
+#define VECTOR_N VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
+
+/** This kernel performs im2col when the kernel size is 3x3 and the data layout is NHWC
*
+ * @note This kernel computes VECTOR_SIZE elements
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
- * @note The width and height of the input tensor must be passed at compile time using -DSRC_WIDTH and -DSRC_HEIGHT: e.g. -DSRC_WIDTH=128 and -DSRC_HEIGHT=128
* @note The width of output tensor after matrix multiplication must be passed at compile time using -DCONVOLVED_WIDTH: e.g. -DCONVOLVED_WIDTH=34
- * @note The kernel width, height and depth must be passed at compile time using -DKERNEL_WIDTH, -DKERNEL_HEIGHT and -DKERNEL_DEPTH: e.g. -DKERNEL_WIDTH=3, -DKERNEL_HEIGHT=3 and -DKERNEL_DEPTH=64
- * @note The pad_left, pad_right, pad_top and pad_bottom must be passed at compile time using -DPAD_LEFT, -DPAD_RIGHT, -DPAD_TOP and -DPAD_BOTTOM: e.g. -DPAD_LEFT=1, -DPAD_RIGHT=2, -DPAD_TOP=3 and -DPAD_BOTTOM=2
- * @note The zero value to store in case we load values out-of-bounds must be passed at compile time using -DPAD_VALUE: e.g. -DPAD_VALUE=0.0
- * @note The stride along the X and Y directions must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1 and -DSTRIDE_Y=1
- * @note The dilation_x and dilation_y must be passed at compile time using -DDILATION_X and -DDILATION_Y: e.g. -DDILATION_X=1, -DDILATION_Y=1
+ * @note The kernel depth must be passed at compile time using -DSRC_DEPTH: e.g. -DSRC_DEPTH=3
+ * @note The stride along the Y direction must be passed at compile time using -DSTRIDE_Y: e.g. -DSTRIDE_Y=1
* @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QASYMM8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -700,108 +892,230 @@
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes).
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes).
*/
-__kernel void im2col_generic_dchw(
+__kernel void im2col3x3_nhwc(
TENSOR3D_DECLARATION(src),
IMAGE_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
- const int xc = get_global_id(0); // x coordinate in the convolved tensor
- const int yc = get_global_id(1); // y coordinate in the convolved tensor
- const int ch = get_global_id(2) % KERNEL_DEPTH; // input feature map
- const int batch = get_global_id(2) / KERNEL_DEPTH; // batch size
+ const int ch = min((int)(get_global_id(0) * VECTOR_SIZE), LAST_ACCESSED); // input feature map
+ const int yo = get_global_id(1);
+ const int batch = get_global_id(2); // batch size
// Calculate input indices
- const int xi = xc * STRIDE_X - PAD_LEFT;
- const int yi = yc * STRIDE_Y - PAD_TOP;
+ const int xi = (get_global_id(1) % CONVOLVED_WIDTH) * STRIDE_X;
+ const int yi = (get_global_id(1) / (int)CONVOLVED_WIDTH) * STRIDE_Y;
- // Calculate output indices
- const int xo = ch * KERNEL_WIDTH * KERNEL_HEIGHT;
- const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
+ // Get input and output address
+ __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + ch * sizeof(DATA_TYPE) + batch * (int)src_stride_w;
+ __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + ch * sizeof(DATA_TYPE) + yo * (int)dst_stride_y + batch * (int)dst_stride_w;
- __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + ch * src_stride_z + batch * src_stride_w;
- __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y + batch * dst_stride_w)) + xo;
+ int yi_coord = 0;
+ int3 offset = 0;
- // Linearize convolution elements
- for(int yk = 0; yk < KERNEL_HEIGHT; ++yk)
- {
- int y = yi + yk * DILATION_Y;
- for(int xk = 0; xk < KERNEL_WIDTH; ++xk, ++output_ptr)
- {
- int x = xi + xk * DILATION_X;
-#if PAD_LEFT == 0 && PAD_TOP == 0 && PAD_RIGHT == 0 && PAD_BOTTOM == 0
- *output_ptr = *((__global DATA_TYPE *)(input_ptr + x * src_stride_x + y * src_stride_y));
-#else // PAD_LEFT == 0 && PAD_TOP == 0 && PAD_RIGHT == 0 && PAD_BOTTOM == 0
- if(x < 0 || x >= SRC_WIDTH || y < 0 || y >= SRC_HEIGHT)
- {
- *output_ptr = PAD_VALUE;
- }
- else
- {
- *output_ptr = *((__global DATA_TYPE *)(input_ptr + x * src_stride_x + y * src_stride_y));
- }
-#endif // PAD_LEFT == 0 && PAD_TOP == 0 && PAD_RIGHT == 0 && PAD_BOTTOM == 0
- }
- }
+ // Clamp xi
+ int3 xi_offset = ((int3)xi + (int3)(0, 1, 2) * DILATION_X - (int3)PAD_LEFT);
+#if PAD_TOP != 0 || PAD_BOTTOM != 0
+#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val)
+ xi_offset = CLAMP(xi_offset, (int3)0, (int3)(SRC_WIDTH - 1));
+#endif // PAD_TOP != 0 || PAD_BOTTOM != 0
+ xi_offset *= (int3)src_stride_y;
+
+ // Out-of-bound condition for X
+ int3 x_cond = (((int3)xi + (int3)(0, 1, 2) * DILATION_X - (int3)PAD_LEFT) < (int3)0) || (((int3)xi + (int3)(0, 1, 2) * DILATION_X - (int3)PAD_LEFT) >= (int3)SRC_WIDTH);
+
+ // yi == 0
+ // Clamp yi
+ // yi_coord is casted to unsigned int in order to use just a min() operation
+ // A "-1" 32 bit signed variable converted to unsigned gives 4294967295
+ yi_coord = yi - (int)PAD_TOP;
+
+ // Clamp only if PAD_TOP or PAD_BOTTOM is not equal to 0
+#if PAD_TOP != 0 || PAD_BOTTOM != 0
+ yi_coord = min((uint)yi_coord, (uint)(SRC_HEIGHT - 1));
+#endif // PAD_TOP != 0 || PAD_BOTTOM != 0
+
+ // Compute offset
+ offset = xi_offset + (yi_coord * (int)src_stride_z);
+
+ // Load input values
+ VECTOR_N values0 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset.s0));
+ VECTOR_N values1 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset.s1));
+ VECTOR_N values2 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset.s2));
+
+#if PAD_TOP != 0 || PAD_LEFT != 0 || PAD_BOTTOM != 0 || PAD_RIGHT != 0
+ // Replace invalid values with PAD_VALUE
+ int y_cond = (int)((uint)(yi - (int)PAD_TOP) >= (uint)(SRC_HEIGHT));
+ values0 = select(values0, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))y_cond || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(x_cond.s0));
+ values1 = select(values1, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))y_cond || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(x_cond.s1));
+ values2 = select(values2, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))y_cond || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(x_cond.s2));
+#endif // PAD_TOP != 0 || PAD_LEFT != 0 || PAD_BOTTOM != 0 || PAD_RIGHT != 0
+
+ // yi == 1
+ // Clamp yi_coord (it can be negative if PAD_TOP > 1)
+ yi_coord = yi - (int)PAD_TOP + 1 * DILATION_Y;
+
+ // Clamp only if PAD_TOP or PAD_BOTTOM is not equal to 0
+#if PAD_TOP != 0 || PAD_BOTTOM != 0
+ yi_coord = min((uint)yi_coord, (uint)(SRC_HEIGHT - 1));
+#endif // PAD_TOP != 0 || PAD_BOTTOM != 0
+
+ // Compute offset
+ offset = xi_offset + (yi_coord * (int)src_stride_z);
+
+ // Load input values
+ VECTOR_N values3 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset.s0));
+ VECTOR_N values4 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset.s1));
+ VECTOR_N values5 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset.s2));
+
+#if PAD_TOP != 0 || PAD_LEFT != 0 || PAD_BOTTOM != 0 || PAD_RIGHT != 0
+ // Replace invalid values with zeros
+ y_cond = (int)((uint)(yi - (int)PAD_TOP + 1 * DILATION_Y) >= (uint)(SRC_HEIGHT));
+ values3 = select(values3, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))y_cond || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(x_cond.s0));
+ values4 = select(values4, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))y_cond || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(x_cond.s1));
+ values5 = select(values5, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))y_cond || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(x_cond.s2));
+#endif // PAD_TOP != 0 || PAD_LEFT != 0 || PAD_BOTTOM != 0 || PAD_RIGHT != 0
+
+ // yi == 2
+ // Clamp yi_coord
+ yi_coord = yi - (int)PAD_TOP + 2 * DILATION_Y;
+
+ // Clamp only if PAD_TOP or PAD_BOTTOM is not equal to 0
+#if PAD_TOP != 0 || PAD_BOTTOM != 0
+ yi_coord = min((uint)yi_coord, (uint)(SRC_HEIGHT - 1));
+#endif // PAD_TOP != 0 || PAD_BOTTOM != 0
+
+ // Compute offset
+ offset = xi_offset + (yi_coord * (int)src_stride_z);
+
+ // Load input values
+ VECTOR_N values6 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset.s0));
+ VECTOR_N values7 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset.s1));
+ VECTOR_N values8 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset.s2));
+
+#if PAD_TOP != 0 || PAD_LEFT != 0 || PAD_BOTTOM != 0 || PAD_RIGHT != 0
+ // Replace invalid values with PAD_VALUE
+ y_cond = (int)((uint)(yi - (int)PAD_TOP + 2 * DILATION_Y) >= (uint)(SRC_HEIGHT));
+ values6 = select(values6, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))y_cond || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(x_cond.s0));
+ values7 = select(values7, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))y_cond || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(x_cond.s1));
+ values8 = select(values8, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))y_cond || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(x_cond.s2));
+#endif // PAD_TOP != 0 || PAD_LEFT != 0 || PAD_BOTTOM != 0 || PAD_RIGHT != 0
+
+ // Store
+ VSTORE(VECTOR_SIZE)
+ (values0, 0, (__global DATA_TYPE *)(output_ptr) + 0 * SRC_DEPTH);
+ VSTORE(VECTOR_SIZE)
+ (values1, 0, (__global DATA_TYPE *)(output_ptr) + 1 * SRC_DEPTH);
+ VSTORE(VECTOR_SIZE)
+ (values2, 0, (__global DATA_TYPE *)(output_ptr) + 2 * SRC_DEPTH);
+ VSTORE(VECTOR_SIZE)
+ (values3, 0, (__global DATA_TYPE *)(output_ptr) + 3 * SRC_DEPTH);
+ VSTORE(VECTOR_SIZE)
+ (values4, 0, (__global DATA_TYPE *)(output_ptr) + 4 * SRC_DEPTH);
+ VSTORE(VECTOR_SIZE)
+ (values5, 0, (__global DATA_TYPE *)(output_ptr) + 5 * SRC_DEPTH);
+ VSTORE(VECTOR_SIZE)
+ (values6, 0, (__global DATA_TYPE *)(output_ptr) + 6 * SRC_DEPTH);
+ VSTORE(VECTOR_SIZE)
+ (values7, 0, (__global DATA_TYPE *)(output_ptr) + 7 * SRC_DEPTH);
+ VSTORE(VECTOR_SIZE)
+ (values8, 0, (__global DATA_TYPE *)(output_ptr) + 8 * SRC_DEPTH);
#ifdef HAS_BIAS
- if(ch == (KERNEL_DEPTH - 1))
+ if((ch + VECTOR_SIZE) >= SRC_DEPTH)
{
-#ifdef FIXED_POINT_POSITION
- *output_ptr = (DATA_TYPE)(1 << FIXED_POINT_POSITION);
-#else // FIXED_POINT_POSITION
- *output_ptr = 1.0f;
-#endif // FIXED_POINT_POSITION
+ *((__global DATA_TYPE *)(output_ptr) - ch + SRC_DEPTH * 9) = 1.0f;
}
#endif // HAS_BIAS
}
-#endif // defined(CONVOLVED_WIDTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(KERNEL_DEPTH) && defined(PAD_LEFT) && defined(PAD_RIGHT) && defined(PAD_TOP) && defined(PAD_BOTTOM) && defined(PAD_VALUE)
-/**This kernel reshapes the input tensor to a tensor used to perform convolution using GEMM when
- * the kernel width and height are the same of width and height of the input tensor
+/** This opencl kernel performs a generic im2col implementation when the data layout is NHWC
*
- * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=float
- * @note In case biases will be added in late stage, -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
+ * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
+ * @note The width and height of the input tensor must be passed at compile time using -DSRC_WIDTH and -DSRC_HEIGHT: e.g. -DSRC_WIDTH=128 and -DSRC_HEIGHT=128
+ * @note The width of output tensor after matrix multiplication must be passed at compile time using -DCONVOLVED_WIDTH: e.g. -DCONVOLVED_WIDTH=34
+ * @note The kernel width, height and depth must be passed at compile time using -DKERNEL_WIDTH, -DKERNEL_HEIGHT and -DSRC_DEPTH: e.g. -DKERNEL_WIDTH=3, -DKERNEL_HEIGHT=3 and -DSRC_DEPTH=64
+ * @note The pad_left, pad_right, pad_top and pad_bottom must be passed at compile time using -DPAD_LEFT, -DPAD_RIGHT, -DPAD_TOP and -DPAD_BOTTOM: e.g. -DPAD_LEFT=1, -DPAD_RIGHT=2, -DPAD_TOP=3 and -DPAD_BOTTOM=2
+ * @note The zero value to store in case we load values out-of-bounds must be passed at compile time using -DPAD_VALUE: e.g. -DPAD_VALUE=0.0
+ * @note The stride along the X and Y directions must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1 and -DSTRIDE_Y=1
+ * @note The dilation_x and dilation_y must be passed at compile time using -DDILATION_X and -DDILATION_Y: e.g. -DDILATION_X=1, -DDILATION_Y=1
+ * @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/QASYMM8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Same as @p src_ptr
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- * @param[in] width The width of the input tensor
- * @param[in] height The height of the input tensor
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes).
+ * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes).
*/
-__kernel void im2col_reduced_dchw(
+__kernel void im2col_generic_nhwc(
TENSOR3D_DECLARATION(src),
- VECTOR_DECLARATION(dst),
- uint width, uint height)
+ IMAGE_DECLARATION(dst),
+ uint src_stride_w,
+ uint dst_stride_w)
{
- Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ const int ch = min((int)(get_global_id(0) * VECTOR_SIZE), LAST_ACCESSED); // input feature map
+ const int yo = get_global_id(1);
+ const int batch = get_global_id(2); // batch size
- const uint image_size = width * height;
+ // Calculate input indices
+ const int xi = (get_global_id(1) % CONVOLVED_WIDTH) * STRIDE_X;
+ const int yi = (get_global_id(1) / (int)CONVOLVED_WIDTH) * STRIDE_Y;
- __global uchar *tmp_out_ptr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) + get_global_id(1) * width + get_global_id(2) * image_size) * dst_stride_x;
+ // Get input and output address
+ __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + ch * sizeof(DATA_TYPE) + batch * (int)src_stride_w;
+ __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + ch * sizeof(DATA_TYPE) + yo * (int)dst_stride_y + batch * (int)dst_stride_w;
- *((__global DATA_TYPE *)tmp_out_ptr) = *((__global DATA_TYPE *)src.ptr);
+ int i = 0;
+ for(int yk = 0; yk < KERNEL_HEIGHT; ++yk)
+ {
+ // Clamp yi_coord
+ int yi_coord = yi + yk * DILATION_Y - (int)PAD_TOP;
+ yi_coord = CLAMP(yi_coord, (int)0, (int)(SRC_HEIGHT - 1));
+
+ // Out-of-bound condition for Y
+ int y_border_condition = ((yi + yk * DILATION_Y - (int)PAD_TOP) < (int)0) || ((yi + yk * DILATION_Y - (int)PAD_TOP) >= (int)SRC_HEIGHT);
+
+ for(int xk = 0; xk < KERNEL_WIDTH; ++xk)
+ {
+ // Clamp xi_coord
+ int xi_coord = (xi + xk * DILATION_X - (int)PAD_LEFT);
+ xi_coord = CLAMP(xi_coord, (int)0, (int)(SRC_WIDTH - 1));
+
+ // Out-of-bound condition for X
+ int x_border_condition = ((xi + xk * DILATION_X - (int)PAD_LEFT) < (int)0) || ((xi + xk * DILATION_X - (int)PAD_LEFT) >= (int)SRC_WIDTH);
+
+ int offset = xi_coord * (int)src_stride_y + (yi_coord * (int)src_stride_z);
+
+ VECTOR_N values0 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(input_ptr + offset));
+
+ // Replace with PAD_VALUE if the value is out-of-bound
+ values0 = select(values0, (VECTOR_N)PAD_VALUE, (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))x_border_condition || (VEC_DATA_TYPE(COND_DATA_TYPE, VECTOR_SIZE))(y_border_condition));
+
+ // Store
+ VSTORE(VECTOR_SIZE)
+ (values0, 0, (__global DATA_TYPE *)(output_ptr) + i * (int)SRC_DEPTH);
+
+ i++;
+ }
+ }
#ifdef HAS_BIAS
- // If it is the last thread in the 3 dimensional workgroup
- if(get_global_id(0) == (get_global_size(0) - 1) && get_global_id(1) == (get_global_size(1) - 1) && get_global_id(2) == (get_global_size(2) - 1))
+ if((ch + VECTOR_SIZE) >= SRC_DEPTH)
{
- tmp_out_ptr += dst_stride_x;
-#ifdef FIXED_POINT_POSITION
- *((__global DATA_TYPE *)tmp_out_ptr) = (DATA_TYPE)(1 << FIXED_POINT_POSITION);
-#else // FIXED_POINT_POSITION
- *((__global DATA_TYPE *)tmp_out_ptr) = (DATA_TYPE)1.0f;
-#endif // FIXED_POINT_POSITION
+ *((__global DATA_TYPE *)(output_ptr) - ch + SRC_DEPTH * KERNEL_WIDTH * KERNEL_HEIGHT) = 1.0f;
}
#endif // HAS_BIAS
}
-#endif // defined(DATA_TYPE) && defined(ELEMENT_SIZE)
\ No newline at end of file
+#endif // defined(CONVOLVED_WIDTH) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(STRIDE_X) && defined(STRIDE_Y) && defined(KERNEL_WIDTH) && defined(KERNEL_HEIGHT) && defined(SRC_DEPTH) && defined(PAD_LEFT) && defined(PAD_RIGHT) && defined(PAD_TOP) && defined(PAD_BOTTOM) && defined(PAD_VALUE) && defined(VECTOR_SIZE) && defined(LAST_ACCESSED)
+#endif // defined(DATA_TYPE) && defined(ELEMENT_SIZE)
diff --git a/src/core/CL/cl_kernels/l2_normalize.cl b/src/core/CL/cl_kernels/l2_normalize.cl
index 8d47631..f58e98b 100644
--- a/src/core/CL/cl_kernels/l2_normalize.cl
+++ b/src/core/CL/cl_kernels/l2_normalize.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,11 +28,11 @@
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
* @note The data size must be passed at compile time using -DDATA_SIZE e.g. -DDATA_SIZE=32
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: QS8/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[in] sum_ptr Pointer to the source tensor. Supported data types: QS8/F16/F32
+ * @param[in] sum_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] sum_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] sum_step_x sum_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] sum_offset_first_element_in_bytes The offset of the first element in the source tensor
diff --git a/src/core/CL/cl_kernels/mean_stddev.cl b/src/core/CL/cl_kernels/mean_stddev.cl
index 7c29d2f..74d6b0b 100644
--- a/src/core/CL/cl_kernels/mean_stddev.cl
+++ b/src/core/CL/cl_kernels/mean_stddev.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,7 +24,6 @@
#include "helpers.h"
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
-#pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable
/** This function calculates the sum and sum of squares of a given input image.
*
@@ -81,4 +80,3 @@
}
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : disable
-#pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : disable
diff --git a/src/core/CL/cl_kernels/normalization_layer.cl b/src/core/CL/cl_kernels/normalization_layer.cl
index bc00252..dbdad27 100644
--- a/src/core/CL/cl_kernels/normalization_layer.cl
+++ b/src/core/CL/cl_kernels/normalization_layer.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,22 +23,6 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-
-#include "fixed_point.h"
-#define MUL_OP(x, y) MUL_SAT_OP_EXPAND((x), (y), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define ADD_OP(x, y) ADD_SAT_OP_EXPAND((x), (y), DATA_TYPE, VEC_SIZE)
-#define DIV_OP(x, y) DIV_SAT_OP_VEC_EXPAND((x), (y), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define EXP_OP(x) EXP_OP_EXPAND((x), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define LOG_OP(x) LOG_OP_EXPAND((x), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
-#define POW_OP(x, y) EXP_OP(MUL_OP(LOG_OP((x)), (y)))
-#define SQCVT_SAT(a) SQCVT_SAT_OP_EXPAND((a), DATA_TYPE, FIXED_POINT_POSITION)
-
-#define LOAD_OP(offset, ptr) vload16(offset, ptr)
-#define STORE_OP(data, offset, ptr) vstore16(data, offset, ptr)
-
-#else // FIXED_POINT_POSITION
-
#define MUL_OP(x, y) ((x) * (y))
#define ADD_OP(x, y) ((x) + (y))
#define DIV_OP(x, y) ((x) / (y))
@@ -48,18 +32,15 @@
#define LOAD_OP(offset, ptr) vload4(offset, ptr)
#define STORE_OP(data, offset, ptr) vstore4(data, offset, ptr)
-#endif // FIXED_POINT_POSITION
-
/** Apply cross-map normalization.
*
* @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
* @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size, e.g. -DVEC_SIZE=16
* @note The radius should be given as a preprocessor argument using -DRADIUS=size. e.g. -DRADIUS=5
* @note The number of slices should be given as a preprocessor argument using -DNUM_SLICES=size. e.g. -DNUM_SLICES=192
- * @note In case of fixed-point operation -DFIXED_POINT_POSITION=fixed_point_position must be provided: e.g. -DFIXED_POINT_POSITION=3
* @note Scaling coefficient (= alpha/norm_size), beta and kappa need to be passed at compile time using -DCOEFF, -DALPHA and -DKAPPA
*
- * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F16/F32
+ * @param[in] input_ptr Pointer to the first source tensor. Supported data types: F16/F32
* @param[in] input_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the first source tensor in Y dimension (in bytes)
@@ -116,10 +97,9 @@
* @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
* @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size, e.g. -DVEC_SIZE=16
* @note The radius should be given as a preprocessor argument using -DRADIUS=size. e.g. -DRADIUS=5
- * @note In case of fixed-point operation -DFIXED_POINT_POSITION=fixed_point_position must be provided: e.g. -DFIXED_POINT_POSITION=3
* @note Scaling coefficient (= alpha/norm_size), beta and kappa need to be passed at compile time using -DCOEFF, -DALPHA and -DKAPPA
*
- * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/F16/F32
+ * @param[in] input_ptr Pointer to the first source tensor. Supported data types: F16/F32
* @param[in] input_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the first source tensor in Y dimension (in bytes)
diff --git a/src/core/CL/cl_kernels/permute.cl b/src/core/CL/cl_kernels/permute.cl
index 6f978c9..03fc15e 100644
--- a/src/core/CL/cl_kernels/permute.cl
+++ b/src/core/CL/cl_kernels/permute.cl
@@ -29,7 +29,7 @@
* @attention Data type can be passed using the -DDATA_TYPE compile flag, e.g. -DDATA_TYPE=float
* @attention Input tensor depth should be given as a preprocessor argument using -DDEPTH_IN=size. e.g. -DDEPTH_IN=16
*
- * @param[in] input_ptr Pointer to the source image. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
+ * @param[in] input_ptr Pointer to the source image. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[in] input_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the source image in Y dimension (in bytes)
@@ -63,7 +63,7 @@
* @attention Data type can be passed using the -DDATA_TYPE compile flag, e.g. -DDATA_TYPE=float
* @attention Input tensor depth should be given as a preprocessor argument using -DDEPTH_IN=size. e.g. -DDEPTH_IN=16
*
- * @param[in] input_ptr Pointer to the source image. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
+ * @param[in] input_ptr Pointer to the source image. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[in] input_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the source image in Y dimension (in bytes)
@@ -97,7 +97,7 @@
* @attention Data type can be passed using the -DDATA_TYPE compile flag, e.g. -DDATA_TYPE=float
* @attention Input tensor depth should be given as a preprocessor argument using -DDEPTH_IN=size. e.g. -DDEPTH_IN=16
*
- * @param[in] input_ptr Pointer to the source image. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
+ * @param[in] input_ptr Pointer to the source image. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[in] input_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the source image in Y dimension (in bytes)
diff --git a/src/core/CL/cl_kernels/pixelwise_mul_int.cl b/src/core/CL/cl_kernels/pixelwise_mul_int.cl
index b5734a3..c99a08a 100644
--- a/src/core/CL/cl_kernels/pixelwise_mul_int.cl
+++ b/src/core/CL/cl_kernels/pixelwise_mul_int.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,18 +23,6 @@
*/
#include "helpers.h"
-#if defined(FIXED_POINT_POSITION)
-
-#include "fixed_point.h"
-
-#if defined(SATURATE)
-#define MUL_OP(x, y, scale, type, size) MUL_SAT_OP_EXPAND((x), (y), type, size, FIXED_POINT_POSITION)
-#else // SATURATE
-#define MUL_OP(x, y, scale, type, size) MUL_OP_EXPAND((x), (y), type, size, FIXED_POINT_POSITION)
-#endif // SATURATE
-
-#else // FIXED_POINT_POSITION
-
#if defined(SATURATE)
#define CONVERT_OP_INT_STR(x, type, size) (convert_##type##size##_sat(x))
#else // SATURATE
@@ -44,17 +32,14 @@
#define MUL_OP(x, y, scale, type, size) CONVERT_OP_INT((x) * (y) >> scale, type, size)
-#endif // FIXED_POINT_POSITION
-
/** Performs a pixelwise multiplication with integer scale of integer inputs.
*
* @attention The inputs and output data types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT:
* e.g. -DDATA_TYPE_IN1=uchar -DDATA_TYPE_IN2=ushort -DDATA_TYPE_OUT=short
* @attention The data_type of the intermediate result of the multiplication should passed as well using -DDATA_TYPE_RES.
* e.g. If one of inputs is S16 -DDATA_TYPE_RES=int should be passed else -DDATA_TYPE_RES=short.
- * @note In case of fixed-point operation -DFIXED_POINT_POSITION=fixed_point_position must be provided: e.g. -DFIXED_POINT_POSITION=3
*
- * @param[in] in1_ptr Pointer to the source image. Supported data types: U8/QS8/QS16/S16
+ * @param[in] in1_ptr Pointer to the source image. Supported data types: U8/S16
* @param[in] in1_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] in1_step_x in1_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] in1_stride_y Stride of the source image in Y dimension (in bytes)
@@ -78,7 +63,7 @@
* @param[in] out_stride_z Stride of the destination image in Y dimension (in bytes)
* @param[in] out_step_z out_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] out_offset_first_element_in_bytes The offset of the first element in the destination image
- * @param[in] scale Integer scaling factor. Supported data types: S32 (ignored for QS8 and QS16 as the assumption is scale = 1).
+ * @param[in] scale Integer scaling factor. Supported data types: S32.
*/
__kernel void pixelwise_mul_int(
TENSOR3D_DECLARATION(in1),
diff --git a/src/core/CL/cl_kernels/pooling_layer.cl b/src/core/CL/cl_kernels/pooling_layer.cl
index 2c7ddfd..0808353 100644
--- a/src/core/CL/cl_kernels/pooling_layer.cl
+++ b/src/core/CL/cl_kernels/pooling_layer.cl
@@ -23,28 +23,6 @@
*/
#include "helpers.h"
-#ifdef FIXED_POINT_POSITION
-
-#include "fixed_point.h"
-
-#if defined(POOL_AVG)
-#define POOL_OP(x, y) add_sat(x, y)
-#else /* POOL_AVG */
-#define POOL_OP(x, y) (max((x), (y)))
-#endif /* POOL_AVG */
-
-#define DIV_OP1(x, y) DIV_SAT_OP_EXPAND((x), (y), DATA_TYPE, FIXED_POINT_POSITION)
-#define DIV_OP(x, y) DIV_OP1(x, y << FIXED_POINT_POSITION)
-#define SQRT_OP(x) DIV_OP1((1 << FIXED_POINT_POSITION), (INVSQRT_OP_EXPAND((x), DATA_TYPE, 1, FIXED_POINT_POSITION)))
-
-#if defined(POOL_L2)
-#define POW2_OP(x, vec_size) MUL_SAT_OP_EXPAND((x), (x), DATA_TYPE, vec_size, FIXED_POINT_POSITION)
-#else /* defined(POOL_L2) */
-#define POW2_OP(x, vec_size) (x)
-#endif /* defined(POOL_L2) */
-
-#else /* FIXED_POINT_POSITION */
-
#if defined(POOL_AVG) || defined(POOL_L2)
#define POOL_OP(x, y) ((x) + (y))
#else /* defined(POOL_AVG) || defined(POOL_L2) */
@@ -60,8 +38,6 @@
#define DIV_OP(x, y) (x * (1.f / y))
#define SQRT_OP(x) sqrt((x))
-#endif /* FIXED_POINT_POSITION */
-
#define DIV_OP_NHWC(x, y) (x * (VEC_DATA_TYPE(DATA_TYPE, 8))(1.f / y))
#if STRIDE_X == 1
@@ -201,14 +177,14 @@
/** Performs a pooling function of pool size equal to 2.
*
- * @note Datatype must be passed using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types are QS8/QS16/F16/F32;
+ * @note Datatype must be passed using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types are F16/F32;
* @note In case of average pooling the following information must be passed at compile time:
* -DPOOL_AVG or -DPOOL_L2 must be provided otherwise max pooling will be performed.
* -DMAX_WIDTH and -DMAX_HEIGHT which are the maximum accessible indeces in x and y dimensions (width + pad)
* -DSTRIDE_X and -DSTRIDE_Y which are the steps of the window along the x and y directions
* -DPAD_X and -DPAD_Y which are the pooling paddings in x and y dimension
*
- * @param[in] input_ptr Pointer to the source image. Supported data types: QS8/QS16/F16/F32
+ * @param[in] input_ptr Pointer to the source image. Supported data types: F16/F32
* @param[in] input_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the source image in Y dimension (in bytes)
@@ -265,14 +241,14 @@
/** Performs a pooling function of pool size equal to 3
*
- * @note Datatype must be passed using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types are QS8/QS16/F16/F32;
+ * @note Datatype must be passed using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types are F16/F32;
* @note In case of average pooling the following information must be passed at compile time:
* -DPOOL_AVG or -DPOOL_L2 must be provided otherwise max pooling will be performed.
* -DMAX_WIDTH and -DMAX_HEIGHT which are the maximum accessible indeces in x and y dimensions (width + pad)
* -DSTRIDE_X and -DSTRIDE_Y which are the steps of the window along the x and y directions
* -DPAD_X and -DPAD_Y which are the pooling paddings in x and y dimension
*
- * @param[in] input_ptr Pointer to the source image. Supported data types: QS8/QS16/F16/F32
+ * @param[in] input_ptr Pointer to the source image. Supported data types: F16/F32
* @param[in] input_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the source image in Y dimension (in bytes)
@@ -331,7 +307,7 @@
*(__global DATA_TYPE *)output.ptr = res;
}
-#if defined(POOLING3x3) && !defined(FIXED_POINT_POSITION)
+#if defined(POOLING3x3)
#define CONVERT_OP(data_type) convert_##data_type##4
#define CONVERT_VECTOR4(data_type) CONVERT_OP(data_type)
@@ -353,7 +329,7 @@
/** Performs an optimized pooling function of pool size equal to 3 when the stride_x is less equal than 3
*
- * @note Datatype must be passed using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types are QS8/QS16/F16/F32;
+ * @note Datatype must be passed using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types are F16/F32;
* @note In case of average pooling the following information must be passed at compile time:
* -DPOOL_AVG or -DPOOL_L2 must be provided otherwise max pooling will be performed.
* -DMAX_WIDTH and -DMAX_HEIGHT which are the maximum accessible indeces in x and y dimensions (width + pad)
@@ -403,7 +379,7 @@
vstore4(res, 0, (__global DATA_TYPE *)output.ptr);
}
-#endif // defined(POOLING3x3) && !defined(FIXED_POINT_POSITION)
+#endif // defined(POOLING3x3)
#if defined(POOL_SIZE_X) && defined(POOL_SIZE_Y)
@@ -411,23 +387,17 @@
#if defined(POOL_AVG) || defined(POOL_L2)
#define INITIAL_VALUE 0
#else /* defined(POOL_AVG) || defined(POOL_L2) */
-#ifdef FIXED_POINT_POSITION
-#define MIN_VAL_EXPAND(type) type##_MIN
-#define MIN_VAL(type) MIN_VAL_EXPAND(type)
-#define INITIAL_VALUE MIN_VAL(DATA_TYPE)
-#else // FIXED_POINT_POSITION
#if FP16
#define INITIAL_VALUE -HALF_MAX
#else // FP16
#define INITIAL_VALUE -FLT_MAX
#endif // FP16
-#endif // FIXED_POINT_POSITION
#endif // POOL_AVG
/** Performs a pooling function of pool size equal to N (NCHW)
*
- * @note Datatype must be passed using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types are QS8/QS16/F16/F32;
+ * @note Datatype must be passed using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types are F16/F32;
* @note -DFP16 must be passed at compile time if half float data type is used
* @note Pool sizes must be passed using -DPOOL_SIZE_X and -DPOOL_SIZE_Y e.g. -DPOOL_SIZE_X=13;
* @note In case of average pooling the following information must be passed at compile time:
@@ -436,7 +406,7 @@
* -DSTRIDE_X and -DSTRIDE_Y which are the steps of the window along the x and y directions
* -DPAD_X and -DPAD_Y which are the pooling paddings in x and y dimension
*
- * @param[in] input_ptr Pointer to the source image. Supported data types: QS8/QS16/F16/F32
+ * @param[in] input_ptr Pointer to the source image. Supported data types: F16/F32
* @param[in] input_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the source image in Y dimension (in bytes)
@@ -579,10 +549,10 @@
for(int y = 0; y < POOL_SIZE_Y; ++y)
{
- int y1 = select(y, PAD_Y - idx_height, y + idx_height < PAD_Y || y + idx_height > MAX_HEIGHT);
+ int y1 = select(y, PAD_Y - idx_height, y + idx_height - PAD_Y < 0 || y + idx_height - PAD_Y >= MAX_HEIGHT);
for(int x = 0; x < POOL_SIZE_X; ++x)
{
- int x1 = select(x, PAD_X - idx_width - 1, x + idx_width < PAD_X || x + idx_width > MAX_WIDTH);
+ int x1 = select(x, PAD_X - idx_width - 1, x + idx_width - PAD_X < 0 || x + idx_width - PAD_X >= MAX_WIDTH);
x1 = select(x1, PAD_X - idx_width - 1, y != y1);
VEC_DATA_TYPE(DATA_TYPE, 8)
diff --git a/src/core/CL/cl_kernels/reshape_layer.cl b/src/core/CL/cl_kernels/reshape_layer.cl
index 23eccbf..11393d2 100644
--- a/src/core/CL/cl_kernels/reshape_layer.cl
+++ b/src/core/CL/cl_kernels/reshape_layer.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,7 +27,7 @@
*
* @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
*
- * @param[in] input_ptr Pointer to the first source tensor. Supported data types: U8/S8/QS8/U16/S16/QS16/U32/S32/F16/F32
+ * @param[in] input_ptr Pointer to the first source tensor. Supported data types: U8/S8/U16/S16/U32/S32/F16/F32
* @param[in] input_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the first source tensor in Y dimension (in bytes)
diff --git a/src/core/CL/cl_kernels/scale.cl b/src/core/CL/cl_kernels/scale.cl
index a2ae8c4..744f28a 100644
--- a/src/core/CL/cl_kernels/scale.cl
+++ b/src/core/CL/cl_kernels/scale.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -83,7 +83,7 @@
* @param[in] scale_x The scale factor along x dimension
* @param[in] scale_y The scale factor along y dimension
*/
-__kernel void scale_nearest_neighbour(
+__kernel void scale_nearest_neighbour_nchw(
IMAGE_DECLARATION(in),
IMAGE_DECLARATION(out),
const float input_width,
@@ -119,7 +119,7 @@
* @param[in] scale_x The scale factor along x dimension
* @param[in] scale_y The scale factor along y dimension
*/
-__kernel void scale_bilinear(
+__kernel void scale_bilinear_nchw(
IMAGE_DECLARATION(in),
IMAGE_DECLARATION(out),
const float input_width,
@@ -133,3 +133,124 @@
const float8 tc = transform_bilinear(get_current_coords(), r);
vstore4(bilinear_interpolate_with_border(&in, tc, input_width, input_height, BORDER_SIZE), 0, (__global DATA_TYPE *)out.ptr);
}
+
+/** Performs scale on an image interpolating with the NEAREAST NEIGHBOUR method. Input and output are single channel F32. (NHWC)
+ *
+ * @note Sampling policy to used is passed as -DSAMPLING_POLICY_(TYPE) e.g. -DSAMPLING_POLICY_TOP_LEFT
+ *
+ * @param[in] in_ptr Pointer to the source image. Supported data types: U8/S16/F16/F32.
+ * @param[in] in_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] in_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] in_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] in_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] in_stride_z Stride of the source image in Z dimension (in bytes)
+ * @param[in] in_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] in_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[out] out_ptr Pointer to the destination image. Supported data types: same as @p in_ptr
+ * @param[in] out_stride_x Stride of the destination image in X dimension (in bytes)
+ * @param[in] out_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] out_stride_y Stride of the destination image in Y dimension (in bytes)
+ * @param[in] out_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] out_stride_z Stride of the destination image in Z dimension (in bytes)
+ * @param[in] out_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] out_offset_first_element_in_bytes The offset of the first element in the destination image
+ * @param[in] input_width Input image width
+ * @param[in] input_height Input image height
+ * @param[in] scale_x The scale factor along x dimension
+ * @param[in] scale_y The scale factor along y dimension
+ */
+__kernel void scale_nearest_neighbour_nhwc(
+ TENSOR3D_DECLARATION(in),
+ TENSOR3D_DECLARATION(out),
+ const float input_width,
+ const float input_height,
+ const float scale_x,
+ const float scale_y)
+{
+ Tensor3D in = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(in);
+ Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
+
+ const float new_x = (get_global_id(1) + 0.5f) * scale_x;
+ const float new_y = (get_global_id(2) + 0.5f) * scale_y;
+ const float clamped_x = clamp(new_x, 0.0f, input_width - 1);
+ const float clamped_y = clamp(new_y, 0.0f, input_height - 1);
+
+ *((__global DATA_TYPE *)out.ptr) = *((__global DATA_TYPE *)tensor3D_offset(&in, get_global_id(0), convert_int(clamped_x), convert_int(clamped_y)));
+}
+
+/** Performs scale on an image interpolating with the BILINEAR method. (NHWC)
+ *
+ * @note Sampling policy to be used is passed as -DSAMPLING_POLICY_(TYPE) e.g. -DSAMPLING_POLICY_TOP_LEFT
+ * @note If border mode replicate is used, is should be passed as -DBORDER_MODE_REPLICATE
+ *
+ * @param[in] in_ptr Pointer to the source image. Supported data types: U8/S16/F16/F32.
+ * @param[in] in_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] in_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] in_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] in_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] in_stride_z Stride of the source image in Z dimension (in bytes)
+ * @param[in] in_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] in_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[out] out_ptr Pointer to the destination image. Supported data types: same as @p in_ptr
+ * @param[in] out_stride_x Stride of the destination image in X dimension (in bytes)
+ * @param[in] out_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] out_stride_y Stride of the destination image in Y dimension (in bytes)
+ * @param[in] out_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] out_stride_z Stride of the destination image in Z dimension (in bytes)
+ * @param[in] out_step_z dst_stride_y * number of elements along Z processed per workitem(in bytes)
+ * @param[in] out_offset_first_element_in_bytes The offset of the first element in the destination image
+ * @param[in] input_width Input image width
+ * @param[in] input_height Input image height
+ * @param[in] scale_x The scale factor along x dimension
+ * @param[in] scale_y The scale factor along y dimension
+ */
+__kernel void scale_bilinear_nhwc(
+ TENSOR3D_DECLARATION(in),
+ TENSOR3D_DECLARATION(out),
+ const float input_width,
+ const float input_height,
+ const float scale_x,
+ const float scale_y)
+{
+ Tensor3D in = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(in);
+ Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
+
+#ifdef SAMPLING_POLICY_TOP_LEFT
+ const float new_x = get_global_id(1) * scale_x;
+ const float new_y = get_global_id(2) * scale_y;
+#elif SAMPLING_POLICY_CENTER
+ const float new_x = (get_global_id(1) + 0.5f) * scale_x - 0.5f;
+ const float new_y = (get_global_id(2) + 0.5f) * scale_y - 0.5f;
+#else /* SAMPLING_POLICY */
+#error("Unsupported sampling policy");
+#endif /* SAMPLING_POLICY */
+
+ const float new_xf = floor(new_x);
+ const float new_yf = floor(new_y);
+ float clamped_x = clamp(new_xf, 0.0f, input_width - 1);
+ float clamped_x1 = clamp(new_xf + 1, 0.0f, input_width - 1);
+ float clamped_x_ = clamped_x;
+ float clamped_x1_ = clamped_x1;
+ const float clamped_y = clamp(new_yf, 0.0f, input_height - 1);
+ const float clamped_y1 = clamp(new_yf + 1, 0.0f, input_height - 1);
+
+#ifndef BORDER_MODE_REPLICATE
+ clamped_x1 = select(clamped_x1, 0.0f - BORDER_SIZE, new_yf + 1 < 0.f || new_yf + 1 > input_height - 1 || new_xf + 1 < 0.f || new_xf + 1 > input_width - 1);
+ clamped_x_ = select(clamped_x_, 0.0f - BORDER_SIZE, new_yf + 1 > input_height - 1 || new_xf < 0.f || new_xf > input_width - 1);
+ clamped_x = select(clamped_x, 0.0f - BORDER_SIZE, new_yf < 0.f || new_yf > input_height - 1 || new_xf < 0.f || new_xf > input_width - 1);
+ clamped_x1_ = select(clamped_x1_, 0.0f - BORDER_SIZE, new_xf + 1 < 0.f || new_xf + 1 > input_width - 1 || new_yf < 0.f || new_yf > input_height - 1);
+#endif /* BORDER_MODE_REPLICATE */
+
+ float4 ins = (float4)(*((__global DATA_TYPE *)tensor3D_offset(&in, get_global_id(0), convert_int(clamped_x), convert_int(clamped_y))),
+ *((__global DATA_TYPE *)tensor3D_offset(&in, get_global_id(0), convert_int(clamped_x1_), convert_int(clamped_y))),
+ *((__global DATA_TYPE *)tensor3D_offset(&in, get_global_id(0), convert_int(clamped_x_), convert_int(clamped_y1))),
+ *((__global DATA_TYPE *)tensor3D_offset(&in, get_global_id(0), convert_int(clamped_x1), convert_int(clamped_y1))));
+
+ const float a = new_x - new_xf;
+ const float b = 1.f - a;
+ const float a1 = new_y - new_yf;
+ const float b1 = 1.f - a1;
+ const float fr = ((ins.s0 * b * b1) + (ins.s1 * a * b1) + (ins.s2 * b * a1) + (ins.s3 * a * a1));
+
+ *((__global DATA_TYPE *)out.ptr) = CONVERT(fr, DATA_TYPE);
+}
diff --git a/src/core/CL/cl_kernels/softmax_layer.cl b/src/core/CL/cl_kernels/softmax_layer.cl
index 7fed879..4ad8180 100644
--- a/src/core/CL/cl_kernels/softmax_layer.cl
+++ b/src/core/CL/cl_kernels/softmax_layer.cl
@@ -23,23 +23,6 @@
*/
#include "helpers.h"
-#ifdef FIXED_POINT_POSITION
-
-#include "fixed_point.h"
-#define MAX_OP(x, y, type, size) MAX_OP_EXPAND(x, y, type, size)
-#define ADD_OP(x, y, type, size) ADD_SAT_OP_EXPAND((x), (y), type, size)
-#define SUB_OP(x, y, type, size) SUB_SAT_OP_EXPAND((x), (y), type, size)
-#define MUL_OP(x, y, type, size) MUL_SAT_OP_EXPAND((x), (y), type, size, FIXED_POINT_POSITION)
-#define DIV_OP(x, y, type, size) DIV_SAT_OP_VEC_EXPAND((x), (y), type, size, FIXED_POINT_POSITION)
-#define EXP_OP(x, type, size) EXP_OP_EXPAND((x), type, size, FIXED_POINT_POSITION)
-
-#define MIN_VAL_EXPAND(type) type##_MIN
-#define MIN_VAL(type) MIN_VAL_EXPAND(type)
-#define MINVAL MIN_VAL(DATA_TYPE)
-#define SELECT_DATA_TYPE EXPAND(DATA_TYPE)
-
-#else /* FIXED_POINT_POSITION */
-
#define MAX_OP(x, y, type, size) max((x), (y))
#define ADD_OP(x, y, type, size) ((x) + (y))
#define SUB_OP(x, y, type, size) ((x) - (y))
@@ -55,8 +38,6 @@
#define SELECT_DATA_TYPE int
#endif /* USE_F16 */
-#endif /* FIXED_POINT_POSITION */
-
/* Number of workitems in dimension 0. */
#if !defined(GRID_SIZE)
#define GRID_SIZE 1
@@ -90,9 +71,8 @@
/** Divides all the values of the input tensor by the sum calculated from softmax_layer_shift_exp_sum kernel.
*
* @note Datatype must be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
- * @note Fixed point position must be given as a preprocessor argument using -DFIXED_POINT_POSITION=pos. e.g. DFIXED_POINT_POSITION=4
*
- * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: QS8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -137,11 +117,10 @@
* then gets the exponent of each element as sums all elements across each row.
*
* @note Datatype must be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
- * @note Fixed point position must be given as a preprocessor argument using -DFIXED_POINT_POSITION=pos. e.g. DFIXED_POINT_POSITION=4
* @note In case the input is not a multiple of VECTOR_SIZE (2,4,8,16) -DNON_MULTIPLE_OF_VECTOR_SIZE must be passed.
* @note Beta can be optionally passed at compile time using -DBETA (by default, it is 1.0).
*
- * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: QS8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -287,11 +266,10 @@
* then gets the exponent of each element as sums all elements across each row.
*
* @note Datatype must be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
- * @note Fixed point position must be given as a preprocessor argument using -DFIXED_POINT_POSITION=pos. e.g. DFIXED_POINT_POSITION=4
* @note In case the input is not a multiple of VECTOR_SIZE (2,4,8,16) -DNON_MULTIPLE_OF_VECTOR_SIZE must be passed.
* @note Beta can be optionally passed at compile time using -DBETA (by default, it is 1.0).
*
- * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: QS8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
diff --git a/src/core/CL/cl_kernels/softmax_layer_quantized.cl b/src/core/CL/cl_kernels/softmax_layer_quantized.cl
index cbcde4e..fcd1ec5 100644
--- a/src/core/CL/cl_kernels/softmax_layer_quantized.cl
+++ b/src/core/CL/cl_kernels/softmax_layer_quantized.cl
@@ -230,10 +230,9 @@
* then gets the exponent of each element as sums all elements across each row.
*
* @note Datatype must be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
- * @note Fixed point position must be given as a preprocessor argument using -DFIXED_POINT_POSITION=pos. e.g. DFIXED_POINT_POSITION=4
* @note In case the input is not a multiple of VECTOR_SIZE (2,4,8,16) -DNON_MULTIPLE_OF_VECTOR_SIZE must be passed.
*
- * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: QS8/QS16/F16/F32
+ * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -517,7 +516,6 @@
/** Divides all the values of the input tensor by the sum calculated from softmax_layer_shift_exp_sum kernel.
*
- * @note Fixed point position must be given as a preprocessor argument using -DFIXED_POINT_POSITION=pos. e.g. DFIXED_POINT_POSITION=4
* @note Quantized beta can be optionally passed at compile time using -DINPUT_BETA_MULTIPLIER and -DINPUT_BETA_LEFT_SHIFT (if undefined, assume beta equals 1.0)
* @note -DDIFF_MIN must be passed at compile time. It is threshold difference between maximum value of input data and current processed value, it defines whether the value will be taken into account or not.
*
diff --git a/src/core/CL/cl_kernels/winograd.cl b/src/core/CL/cl_kernels/winograd.cl
deleted file mode 100644
index 0458e53..0000000
--- a/src/core/CL/cl_kernels/winograd.cl
+++ /dev/null
@@ -1,1611 +0,0 @@
-/*
- * Copyright (c) 2018 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "helpers.h"
-
-#if defined(NUM_CHANNELS)
-
-/** This OpenCL kernel performs Winograd filter transform 3x3 when the data format is NCHW and the output tile is 2x2
- *
- * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
- *
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
- * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
- * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_filter_transform_2x2_3x3_nchw(
- TENSOR4D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
-{
- Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
-
- const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
-
- // Load the values from the input tensor
- float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
- float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
- float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
-
- // Transform the 3x3 tile in a 4x4 tile
- float4 out0 = 0.0f;
- float4 out1 = 0.0f;
- float4 out2 = 0.0f;
- float4 out3 = 0.0f;
-
- // Row 0
- out0.s0 = (w0.s0);
- out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
- out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
- out0.s3 = (w0.s2);
-
- // Row 1
- out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
- out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
- out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
- out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
-
- // Row 2
- out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
- out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
- out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
- out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
-
- // Row 3
- out3.s0 = (w2.s0);
- out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
- out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
- out3.s3 = (w2.s2);
-
- int z = get_global_id(2);
- int x0 = z / NUM_CHANNELS; // idx filter
- int y0 = z % NUM_CHANNELS; // idx channel
-
- // Get output address
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
-
- // Store the 16 values across the 16 channels
- *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
- *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
- *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
- *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
- *(__global float *)(dst_addr + 4 * dst_stride_z) = out1.s0;
- *(__global float *)(dst_addr + 5 * dst_stride_z) = out1.s1;
- *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s2;
- *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s3;
- *(__global float *)(dst_addr + 8 * dst_stride_z) = out2.s0;
- *(__global float *)(dst_addr + 9 * dst_stride_z) = out2.s1;
- *(__global float *)(dst_addr + 10 * dst_stride_z) = out2.s2;
- *(__global float *)(dst_addr + 11 * dst_stride_z) = out2.s3;
- *(__global float *)(dst_addr + 12 * dst_stride_z) = out3.s0;
- *(__global float *)(dst_addr + 13 * dst_stride_z) = out3.s1;
- *(__global float *)(dst_addr + 14 * dst_stride_z) = out3.s2;
- *(__global float *)(dst_addr + 15 * dst_stride_z) = out3.s3;
-}
-
-/** This OpenCL kernel performs Winograd filter transform 3x3 when the data format is NCHW and the output tile is 4x4
- *
- * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
- *
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
- * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
- * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_filter_transform_4x4_3x3_nchw(
- TENSOR4D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
-{
- Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
-
- const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
-
- // Load the values from the input tensor
- float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
- float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
- float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
-
- // Transform the 3x3 tile in a 6x6 tile
- float8 out0 = 0.0f;
- float8 out1 = 0.0f;
- float8 out2 = 0.0f;
- float8 out3 = 0.0f;
- float8 out4 = 0.0f;
- float8 out5 = 0.0f;
-
- // Row 0
- out0.s0 = (w0.s0) / 16.f;
- out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
- out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
- out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
- out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
- out0.s5 = (w0.s2) / 4.f;
-
- // Row 1
- out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
- out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
- out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
- out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
- out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
- out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
-
- // Row 2
- out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
- out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
- out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
- out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
- out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
- out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
-
- // Row 3
- out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
- out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
- out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
- out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
- out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
- out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
-
- // Row 4
- out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
- out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
- out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
- out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
- out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
- out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
-
- // Row 5
- out5.s0 = (w2.s0) / 4.f;
- out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
- out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
- out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
- out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
- out5.s5 = (w2.s2);
-
- int z = get_global_id(2);
- int x0 = z / NUM_CHANNELS; // idx filter
- int y0 = z % NUM_CHANNELS; // idx channel
-
- // Get output address
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
-
- // Store the 36 values across the 36 channels
- *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
- *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
- *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
- *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
- *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
- *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
- *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s0;
- *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s1;
- *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s2;
- *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s3;
- *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s4;
- *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s5;
- *(__global float *)(dst_addr + 12 * dst_stride_z) = out2.s0;
- *(__global float *)(dst_addr + 13 * dst_stride_z) = out2.s1;
- *(__global float *)(dst_addr + 14 * dst_stride_z) = out2.s2;
- *(__global float *)(dst_addr + 15 * dst_stride_z) = out2.s3;
- *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s4;
- *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s5;
- *(__global float *)(dst_addr + 18 * dst_stride_z) = out3.s0;
- *(__global float *)(dst_addr + 19 * dst_stride_z) = out3.s1;
- *(__global float *)(dst_addr + 20 * dst_stride_z) = out3.s2;
- *(__global float *)(dst_addr + 21 * dst_stride_z) = out3.s3;
- *(__global float *)(dst_addr + 22 * dst_stride_z) = out3.s4;
- *(__global float *)(dst_addr + 23 * dst_stride_z) = out3.s5;
- *(__global float *)(dst_addr + 24 * dst_stride_z) = out4.s0;
- *(__global float *)(dst_addr + 25 * dst_stride_z) = out4.s1;
- *(__global float *)(dst_addr + 26 * dst_stride_z) = out4.s2;
- *(__global float *)(dst_addr + 27 * dst_stride_z) = out4.s3;
- *(__global float *)(dst_addr + 28 * dst_stride_z) = out4.s4;
- *(__global float *)(dst_addr + 29 * dst_stride_z) = out4.s5;
- *(__global float *)(dst_addr + 30 * dst_stride_z) = out5.s0;
- *(__global float *)(dst_addr + 31 * dst_stride_z) = out5.s1;
- *(__global float *)(dst_addr + 32 * dst_stride_z) = out5.s2;
- *(__global float *)(dst_addr + 33 * dst_stride_z) = out5.s3;
- *(__global float *)(dst_addr + 34 * dst_stride_z) = out5.s4;
- *(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
-}
-
-/** This OpenCL kernel performs Winograd filter transform 5x5 when the data format is NCHW and the output tile is 4x4
- *
- * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
- *
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
- * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
- * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_filter_transform_4x4_5x5_nchw(
- TENSOR4D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
-{
- Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
-
- const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
-
- // Load the values from the input tensor
- const char stride_x = 4 * sizeof(float); // Used for accessing the last value in each row
- const uchar8 stride_y = (uchar8)(0, 1, 2, 3, 4, 0, 0, 0) * (uchar8)src_stride_y;
-
- float4 w00 = vload4(0, (__global float *)(src_addr + stride_y.s0));
- float w01 = *((__global float *)(src_addr + stride_y.s0 + stride_x));
- float4 w10 = vload4(0, (__global float *)(src_addr + stride_y.s1));
- float w11 = *((__global float *)(src_addr + stride_y.s1 + stride_x));
- float4 w20 = vload4(0, (__global float *)(src_addr + stride_y.s2));
- float w21 = *((__global float *)(src_addr + stride_y.s2 + stride_x));
- float4 w30 = vload4(0, (__global float *)(src_addr + stride_y.s3));
- float w31 = *((__global float *)(src_addr + stride_y.s3 + stride_x));
- float4 w40 = vload4(0, (__global float *)(src_addr + stride_y.s4));
- float w41 = *((__global float *)(src_addr + stride_y.s4 + stride_x));
-
- // Transform the 3x3 tile in a 8x8 tile
- float8 out0 = 0.0f;
- float8 out1 = 0.0f;
- float8 out2 = 0.0f;
- float8 out3 = 0.0f;
- float8 out4 = 0.0f;
- float8 out5 = 0.0f;
- float8 out6 = 0.0f;
- float8 out7 = 0.0f;
-
- // Row 0
- out0.s0 = w00.s0;
- out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
- out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
- out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
- out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
- out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
- out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
- out0.s7 = w01;
-
- // Row 1
- out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
- out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
- (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
- out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
- (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
- out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
- (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
- out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
- (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
- out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
- (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
- out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
- (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
- out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
-
- // Row 2
- out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
- out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
- (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
- out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
- (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
- out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
- (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
- out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
- (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
- out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
- (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
- out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
- (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
- out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
-
- // Row 3
- out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
- out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
- (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
- (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
- out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
- (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
- (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
- out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
- (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
- (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
- out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
- (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
- (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
- out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
- (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
- (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
- out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
- (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
- (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
- out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
-
- // Row 4
- out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
- out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
- (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
- (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
- out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
- (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
- (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
- out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
- (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
- (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
- out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
- (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
- (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
- out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
- (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
- (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
- out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
- (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
- (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
- out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
-
- // Row 5
- out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
- out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
- (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
- (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
- out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
- (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
- (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
- out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
- (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
- (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
- out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
- (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
- (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
- out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
- (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
- (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
- out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
- (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
- (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
- out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
-
- // Row 6
- out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
- out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
- (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
- (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
- out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
- (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
- (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
- out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
- (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
- (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
- out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
- (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
- (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
- out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
- (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
- (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
- out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
- (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
- (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
- out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
-
- // Row 7
- out7.s0 = w40.s0;
- out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
- out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
- out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
- out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
- out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
- out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
- out7.s7 = w41;
-
- int z = get_global_id(2);
- int x0 = z / NUM_CHANNELS; // idx filter
- int y0 = z % NUM_CHANNELS; // idx channel
-
- // Get output address
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
-
- // Store the 64 values across the 64 channels
- *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
- *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
- *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
- *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
- *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
- *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
- *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
- *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
- *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
- *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
- *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
- *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
- *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
- *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
- *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
- *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
- *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
- *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
- *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
- *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
- *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
- *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
- *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
- *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
- *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
- *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
- *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
- *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
- *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
- *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
- *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
- *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
- *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
- *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
- *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
- *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
- *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
- *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
- *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
- *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
- *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
- *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
- *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
- *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
- *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
- *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
- *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
- *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
- *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
- *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
- *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
- *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
- *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
- *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
- *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
- *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
- *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
- *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
- *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
- *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
- *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
- *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
- *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
- *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
-}
-#endif // defined(NUM_CHANNELS)
-
-#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP)
-/** This OpenCL kernel computes the input transform when the kernel size is 3x3 and the output tile is 2x2
- *
- * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
- * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
- *
- * @param[in] src_ptr Pointer to the source image. Supported data types: F32
- * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
- TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
-{
- int x = get_global_id(0);
- int y = get_global_id(1);
- int z = get_global_id(2);
-
- // Compute input address
- __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 2 * src_stride_x + y * 2 * src_stride_y + z * src_stride_z;
-
- src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
-
- float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
- float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
- float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
- float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
-
- float4 tmp0 = in_row0 - in_row2;
- float4 tmp1 = in_row1 + in_row2;
- float4 tmp2 = in_row2 - in_row1;
- float4 tmp3 = in_row1 - in_row3;
-
- float out00 = tmp0.s0 - tmp0.s2;
- float out01 = tmp0.s1 + tmp0.s2;
- float out02 = tmp0.s2 - tmp0.s1;
- float out03 = tmp0.s1 - tmp0.s3;
-
- float out10 = tmp1.s0 - tmp1.s2;
- float out11 = tmp1.s1 + tmp1.s2;
- float out12 = tmp1.s2 - tmp1.s1;
- float out13 = tmp1.s1 - tmp1.s3;
-
- float out20 = tmp2.s0 - tmp2.s2;
- float out21 = tmp2.s1 + tmp2.s2;
- float out22 = tmp2.s2 - tmp2.s1;
- float out23 = tmp2.s1 - tmp2.s3;
-
- float out30 = tmp3.s0 - tmp3.s2;
- float out31 = tmp3.s1 + tmp3.s2;
- float out32 = tmp3.s2 - tmp3.s1;
- float out33 = tmp3.s1 - tmp3.s3;
-
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
-
- *((__global float *)(dst_addr + 0 * dst_stride_z)) = out00;
- *((__global float *)(dst_addr + 1 * dst_stride_z)) = out01;
- *((__global float *)(dst_addr + 2 * dst_stride_z)) = out02;
- *((__global float *)(dst_addr + 3 * dst_stride_z)) = out03;
- *((__global float *)(dst_addr + 4 * dst_stride_z)) = out10;
- *((__global float *)(dst_addr + 5 * dst_stride_z)) = out11;
- *((__global float *)(dst_addr + 6 * dst_stride_z)) = out12;
- *((__global float *)(dst_addr + 7 * dst_stride_z)) = out13;
- *((__global float *)(dst_addr + 8 * dst_stride_z)) = out20;
- *((__global float *)(dst_addr + 9 * dst_stride_z)) = out21;
- *((__global float *)(dst_addr + 10 * dst_stride_z)) = out22;
- *((__global float *)(dst_addr + 11 * dst_stride_z)) = out23;
- *((__global float *)(dst_addr + 12 * dst_stride_z)) = out30;
- *((__global float *)(dst_addr + 13 * dst_stride_z)) = out31;
- *((__global float *)(dst_addr + 14 * dst_stride_z)) = out32;
- *((__global float *)(dst_addr + 15 * dst_stride_z)) = out33;
-}
-
-/** This OpenCL kernel computes the input transform when the kernel size is 3x3, the output tile is 2x2 and the number of channels is multiple of 2
- *
- * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
- * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
- *
- * @param[in] src_ptr Pointer to the source image. Supported data types: F32
- * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
- TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
-{
- int x = get_global_id(0);
- int y = get_global_id(1);
- int z = get_global_id(2) * 2;
-
- // Compute input address
- __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 2 * src_stride_x + y * 2 * src_stride_y + z * src_stride_z;
-
- src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
-
- float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
- float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
- float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
- float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
-
- src_addr += src_stride_z;
- float4 in_row4 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
- float4 in_row5 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
- float4 in_row6 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
- float4 in_row7 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
-
- float4 tmp0 = in_row0 - in_row2;
- float4 tmp1 = in_row1 + in_row2;
- float4 tmp2 = in_row2 - in_row1;
- float4 tmp3 = in_row1 - in_row3;
-
- float4 tmp4 = in_row4 - in_row6;
- float4 tmp5 = in_row5 + in_row6;
- float4 tmp6 = in_row6 - in_row5;
- float4 tmp7 = in_row5 - in_row7;
-
- float2 out00 = (float2)(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
- float2 out01 = (float2)(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
- float2 out02 = (float2)(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
- float2 out03 = (float2)(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
-
- float2 out10 = (float2)(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
- float2 out11 = (float2)(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
- float2 out12 = (float2)(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
- float2 out13 = (float2)(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
-
- float2 out20 = (float2)(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
- float2 out21 = (float2)(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
- float2 out22 = (float2)(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
- float2 out23 = (float2)(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
-
- float2 out30 = (float2)(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
- float2 out31 = (float2)(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
- float2 out32 = (float2)(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
- float2 out33 = (float2)(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
-
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
-
- vstore2(out00, 0, (__global float *)(dst_addr + 0 * dst_stride_z));
- vstore2(out01, 0, (__global float *)(dst_addr + 1 * dst_stride_z));
- vstore2(out02, 0, (__global float *)(dst_addr + 2 * dst_stride_z));
- vstore2(out03, 0, (__global float *)(dst_addr + 3 * dst_stride_z));
- vstore2(out10, 0, (__global float *)(dst_addr + 4 * dst_stride_z));
- vstore2(out11, 0, (__global float *)(dst_addr + 5 * dst_stride_z));
- vstore2(out12, 0, (__global float *)(dst_addr + 6 * dst_stride_z));
- vstore2(out13, 0, (__global float *)(dst_addr + 7 * dst_stride_z));
- vstore2(out20, 0, (__global float *)(dst_addr + 8 * dst_stride_z));
- vstore2(out21, 0, (__global float *)(dst_addr + 9 * dst_stride_z));
- vstore2(out22, 0, (__global float *)(dst_addr + 10 * dst_stride_z));
- vstore2(out23, 0, (__global float *)(dst_addr + 11 * dst_stride_z));
- vstore2(out30, 0, (__global float *)(dst_addr + 12 * dst_stride_z));
- vstore2(out31, 0, (__global float *)(dst_addr + 13 * dst_stride_z));
- vstore2(out32, 0, (__global float *)(dst_addr + 14 * dst_stride_z));
- vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z));
-}
-
-/** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data format is NCHW
- *
- * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
- * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
- *
- * @param[in] src_ptr Pointer to the source image. Supported data types: F32
- * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
- TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
-{
- int x = get_global_id(0);
- int y = get_global_id(1);
- int z = get_global_id(2);
-
- // Compute input address
- __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
-
- src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
-
- // Row4
- float4 d40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
- float2 d41 = vload2(2, (__global float *)(src_addr + 4 * src_stride_y));
-
- float k0 = d41.s0;
- float k1 = d41.s0;
- float k2 = d41.s0;
- float k3 = d41.s0;
- float k4 = d41.s0;
- float k5 = 0.0f;
-
- k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
- k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
- k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
- k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
- k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
- k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
-
- // Row0
- float4 d00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
- float2 d01 = vload2(2, (__global float *)(src_addr + 0 * src_stride_y));
-
- // Row2
- float4 d20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
- float2 d21 = vload2(2, (__global float *)(src_addr + 2 * src_stride_y));
-
- // Compute destination address
- __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y);
-
- uint dst_plane_stride = dst_stride_z / sizeof(float);
-
- float out0 = k0;
- float out1 = k1;
- float out2 = k2;
- float out3 = k3;
- float out4 = k4;
- float out5 = k5;
- float out6 = k0;
- float out7 = k1;
- float out8 = k2;
- float out9 = k3;
- float out10 = k4;
- float out11 = k5;
- float out12 = k0;
- float out13 = k1;
- float out14 = k2;
- float out15 = k3;
- float out16 = k4;
- float out17 = k5;
- float out18 = k0;
- float out19 = k1;
- float out20 = k2;
- float out21 = k3;
- float out22 = k4;
- float out23 = k5;
- float out24 = k0;
- float out25 = k1;
- float out26 = k2;
- float out27 = k3;
- float out28 = k4;
- float out29 = k5;
-
- // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
- out0 += 16.0f * d00.s0 - 20.0f * d00.s2 - 20.0f * d20.s0 + 25.0f * d20.s2 + 4.0f * d01.s0 - 5.0f * d21.s0;
- out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
- out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 - 20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
- out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
- out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 - 10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
- out5 += 16.0f * d00.s1 - 20.0f * d00.s3 - 20.0f * d20.s1 + 4.0f * d01.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
-
- *(dst_addr) = out0;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out1;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out2;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out3;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out4;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out5;
- dst_addr += dst_plane_stride;
-
- // Row1
- float4 d10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
- float2 d11 = vload2(2, (__global float *)(src_addr + 1 * src_stride_y));
-
- // Row3
- float4 d30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
- float2 d31 = vload2(2, (__global float *)(src_addr + 3 * src_stride_y));
-
- // Compute common parts for the channels between [6, 29]
- // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
- // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
- float part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
- float part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
- float part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
- float part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
- float part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
- float part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
- float part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
- float part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
- float part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
- float part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
- float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
- float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
-
- // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
- // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
- float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
- float part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
- float part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
- float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
- float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
- float part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
- float part18 = part6 * 0.25f; // d20.s2 - d21.s0
- float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
- float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
- float part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
- float part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
- float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
-
- out6 += part0 - part1;
- out12 += part0 + part1;
- out7 += part2 + part3 + part4 + part5;
- out8 += part2 - part3 + part4 - part5;
- out13 += part2 + part3 - part4 - part5;
- out14 += part2 - part3 - part4 + part5;
- out9 += part6 + part7 + part8 + part9;
- out10 += part6 - part7 + part8 - part9;
- out15 += part6 - part7 - part8 + part9;
- out16 += part6 + part7 - part8 - part9;
- out11 += part10 + part11;
- out17 += part10 - part11;
-
- out18 += part13 - part12;
- out24 += part13 + part12;
- out19 += part14 + part15 + part16 + part17;
- out20 += part14 - part15 + part16 - part17;
- out25 += part14 - part15 - part16 + part17;
- out26 += part14 + part15 - part16 - part17;
- out21 += part18 + part19 + part20 + part21;
- out22 += part18 - part19 + part20 - part21;
- out27 += part18 - part19 - part20 + part21;
- out28 += part18 + part19 - part20 - part21;
- out23 += part22 + part23;
- out29 += part22 - part23;
-
- *(dst_addr) = out6;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out7;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out8;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out9;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out10;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out11;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out12;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out13;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out14;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out15;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out16;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out17;
- dst_addr += dst_plane_stride;
-
- *(dst_addr) = out18;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out19;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out20;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out21;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out22;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out23;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out24;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out25;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out26;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out27;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out28;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out29;
- dst_addr += dst_plane_stride;
-
- // Row5
- float4 d50 = vload4(0, (__global float *)(src_addr + 5 * src_stride_y));
- float2 d51 = vload2(2, (__global float *)(src_addr + 5 * src_stride_y));
-
- // Channels [30, 35]
- out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
- out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
- out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
- out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
- out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
- out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
-
- *(dst_addr) = out0;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out1;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out2;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out3;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out4;
- dst_addr += dst_plane_stride;
- *(dst_addr) = out5;
- dst_addr += dst_plane_stride;
-}
-
-#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
- ({ \
- comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
- comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
- comm_fact.s2 = 2.5f * tmp.s3; \
- comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
- comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
- comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
- comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
- \
- out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
- out.s1 = comm_fact.s0 + comm_fact.s1; \
- out.s2 = comm_fact.s0 - comm_fact.s1; \
- out.s3 = comm_fact.s3 + comm_fact.s4; \
- out.s4 = comm_fact.s4 - comm_fact.s3; \
- out.s5 = comm_fact.s5 + comm_fact.s6; \
- out.s6 = comm_fact.s5 - comm_fact.s6; \
- out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
- })
-
-/** This OpenCL kernel computes the input transform when the kernel size is 5x5 and the output tile is 4x4
- *
- * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
- * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
- *
- * @param[in] src_ptr Pointer to the source image. Supported data types: F32
- * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
- TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
-{
- int x = get_global_id(0);
- int y = get_global_id(1);
- int z = get_global_id(2);
-
- // Compute input address
- __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
-
- src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
-
- // Load 8x8 input tile
- const float8 in_row0 = vload8(0, (__global float *)(src_addr + 0 * src_stride_y));
- const float8 in_row1 = vload8(0, (__global float *)(src_addr + 1 * src_stride_y));
- const float8 in_row2 = vload8(0, (__global float *)(src_addr + 2 * src_stride_y));
- const float8 in_row3 = vload8(0, (__global float *)(src_addr + 3 * src_stride_y));
- const float8 in_row4 = vload8(0, (__global float *)(src_addr + 4 * src_stride_y));
- const float8 in_row5 = vload8(0, (__global float *)(src_addr + 5 * src_stride_y));
- const float8 in_row6 = vload8(0, (__global float *)(src_addr + 6 * src_stride_y));
- const float8 in_row7 = vload8(0, (__global float *)(src_addr + 7 * src_stride_y));
-
- // Calculate common factors for intermediate tensor
- float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
- float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
- float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
-
- // Calculate intermediate tensor and reuse common factor vectors
- const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
- const float8 tmp1 = comm_fact0 + comm_fact1;
- const float8 tmp2 = comm_fact0 - comm_fact1;
-
- comm_fact0 = 2.5f * in_row3;
- comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
-
- const float8 tmp3 = comm_fact1 + comm_fact2;
- const float8 tmp4 = comm_fact2 - comm_fact1;
-
- comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
- comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
-
- const float8 tmp5 = comm_fact1 + comm_fact2;
- const float8 tmp6 = comm_fact2 - comm_fact1;
- const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
-
- // Calculate output rows (reuse comm_fact0 vector)
- float8 out0, out1, out2, out3, out4, out5, out6, out7;
-
- OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
-
- // Store values across the 64 channels
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
-
- *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
- *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
- *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
- *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
- *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
- *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
- *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
- *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
- *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
- *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
- *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
- *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
- *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
- *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
- *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
- *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
- *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
- *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
- *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
- *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
- *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
- *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
- *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
- *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
- *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
- *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
- *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
- *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
- *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
- *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
- *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
- *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
- *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
- *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
- *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
- *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
- *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
- *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
- *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
- *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
- *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
- *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
- *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
- *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
- *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
- *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
- *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
- *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
- *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
- *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
- *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
- *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
- *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
- *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
- *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
- *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
- *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
- *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
- *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
- *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
- *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
- *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
- *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
- *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
-}
-#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP)
-
-#if defined(NUM_TILES_X)
-/** This OpenCL kernel performs Winograd output transform when the output tile is 2x2, the filter size 3x3 and the data format is NCHW
- *
- * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
- *
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
- * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_output_transform_2x2_3x3_nchw(
- TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst)
-#if defined(HAS_BIAS)
- ,
- VECTOR_DECLARATION(bias)
-#endif // defined(HAS_BIAS)
-)
-{
- // Each thread stores a 2x2 tile
- Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
-
- const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
-
- // Load the values across the 16 channels to compose the 4x4 tile
- float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
- float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
- float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
- float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
-
- float d10 = *((__global float *)(src_addr + 4 * src_stride_z));
- float d11 = *((__global float *)(src_addr + 5 * src_stride_z));
- float d12 = *((__global float *)(src_addr + 6 * src_stride_z));
- float d13 = *((__global float *)(src_addr + 7 * src_stride_z));
-
- float d20 = *((__global float *)(src_addr + 8 * src_stride_z));
- float d21 = *((__global float *)(src_addr + 9 * src_stride_z));
- float d22 = *((__global float *)(src_addr + 10 * src_stride_z));
- float d23 = *((__global float *)(src_addr + 11 * src_stride_z));
-
- float d30 = *((__global float *)(src_addr + 12 * src_stride_z));
- float d31 = *((__global float *)(src_addr + 13 * src_stride_z));
- float d32 = *((__global float *)(src_addr + 14 * src_stride_z));
- float d33 = *((__global float *)(src_addr + 15 * src_stride_z));
-
- // Compute the 2x2 output tile
- float k0 = d01 + d11 + d21;
- float k1 = d02 + d12 + d22;
- float k2 = d11 - d21 - d31;
- float k3 = d12 - d22 - d32;
-
- // out00 = d00 + d10 + d20 + d01 + d11 + d21 + d02 + d12 + d22
- // out01 = d01 + d11 + d21 - (d02 + d12 + d22) - (d03 + d13 + d23)
- // out10 = d10 - d20 - d30 + (d11 - d21 - d31) + (d12 - d22 - d32)
- // out11 = d11 - d21 - d31 - (d12 - d22 - d32) - (d13 - d23 - d33)
-
- float out00 = d10;
- float out01 = -d13;
- float out10 = d10;
- float out11 = -d13;
-
- out00 += d00 + d20 + k0 + k1;
- out01 += k0 - k1 - (d03 + d23);
- out10 += -d20 - d30 + k2 + k3;
- out11 += k2 - k3 + d23 + d33;
-
- int y_in = get_global_id(1);
- int x_out = (y_in % NUM_TILES_X) * 2;
- int y_out = (y_in / NUM_TILES_X) * 2;
- int z_out = get_global_id(0);
-
-#if defined(HAS_BIAS)
- // Add bias
- Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
-
- float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
-
- out00 += (float)b;
- out01 += (float)b;
- out10 += (float)b;
- out11 += (float)b;
-#endif // defined(HAS_BIAS)
-
- // Get output address
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
-
- // Store the 2x2 output tile
- vstore2((float2)(out00, out01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
- vstore2((float2)(out10, out11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
-}
-
-/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data format is NCHW
- *
- * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
- *
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
- * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_output_transform_4x4_3x3_nchw(
- TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst)
-#if defined(HAS_BIAS)
- ,
- VECTOR_DECLARATION(bias)
-#endif // defined(HAS_BIAS)
-)
-{
- // Each thread stores a 4x4 tile
- Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
-
- const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
-
- // Load the values across the 36 channels to compose the 6x6 tile
- float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
- float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
- float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
- float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
- float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
- float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
-
- float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
- float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
- float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
- float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
- float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
- float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
-
- float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
- float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
- float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
- float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
- float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
- float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
-
- float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
- float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
- float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
- float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
- float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
- float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
-
- float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
- float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
- float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
- float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
- float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
- float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
-
- float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
- float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
- float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
- float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
- float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
- float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
-
- // Compute out00, out01, out02 and out03
- float out00 = d01 + d21 + d41 + d11 + d31;
- float out01 = d01 + d21 + d41 + d11 + d31;
- float out02 = d01 + d21 + d41 + d11 + d31;
- float out03 = d01 + d21 + d41 + d11 + d31;
-
- float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
- float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
-
- out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
- out01 += k1 - d02 - d12 - d22 - d32 - d42;
- out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
- out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
-
- // Compute out10, out11, out12 and out13
- float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
- float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
- float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
- float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
-
- k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
- k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
-
- out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
- out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
- out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
- out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
-
- // Compute out20, out21, out22 and out23
- float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
- float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
- float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
- float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
-
- k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
- k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
-
- out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
- out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
- out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
- out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
-
- // Compute out30, out31, out32 and out33
- float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
- float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
- float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
- float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
-
- k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
- k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
-
- out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
- out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
- out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
- out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
-
- int y_in = get_global_id(1);
- int x_out = (y_in % NUM_TILES_X) * 4;
- int y_out = (y_in / NUM_TILES_X) * 4;
- int z_out = get_global_id(0);
-
-#if defined(HAS_BIAS)
- // Add bias
- Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
-
- float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
-
- out00 += (float)b;
- out01 += (float)b;
- out02 += (float)b;
- out03 += (float)b;
-
- out10 += (float)b;
- out11 += (float)b;
- out12 += (float)b;
- out13 += (float)b;
-
- out20 += (float)b;
- out21 += (float)b;
- out22 += (float)b;
- out23 += (float)b;
-
- out30 += (float)b;
- out31 += (float)b;
- out32 += (float)b;
- out33 += (float)b;
-
-#endif // defined(HAS_BIAS)
-
- // Get output address
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
-
- // Store the 4x4 output tile
- vstore4((float4)(out00, out01, out02, out03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
- vstore4((float4)(out10, out11, out12, out13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
- vstore4((float4)(out20, out21, out22, out23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
- vstore4((float4)(out30, out31, out32, out33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
-}
-
-#define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \
- ({ \
- comm_fact.s0 = d1 + d2; \
- comm_fact.s1 = d3 + d4; \
- comm_fact.s2 = d5 + d6; \
- \
- col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \
- col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \
- \
- comm_fact.s0 = d1 - d2; \
- comm_fact.s1 = d3 - d4; \
- comm_fact.s2 = d5 - d6; \
- \
- col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \
- col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \
- })
-
-/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 5x5 and the data format is NCHW
- *
- * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
- *
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
- * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
- * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
- * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
- * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
- * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- */
-__kernel void winograd_output_transform_4x4_5x5_nchw(
- TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst)
-#if defined(HAS_BIAS)
- ,
- VECTOR_DECLARATION(bias)
-#endif // defined(HAS_BIAS)
-)
-{
- // Each thread stores a 4x4 tile
- Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
-
- const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
-
- // Load the values across the 64 channels to compose the 8x8 input tile
- float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
- float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
- float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
- float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
- float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
- float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
- float d06 = *((__global float *)(src_addr + 6 * src_stride_z));
- float d07 = *((__global float *)(src_addr + 7 * src_stride_z));
-
- float d10 = *((__global float *)(src_addr + 8 * src_stride_z));
- float d11 = *((__global float *)(src_addr + 9 * src_stride_z));
- float d12 = *((__global float *)(src_addr + 10 * src_stride_z));
- float d13 = *((__global float *)(src_addr + 11 * src_stride_z));
- float d14 = *((__global float *)(src_addr + 12 * src_stride_z));
- float d15 = *((__global float *)(src_addr + 13 * src_stride_z));
- float d16 = *((__global float *)(src_addr + 14 * src_stride_z));
- float d17 = *((__global float *)(src_addr + 15 * src_stride_z));
-
- float d20 = *((__global float *)(src_addr + 16 * src_stride_z));
- float d21 = *((__global float *)(src_addr + 17 * src_stride_z));
- float d22 = *((__global float *)(src_addr + 18 * src_stride_z));
- float d23 = *((__global float *)(src_addr + 19 * src_stride_z));
- float d24 = *((__global float *)(src_addr + 20 * src_stride_z));
- float d25 = *((__global float *)(src_addr + 21 * src_stride_z));
- float d26 = *((__global float *)(src_addr + 22 * src_stride_z));
- float d27 = *((__global float *)(src_addr + 23 * src_stride_z));
-
- float d30 = *((__global float *)(src_addr + 24 * src_stride_z));
- float d31 = *((__global float *)(src_addr + 25 * src_stride_z));
- float d32 = *((__global float *)(src_addr + 26 * src_stride_z));
- float d33 = *((__global float *)(src_addr + 27 * src_stride_z));
- float d34 = *((__global float *)(src_addr + 28 * src_stride_z));
- float d35 = *((__global float *)(src_addr + 29 * src_stride_z));
- float d36 = *((__global float *)(src_addr + 30 * src_stride_z));
- float d37 = *((__global float *)(src_addr + 31 * src_stride_z));
-
- float d40 = *((__global float *)(src_addr + 32 * src_stride_z));
- float d41 = *((__global float *)(src_addr + 33 * src_stride_z));
- float d42 = *((__global float *)(src_addr + 34 * src_stride_z));
- float d43 = *((__global float *)(src_addr + 35 * src_stride_z));
- float d44 = *((__global float *)(src_addr + 36 * src_stride_z));
- float d45 = *((__global float *)(src_addr + 37 * src_stride_z));
- float d46 = *((__global float *)(src_addr + 38 * src_stride_z));
- float d47 = *((__global float *)(src_addr + 39 * src_stride_z));
-
- float d50 = *((__global float *)(src_addr + 40 * src_stride_z));
- float d51 = *((__global float *)(src_addr + 41 * src_stride_z));
- float d52 = *((__global float *)(src_addr + 42 * src_stride_z));
- float d53 = *((__global float *)(src_addr + 43 * src_stride_z));
- float d54 = *((__global float *)(src_addr + 44 * src_stride_z));
- float d55 = *((__global float *)(src_addr + 45 * src_stride_z));
- float d56 = *((__global float *)(src_addr + 46 * src_stride_z));
- float d57 = *((__global float *)(src_addr + 47 * src_stride_z));
-
- float d60 = *((__global float *)(src_addr + 48 * src_stride_z));
- float d61 = *((__global float *)(src_addr + 49 * src_stride_z));
- float d62 = *((__global float *)(src_addr + 50 * src_stride_z));
- float d63 = *((__global float *)(src_addr + 51 * src_stride_z));
- float d64 = *((__global float *)(src_addr + 52 * src_stride_z));
- float d65 = *((__global float *)(src_addr + 53 * src_stride_z));
- float d66 = *((__global float *)(src_addr + 54 * src_stride_z));
- float d67 = *((__global float *)(src_addr + 55 * src_stride_z));
-
- float d70 = *((__global float *)(src_addr + 56 * src_stride_z));
- float d71 = *((__global float *)(src_addr + 57 * src_stride_z));
- float d72 = *((__global float *)(src_addr + 58 * src_stride_z));
- float d73 = *((__global float *)(src_addr + 59 * src_stride_z));
- float d74 = *((__global float *)(src_addr + 60 * src_stride_z));
- float d75 = *((__global float *)(src_addr + 61 * src_stride_z));
- float d76 = *((__global float *)(src_addr + 62 * src_stride_z));
- float d77 = *((__global float *)(src_addr + 63 * src_stride_z));
-
- // Compute the 8x4 intermediate tensor
- float4 comm_fact0, comm_fact1, comm_fact2;
- float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
-
- COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
- COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
- COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
- COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
- COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
- COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
- COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
- COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
-
- // Compute the 4x4 output tile
- comm_fact0 = tmp_col1 + tmp_col2;
- comm_fact1 = tmp_col3 + tmp_col4;
- comm_fact2 = tmp_col5 + tmp_col6;
-
- float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
- float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
-
- comm_fact0 = tmp_col1 - tmp_col2;
- comm_fact1 = tmp_col3 - tmp_col4;
- comm_fact2 = tmp_col5 - tmp_col6;
-
- float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
- float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
-
- int y_in = get_global_id(1);
- int x_out = (y_in % NUM_TILES_X) * 4;
- int y_out = (y_in / NUM_TILES_X) * 4;
- int z_out = get_global_id(0);
-
-#if defined(HAS_BIAS)
- // Add bias
- Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
-
- float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
-
- out_col0 += (float4)b;
- out_col1 += (float4)b;
- out_col2 += (float4)b;
- out_col3 += (float4)b;
-#endif // defined(HAS_BIAS)
-
- // Get output address
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
-
- // Store the 4x4 output tile
- *(__global float *)(dst_addr + 0 * dst_stride_x + 0 * dst_stride_y) = out_col0.s0;
- *(__global float *)(dst_addr + 1 * dst_stride_x + 0 * dst_stride_y) = out_col1.s0;
- *(__global float *)(dst_addr + 2 * dst_stride_x + 0 * dst_stride_y) = out_col2.s0;
- *(__global float *)(dst_addr + 3 * dst_stride_x + 0 * dst_stride_y) = out_col3.s0;
- *(__global float *)(dst_addr + 0 * dst_stride_x + 1 * dst_stride_y) = out_col0.s1;
- *(__global float *)(dst_addr + 1 * dst_stride_x + 1 * dst_stride_y) = out_col1.s1;
- *(__global float *)(dst_addr + 2 * dst_stride_x + 1 * dst_stride_y) = out_col2.s1;
- *(__global float *)(dst_addr + 3 * dst_stride_x + 1 * dst_stride_y) = out_col3.s1;
- *(__global float *)(dst_addr + 0 * dst_stride_x + 2 * dst_stride_y) = out_col0.s2;
- *(__global float *)(dst_addr + 1 * dst_stride_x + 2 * dst_stride_y) = out_col1.s2;
- *(__global float *)(dst_addr + 2 * dst_stride_x + 2 * dst_stride_y) = out_col2.s2;
- *(__global float *)(dst_addr + 3 * dst_stride_x + 2 * dst_stride_y) = out_col3.s2;
- *(__global float *)(dst_addr + 0 * dst_stride_x + 3 * dst_stride_y) = out_col0.s3;
- *(__global float *)(dst_addr + 1 * dst_stride_x + 3 * dst_stride_y) = out_col1.s3;
- *(__global float *)(dst_addr + 2 * dst_stride_x + 3 * dst_stride_y) = out_col2.s3;
- *(__global float *)(dst_addr + 3 * dst_stride_x + 3 * dst_stride_y) = out_col3.s3;
-}
-#endif // defined(NUM_TILES_X)
diff --git a/src/core/CL/cl_kernels/winograd_filter_transform.cl b/src/core/CL/cl_kernels/winograd_filter_transform.cl
new file mode 100644
index 0000000..73da005
--- /dev/null
+++ b/src/core/CL/cl_kernels/winograd_filter_transform.cl
@@ -0,0 +1,1484 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+
+#if defined(SRC_DIM_Z)
+
+/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 2x2/2x1/1x2
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_2x2_3x3_nchw(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
+
+ const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
+
+ // Load the values from the input tensor
+#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+ float3 w0 = vload3(0, (__global float *)(src_addr));
+#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
+ *((__global float *)(src_addr + 1 * src_stride_y)),
+ *((__global float *)(src_addr + 2 * src_stride_y)));
+#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
+ float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
+ float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+
+ // Row 0
+ float4 out0 = 0.0f;
+ out0.s0 = (w0.s0);
+ out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
+ out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
+ out0.s3 = (w0.s2);
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ // Row 1
+ float4 out1 = 0.0f;
+ out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
+ out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
+ out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
+ out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
+
+ // Row 2
+ float4 out2 = 0.0f;
+ out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
+ out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
+ out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
+ out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
+
+ // Row 3
+ float4 out3 = 0.0f;
+ out3.s0 = (w2.s0);
+ out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
+ out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
+ out3.s3 = (w2.s2);
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+
+ int z = get_global_id(2);
+ int x0 = z / SRC_DIM_Z; // idx filter
+ int y0 = z % SRC_DIM_Z; // idx channel
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
+
+ // Store the values across the channels
+ // 16 channels for 3x3 kernels
+ // 4 channels for 3x1 or 1x3 kernels
+ *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
+ *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
+ *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
+ *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ *(__global float *)(dst_addr + 4 * dst_stride_z) = out1.s0;
+ *(__global float *)(dst_addr + 5 * dst_stride_z) = out1.s1;
+ *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s2;
+ *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s3;
+ *(__global float *)(dst_addr + 8 * dst_stride_z) = out2.s0;
+ *(__global float *)(dst_addr + 9 * dst_stride_z) = out2.s1;
+ *(__global float *)(dst_addr + 10 * dst_stride_z) = out2.s2;
+ *(__global float *)(dst_addr + 11 * dst_stride_z) = out2.s3;
+ *(__global float *)(dst_addr + 12 * dst_stride_z) = out3.s0;
+ *(__global float *)(dst_addr + 13 * dst_stride_z) = out3.s1;
+ *(__global float *)(dst_addr + 14 * dst_stride_z) = out3.s2;
+ *(__global float *)(dst_addr + 15 * dst_stride_z) = out3.s3;
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 4x4/4x1/1x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x4_3x3_nchw(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
+
+ const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
+
+ // Load the values from the input tensor
+#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+ float3 w0 = vload3(0, (__global float *)(src_addr));
+#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
+ *((__global float *)(src_addr + 1 * src_stride_y)),
+ *((__global float *)(src_addr + 2 * src_stride_y)));
+#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
+ float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
+ float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+
+ // Row 0
+ float8 out0 = 0.0f;
+ out0.s0 = (w0.s0) / 16.f;
+ out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
+ out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
+ out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
+ out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
+ out0.s5 = (w0.s2) / 4.f;
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ // Row 1
+ float8 out1 = 0.0f;
+ out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
+ out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
+ out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
+ out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
+ out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
+ out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
+
+ // Row 2
+ float8 out2 = 0.0f;
+ out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
+ out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
+ out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
+ out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
+ out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
+ out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
+
+ // Row 3
+ float8 out3 = 0.0f;
+ out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
+ out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
+ out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
+ out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
+ out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
+ out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
+
+ // Row 4
+ float8 out4 = 0.0f;
+ out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
+ out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
+ out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
+ out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
+ out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
+ out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
+
+ // Row 5
+ float8 out5 = 0.0f;
+ out5.s0 = (w2.s0) / 4.f;
+ out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
+ out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
+ out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
+ out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
+ out5.s5 = (w2.s2);
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+
+ int z = get_global_id(2);
+ int x0 = z / SRC_DIM_Z; // idx filter
+ int y0 = z % SRC_DIM_Z; // idx channel
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
+
+ // Store the values across the channels
+ // 36 channels for 3x3 kernels
+ // 6 channels for 3x1 or 1x3 kernels
+ *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
+ *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
+ *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
+ *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
+ *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
+ *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s0;
+ *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s1;
+ *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s2;
+ *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s3;
+ *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s4;
+ *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s5;
+ *(__global float *)(dst_addr + 12 * dst_stride_z) = out2.s0;
+ *(__global float *)(dst_addr + 13 * dst_stride_z) = out2.s1;
+ *(__global float *)(dst_addr + 14 * dst_stride_z) = out2.s2;
+ *(__global float *)(dst_addr + 15 * dst_stride_z) = out2.s3;
+ *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s4;
+ *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s5;
+ *(__global float *)(dst_addr + 18 * dst_stride_z) = out3.s0;
+ *(__global float *)(dst_addr + 19 * dst_stride_z) = out3.s1;
+ *(__global float *)(dst_addr + 20 * dst_stride_z) = out3.s2;
+ *(__global float *)(dst_addr + 21 * dst_stride_z) = out3.s3;
+ *(__global float *)(dst_addr + 22 * dst_stride_z) = out3.s4;
+ *(__global float *)(dst_addr + 23 * dst_stride_z) = out3.s5;
+ *(__global float *)(dst_addr + 24 * dst_stride_z) = out4.s0;
+ *(__global float *)(dst_addr + 25 * dst_stride_z) = out4.s1;
+ *(__global float *)(dst_addr + 26 * dst_stride_z) = out4.s2;
+ *(__global float *)(dst_addr + 27 * dst_stride_z) = out4.s3;
+ *(__global float *)(dst_addr + 28 * dst_stride_z) = out4.s4;
+ *(__global float *)(dst_addr + 29 * dst_stride_z) = out4.s5;
+ *(__global float *)(dst_addr + 30 * dst_stride_z) = out5.s0;
+ *(__global float *)(dst_addr + 31 * dst_stride_z) = out5.s1;
+ *(__global float *)(dst_addr + 32 * dst_stride_z) = out5.s2;
+ *(__global float *)(dst_addr + 33 * dst_stride_z) = out5.s3;
+ *(__global float *)(dst_addr + 34 * dst_stride_z) = out5.s4;
+ *(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NHWC and the output tile is 4x4/4x1/1x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x4_3x3_nhwc(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
+
+ const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
+
+ // Load the values from the input tensor
+#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ float w00 = *((__global float *)(src_addr + 0 * src_stride_z));
+ float w01 = *((__global float *)(src_addr + 1 * src_stride_z));
+ float w02 = *((__global float *)(src_addr + 2 * src_stride_z));
+#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
+ float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
+ float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+ float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
+ float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
+ float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
+ float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
+ float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
+ float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+
+ // Row 0
+ float out00, out01, out02, out03, out04, out05;
+ out00 = (w00) / 16.f;
+ out01 = (-w00 - w01 - w02) / 24.f;
+ out02 = (-w00 + w01 - w02) / 24.f;
+ out03 = (w00 + 2.f * w01 + 4.f * w02) / 96.f;
+ out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f;
+ out05 = (w02) / 4.f;
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ // Row 1
+ float out10, out11, out12, out13, out14, out15;
+ out10 = (-w00 - w10 - w20) / 24.f;
+ out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f;
+ out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f;
+ out13 = (-w00 - w10 - w20 + 2.f * (-w01 - w11 - w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
+ out14 = (-w00 - w10 - w20 + 2.f * (w01 + w11 + w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
+ out15 = (-w02 - w12 - w22) / 6.f;
+
+ // Row 2
+ float out20, out21, out22, out23, out24, out25;
+ out20 = (-w00 + w10 - w20) / 24.f;
+ out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f;
+ out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f;
+ out23 = (-w00 + w10 - w20 + 2.f * (-w01 + w11 - w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
+ out24 = (-w00 + w10 - w20 + 2.f * (w01 - w11 + w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
+ out25 = (-w02 + w12 - w22) / 6.f;
+
+ // Row 3
+ float out30, out31, out32, out33, out34, out35;
+ out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f;
+ out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
+ out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
+ out33 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (w01 + 2.f * w11 + 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
+ out34 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (-w01 - 2.f * w11 - 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
+ out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f;
+
+ // Row 4
+ float out40, out41, out42, out43, out44, out45;
+ out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f;
+ out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
+ out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
+ out43 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (w01 - 2.f * w11 + 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
+ out44 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (-w01 + 2.f * w11 - 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
+ out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f;
+
+ // Row 5
+ float out50, out51, out52, out53, out54, out55;
+ out50 = (w20) / 4.f;
+ out51 = (-w20 - w21 - w22) / 6.f;
+ out52 = (-w20 + w21 - w22) / 6.f;
+ out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f;
+ out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f;
+ out55 = (w22);
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+
+ int x0 = get_global_id(2); // idx filter
+ int y0 = get_global_id(0); // idx channel
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
+
+ // Store the values across the channels
+ // 36 channels for 3x3 kernels
+ // 6 channels for 3x1 or 1x3 kernels
+ *(__global float *)(dst_addr + 0 * dst_stride_z) = out00;
+ *(__global float *)(dst_addr + 1 * dst_stride_z) = out01;
+ *(__global float *)(dst_addr + 2 * dst_stride_z) = out02;
+ *(__global float *)(dst_addr + 3 * dst_stride_z) = out03;
+ *(__global float *)(dst_addr + 4 * dst_stride_z) = out04;
+ *(__global float *)(dst_addr + 5 * dst_stride_z) = out05;
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ *(__global float *)(dst_addr + 6 * dst_stride_z) = out10;
+ *(__global float *)(dst_addr + 7 * dst_stride_z) = out11;
+ *(__global float *)(dst_addr + 8 * dst_stride_z) = out12;
+ *(__global float *)(dst_addr + 9 * dst_stride_z) = out13;
+ *(__global float *)(dst_addr + 10 * dst_stride_z) = out14;
+ *(__global float *)(dst_addr + 11 * dst_stride_z) = out15;
+ *(__global float *)(dst_addr + 12 * dst_stride_z) = out20;
+ *(__global float *)(dst_addr + 13 * dst_stride_z) = out21;
+ *(__global float *)(dst_addr + 14 * dst_stride_z) = out22;
+ *(__global float *)(dst_addr + 15 * dst_stride_z) = out23;
+ *(__global float *)(dst_addr + 16 * dst_stride_z) = out24;
+ *(__global float *)(dst_addr + 17 * dst_stride_z) = out25;
+ *(__global float *)(dst_addr + 18 * dst_stride_z) = out30;
+ *(__global float *)(dst_addr + 19 * dst_stride_z) = out31;
+ *(__global float *)(dst_addr + 20 * dst_stride_z) = out32;
+ *(__global float *)(dst_addr + 21 * dst_stride_z) = out33;
+ *(__global float *)(dst_addr + 22 * dst_stride_z) = out34;
+ *(__global float *)(dst_addr + 23 * dst_stride_z) = out35;
+ *(__global float *)(dst_addr + 24 * dst_stride_z) = out40;
+ *(__global float *)(dst_addr + 25 * dst_stride_z) = out41;
+ *(__global float *)(dst_addr + 26 * dst_stride_z) = out42;
+ *(__global float *)(dst_addr + 27 * dst_stride_z) = out43;
+ *(__global float *)(dst_addr + 28 * dst_stride_z) = out44;
+ *(__global float *)(dst_addr + 29 * dst_stride_z) = out45;
+ *(__global float *)(dst_addr + 30 * dst_stride_z) = out50;
+ *(__global float *)(dst_addr + 31 * dst_stride_z) = out51;
+ *(__global float *)(dst_addr + 32 * dst_stride_z) = out52;
+ *(__global float *)(dst_addr + 33 * dst_stride_z) = out53;
+ *(__global float *)(dst_addr + 34 * dst_stride_z) = out54;
+ *(__global float *)(dst_addr + 35 * dst_stride_z) = out55;
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NCHW and the output tile is 4x4/4x1 or 1x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ *
+ * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x4_5x5_nchw(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
+
+ const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
+
+ // Load the values from the input tensor
+#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+ float4 w00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
+ float w01 = *((__global float *)(src_addr + 0 * src_stride_y) + 4);
+#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ float4 w00 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
+ *((__global float *)(src_addr + 1 * src_stride_y)),
+ *((__global float *)(src_addr + 2 * src_stride_y)),
+ *((__global float *)(src_addr + 3 * src_stride_y)));
+ float w01 = *((__global float *)(src_addr + 4 * src_stride_y));
+#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ float4 w00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
+ float w01 = *((__global float *)(src_addr + 0 * src_stride_y) + 4);
+ float4 w10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
+ float w11 = *((__global float *)(src_addr + 1 * src_stride_y) + 4);
+ float4 w20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
+ float w21 = *((__global float *)(src_addr + 2 * src_stride_y) + 4);
+ float4 w30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
+ float w31 = *((__global float *)(src_addr + 3 * src_stride_y) + 4);
+ float4 w40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
+ float w41 = *((__global float *)(src_addr + 4 * src_stride_y) + 4);
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+
+ // Transform the input tile
+
+ // Row 0
+ float8 out0 = 0.0f;
+ out0.s0 = w00.s0;
+ out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
+ out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
+ out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
+ out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
+ out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
+ out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
+ out0.s7 = w01;
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ // Row 1
+ float8 out1 = 0.0f;
+ out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
+ out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
+ (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
+ out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
+ (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
+ out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
+ (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
+ out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
+ (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
+ out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
+ (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
+ out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
+ (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
+ out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
+
+ // Row 2
+ float8 out2 = 0.0f;
+ out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
+ out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
+ (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
+ out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
+ (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
+ out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
+ (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
+ out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
+ (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
+ out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
+ (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
+ out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
+ (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
+ out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
+
+ // Row 3
+ float8 out3 = 0.0f;
+ out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
+ out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
+ (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
+ (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
+ out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
+ (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
+ (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
+ out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
+ (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
+ (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
+ out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
+ (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
+ (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
+ out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
+ (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
+ (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
+ out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
+ (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
+ (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
+ out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
+
+ // Row 4
+ float8 out4 = 0.0f;
+ out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
+ out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
+ (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
+ (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
+ out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
+ (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
+ (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
+ out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
+ (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
+ (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
+ out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
+ (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
+ (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
+ out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
+ (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
+ (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
+ out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
+ (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
+ (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
+ out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
+
+ // Row 5
+ float8 out5 = 0.0f;
+ out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
+ out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
+ (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
+ (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
+ out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
+ (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
+ (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
+ out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
+ (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
+ (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
+ out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
+ (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
+ (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
+ out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
+ (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
+ (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
+ out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
+ (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
+ (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
+ out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
+
+ // Row 6
+ float8 out6 = 0.0f;
+ out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
+ out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
+ (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
+ (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
+ out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
+ (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
+ (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
+ out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
+ (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
+ (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
+ out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
+ (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
+ (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
+ out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
+ (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
+ (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
+ out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
+ (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
+ (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
+ out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
+
+ // Row 7
+ float8 out7 = 0.0f;
+ out7.s0 = w40.s0;
+ out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
+ out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
+ out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
+ out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
+ out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
+ out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
+ out7.s7 = w41;
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+
+ int z = get_global_id(2);
+ int x0 = z / SRC_DIM_Z; // idx filter
+ int y0 = z % SRC_DIM_Z; // idx channel
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
+
+ // Store the values across the channels
+ *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
+ *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
+ *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
+ *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
+ *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
+ *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
+ *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
+ *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
+ *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
+ *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
+ *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
+ *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
+ *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
+ *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
+ *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
+ *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
+ *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
+ *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
+ *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
+ *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
+ *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
+ *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
+ *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
+ *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
+ *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
+ *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
+ *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
+ *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
+ *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
+ *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
+ *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
+ *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
+ *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
+ *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
+ *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
+ *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
+ *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
+ *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
+ *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
+ *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
+ *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
+ *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
+ *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
+ *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
+ *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
+ *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
+ *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
+ *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
+ *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
+ *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
+ *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
+ *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
+ *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
+ *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
+ *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
+ *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
+ *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
+ *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
+ *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
+ *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
+ *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
+ *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
+ *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NHWC and the output tile is 4x4/4x1 or 1x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x4_5x5_nhwc(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
+
+ const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(float) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
+
+#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ // Load the values from the input tensor
+ float w00 = *((__global float *)(src_addr + 0 * src_stride_z));
+ float w01 = *((__global float *)(src_addr + 1 * src_stride_z));
+ float w02 = *((__global float *)(src_addr + 2 * src_stride_z));
+ float w03 = *((__global float *)(src_addr + 3 * src_stride_z));
+ float w04 = *((__global float *)(src_addr + 4 * src_stride_z));
+#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ // Load the values from the input tensor
+ float w00 = *((__global float *)(src_addr + 0 * src_stride_y));
+ float w01 = *((__global float *)(src_addr + 1 * src_stride_y));
+ float w02 = *((__global float *)(src_addr + 2 * src_stride_y));
+ float w03 = *((__global float *)(src_addr + 3 * src_stride_y));
+ float w04 = *((__global float *)(src_addr + 4 * src_stride_y));
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
+ float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
+ float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
+ float w13 = *((__global float *)(src_addr + 1 * src_stride_z + 3 * src_stride_y));
+ float w14 = *((__global float *)(src_addr + 1 * src_stride_z + 4 * src_stride_y));
+ float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
+ float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
+ float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
+ float w23 = *((__global float *)(src_addr + 2 * src_stride_z + 3 * src_stride_y));
+ float w24 = *((__global float *)(src_addr + 2 * src_stride_z + 4 * src_stride_y));
+ float w30 = *((__global float *)(src_addr + 3 * src_stride_z + 0 * src_stride_y));
+ float w31 = *((__global float *)(src_addr + 3 * src_stride_z + 1 * src_stride_y));
+ float w32 = *((__global float *)(src_addr + 3 * src_stride_z + 2 * src_stride_y));
+ float w33 = *((__global float *)(src_addr + 3 * src_stride_z + 3 * src_stride_y));
+ float w34 = *((__global float *)(src_addr + 3 * src_stride_z + 4 * src_stride_y));
+ float w40 = *((__global float *)(src_addr + 4 * src_stride_z + 0 * src_stride_y));
+ float w41 = *((__global float *)(src_addr + 4 * src_stride_z + 1 * src_stride_y));
+ float w42 = *((__global float *)(src_addr + 4 * src_stride_z + 2 * src_stride_y));
+ float w43 = *((__global float *)(src_addr + 4 * src_stride_z + 3 * src_stride_y));
+ float w44 = *((__global float *)(src_addr + 4 * src_stride_z + 4 * src_stride_y));
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+
+ // Row 0
+ float8 out0 = 0.0f;
+ out0.s0 = w00;
+ out0.s1 = -2.f * (w00 + w01 + w02 + w03 + w04) / 9.f;
+ out0.s2 = -2.f * (w00 - w01 + w02 - w03 + w04) / 9.f;
+ out0.s3 = (w00 + 2.f * w01 + 4.f * w02 + 8.f * w03 + 16.f * w04) / 90.f;
+ out0.s4 = (w00 - 2.f * w01 + 4.f * w02 - 8.f * w03 + 16.f * w04) / 90.f;
+ out0.s5 = (16.f * w00 + 8.f * w01 + 4.f * w02 + 2.f * w03 + w04) / 180.f;
+ out0.s6 = (16.f * w00 - 8.f * w01 + 4.f * w02 - 2.f * w03 + w04) / 180.f;
+ out0.s7 = w04;
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ // Row 1
+ float8 out1 = 0.0f;
+ out1.s0 = -2.f * (w00 + w10 + w20 + w30 + w40) / 9.f;
+ out1.s1 = 4.f * ((w00 + w10 + w20 + w30 + w40) + (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) + (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
+ out1.s2 = 4.f * ((w00 + w10 + w20 + w30 + w40) - (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) - (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
+ out1.s3 = -((w00 + w10 + w20 + w30 + w40) + 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
+ (w04 + w14 + w24 + w34 + w44)) / 405.f;
+ out1.s4 = -((w00 + w10 + w20 + w30 + w40) - 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
+ (w04 + w14 + w24 + w34 + w44)) / 405.f;
+ out1.s5 = -(16.f * (w00 + w10 + w20 + w30 + w40) + 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 2.f * (w03 + w13 + w23 + w33 + w43) +
+ (w04 + w14 + w24 + w34 + w44)) / 810.f;
+ out1.s6 = -(16.f * (w00 + w10 + w20 + w30 + w40) - 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 2.f * (w03 + w13 + w23 + w33 + w43) +
+ (w04 + w14 + w24 + w34 + w44)) / 810.f;
+ out1.s7 = -2.f * (w04 + w14 + w24 + w34 + w44) / 9.f;
+
+ // Row 2
+ float8 out2 = 0.0f;
+ out2.s0 = -2.f * (w00 - w10 + w20 - w30 + w40) / 9.f;
+ out2.s1 = 4.f * ((w00 - w10 + w20 - w30 + w40) + (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) + (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
+ out2.s2 = 4.f * ((w00 - w10 + w20 - w30 + w40) - (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) - (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
+ out2.s3 = -((w00 - w10 + w20 - w30 + w40) + 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
+ (w04 - w14 + w24 - w34 + w44)) / 405.f;
+ out2.s4 = -((w00 - w10 + w20 - w30 + w40) - 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
+ (w04 - w14 + w24 - w34 + w44)) / 405.f;
+ out2.s5 = -(16.f * (w00 - w10 + w20 - w30 + w40) + 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 2.f * (w03 - w13 + w23 - w33 + w43) +
+ (w04 - w14 + w24 - w34 + w44)) / 810.f;
+ out2.s6 = -(16.f * (w00 - w10 + w20 - w30 + w40) - 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 2.f * (w03 - w13 + w23 - w33 + w43) +
+ (w04 - w14 + w24 - w34 + w44)) / 810.f;
+ out2.s7 = -2.f * (w04 - w14 + w24 - w34 + w44) / 9.f;
+
+ // Row 3
+ float8 out3 = 0.0f;
+ out3.s0 = (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) / 90.f;
+ out3.s1 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) +
+ (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
+ out3.s2 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) -
+ (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
+ out3.s3 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 8.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
+ out3.s4 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 8.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
+ out3.s5 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
+ out3.s6 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
+ out3.s7 = (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44) / 90.f;
+
+ // Row 4
+ float8 out4 = 0.0f;
+ out4.s0 = (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) / 90.f;
+ out4.s1 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) +
+ (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
+ out4.s2 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) -
+ (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
+ out4.s3 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 8.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
+ out4.s4 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 8.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
+ out4.s5 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
+ out4.s6 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
+ out4.s7 = (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44) / 90.f;
+
+ // Row 5
+ float8 out5 = 0.0f;
+ out5.s0 = (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) / 180.f;
+ out5.s1 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) +
+ (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
+ out5.s2 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) -
+ (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
+ out5.s3 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
+ (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 8.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
+ out5.s4 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
+ (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 8.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
+ out5.s5 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
+ (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
+ out5.s6 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
+ (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
+ out5.s7 = (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44) / 180.f;
+
+ // Row 6
+ float8 out6 = 0.0f;
+ out6.s0 = (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) / 180.f;
+ out6.s1 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) +
+ (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
+ out6.s2 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) -
+ (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
+ out6.s3 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
+ (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 8.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
+ out6.s4 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
+ (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 8.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
+ out6.s5 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
+ (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
+ out6.s6 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
+ (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
+ out6.s7 = (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44) / 180.f;
+
+ // Row 7
+ float8 out7 = 0.0f;
+ out7.s0 = w40;
+ out7.s1 = -2.f * (w40 + w41 + w42 + w43 + w44) / 9.f;
+ out7.s2 = -2.f * (w40 - w41 + w42 - w43 + w44) / 9.f;
+ out7.s3 = (w40 + 2.f * w41 + 4.f * w42 + 8.f * w43 + 16.f * w44) / 90.f;
+ out7.s4 = (w40 - 2.f * w41 + 4.f * w42 - 8.f * w43 + 16.f * w44) / 90.f;
+ out7.s5 = (16.f * w40 + 8.f * w41 + 4.f * w42 + 2.f * w43 + w44) / 180.f;
+ out7.s6 = (16.f * w40 - 8.f * w41 + 4.f * w42 - 2.f * w43 + w44) / 180.f;
+ out7.s7 = w44;
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+
+ int x0 = get_global_id(2); // idx filter
+ int y0 = get_global_id(0); // idx channel
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
+
+ // Store the values across the channels
+ *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
+ *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
+ *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
+ *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
+ *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
+ *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
+ *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
+ *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
+
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+ *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
+ *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
+ *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
+ *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
+ *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
+ *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
+ *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
+ *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
+ *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
+ *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
+ *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
+ *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
+ *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
+ *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
+ *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
+ *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
+ *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
+ *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
+ *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
+ *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
+ *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
+ *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
+ *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
+ *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
+ *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
+ *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
+ *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
+ *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
+ *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
+ *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
+ *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
+ *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
+ *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
+ *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
+ *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
+ *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
+ *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
+ *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
+ *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
+ *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
+ *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
+ *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
+ *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
+ *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
+ *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
+ *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
+ *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
+ *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
+ *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
+ *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
+ *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
+ *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
+ *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
+ *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
+ *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
+ *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+}
+#endif // defined(SRC_DIM_Z)
+
+#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 2x1
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_2x1_3x1_nchw(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_2x2_3x3_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 4x1
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x1_3x1_nchw(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_4x4_3x3_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NCHW and the output tile is 4x1
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x1_5x1_nchw(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_4x4_5x5_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NHWC and the output tile is 4x1
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x1_3x1_nhwc(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NHWC and the output tile is 4x1
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x1_5x1_nhwc(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_4x4_5x5_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+
+#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x2
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_1x2_1x3_nchw(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_2x2_3x3_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_1x4_1x3_nchw(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_4x4_3x3_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NCHW and the output tile is 1x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_1x4_1x5_nchw(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_4x4_5x5_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NHWC and the output tile is 1x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_1x4_1x3_nhwc(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NHWC and the output tile is 1x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_1x4_1x5_nhwc(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_filter_transform_4x4_5x5_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_stride_w,
+ src_step_w,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
diff --git a/src/core/CL/cl_kernels/winograd_input_transform.cl b/src/core/CL/cl_kernels/winograd_input_transform.cl
new file mode 100644
index 0000000..da18e4a
--- /dev/null
+++ b/src/core/CL/cl_kernels/winograd_input_transform.cl
@@ -0,0 +1,2048 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+
+#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
+ ({ \
+ comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
+ comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
+ comm_fact.s2 = 2.5f * tmp.s3; \
+ comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
+ comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
+ comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
+ comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
+ \
+ out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
+ out.s1 = comm_fact.s0 + comm_fact.s1; \
+ out.s2 = comm_fact.s0 - comm_fact.s1; \
+ out.s3 = comm_fact.s3 + comm_fact.s4; \
+ out.s4 = comm_fact.s4 - comm_fact.s3; \
+ out.s5 = comm_fact.s5 + comm_fact.s6; \
+ out.s6 = comm_fact.s5 - comm_fact.s6; \
+ out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
+ })
+
+#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
+/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
+ * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2);
+
+ // Compute input address
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
+
+ src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
+
+#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+ float4 in_row0 = vload4(0, (__global float *)(src_addr));
+#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+ float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
+ *((__global float *)(src_addr + 1 * src_stride_y)),
+ *((__global float *)(src_addr + 2 * src_stride_y)),
+ *((__global float *)(src_addr + 3 * src_stride_y)));
+#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
+ float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
+ float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
+ float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ float4 tmp0 = in_row0;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ tmp0 -= in_row2;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ float out00 = tmp0.s0 - tmp0.s2;
+ float out01 = tmp0.s1 + tmp0.s2;
+ float out02 = tmp0.s2 - tmp0.s1;
+ float out03 = tmp0.s1 - tmp0.s3;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ float4 tmp1 = in_row1 + in_row2;
+ float4 tmp2 = in_row2 - in_row1;
+ float4 tmp3 = in_row1 - in_row3;
+
+ float out10 = tmp1.s0 - tmp1.s2;
+ float out11 = tmp1.s1 + tmp1.s2;
+ float out12 = tmp1.s2 - tmp1.s1;
+ float out13 = tmp1.s1 - tmp1.s3;
+
+ float out20 = tmp2.s0 - tmp2.s2;
+ float out21 = tmp2.s1 + tmp2.s2;
+ float out22 = tmp2.s2 - tmp2.s1;
+ float out23 = tmp2.s1 - tmp2.s3;
+
+ float out30 = tmp3.s0 - tmp3.s2;
+ float out31 = tmp3.s1 + tmp3.s2;
+ float out32 = tmp3.s2 - tmp3.s1;
+ float out33 = tmp3.s1 - tmp3.s3;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
+
+ *((__global float *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
+ *((__global float *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
+ *((__global float *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
+ *((__global float *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ *((__global float *)(dst_addr + 4 * dst_stride_z)) = out10;
+ *((__global float *)(dst_addr + 5 * dst_stride_z)) = out11;
+ *((__global float *)(dst_addr + 6 * dst_stride_z)) = out12;
+ *((__global float *)(dst_addr + 7 * dst_stride_z)) = out13;
+ *((__global float *)(dst_addr + 8 * dst_stride_z)) = out20;
+ *((__global float *)(dst_addr + 9 * dst_stride_z)) = out21;
+ *((__global float *)(dst_addr + 10 * dst_stride_z)) = out22;
+ *((__global float *)(dst_addr + 11 * dst_stride_z)) = out23;
+ *((__global float *)(dst_addr + 12 * dst_stride_z)) = out30;
+ *((__global float *)(dst_addr + 13 * dst_stride_z)) = out31;
+ *((__global float *)(dst_addr + 14 * dst_stride_z)) = out32;
+ *((__global float *)(dst_addr + 15 * dst_stride_z)) = out33;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
+ * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2) * 2;
+
+ // Compute input address
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
+
+ src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
+
+#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+ float4 in_row0 = vload4(0, (__global float *)(src_addr));
+#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+ float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
+ *((__global float *)(src_addr + 1 * src_stride_y)),
+ *((__global float *)(src_addr + 2 * src_stride_y)),
+ *((__global float *)(src_addr + 3 * src_stride_y)));
+#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
+ float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
+ float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
+ float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ src_addr += src_stride_z;
+#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+ float4 in_row4 = vload4(0, (__global float *)(src_addr));
+#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+ float4 in_row4 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
+ *((__global float *)(src_addr + 1 * src_stride_y)),
+ *((__global float *)(src_addr + 2 * src_stride_y)),
+ *((__global float *)(src_addr + 3 * src_stride_y)));
+#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ float4 in_row4 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
+ float4 in_row5 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
+ float4 in_row6 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
+ float4 in_row7 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ float4 tmp0 = in_row0;
+ float4 tmp4 = in_row4;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ tmp0 -= in_row2;
+ tmp4 -= in_row6;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ float2 out00 = (float2)(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
+ float2 out01 = (float2)(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
+ float2 out02 = (float2)(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
+ float2 out03 = (float2)(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ float4 tmp1 = in_row1 + in_row2;
+ float4 tmp2 = in_row2 - in_row1;
+ float4 tmp3 = in_row1 - in_row3;
+
+ float4 tmp5 = in_row5 + in_row6;
+ float4 tmp6 = in_row6 - in_row5;
+ float4 tmp7 = in_row5 - in_row7;
+
+ float2 out10 = (float2)(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
+ float2 out11 = (float2)(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
+ float2 out12 = (float2)(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
+ float2 out13 = (float2)(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
+
+ float2 out20 = (float2)(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
+ float2 out21 = (float2)(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
+ float2 out22 = (float2)(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
+ float2 out23 = (float2)(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
+
+ float2 out30 = (float2)(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
+ float2 out31 = (float2)(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
+ float2 out32 = (float2)(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
+ float2 out33 = (float2)(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
+
+ vstore2(out00, 0, (__global float *)(dst_addr + 0 * dst_stride_z));
+ vstore2(out01, 0, (__global float *)(dst_addr + 1 * dst_stride_z));
+ vstore2(out02, 0, (__global float *)(dst_addr + 2 * dst_stride_z));
+ vstore2(out03, 0, (__global float *)(dst_addr + 3 * dst_stride_z));
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ vstore2(out10, 0, (__global float *)(dst_addr + 4 * dst_stride_z));
+ vstore2(out11, 0, (__global float *)(dst_addr + 5 * dst_stride_z));
+ vstore2(out12, 0, (__global float *)(dst_addr + 6 * dst_stride_z));
+ vstore2(out13, 0, (__global float *)(dst_addr + 7 * dst_stride_z));
+ vstore2(out20, 0, (__global float *)(dst_addr + 8 * dst_stride_z));
+ vstore2(out21, 0, (__global float *)(dst_addr + 9 * dst_stride_z));
+ vstore2(out22, 0, (__global float *)(dst_addr + 10 * dst_stride_z));
+ vstore2(out23, 0, (__global float *)(dst_addr + 11 * dst_stride_z));
+ vstore2(out30, 0, (__global float *)(dst_addr + 12 * dst_stride_z));
+ vstore2(out31, 0, (__global float *)(dst_addr + 13 * dst_stride_z));
+ vstore2(out32, 0, (__global float *)(dst_addr + 14 * dst_stride_z));
+ vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z));
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel computes the input transform when the output tile is 4x4/4x1 or 1x4, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
+ * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2);
+
+ // Compute input address
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
+
+ src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
+
+#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ // Row0
+ float4 d00 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
+ *((__global float *)(src_addr + 1 * src_stride_y)),
+ *((__global float *)(src_addr + 2 * src_stride_y)),
+ *((__global float *)(src_addr + 3 * src_stride_y)));
+ float2 d01 = (float2)(*((__global float *)(src_addr + 4 * src_stride_y)),
+ *((__global float *)(src_addr + 5 * src_stride_y)));
+#else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ // Row0
+ float4 d00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
+ float2 d01 = vload2(2, (__global float *)(src_addr + 0 * src_stride_y));
+#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ float out0 = 0.0f;
+ float out1 = 0.0f;
+ float out2 = 0.0f;
+ float out3 = 0.0f;
+ float out4 = 0.0f;
+ float out5 = 0.0f;
+
+ // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
+ out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
+ out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
+ out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
+ out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
+ out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
+ out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ // Row4
+ float4 d40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
+ float2 d41 = vload2(2, (__global float *)(src_addr + 4 * src_stride_y));
+
+ // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
+ float k0 = d41.s0;
+ float k1 = d41.s0;
+ float k2 = d41.s0;
+ float k3 = d41.s0;
+ float k4 = d41.s0;
+ float k5 = 0.0f;
+
+ k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
+ k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
+ k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
+ k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
+ k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
+ k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
+
+ out0 += k0;
+ out1 += k1;
+ out2 += k2;
+ out3 += k3;
+ out4 += k4;
+ out5 += k5;
+
+ // Row2
+ float4 d20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
+ float2 d21 = vload2(2, (__global float *)(src_addr + 2 * src_stride_y));
+
+ out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
+ out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
+ out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
+ out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
+ out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
+ out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
+#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ // Compute destination address
+ __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
+
+ uint dst_plane_stride = dst_stride_z / sizeof(float);
+
+ *(dst_addr) = out0;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out1;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out2;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out3;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out4;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out5;
+ dst_addr += dst_plane_stride;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ float out6 = k0;
+ float out7 = k1;
+ float out8 = k2;
+ float out9 = k3;
+ float out10 = k4;
+ float out11 = k5;
+ float out12 = k0;
+ float out13 = k1;
+ float out14 = k2;
+ float out15 = k3;
+ float out16 = k4;
+ float out17 = k5;
+ float out18 = k0;
+ float out19 = k1;
+ float out20 = k2;
+ float out21 = k3;
+ float out22 = k4;
+ float out23 = k5;
+ float out24 = k0;
+ float out25 = k1;
+ float out26 = k2;
+ float out27 = k3;
+ float out28 = k4;
+ float out29 = k5;
+
+ // Row1
+ float4 d10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
+ float2 d11 = vload2(2, (__global float *)(src_addr + 1 * src_stride_y));
+
+ // Row3
+ float4 d30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
+ float2 d31 = vload2(2, (__global float *)(src_addr + 3 * src_stride_y));
+
+ // Compute common parts for the channels between [6, 29]
+ // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
+ // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
+ float part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
+ float part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
+ float part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
+ float part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
+ float part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
+ float part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
+ float part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
+ float part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
+ float part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
+ float part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
+ float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
+ float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
+
+ // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
+ // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
+ float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
+ float part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
+ float part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
+ float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
+ float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
+ float part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
+ float part18 = part6 * 0.25f; // d20.s2 - d21.s0
+ float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
+ float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
+ float part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
+ float part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
+ float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
+
+ out6 += part0 - part1;
+ out12 += part0 + part1;
+ out7 += part2 + part3 + part4 + part5;
+ out8 += part2 - part3 + part4 - part5;
+ out13 += part2 + part3 - part4 - part5;
+ out14 += part2 - part3 - part4 + part5;
+ out9 += part6 + part7 + part8 + part9;
+ out10 += part6 - part7 + part8 - part9;
+ out15 += part6 - part7 - part8 + part9;
+ out16 += part6 + part7 - part8 - part9;
+ out11 += part10 + part11;
+ out17 += part10 - part11;
+
+ out18 += part13 - part12;
+ out24 += part13 + part12;
+ out19 += part14 + part15 + part16 + part17;
+ out20 += part14 - part15 + part16 - part17;
+ out25 += part14 - part15 - part16 + part17;
+ out26 += part14 + part15 - part16 - part17;
+ out21 += part18 + part19 + part20 + part21;
+ out22 += part18 - part19 + part20 - part21;
+ out27 += part18 - part19 - part20 + part21;
+ out28 += part18 + part19 - part20 - part21;
+ out23 += part22 + part23;
+ out29 += part22 - part23;
+
+ *(dst_addr) = out6;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out7;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out8;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out9;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out10;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out11;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out12;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out13;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out14;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out15;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out16;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out17;
+ dst_addr += dst_plane_stride;
+
+ *(dst_addr) = out18;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out19;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out20;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out21;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out22;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out23;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out24;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out25;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out26;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out27;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out28;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out29;
+ dst_addr += dst_plane_stride;
+
+ // Row5
+ float4 d50 = vload4(0, (__global float *)(src_addr + 5 * src_stride_y));
+ float2 d51 = vload2(2, (__global float *)(src_addr + 5 * src_stride_y));
+
+ // Channels [30, 35]
+ out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
+
+ *(dst_addr) = out0;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out1;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out2;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out3;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out4;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out5;
+ dst_addr += dst_plane_stride;
+#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+}
+
+#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
+/** This OpenCL kernel computes the input transform when the output tile is 4x4, 4x1 or 1x4, the filter size 3x3, 3x1 or 1x3 and the data layout is NHWC
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2);
+
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float);
+
+ // Clamp coordinates. This clamp is valid for all rows
+ int4 y_coord0 = (int4)(y * OUTPUT_TILE_W) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT;
+ int2 y_coord1 = (int2)(y * OUTPUT_TILE_W) + (int2)(4, 5) - (int2)PAD_LEFT;
+ y_coord0 = clamp(y_coord0, (int4) - 1, (int4)SRC_DIM_1);
+ y_coord1 = clamp(y_coord1, (int2) - 1, (int2)SRC_DIM_1);
+
+ int z_coord;
+ int4 valid_y0;
+ int2 valid_y1;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ // Row4
+ z_coord = (z * 4) - (int)PAD_TOP + 4;
+
+ // If z < 0, set y to -1
+ valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
+ valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
+ // If z >= SRC_DIM_2, set y to SRC_DIM_2
+ valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
+ valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
+
+ // Clamp z coordinate
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ float d40 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d41 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ float d42 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ float d43 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ float d44 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d45 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+
+ float k0 = d44;
+ float k1 = d44;
+ float k2 = d44;
+ float k3 = d44;
+ float k4 = d44;
+ float k5 = (float)0.0f;
+
+ k0 += 4.0f * d40 - 5.0f * d42;
+ k1 += -4.0f * d41 - 4.0f * d42 + d43;
+ k2 += 4.0f * d41 - 4.0f * d42 - d43;
+ k3 += -2.0f * d41 + 2.0f * d43 - d42;
+ k4 += 2.0f * d41 - 2.0f * d43 - d42;
+ k5 += 4.0f * d41 - 5.0f * d43 + d45;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ // Row0
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
+
+#if PAD_TOP != 0
+ valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
+ valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
+ valid_y0 = select(valid_y0, (int)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
+ valid_y1 = select(valid_y1, (int)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+#else // PAD_TOP != 0
+ valid_y0 = y_coord0;
+ valid_y1 = y_coord1;
+#endif // if PAD_TOP == 0, we cannot read out of bound
+
+ float d00 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d01 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ float d02 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ float d03 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ float d04 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d05 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+#else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ int4 z_coords0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP;
+ int2 z_coords1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP;
+
+ valid_y0 = select((int4)y_coord0.s0, (int4) - 1, z_coords0 < (int4)0);
+ valid_y1 = select((int2)y_coord0.s0, (int2) - 1, z_coords1 < (int2)0);
+ valid_y0 = select(valid_y0, (int4)SRC_DIM_1, z_coords0 >= (int4)SRC_DIM_2);
+ valid_y1 = select(valid_y1, (int2)SRC_DIM_1, z_coords1 >= (int2)SRC_DIM_2);
+
+ z_coords0 = clamp((int4)z_coords0, (int4)0, (int4)((int)SRC_DIM_2 - 1));
+ z_coords1 = clamp((int2)z_coords1, (int2)0, (int2)((int)SRC_DIM_2 - 1));
+
+ float d00 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coords0.s0 * src_stride_z);
+ float d01 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coords0.s1 * src_stride_z);
+ float d02 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coords0.s2 * src_stride_z);
+ float d03 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coords0.s3 * src_stride_z);
+ float d04 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coords1.s0 * src_stride_z);
+ float d05 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coords1.s1 * src_stride_z);
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ float out0 = 16.0f * d00 - 20.0f * d02 + 4.0f * d04;
+ float out1 = -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 4.0f * d04;
+ float out2 = 16.0f * d01 - 16.0f * d02 - 4.0f * d03 + 4.0f * d04;
+ float out3 = -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 4.0f * d04;
+ float out4 = 8.0f * d01 - 4.0f * d02 - 8.0f * d03 + 4.0f * d04;
+ float out5 = 16.0f * d01 - 20.0f * d03 + 4.0f * d05;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ // Row2
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
+ valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
+ valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
+ valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
+ valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ float d20 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d21 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ float d22 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ float d23 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ float d24 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d25 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+
+ out0 += k0;
+ out1 += k1;
+ out2 += k2;
+ out3 += k3;
+ out4 += k4;
+ out5 += k5;
+ float out6 = k0;
+ float out7 = k1;
+ float out8 = k2;
+ float out9 = k3;
+ float out10 = k4;
+ float out11 = k5;
+ float out12 = k0;
+ float out13 = k1;
+ float out14 = k2;
+ float out15 = k3;
+ float out16 = k4;
+ float out17 = k5;
+ float out18 = k0;
+ float out19 = k1;
+ float out20 = k2;
+ float out21 = k3;
+ float out22 = k4;
+ float out23 = k5;
+ float out24 = k0;
+ float out25 = k1;
+ float out26 = k2;
+ float out27 = k3;
+ float out28 = k4;
+ float out29 = k5;
+
+ // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
+ out0 += -20.0f * d20 + 25.0f * d22 - 5.0f * d24;
+ out1 += 20.0f * d21 + 20.0f * d22 - 5.0f * d23 - 5.0f * d24;
+ out2 += -20.0f * d21 + 20.0f * d22 + 5.0f * d23 - 5.0f * d24;
+ out3 += 10.0f * d21 + 5.0f * d22 - 10.0f * d23 - 5.0f * d24;
+ out4 += -10.0f * d21 + 5.0f * d22 + 10.0f * d23 - 5.0f * d24;
+ out5 += -20.0f * d21 + 25.0f * d23 - 5.0f * d25;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ // Compute destination address
+ __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(float) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
+ uint dst_plane_stride = dst_stride_z / sizeof(float);
+
+ *((__global float *)dst_addr) = out0;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out1;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out2;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out3;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out4;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out5;
+ dst_addr += dst_plane_stride;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ // Row1
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
+ // Row1 can never be out of bounds
+ valid_y0 = y_coord0;
+ valid_y1 = y_coord1;
+
+ float d10 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d11 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ float d12 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ float d13 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ float d14 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d15 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row3
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
+ valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
+ valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
+ valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
+ valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ float d30 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d31 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ float d32 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ float d33 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ float d34 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d35 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Compute common parts for the channels between [6, 29]
+ // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
+ // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
+ float part0 = -16.0f * d20 + 20.0f * d22 - 4.0f * d24;
+ float part1 = 16.0f * d10 - 20.0f * d12 + 4.0f * d14 - 4.0f * d30 + 5.0f * d32 - d34;
+ float part2 = 16.0f * d22 - 4.0f * d24;
+ float part3 = 16.0f * d21 - 4.0f * d23;
+ float part4 = 16.0f * d12 - 4.0f * d14 - 4.0f * d32 + d34;
+ float part5 = 16.0f * d11 - 4.0f * d13 - 4.0f * d31 + d33;
+ float part6 = 4.0f * d22 - 4.0f * d24;
+ float part7 = 8.0f * d11 - 8.0f * d13 - 2.0f * d31 + 2.0f * d33;
+ float part8 = 4.0f * d12 - 4.0f * d14 - d32 + d34;
+ float part9 = 8.0f * d21 - 8.0f * d23;
+ float part10 = -16.0f * d21 + 20.0f * d23 - 4.0f * d25;
+ float part11 = -16.0f * d11 + 20.0f * d13 - 4.0f * d15 + 4.0f * d31 - 5.0f * d33 + d35;
+
+ // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
+ // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
+ float part12 = 8.0f * d10 - 10.0f * d12 + 2.0f * d14 - 8.0f * d30 + 10.0f * d32 - 2.0f * d34;
+ float part13 = part0 * 0.25f; // -4.0f * d20 + 5.0f * d22 - d24
+ float part14 = part2 * 0.25f; // 4.0f * d22 - d24
+ float part15 = 8.0f * d11 - 2.0f * d13 - 8.0f * d31 + 2.0f * d33;
+ float part16 = 8.0f * d12 - 2.0f * d14 - 8.0f * d32 + 2.0f * d34;
+ float part17 = part3 * 0.25f; // 4.0f * d21 - d23
+ float part18 = part6 * 0.25f; // d22 - d24
+ float part19 = 4.0f * d11 - 4.0f * d13 - 4.0f * d31 + 4.0f * d33;
+ float part20 = 2.0f * d12 - 2.0f * d14 - 2.0f * d32 + 2.0f * d34;
+ float part21 = part9 * 0.25f; // 2.0f * (d21 - d23)
+ float part22 = part10 * 0.25f; // - 4.0f * d21 + 5.0f * d23 - d25
+ float part23 = part11 * 0.5f + 6.0f * d31 - 7.5f * d33 + 1.5f * d35; // - 8.0f * d11 + 10.0f * d13 - 2.0f * d15 + 8.0f * d31 - 10.0f * d33 + 2.0f * d35;
+
+ out6 += part0 - part1;
+ out12 += part0 + part1;
+ out7 += part2 + part3 + part4 + part5;
+ out8 += part2 - part3 + part4 - part5;
+ out13 += part2 + part3 - part4 - part5;
+ out14 += part2 - part3 - part4 + part5;
+ out9 += part6 + part7 + part8 + part9;
+ out10 += part6 - part7 + part8 - part9;
+ out15 += part6 - part7 - part8 + part9;
+ out16 += part6 + part7 - part8 - part9;
+ out11 += part10 + part11;
+ out17 += part10 - part11;
+
+ out18 += part13 - part12;
+ out24 += part13 + part12;
+ out19 += part14 + part15 + part16 + part17;
+ out20 += part14 - part15 + part16 - part17;
+ out25 += part14 - part15 - part16 + part17;
+ out26 += part14 + part15 - part16 - part17;
+ out21 += part18 + part19 + part20 + part21;
+ out22 += part18 - part19 + part20 - part21;
+ out27 += part18 - part19 - part20 + part21;
+ out28 += part18 + part19 - part20 - part21;
+ out23 += part22 + part23;
+ out29 += part22 - part23;
+
+ *((__global float *)dst_addr) = out6;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out7;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out8;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out9;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out10;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out11;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out12;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out13;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out14;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out15;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out16;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out17;
+ dst_addr += dst_plane_stride;
+
+ *((__global float *)dst_addr) = out18;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out19;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out20;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out21;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out22;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out23;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out24;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out25;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out26;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out27;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out28;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out29;
+ dst_addr += dst_plane_stride;
+
+ // Row5
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
+ valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
+ valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
+ valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
+ valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ float d50 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d51 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ float d52 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ float d53 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ float d54 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ float d55 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Channels [30, 35]
+ out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34;
+ out1 = -16.0f * d11 - 16.0f * d12 + 4.0f * d13 + 20.0f * d31 + 20.0f * d32 - 5.0f * d33 - 4.0f * d51 - 4.0f * d52 + d53 + d54 + 4.0f * d14 - 5.0f * d34;
+ out2 = 16.0f * d11 - 16.0f * d12 - 4.0f * d13 - 20.0f * d31 + 20.0f * d32 + 5.0f * d33 + 4.0f * d51 - 4.0f * d52 - d53 + d54 + 4.0f * d14 - 5.0f * d34;
+ out3 = -8.0f * d11 - 4.0f * d12 + 8.0f * d13 + 10.0f * d31 - 10.0f * d33 + 5.0f * d32 - 2.0f * d51 + 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
+ out4 = 8.0f * d11 - 4.0f * d12 - 8.0f * d13 - 10.0f * d31 + 5.0f * d32 + 10.0f * d33 + 2.0f * d51 - 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
+ out5 = 16.0f * d11 - 20.0f * d13 + 4.0f * d15 - 20.0f * d31 + 25.0f * d33 - 5.0f * d35 + 4.0f * d51 - 5.0f * d53 + d55;
+
+ *((__global float *)dst_addr) = out0;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out1;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out2;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out3;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out4;
+ dst_addr += dst_plane_stride;
+ *((__global float *)dst_addr) = out5;
+ dst_addr += dst_plane_stride;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NHWC
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2);
+
+ // Compute input address
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float);
+
+#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+ // Clamp coordinates. This clamp is valid for all rows
+ int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
+ y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
+
+ // Row0
+ // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
+ int z_coord = z * OUTPUT_TILE_H;
+
+ // Load the input tile
+ float8 in_row0;
+ in_row0.s0 = *(__global float *)(src_addr + y_coord.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s1 = *(__global float *)(src_addr + y_coord.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s2 = *(__global float *)(src_addr + y_coord.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s3 = *(__global float *)(src_addr + y_coord.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s4 = *(__global float *)(src_addr + y_coord.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s5 = *(__global float *)(src_addr + y_coord.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s6 = *(__global float *)(src_addr + y_coord.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s7 = *(__global float *)(src_addr + y_coord.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Calculate common factors for intermediate tensor
+ float8 comm_fact0 = 0.0f;
+ float8 tmp0 = in_row0;
+
+ float8 out0 = (float8)0.0f;
+
+ OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
+
+#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+ // We can skip the border clamping along the y dimension as we cannot read out-of-bound in case of 1x5 kernels
+ int y_coord = y * (int)OUTPUT_TILE_W;
+
+ // Row0
+ // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
+ int8 z_coord = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
+ int8 valid_y = select((int8)y_coord, (int8) - 1, z_coord < (int8)0); // If z < 0, set y to -1
+ valid_y = select(valid_y, (int8)SRC_DIM_1, z_coord >= (int8)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
+ z_coord = clamp(z_coord, (int8)0, (int8)SRC_DIM_2 - 1); // Clamp z coordinate
+
+ // Load the input tile
+ float8 in_row0;
+ in_row0.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord.s0 * src_stride_z);
+ in_row0.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord.s1 * src_stride_z);
+ in_row0.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord.s2 * src_stride_z);
+ in_row0.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord.s3 * src_stride_z);
+ in_row0.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord.s4 * src_stride_z);
+ in_row0.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord.s5 * src_stride_z);
+ in_row0.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord.s6 * src_stride_z);
+ in_row0.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord.s7 * src_stride_z);
+
+ // Calculate common factors for intermediate tensor
+ float8 comm_fact0 = 0.0f;
+ float8 tmp0 = in_row0;
+
+ float8 out0 = (float8)0.0f;
+
+ OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
+#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+ float8 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
+
+ // Clamp coordinates. This clamp is valid for all rows
+ int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
+ y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
+
+ // Row0
+ int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
+ int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1
+ valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); // Clamp z coordinate
+
+ // Load the input tile
+ in_row0.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row1
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
+ valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
+ valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ in_row1.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row2
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
+ valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
+ valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ in_row2.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row3
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
+ valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
+ valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ in_row3.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row4
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 4;
+ valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
+ valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ in_row4.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row5
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
+ valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
+ valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ in_row5.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row6
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 6;
+ valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
+ valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ in_row6.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row7
+ z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 7;
+ valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
+ valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
+
+ in_row7.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
+ float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
+ float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
+
+ // Calculate intermediate tensor and reuse common factor vectors
+ const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
+ const float8 tmp1 = comm_fact0 + comm_fact1;
+ const float8 tmp2 = comm_fact0 - comm_fact1;
+
+ comm_fact0 = 2.5f * in_row3;
+ comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
+
+ const float8 tmp3 = comm_fact1 + comm_fact2;
+ const float8 tmp4 = comm_fact2 - comm_fact1;
+
+ comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
+ comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
+
+ const float8 tmp5 = comm_fact1 + comm_fact2;
+ const float8 tmp6 = comm_fact2 - comm_fact1;
+ const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
+
+ // Calculate output rows (reuse comm_fact0 vector)
+ float8 out0, out1, out2, out3, out4, out5, out6, out7;
+ OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ // Store values across the channels
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(float) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
+
+ *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
+ *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
+ *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
+ *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
+ *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
+ *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
+ *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
+ *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
+ *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
+ *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
+ *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
+ *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
+ *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
+ *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
+ *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
+ *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
+ *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
+ *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
+ *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
+ *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
+ *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
+ *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
+ *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
+ *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
+ *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
+ *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
+ *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
+ *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
+ *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
+ *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
+ *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
+ *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
+ *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
+ *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
+ *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
+ *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
+ *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
+ *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
+ *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
+ *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
+ *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
+ *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
+ *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
+ *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
+ *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
+ *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
+ *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
+ *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
+ *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
+ *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
+ *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
+ *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
+ *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
+ *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
+ *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
+ *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
+ *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
+ *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
+ *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
+ *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
+ *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
+ *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
+ *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+}
+#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
+
+/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NCHW
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
+ * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2);
+
+ // Compute input address
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
+
+ src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
+
+ // Load input tile
+#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+ const float8 in_row0 = vload8(0, (__global float *)(src_addr));
+#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+ const float8 in_row0 = (float8)(*((__global float *)(src_addr + 0 * src_stride_y)),
+ *((__global float *)(src_addr + 1 * src_stride_y)),
+ *((__global float *)(src_addr + 2 * src_stride_y)),
+ *((__global float *)(src_addr + 3 * src_stride_y)),
+ *((__global float *)(src_addr + 4 * src_stride_y)),
+ *((__global float *)(src_addr + 5 * src_stride_y)),
+ *((__global float *)(src_addr + 6 * src_stride_y)),
+ *((__global float *)(src_addr + 7 * src_stride_y)));
+#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ const float8 in_row0 = vload8(0, (__global float *)(src_addr + 0 * src_stride_y));
+ const float8 in_row1 = vload8(0, (__global float *)(src_addr + 1 * src_stride_y));
+ const float8 in_row2 = vload8(0, (__global float *)(src_addr + 2 * src_stride_y));
+ const float8 in_row3 = vload8(0, (__global float *)(src_addr + 3 * src_stride_y));
+ const float8 in_row4 = vload8(0, (__global float *)(src_addr + 4 * src_stride_y));
+ const float8 in_row5 = vload8(0, (__global float *)(src_addr + 5 * src_stride_y));
+ const float8 in_row6 = vload8(0, (__global float *)(src_addr + 6 * src_stride_y));
+ const float8 in_row7 = vload8(0, (__global float *)(src_addr + 7 * src_stride_y));
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ // Calculate common factors for intermediate tensor
+ float8 tmp0 = in_row0;
+ float8 comm_fact0 = 0.0f;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ comm_fact0 += in_row2 + in_row6 - 4.25f * in_row4;
+ tmp0 += -in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
+
+ float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
+ float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
+
+ const float8 tmp1 = comm_fact0 + comm_fact1;
+ const float8 tmp2 = comm_fact0 - comm_fact1;
+
+ comm_fact0 = 2.5f * in_row3;
+ comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
+
+ const float8 tmp3 = comm_fact1 + comm_fact2;
+ const float8 tmp4 = comm_fact2 - comm_fact1;
+
+ comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
+ comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
+
+ const float8 tmp5 = comm_fact1 + comm_fact2;
+ const float8 tmp6 = comm_fact2 - comm_fact1;
+ const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ // Calculate output rows (reuse comm_fact0 vector)
+ float8 out0;
+
+ OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ float8 out1, out2, out3, out4, out5, out6, out7;
+
+ OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+
+ // Store values across the channels
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
+
+ *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
+ *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
+ *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
+ *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
+ *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
+ *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
+ *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
+ *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
+
+#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
+ *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
+ *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
+ *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
+ *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
+ *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
+ *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
+ *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
+ *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
+ *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
+ *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
+ *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
+ *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
+ *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
+ *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
+ *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
+ *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
+ *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
+ *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
+ *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
+ *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
+ *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
+ *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
+ *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
+ *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
+ *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
+ *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
+ *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
+ *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
+ *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
+ *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
+ *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
+ *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
+ *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
+ *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
+ *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
+ *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
+ *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
+ *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
+ *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
+ *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
+ *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
+ *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
+ *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
+ *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
+ *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
+ *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
+ *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
+ *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
+ *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
+ *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
+ *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
+ *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
+ *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
+ *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
+ *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
+#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+}
+
+#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 when the data layout is NCHW
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
+ * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x1_5x1_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
+/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1 for data layout NHWC
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x1_3x1_stepz1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 for data layout NHWC
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x1_5x1_stepz1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
+#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
+
+#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
+ * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
+ * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_1x4_1x5_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
+/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4 for data layout NHWC
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_1x4_1x3_stepz1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+
+/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4 for data layout NHWC
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_1x4_1x5_stepz1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes);
+}
+#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
+#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/winograd_output_transform.cl b/src/core/CL/cl_kernels/winograd_output_transform.cl
new file mode 100644
index 0000000..a1e7b3e
--- /dev/null
+++ b/src/core/CL/cl_kernels/winograd_output_transform.cl
@@ -0,0 +1,1601 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+
+#if defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
+/** This OpenCL kernel performs Winograd output transform when the output tile is 2x2/2x1 or 1x2, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
+ * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_2x2_3x3_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ // Each thread stores a 2x2/2x1 or 1x2 tile accordingly with the filter size
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+
+ const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
+
+ // Load the values across the 16 or 4 channels to compose the 4x4 or 4x1 tile
+ float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
+ float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
+ float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
+ float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
+
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ // Compute the 2x1 or 1x2 output tile
+ // out00 = d00 + d01 + d02
+ // out01 = d01 - d02 - d03
+
+ float out00 = d00 + d01 + d02;
+ float out01 = d01 - d02 - d03;
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ float d10 = *((__global float *)(src_addr + 4 * src_stride_z));
+ float d11 = *((__global float *)(src_addr + 5 * src_stride_z));
+ float d12 = *((__global float *)(src_addr + 6 * src_stride_z));
+ float d13 = *((__global float *)(src_addr + 7 * src_stride_z));
+
+ float d20 = *((__global float *)(src_addr + 8 * src_stride_z));
+ float d21 = *((__global float *)(src_addr + 9 * src_stride_z));
+ float d22 = *((__global float *)(src_addr + 10 * src_stride_z));
+ float d23 = *((__global float *)(src_addr + 11 * src_stride_z));
+
+ float d30 = *((__global float *)(src_addr + 12 * src_stride_z));
+ float d31 = *((__global float *)(src_addr + 13 * src_stride_z));
+ float d32 = *((__global float *)(src_addr + 14 * src_stride_z));
+ float d33 = *((__global float *)(src_addr + 15 * src_stride_z));
+
+ // Compute the 2x2 output tile
+ float k0 = d01 + d11 + d21;
+ float k1 = d02 + d12 + d22;
+ float k2 = d11 - d21 - d31;
+ float k3 = d12 - d22 - d32;
+
+ // out00 = d00 + d10 + d20 + d01 + d11 + d21 + d02 + d12 + d22
+ // out01 = d01 + d11 + d21 - (d02 + d12 + d22) - (d03 + d13 + d23)
+ // out10 = d10 - d20 - d30 + (d11 - d21 - d31) + (d12 - d22 - d32)
+ // out11 = d11 - d21 - d31 - (d12 - d22 - d32) - (d13 - d23 - d33)
+
+ float out00 = d10;
+ float out01 = -d13;
+ float out10 = d10;
+ float out11 = -d13;
+
+ out00 += d00 + d20 + k0 + k1;
+ out01 += k0 - k1 - (d03 + d23);
+ out10 += -d20 - d30 + k2 + k3;
+ out11 += k2 - k3 + d23 + d33;
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+ int y_in = get_global_id(1);
+ int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
+ int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
+ int z_out = get_global_id(0);
+
+#if defined(HAS_BIAS)
+ // Add bias
+ Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+ float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
+
+ out00 += (float)b;
+ out01 += (float)b;
+#endif // defined(HAS_BIAS)
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z;
+
+ // Store the output tile
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ *((__global float *)(dst_addr + 0 * dst_stride_y)) = out00;
+ *((__global float *)(dst_addr + 1 * dst_stride_y)) = out01;
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ vstore2((float2)(out00, out01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+#if defined(HAS_BIAS)
+ // Add bias
+ out10 += (float)b;
+ out11 += (float)b;
+#endif // defined(HAS_BIAS)
+
+ vstore2((float2)(out10, out11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data layout is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x4_3x3_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ // Each thread stores a 4x4/4x1 or 1x4 tile
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+
+ const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
+
+ // Load the values across the channels to compose the 6x6 or 6x1 tile
+ float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
+ float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
+ float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
+ float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
+ float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
+ float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
+
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ // Compute out00, out01, out02 and out03
+ float out00 = d00 + d01 + d02 + d03 + d04;
+ float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
+ float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
+ float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
+ float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
+ float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
+ float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
+ float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
+ float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
+
+ float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
+ float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
+ float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
+ float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
+ float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
+ float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
+
+ float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
+ float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
+ float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
+ float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
+ float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
+ float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
+
+ float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
+ float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
+ float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
+ float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
+ float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
+ float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
+
+ float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
+ float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
+ float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
+ float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
+ float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
+ float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
+
+ // Compute out00, out01, out02 and out03
+ float out00 = d01 + d21 + d41 + d11 + d31;
+ float out01 = d01 + d21 + d41 + d11 + d31;
+ float out02 = d01 + d21 + d41 + d11 + d31;
+ float out03 = d01 + d21 + d41 + d11 + d31;
+
+ float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
+ float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
+
+ out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
+ out01 += k1 - d02 - d12 - d22 - d32 - d42;
+ out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
+ out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
+
+ // Compute out10, out11, out12 and out13
+ float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+
+ k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
+ k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
+
+ out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
+ out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
+ out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
+ out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
+
+ // Compute out20, out21, out22 and out23
+ float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+
+ k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
+ k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
+
+ out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
+ out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
+ out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
+ out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
+
+ // Compute out30, out31, out32 and out33
+ float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+
+ k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
+ k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
+
+ out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
+ out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
+ out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
+ out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+ int y_in = get_global_id(1);
+ int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
+ int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
+ int z_out = get_global_id(0);
+
+#if defined(HAS_BIAS)
+ // Add bias
+ Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+ float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
+
+ out00 += (float)b;
+ out01 += (float)b;
+ out02 += (float)b;
+ out03 += (float)b;
+#endif // defined(HAS_BIAS)
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z;
+
+ // Store the output tile
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ *((__global float *)(dst_addr + 0 * dst_stride_y)) = out00;
+ *((__global float *)(dst_addr + 1 * dst_stride_y)) = out01;
+ *((__global float *)(dst_addr + 2 * dst_stride_y)) = out02;
+ *((__global float *)(dst_addr + 3 * dst_stride_y)) = out03;
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ vstore4((float4)(out00, out01, out02, out03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+#if defined(HAS_BIAS)
+ // Add bias
+ out10 += (float)b;
+ out11 += (float)b;
+ out12 += (float)b;
+ out13 += (float)b;
+
+ out20 += (float)b;
+ out21 += (float)b;
+ out22 += (float)b;
+ out23 += (float)b;
+
+ out30 += (float)b;
+ out31 += (float)b;
+ out32 += (float)b;
+ out33 += (float)b;
+#endif // defined(HAS_BIAS)
+ vstore4((float4)(out10, out11, out12, out13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+ vstore4((float4)(out20, out21, out22, out23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+ vstore4((float4)(out30, out31, out32, out33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, 4x1 or 1x4, the filter size 3x3, 3x1 or 1x3 and the data layout is NHWC
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] dst_size Size of the destination tensor, minus the last padding
+ */
+__kernel void winograd_output_transform_4x4_3x3_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+#if defined(HAS_BIAS)
+ VECTOR_DECLARATION(bias),
+#endif // defined(HAS_BIAS)
+ int dst_size)
+{
+ // Each thread stores a 4x4/4x1 or 1x4 tile
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+
+ const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
+
+ // Load the values across the 36 channels to compose the 6x6 or 6x1 tile
+ float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
+ float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
+ float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
+ float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
+ float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
+ float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
+
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ // Compute out00, out01, out02 and out03
+ float out00 = d00 + d01 + d02 + d03 + d04;
+ float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
+ float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
+ float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+ float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
+ float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
+ float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
+ float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
+ float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
+ float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
+
+ float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
+ float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
+ float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
+ float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
+ float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
+ float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
+
+ float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
+ float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
+ float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
+ float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
+ float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
+ float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
+
+ float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
+ float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
+ float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
+ float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
+ float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
+ float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
+
+ float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
+ float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
+ float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
+ float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
+ float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
+ float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
+
+ // Compute out00, out01, out02 and out03
+ float out00 = d01 + d21 + d41 + d11 + d31;
+ float out01 = d01 + d21 + d41 + d11 + d31;
+ float out02 = d01 + d21 + d41 + d11 + d31;
+ float out03 = d01 + d21 + d41 + d11 + d31;
+
+ float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
+ float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
+
+ out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
+ out01 += k1 - d02 - d12 - d22 - d32 - d42;
+ out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
+ out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
+
+ // Compute out10, out11, out12 and out13
+ float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+
+ k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
+ k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
+
+ out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
+ out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
+ out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
+ out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
+
+ // Compute out20, out21, out22 and out23
+ float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+
+ k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
+ k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
+
+ out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
+ out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
+ out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
+ out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
+
+ // Compute out30, out31, out32 and out33
+ float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+
+ k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
+ k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
+
+ out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
+ out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
+ out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
+ out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+ int y_in = get_global_id(1);
+ int x_out = get_global_id(0);
+ int y_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
+ int z_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
+
+#if defined(HAS_BIAS)
+ // Add bias
+ Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+ float b = (float) * ((__global float *)(vector_offset(&bias, x_out)));
+
+ out00 += (float)b;
+ out01 += (float)b;
+ out02 += (float)b;
+ out03 += (float)b;
+#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) & !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ out10 += (float)b;
+ out11 += (float)b;
+ out12 += (float)b;
+ out13 += (float)b;
+
+ out20 += (float)b;
+ out21 += (float)b;
+ out22 += (float)b;
+ out23 += (float)b;
+
+ out30 += (float)b;
+ out31 += (float)b;
+ out32 += (float)b;
+ out33 += (float)b;
+#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) & !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+#endif // defined(HAS_BIAS)
+
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z);
+ offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, (int4)dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
+
+ // Store the 1x4 output tile
+ *((__global float *)(dst_ptr + offset.s0)) = out00;
+ *((__global float *)(dst_ptr + offset.s1)) = out01;
+ *((__global float *)(dst_ptr + offset.s2)) = out02;
+ *((__global float *)(dst_ptr + offset.s3)) = out03;
+#elif defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
+ // Store the 4x1 output tile
+ int offset = dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z;
+ int mult_y = min(dst_size - offset, 1);
+
+ *((__global float *)(dst_ptr + mult_y * 0 * dst_stride_y + offset)) = out00;
+ *((__global float *)(dst_ptr + mult_y * 1 * dst_stride_y + offset)) = out01;
+ *((__global float *)(dst_ptr + mult_y * 2 * dst_stride_y + offset)) = out02;
+ *((__global float *)(dst_ptr + mult_y * 3 * dst_stride_y + offset)) = out03;
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
+ // Get output address
+ int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z);
+ offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, (int4)dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
+ int4 mult_y = min((int4)dst_size - offset, (int4)1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise.
+
+ // Store the 4x4 output tile
+ *((__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s0)) = out00;
+ *((__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s0)) = out01;
+ *((__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s0)) = out02;
+ *((__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s0)) = out03;
+ *((__global float *)(dst_ptr + mult_y.s1 * 0 * dst_stride_y + offset.s1)) = out10;
+ *((__global float *)(dst_ptr + mult_y.s1 * 1 * dst_stride_y + offset.s1)) = out11;
+ *((__global float *)(dst_ptr + mult_y.s1 * 2 * dst_stride_y + offset.s1)) = out12;
+ *((__global float *)(dst_ptr + mult_y.s1 * 3 * dst_stride_y + offset.s1)) = out13;
+ *((__global float *)(dst_ptr + mult_y.s2 * 0 * dst_stride_y + offset.s2)) = out20;
+ *((__global float *)(dst_ptr + mult_y.s2 * 1 * dst_stride_y + offset.s2)) = out21;
+ *((__global float *)(dst_ptr + mult_y.s2 * 2 * dst_stride_y + offset.s2)) = out22;
+ *((__global float *)(dst_ptr + mult_y.s2 * 3 * dst_stride_y + offset.s2)) = out23;
+ *((__global float *)(dst_ptr + mult_y.s3 * 0 * dst_stride_y + offset.s3)) = out30;
+ *((__global float *)(dst_ptr + mult_y.s3 * 1 * dst_stride_y + offset.s3)) = out31;
+ *((__global float *)(dst_ptr + mult_y.s3 * 2 * dst_stride_y + offset.s3)) = out32;
+ *((__global float *)(dst_ptr + mult_y.s3 * 3 * dst_stride_y + offset.s3)) = out33;
+
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
+}
+
+#define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \
+ ({ \
+ comm_fact.s0 = d1 + d2; \
+ comm_fact.s1 = d3 + d4; \
+ comm_fact.s2 = d5 + d6; \
+ \
+ col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \
+ col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \
+ \
+ comm_fact.s0 = d1 - d2; \
+ comm_fact.s1 = d3 - d4; \
+ comm_fact.s2 = d5 - d6; \
+ \
+ col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \
+ col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \
+ })
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4/4x1 or 1x4, the filter size 5x5/5x1 or 1x5 and the data layout is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x4_5x5_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ // Each thread stores a 4x4/4x1 or 1x4 tile
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+
+ const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
+
+ // Compute output address
+ int y_in = get_global_id(1);
+ int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
+ int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
+ int z_out = get_global_id(0);
+
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z;
+
+ // Load the values across the channels to compose the input tile
+ float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
+ float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
+ float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
+ float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
+ float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
+ float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
+ float d06 = *((__global float *)(src_addr + 6 * src_stride_z));
+ float d07 = *((__global float *)(src_addr + 7 * src_stride_z));
+
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ // Compute out00, out01, out02 and out03
+ float out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
+ float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
+ float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
+ float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
+
+#if defined(HAS_BIAS)
+ // Add bias
+ Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+ float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
+
+ out00 += (float)b;
+ out01 += (float)b;
+ out02 += (float)b;
+ out03 += (float)b;
+#endif // defined(HAS_BIAS)
+
+ // Store the output tile
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ *((__global float *)(dst_addr + 0 * dst_stride_y)) = out00;
+ *((__global float *)(dst_addr + 1 * dst_stride_y)) = out01;
+ *((__global float *)(dst_addr + 2 * dst_stride_y)) = out02;
+ *((__global float *)(dst_addr + 3 * dst_stride_y)) = out03;
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ vstore4((float4)(out00, out01, out02, out03), 0, (__global float *)(dst_addr));
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ float d10 = *((__global float *)(src_addr + 8 * src_stride_z));
+ float d11 = *((__global float *)(src_addr + 9 * src_stride_z));
+ float d12 = *((__global float *)(src_addr + 10 * src_stride_z));
+ float d13 = *((__global float *)(src_addr + 11 * src_stride_z));
+ float d14 = *((__global float *)(src_addr + 12 * src_stride_z));
+ float d15 = *((__global float *)(src_addr + 13 * src_stride_z));
+ float d16 = *((__global float *)(src_addr + 14 * src_stride_z));
+ float d17 = *((__global float *)(src_addr + 15 * src_stride_z));
+
+ float d20 = *((__global float *)(src_addr + 16 * src_stride_z));
+ float d21 = *((__global float *)(src_addr + 17 * src_stride_z));
+ float d22 = *((__global float *)(src_addr + 18 * src_stride_z));
+ float d23 = *((__global float *)(src_addr + 19 * src_stride_z));
+ float d24 = *((__global float *)(src_addr + 20 * src_stride_z));
+ float d25 = *((__global float *)(src_addr + 21 * src_stride_z));
+ float d26 = *((__global float *)(src_addr + 22 * src_stride_z));
+ float d27 = *((__global float *)(src_addr + 23 * src_stride_z));
+
+ float d30 = *((__global float *)(src_addr + 24 * src_stride_z));
+ float d31 = *((__global float *)(src_addr + 25 * src_stride_z));
+ float d32 = *((__global float *)(src_addr + 26 * src_stride_z));
+ float d33 = *((__global float *)(src_addr + 27 * src_stride_z));
+ float d34 = *((__global float *)(src_addr + 28 * src_stride_z));
+ float d35 = *((__global float *)(src_addr + 29 * src_stride_z));
+ float d36 = *((__global float *)(src_addr + 30 * src_stride_z));
+ float d37 = *((__global float *)(src_addr + 31 * src_stride_z));
+
+ float d40 = *((__global float *)(src_addr + 32 * src_stride_z));
+ float d41 = *((__global float *)(src_addr + 33 * src_stride_z));
+ float d42 = *((__global float *)(src_addr + 34 * src_stride_z));
+ float d43 = *((__global float *)(src_addr + 35 * src_stride_z));
+ float d44 = *((__global float *)(src_addr + 36 * src_stride_z));
+ float d45 = *((__global float *)(src_addr + 37 * src_stride_z));
+ float d46 = *((__global float *)(src_addr + 38 * src_stride_z));
+ float d47 = *((__global float *)(src_addr + 39 * src_stride_z));
+
+ float d50 = *((__global float *)(src_addr + 40 * src_stride_z));
+ float d51 = *((__global float *)(src_addr + 41 * src_stride_z));
+ float d52 = *((__global float *)(src_addr + 42 * src_stride_z));
+ float d53 = *((__global float *)(src_addr + 43 * src_stride_z));
+ float d54 = *((__global float *)(src_addr + 44 * src_stride_z));
+ float d55 = *((__global float *)(src_addr + 45 * src_stride_z));
+ float d56 = *((__global float *)(src_addr + 46 * src_stride_z));
+ float d57 = *((__global float *)(src_addr + 47 * src_stride_z));
+
+ float d60 = *((__global float *)(src_addr + 48 * src_stride_z));
+ float d61 = *((__global float *)(src_addr + 49 * src_stride_z));
+ float d62 = *((__global float *)(src_addr + 50 * src_stride_z));
+ float d63 = *((__global float *)(src_addr + 51 * src_stride_z));
+ float d64 = *((__global float *)(src_addr + 52 * src_stride_z));
+ float d65 = *((__global float *)(src_addr + 53 * src_stride_z));
+ float d66 = *((__global float *)(src_addr + 54 * src_stride_z));
+ float d67 = *((__global float *)(src_addr + 55 * src_stride_z));
+
+ float d70 = *((__global float *)(src_addr + 56 * src_stride_z));
+ float d71 = *((__global float *)(src_addr + 57 * src_stride_z));
+ float d72 = *((__global float *)(src_addr + 58 * src_stride_z));
+ float d73 = *((__global float *)(src_addr + 59 * src_stride_z));
+ float d74 = *((__global float *)(src_addr + 60 * src_stride_z));
+ float d75 = *((__global float *)(src_addr + 61 * src_stride_z));
+ float d76 = *((__global float *)(src_addr + 62 * src_stride_z));
+ float d77 = *((__global float *)(src_addr + 63 * src_stride_z));
+
+ // Compute the 8x4 intermediate tensor
+ float4 comm_fact0, comm_fact1, comm_fact2;
+ float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
+
+ COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
+
+ // Compute the 4x4 output tile
+ comm_fact0 = tmp_col1 + tmp_col2;
+ comm_fact1 = tmp_col3 + tmp_col4;
+ comm_fact2 = tmp_col5 + tmp_col6;
+
+ float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
+ float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
+
+ comm_fact0 = tmp_col1 - tmp_col2;
+ comm_fact1 = tmp_col3 - tmp_col4;
+ comm_fact2 = tmp_col5 - tmp_col6;
+
+ float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
+ float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
+
+#if defined(HAS_BIAS)
+ // Add bias
+ Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+ float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
+
+ out_col0 += (float4)b;
+ out_col1 += (float4)b;
+ out_col2 += (float4)b;
+ out_col3 += (float4)b;
+#endif // defined(HAS_BIAS)
+
+ // Store the output tile
+ vstore4((float4)(out_col0.s0, out_col1.s0, out_col2.s0, out_col3.s0), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+ vstore4((float4)(out_col0.s1, out_col1.s1, out_col2.s1, out_col3.s1), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+ vstore4((float4)(out_col0.s2, out_col1.s2, out_col2.s2, out_col3.s2), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+ vstore4((float4)(out_col0.s3, out_col1.s3, out_col2.s3, out_col3.s3), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4/4x1 or 1x4, the filter size 5x5/5x1 or 1x5 and the data layout is NHWC
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note If this kernel is used to perform Winograd output transform 5x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd output transform 1x5, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x4_5x5_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+#if defined(HAS_BIAS)
+ VECTOR_DECLARATION(bias),
+#endif // defined(HAS_BIAS)
+ int dst_size)
+{
+ // Each thread stores a 4x4/4x1 or 1x4 tile
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+
+ const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
+
+ int y_in = get_global_id(1);
+ int x_out = get_global_id(0);
+ int y_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
+ int z_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
+
+ // Load the values across the channels to compose the input tile
+ float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
+ float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
+ float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
+ float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
+ float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
+ float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
+ float d06 = *((__global float *)(src_addr + 6 * src_stride_z));
+ float d07 = *((__global float *)(src_addr + 7 * src_stride_z));
+
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ // Compute out00, out01, out02 and out03
+ float out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
+ float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
+ float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
+ float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
+
+#if defined(HAS_BIAS)
+ // Add bias
+ Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+ float b = (float) * ((__global float *)(vector_offset(&bias, x_out)));
+
+ out00 += (float)b;
+ out01 += (float)b;
+ out02 += (float)b;
+ out03 += (float)b;
+#endif // defined(HAS_BIAS)
+
+ // Store the output tile
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ // Get output address
+ int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z);
+ offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, (int4)dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
+
+ *(__global float *)(dst_ptr + offset.s0) = out00;
+ *(__global float *)(dst_ptr + offset.s1) = out01;
+ *(__global float *)(dst_ptr + offset.s2) = out02;
+ *(__global float *)(dst_ptr + offset.s3) = out03;
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+ // Get output address
+ int offset = dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z;
+
+ *(__global float *)(dst_ptr + 0 * dst_stride_y + offset) = out00;
+ *(__global float *)(dst_ptr + 1 * dst_stride_y + offset) = out01;
+ *(__global float *)(dst_ptr + 2 * dst_stride_y + offset) = out02;
+ *(__global float *)(dst_ptr + 3 * dst_stride_y + offset) = out03;
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+
+ float d10 = *((__global float *)(src_addr + 8 * src_stride_z));
+ float d11 = *((__global float *)(src_addr + 9 * src_stride_z));
+ float d12 = *((__global float *)(src_addr + 10 * src_stride_z));
+ float d13 = *((__global float *)(src_addr + 11 * src_stride_z));
+ float d14 = *((__global float *)(src_addr + 12 * src_stride_z));
+ float d15 = *((__global float *)(src_addr + 13 * src_stride_z));
+ float d16 = *((__global float *)(src_addr + 14 * src_stride_z));
+ float d17 = *((__global float *)(src_addr + 15 * src_stride_z));
+
+ float d20 = *((__global float *)(src_addr + 16 * src_stride_z));
+ float d21 = *((__global float *)(src_addr + 17 * src_stride_z));
+ float d22 = *((__global float *)(src_addr + 18 * src_stride_z));
+ float d23 = *((__global float *)(src_addr + 19 * src_stride_z));
+ float d24 = *((__global float *)(src_addr + 20 * src_stride_z));
+ float d25 = *((__global float *)(src_addr + 21 * src_stride_z));
+ float d26 = *((__global float *)(src_addr + 22 * src_stride_z));
+ float d27 = *((__global float *)(src_addr + 23 * src_stride_z));
+
+ float d30 = *((__global float *)(src_addr + 24 * src_stride_z));
+ float d31 = *((__global float *)(src_addr + 25 * src_stride_z));
+ float d32 = *((__global float *)(src_addr + 26 * src_stride_z));
+ float d33 = *((__global float *)(src_addr + 27 * src_stride_z));
+ float d34 = *((__global float *)(src_addr + 28 * src_stride_z));
+ float d35 = *((__global float *)(src_addr + 29 * src_stride_z));
+ float d36 = *((__global float *)(src_addr + 30 * src_stride_z));
+ float d37 = *((__global float *)(src_addr + 31 * src_stride_z));
+
+ float d40 = *((__global float *)(src_addr + 32 * src_stride_z));
+ float d41 = *((__global float *)(src_addr + 33 * src_stride_z));
+ float d42 = *((__global float *)(src_addr + 34 * src_stride_z));
+ float d43 = *((__global float *)(src_addr + 35 * src_stride_z));
+ float d44 = *((__global float *)(src_addr + 36 * src_stride_z));
+ float d45 = *((__global float *)(src_addr + 37 * src_stride_z));
+ float d46 = *((__global float *)(src_addr + 38 * src_stride_z));
+ float d47 = *((__global float *)(src_addr + 39 * src_stride_z));
+
+ float d50 = *((__global float *)(src_addr + 40 * src_stride_z));
+ float d51 = *((__global float *)(src_addr + 41 * src_stride_z));
+ float d52 = *((__global float *)(src_addr + 42 * src_stride_z));
+ float d53 = *((__global float *)(src_addr + 43 * src_stride_z));
+ float d54 = *((__global float *)(src_addr + 44 * src_stride_z));
+ float d55 = *((__global float *)(src_addr + 45 * src_stride_z));
+ float d56 = *((__global float *)(src_addr + 46 * src_stride_z));
+ float d57 = *((__global float *)(src_addr + 47 * src_stride_z));
+
+ float d60 = *((__global float *)(src_addr + 48 * src_stride_z));
+ float d61 = *((__global float *)(src_addr + 49 * src_stride_z));
+ float d62 = *((__global float *)(src_addr + 50 * src_stride_z));
+ float d63 = *((__global float *)(src_addr + 51 * src_stride_z));
+ float d64 = *((__global float *)(src_addr + 52 * src_stride_z));
+ float d65 = *((__global float *)(src_addr + 53 * src_stride_z));
+ float d66 = *((__global float *)(src_addr + 54 * src_stride_z));
+ float d67 = *((__global float *)(src_addr + 55 * src_stride_z));
+
+ float d70 = *((__global float *)(src_addr + 56 * src_stride_z));
+ float d71 = *((__global float *)(src_addr + 57 * src_stride_z));
+ float d72 = *((__global float *)(src_addr + 58 * src_stride_z));
+ float d73 = *((__global float *)(src_addr + 59 * src_stride_z));
+ float d74 = *((__global float *)(src_addr + 60 * src_stride_z));
+ float d75 = *((__global float *)(src_addr + 61 * src_stride_z));
+ float d76 = *((__global float *)(src_addr + 62 * src_stride_z));
+ float d77 = *((__global float *)(src_addr + 63 * src_stride_z));
+
+ // Compute the 8x4 intermediate tensor
+ float4 comm_fact0, comm_fact1, comm_fact2;
+ float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
+
+ COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
+ COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
+
+ // Compute the output tile
+ comm_fact0 = tmp_col1 + tmp_col2;
+ comm_fact1 = tmp_col3 + tmp_col4;
+ comm_fact2 = tmp_col5 + tmp_col6;
+
+ float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
+ float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
+
+ comm_fact0 = tmp_col1 - tmp_col2;
+ comm_fact1 = tmp_col3 - tmp_col4;
+ comm_fact2 = tmp_col5 - tmp_col6;
+
+ float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
+ float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
+
+#if defined(HAS_BIAS)
+ // Add bias
+ Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+ float b = (float) * ((__global float *)(vector_offset(&bias, x_out)));
+
+ out_col0 += (float4)b;
+ out_col1 += (float4)b;
+ out_col2 += (float4)b;
+ out_col3 += (float4)b;
+#endif // defined(HAS_BIAS)
+ // Get output address
+ int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z);
+ offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, (int4)dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
+ int4 mult_y = min((int4)dst_size - offset, (int4)1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise.
+
+ // Store the output tile
+ *(__global float *)(dst_ptr + mult_y.s0 * 0 * (int)dst_stride_y + offset.s0) = out_col0.s0;
+ *(__global float *)(dst_ptr + mult_y.s0 * 1 * (int)dst_stride_y + offset.s0) = out_col1.s0;
+ *(__global float *)(dst_ptr + mult_y.s0 * 2 * (int)dst_stride_y + offset.s0) = out_col2.s0;
+ *(__global float *)(dst_ptr + mult_y.s0 * 3 * (int)dst_stride_y + offset.s0) = out_col3.s0;
+ *(__global float *)(dst_ptr + mult_y.s1 * 0 * (int)dst_stride_y + offset.s1) = out_col0.s1;
+ *(__global float *)(dst_ptr + mult_y.s1 * 1 * (int)dst_stride_y + offset.s1) = out_col1.s1;
+ *(__global float *)(dst_ptr + mult_y.s1 * 2 * (int)dst_stride_y + offset.s1) = out_col2.s1;
+ *(__global float *)(dst_ptr + mult_y.s1 * 3 * (int)dst_stride_y + offset.s1) = out_col3.s1;
+ *(__global float *)(dst_ptr + mult_y.s2 * 0 * (int)dst_stride_y + offset.s2) = out_col0.s2;
+ *(__global float *)(dst_ptr + mult_y.s2 * 1 * (int)dst_stride_y + offset.s2) = out_col1.s2;
+ *(__global float *)(dst_ptr + mult_y.s2 * 2 * (int)dst_stride_y + offset.s2) = out_col2.s2;
+ *(__global float *)(dst_ptr + mult_y.s2 * 3 * (int)dst_stride_y + offset.s2) = out_col3.s2;
+ *(__global float *)(dst_ptr + mult_y.s3 * 0 * (int)dst_stride_y + offset.s3) = out_col0.s3;
+ *(__global float *)(dst_ptr + mult_y.s3 * 1 * (int)dst_stride_y + offset.s3) = out_col1.s3;
+ *(__global float *)(dst_ptr + mult_y.s3 * 2 * (int)dst_stride_y + offset.s3) = out_col2.s3;
+ *(__global float *)(dst_ptr + mult_y.s3 * 3 * (int)dst_stride_y + offset.s3) = out_col3.s3;
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+}
+
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
+/** This OpenCL kernel performs Winograd output transform when the output tile is 2x1, the filter size 3x1 and the data layout is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_2x1_3x1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ winograd_output_transform_2x2_3x3_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes
+#if defined(HAS_BIAS)
+ ,
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes
+#endif // defined(HAS_BIAS)
+ );
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x1, the filter size 3x1 and the data layout is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x1_3x1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ winograd_output_transform_4x4_3x3_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes
+#if defined(HAS_BIAS)
+ ,
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes
+#endif // defined(HAS_BIAS)
+ );
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x1, the filter size 5x1 and the data layout is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x1_5x1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ winograd_output_transform_4x4_5x5_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes
+#if defined(HAS_BIAS)
+ ,
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes
+#endif // defined(HAS_BIAS)
+ );
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x1, the filter size 3x1 and the data layout is NHWC
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x1_3x1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+#if defined(HAS_BIAS)
+ VECTOR_DECLARATION(bias),
+#endif // defined(HAS_BIAS)
+ int dst_size)
+{
+ winograd_output_transform_4x4_3x3_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes,
+#if defined(HAS_BIAS)
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes,
+#endif // defined(HAS_BIAS)
+ dst_size);
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x1, the filter size 5x1 and the data layout is NHWC
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x1_5x1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+#if defined(HAS_BIAS)
+ VECTOR_DECLARATION(bias),
+#endif // defined(HAS_BIAS)
+ int dst_size)
+{
+ winograd_output_transform_4x4_5x5_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes,
+#if defined(HAS_BIAS)
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes,
+#endif // defined(HAS_BIAS)
+ dst_size);
+}
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
+
+#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+/** This OpenCL kernel performs Winograd output transform when the output tile is 1x2, the filter size 1x3 and the data layout is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_1x2_1x3_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ winograd_output_transform_2x2_3x3_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes
+#if defined(HAS_BIAS)
+ ,
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes
+#endif // defined(HAS_BIAS)
+ );
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 1x4, the filter size 1x3 and the data layout is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_1x4_1x3_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ winograd_output_transform_4x4_3x3_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes
+#if defined(HAS_BIAS)
+ ,
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes
+#endif // defined(HAS_BIAS)
+ );
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 1x4, the filter size 1x5 and the data layout is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_1x4_1x5_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ winograd_output_transform_4x4_5x5_nchw(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes
+#if defined(HAS_BIAS)
+ ,
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes
+#endif // defined(HAS_BIAS)
+ );
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 1x4, the filter size 1x3 and the data layout is NHWC
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_1x4_1x3_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+#if defined(HAS_BIAS)
+ VECTOR_DECLARATION(bias),
+#endif // defined(HAS_BIAS)
+ int dst_size)
+{
+ winograd_output_transform_4x4_3x3_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes,
+#if defined(HAS_BIAS)
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes,
+#endif // defined(HAS_BIAS)
+ dst_size);
+}
+
+/** This OpenCL kernel performs Winograd output transform when the output tile is 1x4, the filter size 1x5 and the data layout is NHWC
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
+ * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
+ * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_1x4_1x5_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst),
+#if defined(HAS_BIAS)
+ VECTOR_DECLARATION(bias),
+#endif // defined(HAS_BIAS)
+ int dst_size)
+{
+ winograd_output_transform_4x4_5x5_nhwc(src_ptr,
+ src_stride_x,
+ src_step_x,
+ src_stride_y,
+ src_step_y,
+ src_stride_z,
+ src_step_z,
+ src_offset_first_element_in_bytes,
+ dst_ptr,
+ dst_stride_x,
+ dst_step_x,
+ dst_stride_y,
+ dst_step_y,
+ dst_stride_z,
+ dst_step_z,
+ dst_offset_first_element_in_bytes,
+#if defined(HAS_BIAS)
+ bias_ptr,
+ bias_stride_x,
+ bias_step_x,
+ bias_offset_first_element_in_bytes,
+#endif // defined(HAS_BIAS)
+ dst_size);
+}
+#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
+#endif // defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
diff --git a/src/core/CL/kernels/CLAbsoluteDifferenceKernel.cpp b/src/core/CL/kernels/CLAbsoluteDifferenceKernel.cpp
index 685b8e2..0c1206a 100644
--- a/src/core/CL/kernels/CLAbsoluteDifferenceKernel.cpp
+++ b/src/core/CL/kernels/CLAbsoluteDifferenceKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -81,7 +81,7 @@
output_access.set_valid_region(win, valid_region);
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLAbsoluteDifferenceKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLActivationLayerKernel.cpp b/src/core/CL/kernels/CLActivationLayerKernel.cpp
index a78b3e1..a15e99b 100644
--- a/src/core/CL/kernels/CLActivationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLActivationLayerKernel.cpp
@@ -25,13 +25,12 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/CL/CLHelpers.h"
@@ -46,18 +45,19 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ActivationLayerInfo &act_info)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG((input->data_type() == DataType::QASYMM8) && (act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU)
&& (act_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU)
- && (act_info.activation() != ActivationLayerInfo::ActivationFunction::RELU),
- "For QASYMM8 only relu, lower bounded relu and lower-upper bounded relu are supported");
+ && (act_info.activation() != ActivationLayerInfo::ActivationFunction::RELU)
+ && (act_info.activation() != ActivationLayerInfo::ActivationFunction::LOGISTIC),
+ "For QASYMM8 only logistic, relu, lower bounded relu and lower-upper bounded relu are supported");
// Checks performed when output is configured
if((output != nullptr) && (output->total_size() != 0))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -117,7 +117,6 @@
const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
const DataType dt = input->info()->data_type();
- const int fixed_point_position = input->info()->fixed_point_position();
float a_const = act_info.a();
float b_const = act_info.b();
int a_const_int = 0;
@@ -126,16 +125,8 @@
// Create quantized version of constants a, b if needed
if(is_data_type_quantized(dt))
{
- if(is_data_type_fixed_point(dt))
- {
- a_const_int = static_cast<int>(lround(a_const * (1 << fixed_point_position)));
- b_const_int = static_cast<int>(lround(b_const * (1 << fixed_point_position)));
- }
- else
- {
- a_const_int = input->info()->quantization_info().quantize(a_const, RoundingPolicy::TO_NEAREST_UP);
- b_const_int = input->info()->quantization_info().quantize(b_const, RoundingPolicy::TO_NEAREST_UP);
- }
+ a_const_int = input->info()->quantization_info().quantize(a_const, RoundingPolicy::TO_NEAREST_UP);
+ b_const_int = input->info()->quantization_info().quantize(b_const, RoundingPolicy::TO_NEAREST_UP);
}
// Set build options
@@ -149,22 +140,22 @@
build_opts.emplace(("-DA_VAL=" + support::cpp11::to_string(a_const_int)));
build_opts.emplace(("-DB_VAL=" + support::cpp11::to_string(b_const_int)));
- const int o1 = input->info()->quantization_info().offset;
+ const int o1 = input->info()->quantization_info().offset;
+ const float s1 = input->info()->quantization_info().scale;
// Quantized value of 0 corresponds to the offset o1
build_opts.emplace(("-DCONST_0=" + support::cpp11::to_string(o1)));
+ build_opts.emplace(("-DS1_VAL=" + float_to_string_with_full_precision(s1)));
+ build_opts.emplace(("-DO1_VAL=" + support::cpp11::to_string(o1)));
// Set scale and offset of the input and output if they have different quantization info
if(is_data_type_quantized_asymmetric(dt) && output != nullptr)
{
- const float s1 = input->info()->quantization_info().scale;
const float s2 = output->info()->quantization_info().scale;
const int o2 = output->info()->quantization_info().offset;
if(o1 != o2 || s1 != s2)
{
- build_opts.emplace(("-DS1_VAL=" + float_to_string_with_full_precision(s1)));
build_opts.emplace(("-DS2_VAL=" + float_to_string_with_full_precision(s2)));
- build_opts.emplace(("-DO1_VAL=" + support::cpp11::to_string(o1)));
build_opts.emplace(("-DO2_VAL=" + support::cpp11::to_string(o2)));
}
}
@@ -176,10 +167,6 @@
}
build_opts.emplace((_run_in_place) ? "-DIN_PLACE" : "");
- if(is_data_type_fixed_point(dt))
- {
- build_opts.emplace(("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(fixed_point_position)));
- }
// Create kernel
std::string kernel_name = is_data_type_quantized_asymmetric(dt) ? std::string("activation_layer_qa8") : std::string("activation_layer");
@@ -192,7 +179,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), (_run_in_place) ? nullptr : output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Set config_id for enabling LWS tuning
_config_id = "activation_layer_";
@@ -228,7 +215,7 @@
{
add_3D_tensor_argument(idx, _output, slice);
}
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(collapsed.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
index c4904ec..2372d45 100644
--- a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
+++ b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/core/CL/kernels/CLArithmeticAdditionKernel.h"
#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
using namespace arm_compute;
@@ -35,24 +36,34 @@
Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output, ConvertPolicy policy)
{
ARM_COMPUTE_UNUSED(policy);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ const bool is_qasymm = is_data_type_quantized_asymmetric(input1.data_type()) || is_data_type_quantized_asymmetric(input2.data_type());
+ if(is_qasymm)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &input2);
+ }
const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(&input1, &input2);
// Validate in case of configured output
if(output.total_size() > 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG((output.data_type() == DataType::U8) && ((input1.data_type() != DataType::U8) || (input2.data_type() != DataType::U8)),
"Output can only be U8 if both inputs are U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
"Wrong shape for output");
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(&input1, &output);
+ if(is_qasymm)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &output);
+ }
}
return Status{};
@@ -121,21 +132,29 @@
const bool has_float_out = is_data_type_float(output->info()->data_type());
+ std::string kernel_name = "arithmetic_add";
+
// Set kernel build options
std::set<std::string> build_opts;
build_opts.emplace((policy == ConvertPolicy::WRAP || has_float_out) ? "-DWRAP" : "-DSATURATE");
build_opts.emplace("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(input1->info()->data_type()));
build_opts.emplace("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(input2->info()->data_type()));
build_opts.emplace("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->info()->data_type()));
- if(is_data_type_fixed_point(input1->info()->data_type()))
+ if(is_data_type_quantized_asymmetric(input1->info()->data_type()))
{
- build_opts.emplace("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input1->info()->fixed_point_position()));
+ build_opts.emplace("-DOFFSET_IN1=" + support::cpp11::to_string(input1->info()->quantization_info().offset));
+ build_opts.emplace("-DOFFSET_IN2=" + support::cpp11::to_string(input2->info()->quantization_info().offset));
+ build_opts.emplace("-DOFFSET_OUT=" + support::cpp11::to_string(output->info()->quantization_info().offset));
+ build_opts.emplace("-DSCALE_IN1=" + support::cpp11::to_string(input1->info()->quantization_info().scale));
+ build_opts.emplace("-DSCALE_IN2=" + support::cpp11::to_string(input2->info()->quantization_info().scale));
+ build_opts.emplace("-DSCALE_OUT=" + support::cpp11::to_string(output->info()->quantization_info().scale));
+ kernel_name += "_quantized";
}
// Create kernel
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("arithmetic_add", build_opts));
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts));
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
diff --git a/src/core/CL/kernels/CLArithmeticDivisionKernel.cpp b/src/core/CL/kernels/CLArithmeticDivisionKernel.cpp
new file mode 100644
index 0000000..e995ba1
--- /dev/null
+++ b/src/core/CL/kernels/CLArithmeticDivisionKernel.cpp
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/CL/kernels/CLArithmeticDivisionKernel.h"
+
+#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLValidate.h"
+#include "arm_compute/core/CL/ICLTensor.h"
+
+using namespace arm_compute;
+
+namespace
+{
+constexpr unsigned int num_elems_processed_per_iteration = 16;
+
+Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
+
+ const TensorShape out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
+
+ // Validate in case of configured output
+ if(output->total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0),
+ "Wrong shape for output");
+ }
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
+{
+ const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
+ const TensorShape &out_shape = broadcast_pair.first;
+ const ValidRegion &valid_region = broadcast_pair.second;
+
+ // Auto initialize output if not initialized
+ {
+ set_shape_if_empty(*output, out_shape);
+
+ if(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16)
+ {
+ set_format_if_unknown(*output, Format::F16);
+ }
+ else if(input1->data_type() == DataType::F32 || input2->data_type() == DataType::F32)
+ {
+ set_format_if_unknown(*output, Format::F32);
+ }
+ }
+
+ Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration));
+ Window win_input1 = win.broadcast_if_dimension_le_one(*input1);
+ Window win_input2 = win.broadcast_if_dimension_le_one(*input2);
+
+ AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+
+ bool window_changed = update_window_and_padding(win_input1, input1_access)
+ || update_window_and_padding(win_input2, input2_access)
+ || update_window_and_padding(win, output_access);
+
+ output_access.set_valid_region(win, valid_region);
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
+} // namespace
+
+CLArithmeticDivisionKernel::CLArithmeticDivisionKernel()
+ : _input1(nullptr), _input2(nullptr), _output(nullptr)
+{
+}
+
+void CLArithmeticDivisionKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info()));
+
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+
+ _input1 = input1;
+ _input2 = input2;
+ _output = output;
+
+ // Set kernel build options
+ std::set<std::string> build_opts;
+ build_opts.emplace("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(input1->info()->data_type()));
+ build_opts.emplace("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(input2->info()->data_type()));
+ build_opts.emplace("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->info()->data_type()));
+
+ // Create kernel
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("arithmetic_div", build_opts));
+
+ ICLKernel::configure_internal(win_config.second);
+}
+
+Status CLArithmeticDivisionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
+
+ return Status{};
+}
+
+void CLArithmeticDivisionKernel::run(const Window &window, cl::CommandQueue &queue)
+{
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
+
+ const TensorShape &in_shape1 = _input1->info()->tensor_shape();
+ const TensorShape &in_shape2 = _input2->info()->tensor_shape();
+ const TensorShape &out_shape = _output->info()->tensor_shape();
+
+ bool can_collapse = true;
+ if(std::min(in_shape1.total_size(), in_shape2.total_size()) > 1)
+ {
+ can_collapse = (std::min(in_shape1.num_dimensions(), in_shape2.num_dimensions()) > Window::DimZ);
+ for(size_t d = Window::DimZ; can_collapse && (d < out_shape.num_dimensions()); d++)
+ {
+ can_collapse = (in_shape1[d] == in_shape2[d]);
+ }
+ }
+
+ bool has_collapsed = false;
+ Window collapsed = can_collapse ? window.collapse_if_possible(ICLKernel::window(), Window::DimZ, &has_collapsed) : window;
+
+ const TensorShape &in_shape1_collapsed = has_collapsed ? in_shape1.collapsed_from(Window::DimZ) : in_shape1;
+ const TensorShape &in_shape2_collapsed = has_collapsed ? in_shape2.collapsed_from(Window::DimZ) : in_shape2;
+
+ Window slice = collapsed.first_slice_window_3D();
+ Window slice_input1 = slice.broadcast_if_dimension_le_one(in_shape1_collapsed);
+ Window slice_input2 = slice.broadcast_if_dimension_le_one(in_shape2_collapsed);
+
+ do
+ {
+ unsigned int idx = 0;
+
+ add_3D_tensor_argument(idx, _input1, slice_input1);
+ add_3D_tensor_argument(idx, _input2, slice_input2);
+ add_3D_tensor_argument(idx, _output, slice);
+
+ enqueue(queue, *this, slice);
+
+ collapsed.slide_window_slice_3D(slice_input1);
+ collapsed.slide_window_slice_3D(slice_input2);
+ }
+ while(collapsed.slide_window_slice_3D(slice));
+}
+
+BorderSize CLArithmeticDivisionKernel::border_size() const
+{
+ const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
+ const unsigned int border = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
+ return BorderSize(0, border, 0, 0);
+}
diff --git a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
index 8308aa0..299ac55 100644
--- a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
+++ b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,12 +25,12 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include <set>
@@ -43,19 +43,20 @@
Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
{
ARM_COMPUTE_UNUSED(policy);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2);
// Validate in case of configured output
if((output != nullptr) && (output->total_size() != 0))
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
"Output can only be U8 if both inputs are U8");
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output);
}
return Status{};
@@ -119,10 +120,6 @@
build_opts.emplace("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(input1->info()->data_type()));
build_opts.emplace("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(input2->info()->data_type()));
build_opts.emplace("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->info()->data_type()));
- if(is_data_type_fixed_point(input1->info()->data_type()))
- {
- build_opts.emplace("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input1->info()->fixed_point_position()));
- }
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("arithmetic_sub", build_opts));
@@ -130,7 +127,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
index 293361b..d4a7207 100644
--- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
@@ -25,12 +25,11 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "support/ToolchainSupport.h"
@@ -45,29 +44,28 @@
float epsilon, ActivationLayerInfo act_info)
{
ARM_COMPUTE_UNUSED(epsilon);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var);
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) != mean->dimension(0));
if(beta != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, beta);
}
if(gamma != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma);
}
if(act_info.enabled())
{
ActivationLayerInfo::ActivationFunction act = act_info.activation();
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32 && input->data_type() != DataType::F16);
- ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU
+ ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU
+ && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU
&& act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
ARM_COMPUTE_RETURN_ERROR_ON(act_info.b() > act_info.a());
}
@@ -77,7 +75,6 @@
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -167,7 +164,6 @@
build_opts.add_option_if(act_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(act_info.a()));
build_opts.add_option_if(act_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(act_info.b()));
build_opts.add_option_if(_run_in_place, "-DIN_PLACE");
- build_opts.add_option_if(is_data_type_fixed_point(input->info()->data_type()), "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
build_opts.add_option_if(beta == nullptr, "-DUSE_DEFAULT_BETA");
build_opts.add_option_if(gamma == nullptr, "-DUSE_DEFAULT_GAMMA");
@@ -193,7 +189,7 @@
(beta != nullptr) ? beta->info() : nullptr,
(gamma != nullptr) ? gamma->info() : nullptr);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
_config_id = "batch_normalization_layer_";
_config_id += string_from_data_layout(input->info()->data_layout());
@@ -205,6 +201,8 @@
_config_id += support::cpp11::to_string(input->info()->dimension(1));
_config_id += "_";
_config_id += support::cpp11::to_string(input->info()->dimension(2));
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_layout(input->info()->data_layout()));
}
Status CLBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output,
@@ -254,7 +252,7 @@
{
add_3D_tensor_argument(idx, _output, slice);
}
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLBitwiseAndKernel.cpp b/src/core/CL/kernels/CLBitwiseAndKernel.cpp
index 5ea4a86..dd301cd 100644
--- a/src/core/CL/kernels/CLBitwiseAndKernel.cpp
+++ b/src/core/CL/kernels/CLBitwiseAndKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -66,7 +66,7 @@
output_access.set_valid_region(win, valid_region);
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLBitwiseAndKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLBitwiseOrKernel.cpp b/src/core/CL/kernels/CLBitwiseOrKernel.cpp
index 2eeef0a..aa84618 100644
--- a/src/core/CL/kernels/CLBitwiseOrKernel.cpp
+++ b/src/core/CL/kernels/CLBitwiseOrKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -67,7 +67,7 @@
output_access.set_valid_region(win, valid_region);
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLBitwiseOrKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLBitwiseXorKernel.cpp b/src/core/CL/kernels/CLBitwiseXorKernel.cpp
index c19a78e..ad1f923 100644
--- a/src/core/CL/kernels/CLBitwiseXorKernel.cpp
+++ b/src/core/CL/kernels/CLBitwiseXorKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -67,7 +67,7 @@
output_access.set_valid_region(win, valid_region);
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLBitwiseXorKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLBox3x3Kernel.cpp b/src/core/CL/kernels/CLBox3x3Kernel.cpp
index 0299f62..b81697f 100644
--- a/src/core/CL/kernels/CLBox3x3Kernel.cpp
+++ b/src/core/CL/kernels/CLBox3x3Kernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -73,5 +73,5 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/kernels/CLCannyEdgeKernel.cpp b/src/core/CL/kernels/CLCannyEdgeKernel.cpp
index 5d06d34..94e5e23 100644
--- a/src/core/CL/kernels/CLCannyEdgeKernel.cpp
+++ b/src/core/CL/kernels/CLCannyEdgeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -77,7 +77,7 @@
mag_access.set_valid_region(win, _gx->info()->valid_region());
phase_access.set_valid_region(win, _gx->info()->valid_region());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLGradientKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -145,7 +145,7 @@
output_access.set_valid_region(win, _magnitude->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLEdgeNonMaxSuppressionKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -230,7 +230,7 @@
l1_stack_access.set_valid_region(win, _input->info()->valid_region());
l1_stack_counter_access.set_valid_region(win, _input->info()->valid_region());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLEdgeTraceKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLChannelCombineKernel.cpp b/src/core/CL/kernels/CLChannelCombineKernel.cpp
index 6e55e66..c7b1da4 100644
--- a/src/core/CL/kernels/CLChannelCombineKernel.cpp
+++ b/src/core/CL/kernels/CLChannelCombineKernel.cpp
@@ -128,7 +128,7 @@
}
output_access.set_valid_region(win, ValidRegion(valid_region.anchor, output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLChannelCombineKernel::configure(const ICLImage *plane0, const ICLImage *plane1, const ICLImage *plane2, ICLMultiImage *output)
@@ -232,7 +232,7 @@
output_plane1_access.set_valid_region(win, ValidRegion(output_plane1_region.anchor, output->plane(1)->info()->tensor_shape()));
output_plane2_access.set_valid_region(win, ValidRegion(plane2->info()->valid_region().anchor, output->plane(2)->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLChannelCombineKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLChannelExtractKernel.cpp b/src/core/CL/kernels/CLChannelExtractKernel.cpp
index 65843b8..8bddba8 100644
--- a/src/core/CL/kernels/CLChannelExtractKernel.cpp
+++ b/src/core/CL/kernels/CLChannelExtractKernel.cpp
@@ -101,7 +101,7 @@
ValidRegion input_valid_region = input->info()->valid_region();
output_access.set_valid_region(win, ValidRegion(input_valid_region.anchor, output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLChannelExtractKernel::configure(const ICLMultiImage *input, Channel channel, ICLImage *output)
@@ -162,7 +162,7 @@
output_access.set_valid_region(win, input_plane->info()->valid_region());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLChannelExtractKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLChannelShuffleLayerKernel.cpp b/src/core/CL/kernels/CLChannelShuffleLayerKernel.cpp
index a667119..be4d687 100644
--- a/src/core/CL/kernels/CLChannelShuffleLayerKernel.cpp
+++ b/src/core/CL/kernels/CLChannelShuffleLayerKernel.cpp
@@ -25,11 +25,11 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
namespace arm_compute
@@ -38,8 +38,9 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int num_groups)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8,
- DataType::U16, DataType::S16, DataType::QS16,
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups < 2, "Channel shuffling with less than 2 groups would be inefficient");
@@ -124,7 +125,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLChannelShuffleLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int num_groups)
diff --git a/src/core/CL/kernels/CLCol2ImKernel.cpp b/src/core/CL/kernels/CLCol2ImKernel.cpp
index 91c0430..40032f9 100644
--- a/src/core/CL/kernels/CLCol2ImKernel.cpp
+++ b/src/core/CL/kernels/CLCol2ImKernel.cpp
@@ -25,12 +25,12 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include <cmath>
@@ -40,30 +40,31 @@
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, std::pair<unsigned int, unsigned int> convolved_dims)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, std::pair<unsigned int, unsigned int> convolved_dims, unsigned int num_groups)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
// Checks performed when output is configured
if(output->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_col2im_shape(*input, convolved_dims));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_col2im_shape(*input, convolved_dims, num_groups));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_layout() != DataLayout::NCHW, "Col2Im output's data layout must always be NCHW");
}
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, std::pair<unsigned int, unsigned int> convolved_dims)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, std::pair<unsigned int, unsigned int> convolved_dims, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_col2im_shape(*input, convolved_dims)));
+ auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_col2im_shape(*input, convolved_dims, num_groups)).set_data_layout(DataLayout::NCHW));
- const unsigned int num_elems_read_per_iteration = is_data_type_fixed_point(input->data_type()) ? 1 : 8;
+ const unsigned int num_elems_read_per_iteration = 8;
// Configure window
Window win = calculate_max_window(*input, Steps(num_elems_read_per_iteration));
@@ -86,12 +87,12 @@
{
}
-void CLCol2ImKernel::configure(const ICLTensor *input, ICLTensor *output, std::pair<unsigned int, unsigned int> convolved_dims)
+void CLCol2ImKernel::configure(const ICLTensor *input, ICLTensor *output, std::pair<unsigned int, unsigned int> convolved_dims, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), convolved_dims));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), convolved_dims, num_groups));
_input = input;
_output = output;
@@ -105,34 +106,21 @@
build_opts.add_option("-DELEMENT_SIZE=" + support::cpp11::to_string(input->info()->element_size()));
build_opts.add_option("-DWIDTH_INPUT=" + support::cpp11::to_string(input->info()->dimension(0)));
build_opts.add_option("-DWIDTH_OUTPUT=" + support::cpp11::to_string(_convolved_dims.first));
- build_opts.add_option_if(is_data_type_fixed_point(data_type), "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
+ build_opts.add_option_if(num_groups > 1, "-DGROUPING");
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("col2im", build_opts.options()));
- // Configure the local work size for Bifrost with a value obtained
- // via exhaustive autotuning over 30 representative tensor shapes.
- const GPUTarget gpu_target = get_target();
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
- {
- if((_convolved_dims.first == 7) || (_convolved_dims.first == 14))
- {
- _lws_hint = cl::NDRange(1, 7, 1);
- }
- else
- {
- _lws_hint = cl::NDRange(1, 8, 1);
- }
- }
-
// Configure kernel window
- auto win_config = validate_and_configure_window(input->info(), output->info(), _convolved_dims);
+ auto win_config = validate_and_configure_window(input->info(), output->info(), _convolved_dims, num_groups);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Set config_id for enabling LWS tuning
_config_id = "col2im_";
_config_id += lower_string(string_from_data_type(input->info()->data_type()));
_config_id += "_";
+ _config_id += support::cpp11::to_string(num_groups);
+ _config_id += "_";
_config_id += support::cpp11::to_string(input->info()->dimension(0));
_config_id += "_";
_config_id += support::cpp11::to_string(input->info()->dimension(1));
@@ -142,11 +130,11 @@
_config_id += support::cpp11::to_string(output->info()->dimension(1));
}
-Status CLCol2ImKernel::validate(const ITensorInfo *input, const ITensorInfo *output, std::pair<unsigned int, unsigned int> convolved_dims)
+Status CLCol2ImKernel::validate(const ITensorInfo *input, const ITensorInfo *output, std::pair<unsigned int, unsigned int> convolved_dims, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, convolved_dims));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), convolved_dims).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, convolved_dims, num_groups));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), convolved_dims, num_groups).first);
return Status{};
}
@@ -154,13 +142,13 @@
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
- // The collapse method rely on the assumption that the third dimension of input buffer is 1
- ARM_COMPUTE_ERROR_ON(window.z().end() != 1);
- Window collapsed_window = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
- Window slice = collapsed_window.first_slice_window_3D();
+ Window out_window;
+ out_window.use_tensor_dimensions(_output->info()->tensor_shape());
- // Set static kernel arguments
+ Window slice = window.first_slice_window_3D();
+ Window slice_out = out_window.first_slice_window_3D();
+
unsigned int idx = 2 * num_arguments_per_3D_tensor();
_kernel.setArg<cl_uint>(idx++, _output->info()->strides_in_bytes()[3]);
@@ -169,8 +157,8 @@
// Set inputs
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice);
- add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ add_3D_tensor_argument(idx, _output, slice_out);
+ enqueue(queue, *this, slice, lws_hint());
}
- while(collapsed_window.slide_window_slice_3D(slice));
+ while(window.slide_window_slice_3D(slice) && out_window.slide_window_slice_3D(slice_out));
}
diff --git a/src/core/CL/kernels/CLColorConvertKernel.cpp b/src/core/CL/kernels/CLColorConvertKernel.cpp
index ead2b8f..e79019e 100644
--- a/src/core/CL/kernels/CLColorConvertKernel.cpp
+++ b/src/core/CL/kernels/CLColorConvertKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -120,7 +120,7 @@
output_access.set_valid_region(win, input->info()->valid_region());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLColorConvertKernel::configure(const ICLMultiImage *input, ICLImage *output)
@@ -189,7 +189,7 @@
input->plane(2)->info()->valid_region());
output_access.set_valid_region(win, ValidRegion(intersect_region.anchor, output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLColorConvertKernel::configure(const ICLImage *input, ICLMultiImage *output)
@@ -198,6 +198,7 @@
ARM_COMPUTE_ERROR_ON(output == nullptr);
unsigned int num_elems_processed_per_iteration = 0;
+ unsigned int num_elems_read_per_iteration_x = 0;
bool has_two_planes = (output->info()->format() == Format::NV12) || (output->info()->format() == Format::NV21);
float sub_sampling = (has_two_planes || (output->info()->format() == Format::IYUV)) ? 0.5f : 1;
@@ -212,9 +213,11 @@
case Format::NV12:
case Format::IYUV:
num_elems_processed_per_iteration = 2;
+ num_elems_read_per_iteration_x = 8;
break;
case Format::YUV444:
num_elems_processed_per_iteration = 4;
+ num_elems_read_per_iteration_x = 16;
break;
default:
break;
@@ -229,6 +232,7 @@
case Format::NV12:
case Format::IYUV:
num_elems_processed_per_iteration = 8;
+ num_elems_read_per_iteration_x = 8;
break;
default:
break;
@@ -238,6 +242,7 @@
default:
break;
}
+
ARM_COMPUTE_ERROR_ON_MSG(num_elems_processed_per_iteration == 0, "Conversion from %s to %s not supported",
string_from_format(input->info()->format()).c_str(),
string_from_format(output->info()->format()).c_str());
@@ -248,7 +253,6 @@
kernel_name << "_to_";
kernel_name << string_from_format(output->info()->format());
kernel_name << "_bt709";
-
_input = input;
_multi_output = output;
@@ -267,8 +271,10 @@
AccessWindowRectangle output_plane2_access(has_two_planes ? nullptr : output->plane(2)->info(), 0, 0,
num_elems_processed_per_iteration, 1, sub_sampling, sub_sampling);
+ AccessWindowHorizontal input_access(input->info(), 0, num_elems_read_per_iteration_x);
+
update_window_and_padding(win,
- AccessWindowHorizontal(input->info(), 0, num_elems_processed_per_iteration),
+ input_access,
output_plane0_access,
output_plane1_access,
output_plane2_access);
@@ -279,7 +285,7 @@
output_plane1_access.set_valid_region(win, ValidRegion(input_region.anchor, output->plane(1)->info()->tensor_shape()));
output_plane2_access.set_valid_region(win, ValidRegion(input_region.anchor, output->plane(2)->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLColorConvertKernel::configure(const ICLMultiImage *input, ICLMultiImage *output)
@@ -363,7 +369,7 @@
output_plane1_access.set_valid_region(win, ValidRegion(intersect_region.anchor, output->plane(1)->info()->tensor_shape()));
output_plane2_access.set_valid_region(win, ValidRegion(intersect_region.anchor, output->plane(2)->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLColorConvertKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp b/src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp
index 1b211b0..ace3fd5 100644
--- a/src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp
+++ b/src/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
@@ -40,47 +41,61 @@
DataLayout data_layout)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+
+ // Output tensor auto initialisation if not yet initialized
+ auto_init_if_empty(*output->info(), *input->info()->clone());
+
ARM_COMPUTE_ERROR_THROW_ON(CLConvertFullyConnectedWeightsKernel::validate(input->info(), output->info(), original_input_shape, data_layout));
_input = input;
_output = output;
- const unsigned int num_elems_per_input_plane = original_input_shape.x() * original_input_shape.y();
- const unsigned int num_channels = original_input_shape.z();
+ const DataLayout input_data_layout = (data_layout == DataLayout::NCHW) ? DataLayout::NHWC : DataLayout::NCHW;
+
+ const int width_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::CHANNEL);
+
+ const unsigned int num_elems_per_input_plane = original_input_shape[width_idx] * original_input_shape[height_idx];
+ const unsigned int num_channels = original_input_shape[channel_idx];
+
+ const unsigned int factor_1 = (data_layout == DataLayout::NCHW) ? num_elems_per_input_plane : num_channels;
+ const unsigned int factor_2 = (data_layout == DataLayout::NCHW) ? num_channels : num_elems_per_input_plane;
// Set build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
- if(data_layout == DataLayout::NCHW)
- {
- build_opts.add_option("-DFACTOR_1=" + support::cpp11::to_string(num_elems_per_input_plane));
- build_opts.add_option("-DFACTOR_2=" + support::cpp11::to_string(num_channels));
- }
- else
- {
- build_opts.add_option("-DFACTOR_1=" + support::cpp11::to_string(num_channels));
- build_opts.add_option("-DFACTOR_2=" + support::cpp11::to_string(num_elems_per_input_plane));
- }
+ build_opts.add_option("-DFACTOR_1=" + support::cpp11::to_string(factor_1));
+ build_opts.add_option("-DFACTOR_2=" + support::cpp11::to_string(factor_2));
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("convert_fc_weights", build_opts.options()));
// Configure kernel window
Window win = calculate_max_window(*input->info(), Steps());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
Status CLConvertFullyConnectedWeightsKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape,
DataLayout data_layout)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8, DataType::U16, DataType::S16, DataType::QS16, DataType::U32, DataType::S32,
- DataType::QS32, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1,
+ DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
+ DataType::U32, DataType::S32,
+ DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2);
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != original_input_shape.total_size_lower(3));
ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::UNKNOWN);
+ // Checks performed when output is configured
+ if((output != nullptr) && (output->total_size() != 0))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ }
+
return Status{};
}
@@ -94,4 +109,4 @@
add_2D_tensor_argument(idx, _output, window);
enqueue(queue, *this, window);
}
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/core/CL/kernels/CLConvolutionKernel.cpp b/src/core/CL/kernels/CLConvolutionKernel.cpp
index 2b08c8d..e677793 100644
--- a/src/core/CL/kernels/CLConvolutionKernel.cpp
+++ b/src/core/CL/kernels/CLConvolutionKernel.cpp
@@ -105,7 +105,7 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
/****************************************************************************************\
@@ -167,7 +167,7 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
template <unsigned int matrix_size>
@@ -226,7 +226,7 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
/****************************************************************************************\
@@ -298,7 +298,7 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLConvolutionRectangleKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLCopyKernel.cpp b/src/core/CL/kernels/CLCopyKernel.cpp
index 4f00ef9..2da67d2 100644
--- a/src/core/CL/kernels/CLCopyKernel.cpp
+++ b/src/core/CL/kernels/CLCopyKernel.cpp
@@ -33,10 +33,44 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
-#include <algorithm>
-
using namespace arm_compute;
+namespace
+{
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+
+ // Validate output if initialized
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(input->tensor_shape(), output->tensor_shape());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
+{
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output, *input);
+
+ // Configure window
+ const unsigned int num_elems_processed_per_iteration = 16 / input->element_size();
+
+ Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
+
+ AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+
+ bool window_changed = update_window_and_padding(win, input_access, output_access);
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
+} // namespace
+
CLCopyKernel::CLCopyKernel()
: _input(nullptr), _output(nullptr)
{
@@ -44,28 +78,32 @@
void CLCopyKernel::configure(const ICLTensor *input, ICLTensor *output)
{
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(input->info()->tensor_shape(), output->info()->tensor_shape());
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info()));
_input = input;
_output = output;
+ const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
+
// Create kernel
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+ build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("copy_tensor", build_opts.options()));
- // Configure window
- constexpr unsigned int num_elems_processed_per_iteration = 16;
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(input->info(), output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ ICLKernel::configure_internal(win_config.second);
+}
- Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
+Status CLCopyKernel::validate(const arm_compute::ITensorInfo *input, const arm_compute::ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get()).first);
- AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
- AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
-
- update_window_and_padding(win, input_access, output_access);
-
- ICLKernel::configure(win);
+ return Status{};
}
void CLCopyKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -73,15 +111,15 @@
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
- Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimX);
- Window slice = collapsed.first_slice_window_1D();
+ Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+ Window slice = collapsed.first_slice_window_3D();
do
{
unsigned int idx = 0;
- add_1D_tensor_argument(idx, _input, slice);
- add_1D_tensor_argument(idx, _output, slice);
+ add_3D_tensor_argument(idx, _input, slice);
+ add_3D_tensor_argument(idx, _output, slice);
enqueue(queue, *this, slice);
}
- while(collapsed.slide_window_slice_1D(slice));
+ while(collapsed.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLDeconvolutionLayerUpsampleKernel.cpp b/src/core/CL/kernels/CLDeconvolutionLayerUpsampleKernel.cpp
index 650c5b8..c6a0031 100644
--- a/src/core/CL/kernels/CLDeconvolutionLayerUpsampleKernel.cpp
+++ b/src/core/CL/kernels/CLDeconvolutionLayerUpsampleKernel.cpp
@@ -43,7 +43,7 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) == 0);
ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) == 0);
@@ -74,7 +74,9 @@
ARM_COMPUTE_ERROR_THROW_ON(CLDeconvolutionLayerUpsampleKernel::validate(input->info(), output->info(), inner_border, info));
// Create kernel
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("deconvolution_upsample"));
+ CLBuildOptions build_opts;
+ build_opts.add_option(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())));
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("deconvolution_upsample", build_opts.options()));
constexpr unsigned int num_elems_processed_per_iteration = 1;
@@ -83,7 +85,7 @@
AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLDeconvolutionLayerUpsampleKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -99,18 +101,20 @@
const int out_end_y = _output->info()->dimension(1) - _info.pad().second + _info.stride().second - 1;
const int out_step_y = _info.stride().second;
- Window slice_out = window.first_slice_window_2D();
+ Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+
+ Window slice_out = collapsed.first_slice_window_3D();
slice_out.set(Window::DimX, Window::Dimension(out_start_x, out_end_x, out_step_x));
slice_out.set(Window::DimY, Window::Dimension(out_start_y, out_end_y, out_step_y));
- Window slice_in = window.first_slice_window_2D();
+ Window slice_in = collapsed.first_slice_window_3D();
do
{
unsigned int idx = 0;
- add_2D_tensor_argument(idx, _input, slice_in);
- add_2D_tensor_argument(idx, _output, slice_out);
+ add_3D_tensor_argument(idx, _input, slice_in);
+ add_3D_tensor_argument(idx, _output, slice_out);
enqueue(queue, *this, slice_out);
}
- while(window.slide_window_slice_2D(slice_in) && window.slide_window_slice_2D(slice_out));
+ while(collapsed.slide_window_slice_3D(slice_in) && collapsed.slide_window_slice_3D(slice_out));
}
diff --git a/src/core/CL/kernels/CLDepthConcatenateLayerKernel.cpp b/src/core/CL/kernels/CLDepthConcatenateLayerKernel.cpp
index 9b30c64..4002394 100644
--- a/src/core/CL/kernels/CLDepthConcatenateLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDepthConcatenateLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,7 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
@@ -32,7 +33,6 @@
#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "support/ToolchainSupport.h"
@@ -41,6 +41,53 @@
using namespace arm_compute;
+namespace
+{
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, unsigned int depth_offset, ITensorInfo *output)
+{
+ ARM_COMPUTE_UNUSED(depth_offset);
+
+ // Configure kernel window
+ const int left_right = (output->dimension(0) - input->dimension(0)) / 2;
+ const int top_bottom = (output->dimension(1) - input->dimension(1)) / 2;
+
+ const unsigned int num_elems_processed_per_iteration = 16 / input->element_size();
+ const unsigned int num_elems_read_per_iteration = 16 / input->element_size();
+ const unsigned int num_rows_read_per_iteration = 1;
+
+ // The window needs to be based on input as we copy all the depths of input
+ Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
+ win.set(Window::DimZ, Window::Dimension(0, input->tensor_shape().z(), 1));
+
+ AccessWindowRectangle input_access(input, -left_right, -top_bottom, num_elems_read_per_iteration, num_rows_read_per_iteration);
+ AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+ bool window_changed = update_window_and_padding(win, input_access, output_access);
+ output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
+Status validate_arguments(const ITensorInfo *input, unsigned int depth_offset, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) + depth_offset > output->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) > output->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) > output->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(3, input, output);
+
+ // The gaps between the two lowest dimensions of input and output need to be divisible by 2
+ // Otherwise it is not clear how the padding should be added onto the input tensor
+ ARM_COMPUTE_RETURN_ERROR_ON((output->dimension(0) - input->dimension(0)) % 2);
+ ARM_COMPUTE_RETURN_ERROR_ON((output->dimension(1) - input->dimension(1)) % 2);
+
+ return Status{};
+}
+} // namespace
+
CLDepthConcatenateLayerKernel::CLDepthConcatenateLayerKernel()
: _input(nullptr), _output(nullptr), _top_bottom(0), _left_right(0), _depth_offset(0)
{
@@ -53,59 +100,41 @@
void CLDepthConcatenateLayerKernel::configure(const ICLTensor *input, unsigned int depth_offset, ICLTensor *output)
{
- static std::map<int, std::pair<std::string, int>> configs_map =
- {
- { 1, { "uchar", 16 } },
- { 2, { "ushort", 8 } },
- { 4, { "uint", 4 } },
- { 8, { "ulong", 2 } },
- };
-
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) + depth_offset > output->info()->dimension(2));
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) > output->info()->dimension(0));
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(1) > output->info()->dimension(1));
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(3, input, output);
- ARM_COMPUTE_ERROR_ON(configs_map.find(input->info()->element_size()) == configs_map.end());
-
- // The gaps between the two lowest dimensions of input and output need to be divisible by 2
- // Otherwise it is not clear how the padding should be added onto the input tensor
- ARM_COMPUTE_ERROR_ON((output->info()->dimension(0) - input->info()->dimension(0)) % 2);
- ARM_COMPUTE_ERROR_ON((output->info()->dimension(1) - input->info()->dimension(1)) % 2);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), depth_offset, output->info()));
_input = input;
_output = output;
_depth_offset = depth_offset;
+ const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
+
// Add build options
- auto config = configs_map.find(static_cast<int>(input->info()->element_size()));
- std::set<std::string> build_opts;
- build_opts.emplace(("-DDATA_TYPE=" + config->second.first));
- build_opts.emplace(("-DVEC_SIZE=" + support::cpp11::to_string(config->second.second)));
+ CLBuildOptions build_opts;
+ build_opts.add_option("-DDATA_TYPE=" + get_underlying_cl_type_from_data_type(input->info()->data_type()));
+ build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
// Create kernel
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("concatenate_depth", build_opts));
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("concatenate_depth", build_opts.options()));
// Configure kernel window
_left_right = (output->info()->dimension(0) - input->info()->dimension(0)) / 2;
_top_bottom = (output->info()->dimension(1) - input->info()->dimension(1)) / 2;
- const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
- const unsigned int num_elems_read_per_iteration = 16 / input->info()->element_size();
- const unsigned int num_rows_read_per_iteration = 1;
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(input->info(), depth_offset, output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- // The window needs to be based on input as we copy all the depths of input
- Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration));
- win.set(Window::DimZ, Window::Dimension(0, input->info()->tensor_shape().z(), 1));
+ ICLKernel::configure_internal(std::get<1>(win_config));
+}
- AccessWindowRectangle input_access(input->info(), -_left_right, -_top_bottom, num_elems_read_per_iteration, num_rows_read_per_iteration);
- AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
- update_window_and_padding(win, input_access, output_access);
- output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
-
- ICLKernel::configure(win);
+Status CLDepthConcatenateLayerKernel::validate(const arm_compute::ITensorInfo *input,
+ unsigned int depth_offset,
+ const arm_compute::ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, depth_offset, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), depth_offset, output->clone().get()).first);
+ return Status{};
}
void CLDepthConcatenateLayerKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLDepthConvertLayerKernel.cpp b/src/core/CL/kernels/CLDepthConvertLayerKernel.cpp
index 83908a1..ffbd295 100644
--- a/src/core/CL/kernels/CLDepthConvertLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDepthConvertLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,7 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
@@ -38,74 +39,83 @@
using namespace arm_compute;
-void CLDepthConvertLayerKernel::configure(const ICLTensor *input, ICLTensor *output, ConvertPolicy policy, uint32_t shift)
+namespace
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::U8, DataType::S16, DataType::QS16,
- DataType::U16, DataType::U32, DataType::S32, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::S16, DataType::QS16,
- DataType::U16, DataType::U32, DataType::S32, DataType::F32);
- ARM_COMPUTE_ERROR_ON(input == output);
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == output->info()->data_type(), "Input and output data types must be different");
- ARM_COMPUTE_ERROR_ON(shift >= 8);
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, ConvertPolicy policy, uint32_t shift)
+{
+ ARM_COMPUTE_UNUSED(policy);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON(input == output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S16,
+ DataType::U16, DataType::U32, DataType::S32,
+ DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16,
+ DataType::U16, DataType::U32, DataType::S32,
+ DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == output->data_type(), "Input and output data types must be different");
+ ARM_COMPUTE_RETURN_ERROR_ON(shift >= 8);
// Check if convertion is supported
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::QS8 && output->info()->data_type() != DataType::F32,
- "Only data types supported [in] QS8 -> [out] F32");
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::QS16 && (output->info()->data_type() != DataType::F32),
- "Only data types supported [in] QS16 -> [out] F32");
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::F32 && ((output->info()->data_type() != DataType::QS8) && output->info()->data_type() != DataType::QS16),
- "Only data types supported [in] F32 -> [out] QS8, QS16");
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::U8 && (output->info()->data_type() != DataType::U16 && output->info()->data_type() != DataType::S16
- && output->info()->data_type() != DataType::U32 && output->info()->data_type() != DataType::S32),
- "Only data types supported [in] U8 -> [out] U16, S16, U32, S32");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::U8 && (output->data_type() != DataType::U16 && output->data_type() != DataType::S16
+ && output->data_type() != DataType::U32 && output->data_type() != DataType::S32),
+ "Only data types supported [in] U8 -> [out] U16, S16, U32, S32");
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::U16 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::U32
- && output->info()->data_type() != DataType::S32),
- "Only data types supported [in] U16 -> [out] U8, U32, S32");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::U16 && (output->data_type() != DataType::U8 && output->data_type() != DataType::U32
+ && output->data_type() != DataType::S32),
+ "Only data types supported [in] U16 -> [out] U8, U32, S32");
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::S16 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::U32
- && output->info()->data_type() != DataType::S32),
- "Only data types supported [in] S16 -> [out] U8, U32, S32");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S16 && (output->data_type() != DataType::U8 && output->data_type() != DataType::U32
+ && output->data_type() != DataType::S32),
+ "Only data types supported [in] S16 -> [out] U8, U32, S32");
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::U32 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::U16
- && output->info()->data_type() != DataType::S16),
- "Only data types supported [in] U32 -> [out] U8, U16, S16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::U32 && (output->data_type() != DataType::U8 && output->data_type() != DataType::U16
+ && output->data_type() != DataType::S16),
+ "Only data types supported [in] U32 -> [out] U8, U16, S16");
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::S32 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::U16
- && output->info()->data_type() != DataType::S16),
- "Only data types supported [in] S32 -> [out] U8, U16, S16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && (output->data_type() != DataType::U8 && output->data_type() != DataType::U16
+ && output->data_type() != DataType::S16),
+ "Only data types supported [in] S32 -> [out] U8, U16, S16");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F16 && output->data_type() != DataType::F32,
+ "Only data types supported [in] F16 -> [out] F32");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F32 && output->data_type() != DataType::F16,
+ "Only data types supported [in] F32 -> [out] F16");
+
+ // Validate in case of configured output
+ if(output->total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ }
+
+ return Status{};
+}
+} // namespace
+
+void CLDepthConvertLayerKernel::configure(const ICLTensor *input, ICLTensor *output, ConvertPolicy policy, uint32_t shift)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Auto initialize output shape if not initialized (We can only auto-configure the shape, datatype must be given)
set_shape_if_empty(*output->info(), input->info()->tensor_shape());
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), policy, shift));
// Get data sizes
const size_t input_size = data_size_from_type(input->info()->data_type());
const size_t output_size = data_size_from_type(output->info()->data_type());
- // Construct kernel name and build options
- std::string kernel_name = "convert_depth";
- std::set<std::string> build_opts;
- if(input_size > output_size)
- {
- kernel_name += "_down";
- // Down conversions from float always SATURATE as out-of-bounds conversion from float->integer is implementation defined
- build_opts.insert(((policy == ConvertPolicy::WRAP) && !is_data_type_float(input->info()->data_type())) ? "-DWRAP" : "-DSATURATE");
- }
- else
- {
- kernel_name += "_up";
- }
- build_opts.emplace("-DDATA_TYPE_IN=" + get_cl_type_from_data_type(input->info()->data_type()));
- build_opts.emplace("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->info()->data_type()));
- if(is_data_type_fixed_point(input->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type()))
- {
- build_opts.emplace("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
- }
+ // Set build options
+ CLBuildOptions build_opts;
+ build_opts.add_option("-DDATA_TYPE_IN=" + get_cl_type_from_data_type(input->info()->data_type()));
+ build_opts.add_option("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->info()->data_type()));
+ // Down conversions from float always SATURATE as out-of-bounds conversion from float->integer is implementation defined
+ build_opts.add_option_if(input_size > output_size, ((policy == ConvertPolicy::WRAP) && !is_data_type_float(input->info()->data_type())) ? "-DWRAP" : "-DSATURATE");
+ build_opts.add_option_if(is_data_type_float(input->info()->data_type()), "-DIS_DATA_TYPE_FLOAT");
// Create kernel
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts));
+ const std::string kernel_name = (input_size > output_size) ? "convert_depth_down" : "convert_depth_up";
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Set shift arg
unsigned int idx = 2 * num_arguments_per_2D_tensor(); //Skip the input and output parameters
@@ -115,3 +125,10 @@
constexpr unsigned int num_elems_processed_per_iteration = 16;
ICLSimple2DKernel::configure(input, output, num_elems_processed_per_iteration);
}
+
+Status CLDepthConvertLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, ConvertPolicy policy, uint32_t shift)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, policy, shift));
+
+ return Status{};
+}
diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp
index e4ad97f..a40aa28 100644
--- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLKernel.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/Error.h"
@@ -44,14 +45,15 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier,
const ActivationLayerInfo &act_info)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(act_info.enabled() && ((input->data_type() != DataType::QASYMM8) || ((act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU)
&& (act_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU)
- && (act_info.activation() != ActivationLayerInfo::ActivationFunction::RELU))),
- "For QASYMM8 only relu, lower bounded relu and lower-upper bounded relu are supported");
+ && (act_info.activation() != ActivationLayerInfo::ActivationFunction::RELU)
+ && (act_info.activation() != ActivationLayerInfo::ActivationFunction::LOGISTIC))),
+ "For QASYMM8 only logistic, relu, lower bounded relu and lower-upper bounded relu are supported");
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(0) != 3 || weights->dimension(1) != 3);
- ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(2) * depth_multiplier) != output->dimension(2));
ARM_COMPUTE_RETURN_ERROR_ON(conv_info.stride().first < 1 || conv_info.stride().first > 3);
const bool is_qasymm = is_data_type_quantized_asymmetric(input->data_type());
@@ -66,7 +68,7 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
}
- ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON((biases->dimension(0) != weights->dimension(2)) && (weights->dimension(2) != 1 || biases->dimension(0) != weights->dimension(3)));
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
@@ -167,9 +169,11 @@
}
else
{
- kernel_name = is_qasymm ? "depthwise_convolution_3x3_quantized_nchw" : "depthwise_convolution_3x3";
+ const bool is_dot8_supported = dot8_supported(CLKernelLibrary::get().get_device());
+
+ kernel_name = is_qasymm ? (std::string("depthwise_convolution_3x3_quantized") + (is_dot8_supported ? "_dot8" : "") + "_nchw") : "depthwise_convolution_3x3";
num_elems_written_per_iteration_x = 8 / data_size_from_type(input->data_type());
- num_elems_written_per_iteration_y = (is_qasymm && conv_stride_y < 3) ? (2 / conv_stride_y) : 1;
+ num_elems_written_per_iteration_y = (is_qasymm && conv_stride_y == 1) ? 2 : 1;
num_elems_read_per_iteration_x = 3 + (num_elems_written_per_iteration_x - 1) * conv_stride_x;
num_elems_read_per_iteration_y = num_elems_written_per_iteration_y + 2;
}
@@ -193,7 +197,7 @@
} // namespace
CLDepthwiseConvolutionLayer3x3NCHWKernel::CLDepthwiseConvolutionLayer3x3NCHWKernel()
- : _conv_stride_x(0), _conv_pad_top(0)
+ : _conv_stride_x(0), _conv_pad_top(0), _conv_pad_left(0)
{
}
@@ -207,6 +211,7 @@
ActivationLayerInfo act_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), conv_info, depth_multiplier, act_info));
bool is_qasymm = is_data_type_quantized_asymmetric(input->info()->data_type());
@@ -275,7 +280,7 @@
auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, depth_multiplier, gpu_target, kernel_name);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
@@ -340,7 +345,7 @@
add_3D_tensor_argument(idx, _output, slice_out);
add_3D_tensor_argument(idx, _weights, slice_weights);
- enqueue(queue, *this, slice_out, _lws_hint);
+ enqueue(queue, *this, slice_out, lws_hint());
}
while(window.slide_window_slice_3D(slice_out) && win_in.slide_window_slice_3D(slice_in));
}
diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp
index a54e92c..50f17d5 100644
--- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.cpp
@@ -44,18 +44,28 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier,
const ActivationLayerInfo &act_info)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((act_info.enabled()) && (act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU)
- && (act_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU)
- && (act_info.activation() != ActivationLayerInfo::ActivationFunction::RELU),
- "For QASYMM8 only relu, lower bounded relu and lower-upper bounded relu are supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::QASYMM8);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((act_info.enabled()) && (input->data_type() == DataType::F32 || ((act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU)
+ && (act_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU)
+ && (act_info.activation() != ActivationLayerInfo::ActivationFunction::RELU)
+ && (act_info.activation() != ActivationLayerInfo::ActivationFunction::LOGISTIC))),
+ "For QASYMM8 only logistic, relu, lower bounded relu and lower-upper bounded relu are supported");
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
- ARM_COMPUTE_RETURN_ERROR_ON(depth_multiplier > 1); // COMPMID-1071 Add depth multiplier support for NHWC
+ ARM_COMPUTE_RETURN_ERROR_ON(depth_multiplier > 1);
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(1) != 3 || weights->dimension(2) != 3);
+ const bool is_qasymm = is_data_type_quantized_asymmetric(input->data_type());
+
if(biases != nullptr)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+ if(is_qasymm)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
+ }
ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(0));
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
@@ -72,12 +82,26 @@
std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *output,
const PadStrideInfo &conv_info)
{
- const unsigned int num_rows_processed_per_iteration = 4;
- const unsigned int num_elems_accessed_per_iteration = 4;
- const unsigned int num_rows_read_per_iteration = num_rows_processed_per_iteration + 2;
- const unsigned int num_rows_written_per_iteration = num_rows_processed_per_iteration / conv_info.stride().first;
+ // Get convolved dimensions
+ const TensorShape output_shape = compute_depthwise_convolution_shape(*input, *weights, conv_info, 1 /* depth_multiplier */);
- const BorderSize border_size(conv_info.pad_left() + num_rows_read_per_iteration * std::max(conv_info.pad_top(), conv_info.pad_bottom()), 0, conv_info.pad_right(), 0);
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output,
+ output_shape,
+ 1,
+ input->data_type(),
+ input->quantization_info());
+
+ const bool is_qasymm = is_data_type_quantized_asymmetric(input->data_type());
+ const bool is_stride_1 = ((conv_info.stride().first == conv_info.stride().second) && (conv_info.stride().first == 1));
+
+ const unsigned int num_rows_processed_per_iteration = is_stride_1 ? 2 : 1;
+ const unsigned int num_elems_accessed_per_iteration = is_qasymm ? 4 : 2;
+ const unsigned int num_rows_read_per_iteration = num_rows_processed_per_iteration + 2;
+ const unsigned int num_rows_written_per_iteration = std::ceil(num_rows_processed_per_iteration / static_cast<float>(conv_info.stride().first));
+
+ BorderSize border_size;
+ border_size = BorderSize(conv_info.pad_left(), 0, std::max(std::max(conv_info.pad_right(), conv_info.pad_bottom()), conv_info.pad_top()), 0);
// Configure kernel window
Window win = calculate_max_window(*output, Steps(num_elems_accessed_per_iteration, num_rows_written_per_iteration));
@@ -103,7 +127,7 @@
} // namespace
CLDepthwiseConvolutionLayer3x3NHWCKernel::CLDepthwiseConvolutionLayer3x3NHWCKernel()
- : _num_rows_processed_per_iteration(1)
+ : _num_rows_processed_per_iteration(1), _num_planes_processed_per_iteration(1)
{
}
@@ -126,7 +150,6 @@
output_shape,
1,
input->info()->data_type(),
- input->info()->fixed_point_position(),
input->info()->quantization_info());
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), conv_info, depth_multiplier, act_info));
@@ -135,72 +158,93 @@
ARM_COMPUTE_ERROR_ON(conv_stride_x < 1 || conv_stride_x > 2);
ARM_COMPUTE_ERROR_ON(std::max(conv_info.pad_top(), conv_info.pad_bottom()) > 1);
- _input = input;
- _output = output;
- _weights = weights;
- _biases = biases;
- _conv_stride_y = conv_info.stride().second;
- _conv_pad_left = conv_info.pad_left();
- _num_rows_processed_per_iteration = 4;
+ const bool is_qasymm = is_data_type_quantized_asymmetric(input->info()->data_type());
+ const bool is_stride_1 = ((conv_info.stride().first == conv_info.stride().second) && (conv_info.stride().first == 1));
- const unsigned int num_elems_accessed_per_iteration = 4;
- const unsigned int num_rows_read_per_iteration = _num_rows_processed_per_iteration + 2;
+ _input = input;
+ _output = output;
+ _weights = weights;
+ _biases = biases;
+ _conv_stride_y = conv_info.stride().second;
+ _num_rows_processed_per_iteration = is_stride_1 ? 2 : 1;
+ _num_planes_processed_per_iteration = is_stride_1 ? 2 : 1;
+ _border_size = BorderSize(conv_info.pad_left(), 0, std::max(std::max(conv_info.pad_right(), conv_info.pad_bottom()), conv_info.pad_top()), 0);
- _border_size = BorderSize(_conv_pad_left + num_rows_read_per_iteration * std::max(conv_info.pad_top(), conv_info.pad_bottom()), 0, conv_info.pad_right(), 0);
-
- float multiplier = _input->info()->quantization_info().scale * _weights->info()->quantization_info().scale / _output->info()->quantization_info().scale;
- int output_multiplier = 0;
- int output_shift = 0;
- quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+ const unsigned int num_elems_accessed_per_iteration = is_qasymm ? 4 : 2;
CLBuildOptions build_opts;
build_opts.add_option_if(_biases != nullptr, "-DHAS_BIAS");
- build_opts.add_option("-DINPUT_OFFSET=" + support::cpp11::to_string(-_input->info()->quantization_info().offset));
- build_opts.add_option("-DWEIGHTS_OFFSET=" + support::cpp11::to_string(-_weights->info()->quantization_info().offset));
- build_opts.add_option("-DOUTPUT_OFFSET=" + support::cpp11::to_string(_output->info()->quantization_info().offset));
- build_opts.add_option("-DK_OFFSET=" + support::cpp11::to_string(9 * input->info()->quantization_info().offset * weights->info()->quantization_info().offset));
- build_opts.add_option("-DOUTPUT_MULTIPLIER=" + support::cpp11::to_string(output_multiplier));
- build_opts.add_option("-DOUTPUT_SHIFT=" + support::cpp11::to_string(output_shift));
build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_accessed_per_iteration));
- build_opts.add_option("-DSRC_DEPTH=" + support::cpp11::to_string(_input->info()->dimension(2)));
+ build_opts.add_option("-DSRC_DIM_2=" + support::cpp11::to_string(_input->info()->dimension(2)));
build_opts.add_option("-DCONV_PAD_TOP=" + support::cpp11::to_string(conv_info.pad_top()));
- build_opts.add_option("-DROWS_READ=" + support::cpp11::to_string(num_rows_read_per_iteration));
+ build_opts.add_option("-DCONV_PAD_LEFT=" + support::cpp11::to_string(conv_info.pad_left()));
- if(act_info.enabled())
+ if(is_qasymm)
{
- const int a_val = input->info()->quantization_info().quantize(act_info.a(), RoundingPolicy::TO_NEAREST_UP);
- const int b_val = input->info()->quantization_info().quantize(act_info.b(), RoundingPolicy::TO_NEAREST_UP);
- const int o1 = input->info()->quantization_info().offset;
+ float multiplier = _input->info()->quantization_info().scale * _weights->info()->quantization_info().scale / _output->info()->quantization_info().scale;
+ int output_multiplier = 0;
+ int output_shift = 0;
+ quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- build_opts.add_option("-DFUSED_ACTIVATION=" + lower_string(string_from_activation_func(act_info.activation())));
- build_opts.add_option("-DA_VAL=" + support::cpp11::to_string(a_val));
- build_opts.add_option("-DB_VAL=" + support::cpp11::to_string(b_val));
- build_opts.add_option("-DCONST_0=" + support::cpp11::to_string(o1));
+ build_opts.add_option("-DSRC_DIM_1=" + support::cpp11::to_string(_input->info()->dimension(1)));
+ build_opts.add_option("-DINPUT_OFFSET=" + support::cpp11::to_string(-_input->info()->quantization_info().offset));
+ build_opts.add_option("-DWEIGHTS_OFFSET=" + support::cpp11::to_string(-_weights->info()->quantization_info().offset));
+ build_opts.add_option("-DOUTPUT_OFFSET=" + support::cpp11::to_string(_output->info()->quantization_info().offset));
+ build_opts.add_option("-DK_OFFSET=" + support::cpp11::to_string(9 * input->info()->quantization_info().offset * weights->info()->quantization_info().offset));
+ build_opts.add_option("-DOUTPUT_MULTIPLIER=" + support::cpp11::to_string(output_multiplier));
+ build_opts.add_option("-DOUTPUT_SHIFT=" + support::cpp11::to_string(output_shift));
- if(output != nullptr)
+ if(act_info.enabled())
{
- const float s1 = input->info()->quantization_info().scale;
- const float s2 = output->info()->quantization_info().scale;
- const int o2 = output->info()->quantization_info().offset;
+ const int a_val = input->info()->quantization_info().quantize(act_info.a(), RoundingPolicy::TO_NEAREST_UP);
+ const int b_val = input->info()->quantization_info().quantize(act_info.b(), RoundingPolicy::TO_NEAREST_UP);
+ const int o1 = input->info()->quantization_info().offset;
- if(o1 != o2 || s1 != s2)
+ build_opts.add_option("-DFUSED_ACTIVATION=" + lower_string(string_from_activation_func(act_info.activation())));
+ build_opts.add_option("-DA_VAL=" + support::cpp11::to_string(a_val));
+ build_opts.add_option("-DB_VAL=" + support::cpp11::to_string(b_val));
+ build_opts.add_option("-DCONST_0=" + support::cpp11::to_string(o1));
+
+ if(output != nullptr)
{
- build_opts.add_option("-DS1_VAL=" + float_to_string_with_full_precision(s1));
- build_opts.add_option("-DS2_VAL=" + float_to_string_with_full_precision(s2));
- build_opts.add_option("-DO1_VAL=" + support::cpp11::to_string(o1));
- build_opts.add_option("-DO2_VAL=" + support::cpp11::to_string(o2));
+ const float s1 = input->info()->quantization_info().scale;
+ const float s2 = output->info()->quantization_info().scale;
+ const int o2 = output->info()->quantization_info().offset;
+
+ if(o1 != o2 || s1 != s2)
+ {
+ build_opts.add_option("-DS1_VAL=" + float_to_string_with_full_precision(s1));
+ build_opts.add_option("-DS2_VAL=" + float_to_string_with_full_precision(s2));
+ build_opts.add_option("-DO1_VAL=" + support::cpp11::to_string(o1));
+ build_opts.add_option("-DO2_VAL=" + support::cpp11::to_string(o2));
+ }
}
}
}
+ if(is_stride_1)
+ {
+ build_opts.add_option("-DNUM_ROWS_PROCESSED=" + support::cpp11::to_string(_num_rows_processed_per_iteration));
+ build_opts.add_option("-DNUM_PLANES_PROCESSED=" + support::cpp11::to_string(_num_planes_processed_per_iteration));
+ build_opts.add_option("-DDST_DIM_2=" + support::cpp11::to_string(_output->info()->dimension(2)));
+ }
+ else
+ {
+ build_opts.add_option("-DCONV_STRIDE_X=" + support::cpp11::to_string(conv_stride_x));
+ build_opts.add_option("-DCONV_STRIDE_Y=" + support::cpp11::to_string(_conv_stride_y));
+ }
+
// Create kernel
- std::string kernel_name = std::string("depthwise_convolution_3x3_quantized_nhwc_stride") + support::cpp11::to_string(conv_stride_x);
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
+ const bool is_dot8_supported = dot8_supported(CLKernelLibrary::get().get_device());
+ std::string kernel_name = std::string("depthwise_convolution_3x3") + (is_qasymm ? std::string("_quantized") + ((is_dot8_supported
+ && is_stride_1 ) ? "_dot8" : "") : "") + "_nhwc" + (is_stride_1 ? "_stride1" : "");
+
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), conv_info);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Set config_id for enabling LWS tuning
_config_id = kernel_name;
@@ -214,6 +258,8 @@
_config_id += support::cpp11::to_string(output->info()->dimension(0));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(1));
+ _config_id += "_";
+ _config_id += string_from_data_type(input->info()->data_type());
}
Status CLDepthwiseConvolutionLayer3x3NHWCKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
@@ -234,26 +280,33 @@
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
+ Window win = window;
+ win.set(Window::DimZ, Window::Dimension(0, std::ceil(_output->info()->dimension(2) / static_cast<float>(_num_planes_processed_per_iteration)), 1));
+
// Create input window and adjust
- Window win_in = window;
- win_in.adjust(Window::DimY, -_conv_pad_left, true);
+ Window win_in = win;
win_in.set_dimension_step(Window::DimY, _num_rows_processed_per_iteration);
win_in.set_dimension_step(Window::DimZ, _conv_stride_y);
ARM_COMPUTE_ERROR_ON((win_in.y().step() < window.y().step()) || (win_in.z().step() < window.z().step()));
Window slice_in = win_in.first_slice_window_3D();
- Window slice_out = window.first_slice_window_3D();
+ Window slice_out = win.first_slice_window_3D();
+
+ unsigned int idx = 3 * num_arguments_per_3D_tensor();
if(_biases != nullptr)
{
- unsigned int idx = 3 * num_arguments_per_3D_tensor();
- Window win_biases;
+ Window win_biases;
win_biases.use_tensor_dimensions(_biases->info()->tensor_shape());
win_biases.set_dimension_step(Window::DimX, window.x().step());
add_1D_tensor_argument(idx, _biases, win_biases);
}
+ const int max_offset = _input->info()->strides_in_bytes().z() * _input->info()->dimension(2) - (_input->info()->padding().bottom + _input->info()->padding().top) *
+ _input->info()->strides_in_bytes().y();
+ _kernel.setArg(idx, max_offset);
+
do
{
unsigned int idx = 0;
@@ -261,7 +314,7 @@
add_3D_tensor_argument(idx, _output, slice_out);
add_3D_tensor_argument(idx, _weights, slice_out);
- enqueue(queue, *this, slice_out, _lws_hint);
+ enqueue(queue, *this, slice_out, lws_hint());
}
while(window.slide_window_slice_3D(slice_out) && win_in.slide_window_slice_3D(slice_in));
}
diff --git a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
index f44f08b..d5c333a 100644
--- a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
@@ -46,12 +47,14 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier)
{
+ const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
+
ARM_COMPUTE_UNUSED(conv_info);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()) && has_bias);
- ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(2) * depth_multiplier) != output->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(idx_c) * depth_multiplier) != output->dimension(2));
ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0)));
return Status{};
@@ -66,6 +69,10 @@
_input = input;
_output = output;
+ const DataLayout data_layout = input->info()->data_layout();
+ const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
// Create kernel
CLBuildOptions build_opts;
@@ -76,11 +83,12 @@
build_opts.add_option("-DPAD_TOP=" + support::cpp11::to_string(conv_info.pad_top()));
build_opts.add_option("-DPAD_RIGHT=" + support::cpp11::to_string(conv_info.pad_right()));
build_opts.add_option("-DPAD_BOTTOM=" + support::cpp11::to_string(conv_info.pad_bottom()));
- build_opts.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(input->info()->dimension(0)));
- build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(1)));
+ build_opts.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(input->info()->dimension(idx_w)));
+ build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(idx_h)));
build_opts.add_option("-DKERNEL_WIDTH=" + support::cpp11::to_string(kernel_dims.width));
build_opts.add_option("-DKERNEL_HEIGHT=" + support::cpp11::to_string(kernel_dims.height));
build_opts.add_option("-DDEPTH_MULTIPLIER=" + support::cpp11::to_string(depth_multiplier));
+ build_opts.add_option("-D" + string_from_data_layout(input->info()->data_layout()));
build_opts.add_option_if(has_bias, "-DHAS_BIAS");
build_opts.add_option_if_else(is_data_type_quantized_asymmetric(input->info()->data_type()),
"-DPAD_VALUE=" + support::cpp11::to_string(input->info()->quantization_info().offset),
@@ -88,21 +96,12 @@
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("depthwise_im2col", build_opts.options()));
- // Configure the local work size for Bifrost with a value obtained
- // via exhaustive autotuning for the MobileNets tensor shapes.
- const GPUTarget gpu_target = get_target();
-
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
- {
- _lws_hint = cl::NDRange(1, 2, 1);
- }
-
// Configure kernel window
Window win = calculate_max_window(*output->info(), Steps());
// CLDepthwiseIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped
output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
Status CLDepthwiseIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier)
@@ -136,7 +135,7 @@
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice_in);
add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice) && window.slide_window_slice_3D(slice_in));
}
diff --git a/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp b/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp
index 26336eb..cdc27e8 100644
--- a/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp
@@ -25,37 +25,30 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "support/ToolchainSupport.h"
using namespace arm_compute;
+using namespace arm_compute::misc::shape_calculator;
namespace
{
-TensorShape compute_output_shape(const TensorShape &input, size_t conv_w, size_t conv_h)
-{
- TensorShape output_shape(input);
- output_shape.set(0, conv_w);
- output_shape.set(1, conv_h);
- output_shape.set(2, input.x() / (conv_w * conv_h));
-
- return output_shape;
-}
-
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
if(output->total_size() != 0)
{
- TensorShape output_shape = compute_output_shape(input->tensor_shape(), conv_w, conv_h);
+ TensorShape output_shape = compute_vector_to_tensor_output_shape(input->tensor_shape(), conv_w, conv_h, output->data_layout());
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -72,7 +65,7 @@
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Output auto inizialitation if not yet initialized
- TensorShape output_shape = compute_output_shape(input->info()->tensor_shape(), conv_w, conv_h);
+ TensorShape output_shape = compute_vector_to_tensor_output_shape(input->info()->tensor_shape(), conv_w, conv_h, output->info()->data_layout());
auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), conv_w, conv_h));
@@ -85,6 +78,7 @@
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
build_opts.add_option("-DCONV_WIDTH=" + support::cpp11::to_string(conv_w));
build_opts.add_option("-DCONV_HEIGHT=" + support::cpp11::to_string(conv_h));
+ build_opts.add_option("-D" + string_from_data_layout(output->info()->data_layout()));
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("depthwise_vector_to_tensor", build_opts.options()));
@@ -93,7 +87,7 @@
// The CLDepthwisevectorToTensorKernel doesn't need padding so update_window_and_padding() can be skipped
output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
Status CLDepthwiseVectorToTensorKernel::validate(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h)
diff --git a/src/core/CL/kernels/CLDepthwiseWeightsReshapeKernel.cpp b/src/core/CL/kernels/CLDepthwiseWeightsReshapeKernel.cpp
index b5a607d..683dda8 100644
--- a/src/core/CL/kernels/CLDepthwiseWeightsReshapeKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseWeightsReshapeKernel.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
@@ -38,18 +39,21 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *biases)
{
+ const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+ const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()) && (biases != nullptr));
- ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != output->dimension(1));
- ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != (input->dimension(0) * input->dimension(1) + ((biases != nullptr) ? 1 : 0)));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(idx_c) != output->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != (input->dimension(idx_w) * input->dimension(idx_h) + ((biases != nullptr) ? 1 : 0)));
if(biases != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases);
- ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != input->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != input->dimension(idx_c));
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
@@ -71,11 +75,14 @@
_biases = biases;
_output = output;
+ const size_t idx_w = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::WIDTH);
+
// Create kernel
std::set<std::string> build_opts;
build_opts.emplace("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
- build_opts.emplace("-DSRC_WIDTH=" + support::cpp11::to_string(input->info()->dimension(0)));
+ build_opts.emplace("-DSRC_WIDTH=" + support::cpp11::to_string(input->info()->dimension(idx_w)));
+ build_opts.emplace("-D" + string_from_data_layout(input->info()->data_layout()));
if(_biases != nullptr)
{
build_opts.emplace("-DHAS_BIAS");
@@ -88,7 +95,7 @@
// The CLDepthwiseWeightsReshapeKernel doesn't need padding so update_window_and_padding() can be skipped
output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
Status CLDepthwiseWeightsReshapeKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *biases)
@@ -105,10 +112,14 @@
Window slice = window.first_slice_window_3D();
Window slice_out = window.first_slice_window_2D();
+ const size_t idx_w = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::HEIGHT);
+ const size_t idx_c = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::CHANNEL);
+
// Setup slice
- slice.set(Window::DimX, Window::Dimension(0, _input->info()->dimension(0), _input->info()->dimension(0)));
- slice.set(Window::DimY, Window::Dimension(0, _input->info()->dimension(1), 1));
- slice.set(Window::DimZ, Window::Dimension(0, _input->info()->dimension(2), 1));
+ slice.set(Window::DimX, Window::Dimension(0, _input->info()->dimension(idx_w), _input->info()->dimension(idx_w)));
+ slice.set(Window::DimY, Window::Dimension(0, _input->info()->dimension(idx_h), 1));
+ slice.set(Window::DimZ, Window::Dimension(0, _input->info()->dimension(idx_c), 1));
// Setup output slice
// The first two dimensions of the output are increased by the inner loops
diff --git a/src/core/CL/kernels/CLDequantizationLayerKernel.cpp b/src/core/CL/kernels/CLDequantizationLayerKernel.cpp
index fa982d6..d4c1bec 100644
--- a/src/core/CL/kernels/CLDequantizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDequantizationLayerKernel.cpp
@@ -54,7 +54,7 @@
std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *min_max)
{
// Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::F32, 0);
+ auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::F32);
constexpr unsigned int num_elems_processed_per_iteration = 4;
@@ -96,7 +96,7 @@
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- ICLKernel::configure(std::get<1>(win_config));
+ ICLKernel::configure_internal(std::get<1>(win_config));
}
Status CLDequantizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *min_max)
diff --git a/src/core/CL/kernels/CLDerivativeKernel.cpp b/src/core/CL/kernels/CLDerivativeKernel.cpp
index da02227..f51628f 100644
--- a/src/core/CL/kernels/CLDerivativeKernel.cpp
+++ b/src/core/CL/kernels/CLDerivativeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -115,7 +115,7 @@
output_x_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
output_y_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLDerivativeKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLDilateKernel.cpp b/src/core/CL/kernels/CLDilateKernel.cpp
index 3abd747..89853d7 100644
--- a/src/core/CL/kernels/CLDilateKernel.cpp
+++ b/src/core/CL/kernels/CLDilateKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -61,5 +61,5 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index 7c6c7de..c8da7ac 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
@@ -33,7 +34,6 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "support/ToolchainSupport.h"
@@ -44,21 +44,23 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(0) != weights->dimension(1),
- "Weights should have same width as length");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(0) != 1 && weights->dimension(0) != 3 && weights->dimension(0) != 5,
+
+ const DataLayout data_layout = input->data_layout();
+ const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(width_idx) != weights->dimension(height_idx), "Weights should have same width and height");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(width_idx) != 1 && weights->dimension(width_idx) != 3 && weights->dimension(width_idx) != 5,
"Kernel sizes other than 1x1, 3x3 or 5x5 are not supported");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(2) != input->dimension(2),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(channel_idx) != input->dimension(channel_idx),
"Weights feature map dimension should match the respective input's one");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(0) != weights->dimension(1),
- "Only rectangular weights are supported!");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 4,
- "Weights can be at most 4 dimensional");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(0) == 1) && std::get<0>(conv_info.stride()) > 3,
- "Strides larger than 3 not supported for 1x1 convolution.");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(0) == 3 || weights->dimension(0) == 5) && std::get<0>(conv_info.stride()) > 2,
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 4, "Weights can be at most 4 dimensional");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 1) && std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported for 1x1 convolution.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5) && std::get<0>(conv_info.stride()) > 2,
"Strides larger than 2 not supported for 3x3 convolution.");
if(biases != nullptr)
@@ -83,42 +85,32 @@
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(),
misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, const GPUTarget target)
+inline bool can_run_optimized_kernel_for_bifrost(GPUTarget gpu_target, unsigned int conv_stride_x, unsigned int conv_stride_y, unsigned int kernel_size,
+ DataType data_type, DataLayout data_layout)
{
- const unsigned int kernel_size = weights->dimension(0);
- const DataType data_type = input->data_type();
+ return gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::G76) && (kernel_size <= 5)
+ && (conv_stride_x == 1) && (conv_stride_y == 1) && (data_type == DataType::F32) && (data_layout == DataLayout::NCHW);
+}
- // Get convolved dimensions
- TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
+inline void setup_num_elems(unsigned int &num_elems_read_per_iteration_x, unsigned int &num_elems_read_per_iteration_y,
+ unsigned int &num_elems_written_per_iteration_x, unsigned int &num_elems_written_per_iteration_y,
+ unsigned int kernel_size, const PadStrideInfo &conv_info, const GPUTarget target, ITensorInfo *input)
+{
+ const DataType data_type = input->data_type();
+ const DataLayout data_layout = input->data_layout();
+ unsigned int conv_stride_x = std::get<0>(conv_info.stride());
+ unsigned int conv_stride_y = std::get<1>(conv_info.stride());
- // Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output, output_shape,
- 1,
- input->data_type(),
- input->fixed_point_position(),
- input->quantization_info());
+ const bool run_optimized_bifrost = can_run_optimized_kernel_for_bifrost(target, conv_stride_x, conv_stride_y, kernel_size, data_type, data_layout);
- unsigned int conv_stride_x = std::get<0>(conv_info.stride());
- unsigned int conv_stride_y = std::get<1>(conv_info.stride());
- unsigned int conv_pad_left = conv_info.pad_left();
- unsigned int conv_pad_top = conv_info.pad_top();
-
- unsigned int num_elems_read_per_iteration_x = 0;
- unsigned int num_elems_read_per_iteration_y = 0;
- unsigned int num_elems_written_per_iteration_x = 0;
- unsigned int num_elems_written_per_iteration_y = 0;
-
- if(gpu_target_is_in(target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && (kernel_size <= 5) && (conv_stride_x == 1)
- && (conv_stride_y == 1) && (data_type == DataType::F32))
+ if(run_optimized_bifrost)
{
// Configure kernel window
-
switch(kernel_size)
{
case 1:
@@ -218,22 +210,123 @@
}
}
+ if(data_layout == DataLayout::NHWC)
+ {
+ num_elems_written_per_iteration_x = 1;
+ num_elems_read_per_iteration_x = 1;
+ switch(kernel_size)
+ {
+ case 1:
+ switch(conv_stride_x)
+ {
+ case 1:
+ num_elems_read_per_iteration_y = 8;
+ num_elems_written_per_iteration_y = 8;
+ break;
+ case 2:
+ num_elems_read_per_iteration_y = 16;
+ num_elems_written_per_iteration_y = 8;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Invalid convolution stride X");
+ }
+ break;
+ case 3:
+ switch(conv_stride_x)
+ {
+ case 1:
+ num_elems_read_per_iteration_y = 10;
+ num_elems_written_per_iteration_y = 8;
+ break;
+ case 2:
+ num_elems_read_per_iteration_y = 17;
+ num_elems_written_per_iteration_y = 8;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Invalid convolution stride X");
+ }
+ break;
+ case 5:
+ switch(conv_stride_x)
+ {
+ case 1:
+ num_elems_read_per_iteration_y = 12;
+ num_elems_written_per_iteration_y = 8;
+ break;
+ case 2:
+ num_elems_read_per_iteration_y = 20;
+ num_elems_written_per_iteration_y = 8;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Invalid convolution stride X");
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented.");
+ break;
+ }
+ }
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, const GPUTarget target)
+{
+ const DataLayout data_layout = input->data_layout();
+ const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int kernel_size = weights->dimension(width_idx);
+
+ // Get convolved dimensions
+ TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
+
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output, output_shape,
+ 1,
+ input->data_type(),
+ input->quantization_info());
+
+ unsigned int num_elems_read_per_iteration_x = 0;
+ unsigned int num_elems_read_per_iteration_y = 0;
+ unsigned int num_elems_written_per_iteration_x = 0;
+ unsigned int num_elems_written_per_iteration_y = 0;
+
+ unsigned int conv_pad_left = conv_info.pad_left();
+ unsigned int conv_pad_top = conv_info.pad_top();
+ unsigned int conv_stride_x = std::get<0>(conv_info.stride());
+ unsigned int conv_stride_y = std::get<1>(conv_info.stride());
+
+ setup_num_elems(num_elems_read_per_iteration_x, num_elems_read_per_iteration_y,
+ num_elems_written_per_iteration_x, num_elems_written_per_iteration_y,
+ kernel_size, conv_info, target, input);
+
// Create window and update padding
bool window_changed = false;
Window win = calculate_max_window(*output, Steps(num_elems_written_per_iteration_x, num_elems_written_per_iteration_y));
- AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
- num_elems_read_per_iteration_x, num_elems_read_per_iteration_y,
- conv_stride_x, conv_stride_y);
- AccessWindowStatic weights_access(weights, 0, 0, kernel_size, kernel_size);
- AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration_x, num_elems_written_per_iteration_y);
-
- window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
-
- output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
-
- Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
+ if(data_layout == DataLayout::NHWC)
+ {
+ AccessWindowStatic input_access(input, 0, -conv_pad_left,
+ num_elems_read_per_iteration_x,
+ ceil_to_multiple(input->dimension(1) + conv_info.pad_right(), num_elems_read_per_iteration_y));
+ AccessWindowStatic weights_access(weights, 0, 0, weights->dimension(0), weights->dimension(1));
+ AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration_x, num_elems_written_per_iteration_y);
+ window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
+ output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+ }
+ else if(data_layout == DataLayout::NCHW)
+ {
+ AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top, num_elems_read_per_iteration_x, num_elems_read_per_iteration_y, conv_stride_x, conv_stride_y);
+ AccessWindowStatic weights_access(weights, 0, 0, kernel_size, kernel_size);
+ AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration_x, num_elems_written_per_iteration_y);
+ window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
+ output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported");
+ }
}
} // namespace
@@ -251,7 +344,12 @@
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
- const unsigned int kernel_size = weights->info()->dimension(0);
+ const DataLayout data_layout = input->info()->data_layout();
+ const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+
+ const unsigned int kernel_size = weights->info()->dimension(width_idx);
const DataType data_type = input->info()->data_type();
// Get convolved dimensions
@@ -262,7 +360,6 @@
output_shape,
1,
input->info()->data_type(),
- input->info()->fixed_point_position(),
input->info()->quantization_info());
// Perform validation step
@@ -274,7 +371,19 @@
_conv_stride_x = std::get<0>(conv_info.stride());
_conv_stride_y = std::get<1>(conv_info.stride());
- _border_size = BorderSize(conv_info.pad_top(), conv_info.pad_right(), conv_info.pad_bottom(), conv_info.pad_left());
+
+ if(data_layout == DataLayout::NHWC)
+ {
+ _border_size = BorderSize(conv_info.pad_left(), 0, conv_info.pad_right(), 0);
+ }
+ else if(data_layout == DataLayout::NCHW)
+ {
+ _border_size = BorderSize(conv_info.pad_top(), conv_info.pad_right(), conv_info.pad_bottom(), conv_info.pad_left());
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported");
+ }
_input = input;
_weights = weights;
@@ -285,33 +394,44 @@
std::stringstream kernel_name;
kernel_name << "direct_convolution" << kernel_size << "x" << kernel_size;
+ if(data_layout == DataLayout::NHWC)
+ {
+ kernel_name << "_" << lower_string(string_from_data_layout(data_layout));
+ }
CLBuildOptions build_options;
build_options.add_option_if(_biases != nullptr, std::string("-DHAS_BIAS"));
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && (kernel_size <= 5) && (_conv_stride_x == 1)
- && (_conv_stride_y == 1) && (data_type == DataType::F32))
+ const bool run_optimized_for_bifrost = can_run_optimized_kernel_for_bifrost(gpu_target, _conv_stride_x, _conv_stride_y, kernel_size, data_type, data_layout);
+
+ if(run_optimized_for_bifrost)
{
- build_options.add_option(std::string("-DWEIGHTS_DEPTH=" + support::cpp11::to_string(_weights->info()->dimension(2))));
+ build_options.add_option(std::string("-DWEIGHTS_DEPTH=" + support::cpp11::to_string(_weights->info()->dimension(channel_idx))));
kernel_name << "_f32_bifrost";
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name.str(), build_options.options()));
}
else
{
- bool is_quantized_fixed_point = is_data_type_fixed_point(data_type);
- bool is_quantized_asymm = is_data_type_quantized_asymmetric(data_type);
- DataType promoted_type = (is_quantized_fixed_point) ? get_promoted_data_type(data_type) : data_type;
+ bool is_quantized_asymm = is_data_type_quantized_asymmetric(data_type);
build_options.add_option_if(is_quantized_asymm, std::string("-DKERNEL_SIZE=" + support::cpp11::to_string(kernel_size)));
build_options.add_option(std::string("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type)));
build_options.add_option(std::string("-DDATA_SIZE=" + get_data_size_from_data_type(data_type)));
- build_options.add_option(std::string("-DWEIGHTS_DEPTH=" + support::cpp11::to_string(_weights->info()->dimension(2))));
+ build_options.add_option(std::string("-DWEIGHTS_DEPTH=" + support::cpp11::to_string(_weights->info()->dimension(channel_idx))));
build_options.add_option(std::string("-DSTRIDE_X=" + support::cpp11::to_string(_conv_stride_x)));
- build_options.add_option_if(is_quantized_fixed_point,
- std::string("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position())));
- build_options.add_option(std::string("-DDATA_TYPE_PROMOTED=" + get_cl_type_from_data_type(promoted_type)));
-
+ if(data_layout == DataLayout::NHWC)
+ {
+ build_options.add_option(std::string("-DDATA_LAYOUT_NHWC=1"));
+ build_options.add_option(std::string("-DDST_HEIGHT=" + support::cpp11::to_string(_output->info()->dimension(height_idx))));
+ build_options.add_option(std::string("-DDST_WIDTH=" + support::cpp11::to_string(_output->info()->dimension(width_idx))));
+ build_options.add_option(std::string("-DSRC_HEIGHT=" + support::cpp11::to_string(_input->info()->dimension(height_idx))));
+ build_options.add_option(std::string("-DSRC_WIDTH=" + support::cpp11::to_string(_input->info()->dimension(width_idx))));
+ build_options.add_option(std::string("-DPAD_LEFT=" + support::cpp11::to_string(conv_info.pad_left())));
+ build_options.add_option(std::string("-DPAD_TOP=" + support::cpp11::to_string(conv_info.pad_top())));
+ build_options.add_option(std::string("-DSTRIDE_Y=" + support::cpp11::to_string(_conv_stride_y)));
+ }
+ build_options.add_option(std::string("-DDATA_TYPE_PROMOTED=" + get_cl_type_from_data_type(data_type)));
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(is_quantized_asymm ? "direct_convolution_1x1_3x3_5x5_quantized" : kernel_name.str(),
build_options.options()));
@@ -320,7 +440,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, gpu_target);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Set static kernel arguments
if(is_data_type_quantized_asymmetric(data_type))
@@ -357,9 +477,11 @@
_config_id += "_";
_config_id += support::cpp11::to_string(_conv_stride_y);
_config_id += "_";
- _config_id += support::cpp11::to_string(output->info()->dimension(0));
+ _config_id += support::cpp11::to_string(output->info()->dimension(width_idx));
_config_id += "_";
- _config_id += support::cpp11::to_string(output->info()->dimension(1));
+ _config_id += support::cpp11::to_string(output->info()->dimension(height_idx));
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_layout(data_layout));
}
Status CLDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
@@ -382,12 +504,16 @@
win_in.adjust(Window::DimX, -_border_size.left, true);
win_in.adjust(Window::DimY, -_border_size.top, true);
- win_in.set_dimension_step(Window::DimX, window.x().step() * _conv_stride_x);
- win_in.set_dimension_step(Window::DimY, window.y().step() * _conv_stride_y);
- Window slice_in = win_in.first_slice_window_3D();
+ const DataLayout data_layout = _input->info()->data_layout();
+ const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
- unsigned int idx1 = 2 * num_arguments_per_3D_tensor();
+ win_in.set_dimension_step(width_idx, window[width_idx].step() * _conv_stride_x);
+ win_in.set_dimension_step(height_idx, window[height_idx].step() * _conv_stride_y);
+
+ Window slice_in = win_in.first_slice_window_3D();
+ unsigned int idx1 = 2 * num_arguments_per_3D_tensor();
add_3D_tensor_argument(idx1, _weights, slice);
if(_biases != nullptr)
@@ -404,8 +530,7 @@
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice_in);
add_3D_tensor_argument(idx, _output, slice);
-
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice) && win_in.slide_window_slice_3D(slice_in));
}
diff --git a/src/core/CL/kernels/CLDirectConvolutionOutputStageKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionOutputStageKernel.cpp
index f23ecf3..5f4dacb 100644
--- a/src/core/CL/kernels/CLDirectConvolutionOutputStageKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionOutputStageKernel.cpp
@@ -24,11 +24,11 @@
#include "arm_compute/core/CL/kernels/CLDirectConvolutionLayerOutputStageKernel.h"
#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include <cstddef>
@@ -41,11 +41,13 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16,
DataType::F32);
if(bias != nullptr)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(bias);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32, DataType::F16, DataType::F32);
if(is_data_type_quantized_asymmetric(input->data_type()))
@@ -88,44 +90,29 @@
bool window_changed = false;
unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type());
- // Update processed elements when input is S32 (comes from quantization input)
- if(input->data_type() == DataType::S32)
+ // Configure kernel window
+ Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
+
+ // Input window
+ AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
+ window_changed = window_changed || update_window_and_padding(win, input_access);
+
+ // Bias window
+ if(bias != nullptr)
{
- num_elems_processed_per_iteration = 16;
+ AccessWindowStatic bias_access(bias, 0, 0, ceil_to_multiple(bias->dimension(0), num_elems_processed_per_iteration), bias->dimension(1));
+ window_changed = window_changed || update_window_and_padding(win, bias_access);
}
- // Configure kernel window
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
- AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
-
+ // Output window
if(output != nullptr && (output->total_size() != 0))
{
AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
-
- if(bias == nullptr)
- {
- window_changed = update_window_and_padding(win, input_access, output_access);
- }
- else
- {
- AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
- window_changed = update_window_and_padding(win, input_access, output_access, bias_access);
- }
-
+ window_changed = window_changed || update_window_and_padding(win, output_access);
output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
}
else
{
- if(bias == nullptr)
- {
- window_changed = update_window_and_padding(win, input_access);
- }
- else
- {
- AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
- window_changed = update_window_and_padding(win, input_access, bias_access);
- }
-
input_access.set_valid_region(win, ValidRegion(Coordinates(), input->tensor_shape()));
}
@@ -163,9 +150,13 @@
_result_shift = result_shift;
_result_offset_after_shift = result_offset_after_shift;
+ const unsigned int num_elems_accessed_per_iteration = 16 / element_size_from_data_type(input->info()->data_type());
+
// Create kernel
CLBuildOptions build_opts;
build_opts.add_option_if(bias != nullptr, "-DHAS_BIAS");
+ build_opts.add_option("-D" + string_from_data_layout(input->info()->data_layout()));
+ build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_accessed_per_iteration));
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("output_stage_quantized", build_opts.options()));
// Set static kernel arguments
@@ -177,13 +168,13 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLDirectConvolutionLayerOutputStageKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), bias->clone().get(), output == nullptr ? nullptr : output->clone().get()).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), bias == nullptr ? nullptr : bias->clone().get(), output == nullptr ? nullptr : output->clone().get()).first);
return Status{};
}
@@ -211,7 +202,7 @@
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice);
add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLErodeKernel.cpp b/src/core/CL/kernels/CLErodeKernel.cpp
index a7aa88f..e56b71a 100644
--- a/src/core/CL/kernels/CLErodeKernel.cpp
+++ b/src/core/CL/kernels/CLErodeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -61,5 +61,5 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/kernels/CLFastCornersKernel.cpp b/src/core/CL/kernels/CLFastCornersKernel.cpp
index 616e41b..782ab7a 100644
--- a/src/core/CL/kernels/CLFastCornersKernel.cpp
+++ b/src/core/CL/kernels/CLFastCornersKernel.cpp
@@ -87,7 +87,7 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_mode == BorderMode::UNDEFINED, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLFastCornersKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -148,7 +148,7 @@
Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
update_window_and_padding(win,
AccessWindowHorizontal(input->info(), 0, num_elems_processed_per_iteration));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLCopyToArrayKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLFillBorderKernel.cpp b/src/core/CL/kernels/CLFillBorderKernel.cpp
index 66504e6..baf6bb6 100644
--- a/src/core/CL/kernels/CLFillBorderKernel.cpp
+++ b/src/core/CL/kernels/CLFillBorderKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -91,10 +91,6 @@
build_opts.emplace(("-DBORDER_SIZE_BOTTOM=" + support::cpp11::to_string(border_size.bottom)));
build_opts.emplace(("-DBORDER_SIZE_LEFT=" + support::cpp11::to_string(border_size.left)));
build_opts.emplace(("-DBORDER_SIZE_RIGHT=" + support::cpp11::to_string(border_size.right)));
- if(is_data_type_fixed_point(tensor->info()->data_type()))
- {
- build_opts.emplace("-DFIXED_POINT_POSITION");
- }
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts));
@@ -125,14 +121,12 @@
case DataType::QASYMM8:
set_constant_border<uint8_t>(idx, constant_border_value);
break;
- case DataType::QS8:
case DataType::S8:
set_constant_border<int8_t>(idx, constant_border_value);
break;
case DataType::U16:
set_constant_border<uint16_t>(idx, constant_border_value);
break;
- case DataType::QS16:
case DataType::S16:
set_constant_border<int16_t>(idx, constant_border_value);
break;
@@ -160,7 +154,7 @@
win.set(Window::DimX, Window::Dimension(0, total_valid_width + valid_height));
win.set(Window::DimY, Window::Dimension(0, 1, 1));
win.use_tensor_dimensions(tensor->info()->tensor_shape(), Window::DimZ);
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLFillBorderKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLFlattenLayerKernel.cpp b/src/core/CL/kernels/CLFlattenLayerKernel.cpp
new file mode 100644
index 0000000..1718914
--- /dev/null
+++ b/src/core/CL/kernels/CLFlattenLayerKernel.cpp
@@ -0,0 +1,151 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/CL/kernels/CLFlattenLayerKernel.h"
+
+#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
+#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/CL/OpenCL.h"
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/IAccessWindow.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "support/ToolchainSupport.h"
+
+using namespace arm_compute::misc::shape_calculator;
+
+namespace arm_compute
+{
+namespace
+{
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
+ DataType::U32, DataType::S32,
+ DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+
+ // Checks performed when output is configured
+ if(output->total_size() != 0)
+ {
+ const TensorInfo tensor_info_output = input->clone()->set_tensor_shape(compute_flatten_shape(input));
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
+{
+ // Output tensor auto initialization if not yet initialized
+ auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_flatten_shape(input)));
+
+ Window win = calculate_max_window(*input, Steps()); // Flatten does not need paddings
+
+ output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
+
+ return std::make_pair(Status{}, win);
+}
+} // namespace
+
+CLFlattenLayerKernel::CLFlattenLayerKernel()
+ : _input(nullptr), _output(nullptr)
+{
+}
+
+void CLFlattenLayerKernel::configure(const ICLTensor *input, ICLTensor *output)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info()));
+
+ _input = input;
+ _output = output;
+
+ CLBuildOptions build_opts;
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+ build_opts.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(input->info()->dimension(0)));
+ build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(1)));
+
+ // Create kernel
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("flatten", build_opts.options()));
+
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(input->info(), output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ ICLKernel::configure_internal(win_config.second);
+
+ // Set config_id for enabling LWS tuning
+ _config_id = "flatten";
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_type(input->info()->data_type()));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(input->info()->dimension(0));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(input->info()->dimension(1));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(input->info()->dimension(2));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(0));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(1));
+}
+
+Status CLFlattenLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get()).first);
+ return Status{};
+}
+
+void CLFlattenLayerKernel::run(const Window &window, cl::CommandQueue &queue)
+{
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
+
+ Window out_window;
+ out_window.use_tensor_dimensions(_output->info()->tensor_shape());
+
+ Window out_slice = out_window.first_slice_window_1D();
+ Window in_slice = window.first_slice_window_3D();
+
+ // Run kernel
+ do
+ {
+ // Set arguments
+ unsigned int idx = 0;
+ add_3D_tensor_argument(idx, _input, in_slice);
+ add_1D_tensor_argument(idx, _output, out_slice);
+ enqueue(queue, *this, in_slice, lws_hint());
+ }
+ while(window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_1D(out_slice));
+}
+} // namespace arm_compute
diff --git a/src/core/CL/kernels/CLFloorKernel.cpp b/src/core/CL/kernels/CLFloorKernel.cpp
index 11f8e33..20e3a3a 100644
--- a/src/core/CL/kernels/CLFloorKernel.cpp
+++ b/src/core/CL/kernels/CLFloorKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,7 +45,7 @@
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Auto initialize output
- auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
@@ -69,7 +69,7 @@
update_window_and_padding(win, input_access, output_access);
output_access.set_valid_region(win, input->info()->valid_region());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLFloorKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
index 8f669a9..ae54e77 100644
--- a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
+++ b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
@@ -23,15 +23,17 @@
*/
#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
+#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
@@ -40,34 +42,40 @@
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
ARM_COMPUTE_RETURN_ERROR_ON(mult_interleave4x4_height < 1);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::U8, DataType::S8,
- DataType::QS16, DataType::U16, DataType::S16, DataType::U32, DataType::S32,
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::U8, DataType::S8,
+ DataType::U16, DataType::S16, DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
if(output->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_interleaved_shape(*input, mult_interleave4x4_height));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_interleaved_shape(*input, mult_interleave4x4_height, reinterpret_input_as_3d));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, int mult_interleave4x4_height)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
constexpr unsigned int num_elems_processed_per_iteration_x = 4;
constexpr unsigned int num_elems_processed_per_iteration_y = 4;
const unsigned int num_elems_written_per_iteration = num_elems_processed_per_iteration_x * num_elems_processed_per_iteration_y * mult_interleave4x4_height;
bool window_changed = false;
- // Configure kernel window
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- AccessWindowRectangle input_access(input, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
- window_changed = window_changed || update_window_and_padding(win, input_access);
+ TensorInfo tmp_info(*input);
+
+ if(reinterpret_input_as_3d)
+ {
+ // Since the input tensor has to be reinterpreted as 3D and the execute window is based on a 2D interleave,
+ // the window needs to be constructed on the 2D collapsed version of the tensor
+ TensorShape tmp_shape(input->tensor_shape());
+ tmp_shape.collapse(2U, 1U);
+ tmp_info.set_tensor_shape(tmp_shape);
+ }
// Output auto inizialitation if not yet initialized
auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_interleaved_shape(*input, mult_interleave4x4_height)));
@@ -76,9 +84,22 @@
const float scale_x = 4.0f * static_cast<float>(mult_interleave4x4_height);
const float scale_y = 1.0f / (scale_x);
+ // Note: bottom paddings are calculated manually as the input can be reinterpreted as 3D tensor
+ // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+ const int m = reinterpret_input_as_3d ? input->tensor_shape()[1] * input->tensor_shape()[2] : input->tensor_shape()[1];
+ const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
+ Window win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ Window win_in = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+
+ AccessWindowStatic input_access(input, 0, 0,
+ ceil_to_multiple(input->dimension(0), num_elems_processed_per_iteration_x),
+ input->dimension(1) + bottom_pad);
AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration, 1, scale_x, scale_y);
- window_changed = window_changed || update_window_and_padding(win, output_access);
- output_access.set_valid_region(win, input->valid_region());
+
+ window_changed = update_window_and_padding(win_in, input_access) || // window used by the execute_window_loop
+ update_window_and_padding(win, output_access); // window used to update the padding requirements of output tensor
+ output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
// Collapse along the Z direction
// This collapse needs to be here in order to tune the Z dimension of LWS
@@ -90,26 +111,31 @@
} // namespace
CLGEMMInterleave4x4Kernel::CLGEMMInterleave4x4Kernel()
- : _input(nullptr), _output(nullptr)
+ : _input(nullptr), _output(nullptr), _reinterpret_input_as_3d(false)
{
}
-void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height)
+void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_interleaved_shape(*input->info(), mult_interleave4x4_height)));
+ auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_interleaved_shape(*input->info(), mult_interleave4x4_height, reinterpret_input_as_3d)));
// Perform validate step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), mult_interleave4x4_height));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), mult_interleave4x4_height, reinterpret_input_as_3d));
- _input = input;
- _output = output;
+ _input = input;
+ _output = output;
+ _reinterpret_input_as_3d = reinterpret_input_as_3d;
// Create build options
CLBuildOptions build_opts;
build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(input->info()->dimension(1)));
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(input->info()->dimension(2)));
+
switch(input->info()->element_size())
{
case 1:
@@ -129,12 +155,13 @@
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_interleave4x4", build_opts.options()));
// Configure kernel window
- auto win_config = validate_and_configure_window(input->info(), output->info(), mult_interleave4x4_height);
+ auto win_config = validate_and_configure_window(input->info(), output->info(), mult_interleave4x4_height, reinterpret_input_as_3d);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Set config_id for enabling LWS tuning
_config_id = "interleave4x4_";
+ _config_id += (_reinterpret_input_as_3d ? "3d_" : "");
_config_id += lower_string(string_from_data_type(input->info()->data_type()));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(0));
@@ -146,10 +173,10 @@
_config_id += support::cpp11::to_string(output->info()->dimension(3));
}
-Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
+Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mult_interleave4x4_height));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), mult_interleave4x4_height).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mult_interleave4x4_height, reinterpret_input_as_3d));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), mult_interleave4x4_height, reinterpret_input_as_3d).first);
return Status{};
}
@@ -170,12 +197,20 @@
*/
Window slice = window.first_slice_window_3D();
+ if(_reinterpret_input_as_3d)
+ {
+ // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor
+ const unsigned int idx0 = 2 * num_arguments_per_3D_tensor();
+ const unsigned int total_cross_plane_pad = _input->info()->padding().top + _input->info()->padding().bottom;
+ _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
+ }
+
do
{
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice);
add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
index 3f705ac..9adf95f 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
@@ -172,7 +172,7 @@
tensor_shape.set(0, is_interleaved_transposed ? reshape_info.n() : input1->info()->dimension(0));
tensor_shape.set(1, is_interleaved_transposed ? reshape_info.m() : input0->info()->dimension(1));
- auto_init_if_empty(*output->info(), tensor_shape, 1, DataType::S32, 1, QuantizationInfo());
+ auto_init_if_empty(*output->info(), tensor_shape, 1, DataType::S32, QuantizationInfo());
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
@@ -188,7 +188,9 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, num_elements_processed);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
+
+ const bool is_dot8_supported = dot8_supported(CLKernelLibrary::get().get_device());
// Create build options
CLBuildOptions build_opts;
@@ -206,15 +208,17 @@
build_opts.add_option("-DTRANSPOSE1XW_WIDTH_STEP=" + support::cpp11::to_string(4 * mult_transpose1xW_width));
build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
- kernel_name = "gemmlowp_mm_interleaved_transposed_" + string_from_target(arch_target);
+ kernel_name = "gemmlowp_mm_interleaved_transposed_" + string_from_target(arch_target) + (is_dot8_supported ? "_dot8" : "");
}
else
{
build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0)));
build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x()));
build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y()));
- kernel_name = "gemmlowp_mm_" + string_from_target(arch_target);
+
+ kernel_name = "gemmlowp_mm_" + string_from_target(arch_target) + (is_dot8_supported ? "_dot8" : "");
}
+
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
@@ -269,7 +273,7 @@
add_2D_tensor_argument(idx, _input0, slice);
add_2D_tensor_argument(idx, _input1, slice_b);
add_2D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_2D(slice));
}
diff --git a/src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.cpp b/src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.cpp
index 221a156..aa954ab 100644
--- a/src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.cpp
@@ -159,7 +159,7 @@
vector_sum_row != nullptr ? vector_sum_row->info() : nullptr,
a_offset, b_offset); // NOLINT
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Set config_id for enabling LWS tuning
_config_id = "gemmlowp_offset_contribution_";
diff --git a/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp b/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp
index ff2fc64..875e26d 100644
--- a/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -146,7 +146,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), (bias != nullptr) ? bias->info() : nullptr, output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
void CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -174,4 +174,4 @@
enqueue(queue, *this, slice);
}
while(collapsed.slide_window_slice_3D(slice));
-}
\ No newline at end of file
+}
diff --git a/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.cpp b/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.cpp
index 151a658..5789113 100644
--- a/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -145,7 +145,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), (bias != nullptr) ? bias->info() : nullptr, output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
void CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp b/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
index 6951512..cd26cd1 100644
--- a/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -121,7 +121,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window_matrix_a_reduction(_input->info(), _output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLGEMMLowpMatrixAReductionKernel::validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row)
@@ -175,7 +175,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window_matrix_b_reduction(_input->info(), _output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLGEMMLowpMatrixBReductionKernel::validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col)
diff --git a/src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp
index d409fdb..2f1f1bf 100644
--- a/src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.cpp
@@ -26,13 +26,13 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
using namespace arm_compute;
@@ -40,9 +40,9 @@
{
Status validate_arguments(const ITensorInfo *accum, const ITensorInfo *biases)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(accum);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(biases, accum);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(biases, accum);
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() != 1);
return Status{};
@@ -52,7 +52,7 @@
unsigned int &num_elems_processed_per_iteration)
{
// Select the vector size to use (8 for Bifrost; 16 for Midgard).
- num_elems_processed_per_iteration = gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) ? 8 : 16;
+ num_elems_processed_per_iteration = gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::G76) ? 8 : 16;
// Configure kernel window
Window win = calculate_max_window(*accum, Steps(num_elems_processed_per_iteration));
@@ -88,14 +88,12 @@
// Configure kernel window
auto win_config = validate_and_configure_window(accum->info(), biases->info(), gpu_target, vector_size);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Add build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(accum->info()->data_type()));
build_opts.add_option("-DVECTOR_SIZE=" + support::cpp11::to_string(vector_size));
- build_opts.add_option_if(is_data_type_fixed_point(accum->info()->data_type()),
- "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(accum->info()->fixed_point_position()));
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_accumulate_biases", build_opts.options()));
@@ -128,7 +126,7 @@
add_2D_tensor_argument(idx, _accum, accum_slice);
add_1D_tensor_argument(idx, _biases, biases_slice);
- enqueue(queue, *this, accum_slice, _lws_hint);
+ enqueue(queue, *this, accum_slice, lws_hint());
}
while(window.slide_window_slice_2D(accum_slice));
}
diff --git a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
index 4538812..0c65bb4 100644
--- a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
@@ -25,13 +25,12 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
using namespace arm_compute;
@@ -63,7 +62,8 @@
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_UNUSED(input, output, beta);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
@@ -87,19 +87,7 @@
_output = output;
std::ostringstream ma_arguments;
- if(is_data_type_fixed_point(input->info()->data_type()))
- {
- ma_arguments << "-DBETA=" << (input->info()->data_type() == DataType::QS8 ?
- sqcvt_qs8_f32(beta, input->info()->fixed_point_position()) :
- sqcvt_qs16_f32(beta, input->info()->fixed_point_position()))
- << " ";
- ma_arguments << "-DFIXED_POINT_POSITION=" << input->info()->fixed_point_position();
- }
- else
- {
- ma_arguments << "-DBETA=" << beta;
- }
-
+ ma_arguments << "-DBETA=" << beta;
std::set<std::string> build_opts;
build_opts.emplace(ma_arguments.str());
@@ -110,7 +98,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLGEMMMatrixAdditionKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float beta)
@@ -125,14 +113,14 @@
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
- Window slice = window.first_slice_window_2D();
+ Window slice = window.first_slice_window_3D();
do
{
unsigned int idx = 0;
- add_2D_tensor_argument(idx, _input, slice);
- add_2D_tensor_argument(idx, _output, slice);
+ add_3D_tensor_argument(idx, _input, slice);
+ add_3D_tensor_argument(idx, _output, slice);
enqueue(queue, *this, slice);
}
- while(window.slide_window_slice_2D(slice));
+ while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index cc9ae27..8530ed2 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -27,15 +27,14 @@
#include "arm_compute/core/AccessWindowTranspose.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
@@ -52,22 +51,17 @@
inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && reshape_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D");
if(!is_interleaved_transposed)
{
ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
-
- if(output->total_size() != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
- ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
- }
}
else
{
@@ -93,14 +87,13 @@
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
+ }
- if(output->total_size() != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
- ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
- }
+ if(output->total_size() != 0)
+ {
+ const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
}
return Status{};
@@ -112,31 +105,65 @@
{
bool window_changed = false;
Window win{};
+ Window win_out{};
const DataType data_type = input0->data_type();
unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0];
unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1];
+ bool reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d();
+ bool reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 1);
+
+ // In case both input and output have to be reinterpreted as 3D tensors,
+ // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
+ if(reinterpret_input_as_3d == reinterpret_output_as_3d)
+ {
+ reinterpret_input_as_3d = false;
+ reinterpret_output_as_3d = false;
+ }
// Output tensor auto inizialitation if not yet initialized
auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info)));
+ TensorInfo tmp_info(*output);
+
+ if(reinterpret_output_as_3d)
+ {
+ // Since the output tensor has to be reinterpreted as 3D and the execute window is based on a 2D GEMM,
+ // the window needs to be constructed on the 2D collapsed version of the tensor
+ TensorShape tmp_shape(output->tensor_shape());
+ tmp_shape.collapse(2U, 1U);
+ tmp_info.set_tensor_shape(tmp_shape);
+ }
+
if(is_interleaved_transposed)
{
+ // reinterpret_input_as_3d is not supported if is_interleaved_transposed is set
+ ARM_COMPUTE_ERROR_ON(reshape_info.reinterpret_input_as_3d());
+
// Configure kernel window
num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(data_type);
num_elems_processed_per_iteration_y = 4;
- win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
+ // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+ const int m = reshape_info.m();
+ const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
+ win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
AccessWindowRectangle input0_access(input0, 0, 0, num_elems_processed_per_iteration_y, 1, 1.f, 0.25f);
AccessWindowStatic input1_access(input1, 0, 0,
ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x),
ceil_to_multiple(input1->dimension(1), num_elems_processed_per_iteration_y));
- AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+ AccessWindowStatic output_access(output, 0, 0,
+ ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
+ output->dimension(1) + bottom_pad);
- window_changed = update_window_and_padding(win, input0_access, input1_access, output_access);
+ window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
+ update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
- output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
+ output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
}
else // The input tensors have not been reshaped
{
@@ -144,6 +171,11 @@
num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(data_type);
num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 4);
+ // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
+ // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+ const int m = reinterpret_input_as_3d ? input0->tensor_shape()[1] * input0->tensor_shape()[2] : input0->tensor_shape()[1];
+ const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
// Create kernels according to the architecture, data type and input size.
GPUTarget arch_target = get_arch_from_target(gpu_target);
if(arch_target == GPUTarget::BIFROST && data_type == DataType::F32)
@@ -152,17 +184,21 @@
}
// Configure window
- win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), ceil_to_multiple(input0->dimension(1), num_elems_processed_per_iteration_y));
- AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
- AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+ AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), input0->dimension(1) + bottom_pad);
+ AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
+ AccessWindowStatic output_access(output, 0, 0,
+ ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
+ output->dimension(1) + bottom_pad);
- window_changed = update_window_and_padding(win, input0_access, input1_access, output_access);
+ window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
+ update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
Coordinates coord;
coord.set_num_dimensions(output->num_dimensions());
- output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
+ output_access.set_valid_region(win_out, ValidRegion(coord, output->tensor_shape()));
}
// Collapse along the Z direction
@@ -177,7 +213,7 @@
} // namespace
CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
- : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
+ : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false)
{
}
@@ -188,75 +224,49 @@
// Perform validate step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
- _input0 = input0;
- _input1 = input1;
- _output = output;
- _slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions();
+ _input0 = input0;
+ _input1 = input1;
+ _output = output;
+ _reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d();
+ _reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 1);
+
+ // In case both input and output have to be reinterpreted as 3D tensors,
+ // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
+ if(_reinterpret_input_as_3d == _reinterpret_output_as_3d)
+ {
+ _reinterpret_input_as_3d = false;
+ _reinterpret_output_as_3d = false;
+ }
+
+ // Check if we need to slide the matrix B
+ const unsigned int num_dimensions_input0 = _reinterpret_input_as_3d ? _input0->info()->num_dimensions() - 1 : _input0->info()->num_dimensions();
+
+ _slide_matrix_b = (_input1->info()->num_dimensions() >= num_dimensions_input0);
const DataType data_type = input0->info()->data_type();
- const int fp_pos = input0->info()->fixed_point_position();
// Get target architecture
GPUTarget gpu_target = get_target();
- // Configure LWS hint
- switch(gpu_target)
- {
- case GPUTarget::MIDGARD:
- case GPUTarget::T600:
- case GPUTarget::T700:
- case GPUTarget::T800:
- if(output->info()->dimension(1) == 196)
- {
- _lws_hint = cl::NDRange(1, 7);
- }
- else
- {
- _lws_hint = cl::NDRange(8, 8);
- }
- break;
- case GPUTarget::G71:
- case GPUTarget::G72:
- case GPUTarget::G51:
- case GPUTarget::G51BIG:
- case GPUTarget::G51LIT:
- case GPUTarget::TNOX:
- if(input1->info()->dimension(1) == 24)
- {
- // LWS optimized for the 11x11 AlexNet convolution on Bifrost.
- _lws_hint = cl::NDRange(2, 2);
- }
- else if(output->info()->dimension(1) == 196)
- {
- _lws_hint = cl::NDRange(1, 7);
- }
- else
- {
- _lws_hint = cl::NDRange(8, 8);
- }
- break;
- default:
- _lws_hint = cl::NullRange;
- }
-
ElementsProcessed num_elements_processed{};
// Configure kernel window
auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, gpu_target, num_elements_processed);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Create build options
CLBuildOptions build_opts;
- build_opts.add_option_if(is_data_type_fixed_point(data_type), "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(fp_pos));
// Only define ALPHA when alpha is not 1.0f. This avoids performing unnecessary multiplications.
if(std::abs(1.0f - alpha) > 0.00001f)
{
- build_opts.add_option_if_else(is_data_type_fixed_point(data_type),
- "-DALPHA=" + support::cpp11::to_string((data_type == DataType::QS8 ? sqcvt_qs8_f32(alpha, fp_pos) : sqcvt_qs16_f32(alpha, fp_pos))),
- "-DALPHA=" + float_to_string_with_full_precision(alpha));
+ build_opts.add_option("-DALPHA=" + float_to_string_with_full_precision(alpha));
}
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
+ build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D");
+ build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
+ build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
// Do not slide matrix B if _slide_matrix_b = false
build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
@@ -306,11 +316,7 @@
// The work-group size equal to the Bifrost quad size has been proved to be optimal for these kernels
// via exhaustive autotuning over a range of representative layer configurations.
- _lws_hint = cl::NDRange(4);
- }
- else if(is_data_type_fixed_point(data_type))
- {
- kernel_name = "gemm_mm_" + lower_string(string_from_data_type(data_type));
+ set_lws_hint(cl::NDRange(4));
}
else // (MIDGARD and F32) or (F16)
{
@@ -326,6 +332,8 @@
// Set config_id for enabling LWS tuning
_config_id = "gemm_";
_config_id += (is_interleaved_transposed ? "reshaped_" : "");
+ _config_id += (_reinterpret_input_as_3d ? "3di_" : "");
+ _config_id += (_reinterpret_output_as_3d ? "3do_" : "");
_config_id += lower_string(string_from_data_type(input0->info()->data_type()));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(1));
@@ -375,6 +383,22 @@
slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
+ if(_reinterpret_input_as_3d)
+ {
+ // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
+ const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3;
+ const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom;
+ _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
+ }
+
+ if(_reinterpret_output_as_3d)
+ {
+ // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
+ const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
+ const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom;
+ _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
+ }
+
do
{
Window slice_b = slice;
@@ -392,7 +416,7 @@
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input0->info()->strides_in_bytes()[2]));
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input1->info()->strides_in_bytes()[2]));
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_output->info()->strides_in_bytes()[2]));
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp
index b2ea95b..11a4292 100644
--- a/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
@@ -38,9 +39,9 @@
{
Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input0->data_type()) && (output->data_type() != DataType::S32));
ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(2) != input1->dimension(1));
@@ -108,14 +109,6 @@
_kernel.setArg<int>(idx++, -_input1->info()->quantization_info().offset);
}
- // Configure the local work size for Bifrost with a value obtained
- // via exhaustive autotuning for the MobileNets tensor shapes.
- const GPUTarget gpu_target = get_target();
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
- {
- _lws_hint = cl::NDRange(1, 1, 1);
- }
-
// Configure kernel window
const unsigned int num_elems_read_per_iteration = 4;
@@ -128,7 +121,7 @@
auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLGEMMMatrixVectorMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
@@ -172,7 +165,7 @@
unsigned int idx_2 = num_arguments_per_3D_tensor() + num_arguments_per_2D_tensor();
add_3D_tensor_argument(idx_0, _input0, slice_in);
add_1D_tensor_argument(idx_2, _output, slice_out);
- enqueue(queue, *this, slice_in, _lws_hint);
+ enqueue(queue, *this, slice_in, lws_hint());
}
while(window.slide_window_slice_3D(slice_in) && window.slide_window_slice_3D(slice_out));
}
diff --git a/src/core/CL/kernels/CLGEMMTranspose1xWKernel.cpp b/src/core/CL/kernels/CLGEMMTranspose1xWKernel.cpp
index 05a20fd..5b29905 100644
--- a/src/core/CL/kernels/CLGEMMTranspose1xWKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMTranspose1xWKernel.cpp
@@ -27,12 +27,12 @@
#include "arm_compute/core/AccessWindowTranspose.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
@@ -46,8 +46,9 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, int mult_transpose1xW_width)
{
ARM_COMPUTE_RETURN_ERROR_ON(mult_transpose1xW_width < 1);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::U8, DataType::S8,
- DataType::QS16, DataType::U16, DataType::S16, DataType::U32, DataType::S32,
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::U8, DataType::S8,
+ DataType::U16, DataType::S16, DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
if(output->total_size() != 0)
@@ -55,7 +56,6 @@
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(),
compute_transpose1xW_with_element_size_shape(*input, mult_transpose1xW_width));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -107,7 +107,7 @@
unsigned int num_elems_processed_per_iteration = 1;
auto win_config = validate_and_configure_window(input->info(), output->info(), num_elems_processed_per_iteration, mult_transpose1xW_width);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Create build options
CLBuildOptions build_opts;
@@ -157,7 +157,7 @@
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, in_slice);
add_3D_tensor_argument(idx, _output, out_slice);
- enqueue(queue, *this, in_slice, _lws_hint);
+ enqueue(queue, *this, in_slice, lws_hint());
}
while(window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_3D(out_slice));
}
diff --git a/src/core/CL/kernels/CLGaussian3x3Kernel.cpp b/src/core/CL/kernels/CLGaussian3x3Kernel.cpp
index e5bc3f9..7e8f313 100644
--- a/src/core/CL/kernels/CLGaussian3x3Kernel.cpp
+++ b/src/core/CL/kernels/CLGaussian3x3Kernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -72,5 +72,5 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/kernels/CLGaussianPyramidKernel.cpp b/src/core/CL/kernels/CLGaussianPyramidKernel.cpp
index a4fda36..6b729c8 100644
--- a/src/core/CL/kernels/CLGaussianPyramidKernel.cpp
+++ b/src/core/CL/kernels/CLGaussianPyramidKernel.cpp
@@ -95,7 +95,7 @@
output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLGaussianPyramidHorKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -177,7 +177,7 @@
output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLGaussianPyramidVertKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLHOGDescriptorKernel.cpp b/src/core/CL/kernels/CLHOGDescriptorKernel.cpp
index a15aab1..26c3b81 100644
--- a/src/core/CL/kernels/CLHOGDescriptorKernel.cpp
+++ b/src/core/CL/kernels/CLHOGDescriptorKernel.cpp
@@ -91,7 +91,7 @@
output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLHOGOrientationBinningKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -174,7 +174,7 @@
output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLHOGBlockNormalizationKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLHOGDetectorKernel.cpp b/src/core/CL/kernels/CLHOGDetectorKernel.cpp
index caca498..12bbbaf 100644
--- a/src/core/CL/kernels/CLHOGDetectorKernel.cpp
+++ b/src/core/CL/kernels/CLHOGDetectorKernel.cpp
@@ -110,7 +110,7 @@
update_window_and_padding(win, AccessWindowRectangle(input->info(), 0, 0, num_elems_read_per_iteration, num_rows_read_per_iteration));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLHOGDetectorKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLHarrisCornersKernel.cpp b/src/core/CL/kernels/CLHarrisCornersKernel.cpp
index 1f757fe..5320b6b 100644
--- a/src/core/CL/kernels/CLHarrisCornersKernel.cpp
+++ b/src/core/CL/kernels/CLHarrisCornersKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -106,7 +106,7 @@
ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), input2->info()->valid_region());
output_access.set_valid_region(win, valid_region, border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLHarrisScoreKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLHistogramKernel.cpp b/src/core/CL/kernels/CLHistogramKernel.cpp
index fa39ce6..ee39c71 100644
--- a/src/core/CL/kernels/CLHistogramKernel.cpp
+++ b/src/core/CL/kernels/CLHistogramKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -107,7 +107,7 @@
update_window_and_padding(win, AccessWindowHorizontal(input->info(), 0, pixels_per_item));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLHistogramKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -197,7 +197,7 @@
win.set(0, Window::Dimension(start_position, _input->info()->dimension(0)));
win.set(1, Window::Dimension(0, _input->info()->dimension(1)));
update_window_and_padding(win, AccessWindowHorizontal(input->info(), 0, 1));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLHistogramBorderKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLIm2ColKernel.cpp b/src/core/CL/kernels/CLIm2ColKernel.cpp
index d04c1dc..0ba0d0e 100644
--- a/src/core/CL/kernels/CLIm2ColKernel.cpp
+++ b/src/core/CL/kernels/CLIm2ColKernel.cpp
@@ -23,323 +23,388 @@
*/
#include "arm_compute/core/CL/kernels/CLIm2ColKernel.h"
+#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/Size2D.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "support/ToolchainSupport.h"
#include <cmath>
#include <tuple>
+#include <utility>
using namespace arm_compute;
+using namespace arm_compute::misc::shape_calculator;
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, bool has_bias, const Size2D &dilation)
+struct Im2ColConfiguration
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ std::string kernel_name{};
+ std::set<std::string> build_options{};
+ unsigned int num_elems_processed_per_iteration{};
+ bool is_padding_required_nchw{};
+};
+
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation,
+ unsigned int num_groups)
+{
+ const unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::QASYMM8 && has_bias);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
ARM_COMPUTE_RETURN_ERROR_ON((dilation.x() < 1) || (dilation.y() < 1));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
+ ARM_COMPUTE_RETURN_ERROR_ON(num_groups == 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::NHWC && num_groups > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(channel_idx) % num_groups) != 0);
- // Checks performed when output is configured
- if(output->total_size() != 0)
+ if(output->total_size() > 0)
{
+ const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, num_groups == 1, num_groups));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
}
-} // namespace
-CLIm2ColKernel::CLIm2ColKernel()
- : _input(nullptr), _output(nullptr), _convolved_dims(), _num_elems_processed_per_iteration(1), _run_func(nullptr), _kernel_dims()
-{
-}
-
-void CLIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation,
+ unsigned int num_elems_processed_per_iteration, bool is_padding_required_nchw, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- // Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), has_bias, dilation));
+ // Output tensor auto initialization if not yet initialized
+ TensorShape expected_output_shape = compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, num_groups == 1, num_groups);
- _input = input;
- _output = output;
- _kernel_dims = kernel_dims;
+ auto_init_if_empty(*output, input->clone()->set_tensor_shape(expected_output_shape));
- const DataType data_type = input->info()->data_type();
- const GPUTarget gpu_target = get_target();
+ const DataLayout data_layout = input->data_layout();
+ const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const unsigned int input_width = input->dimension(width_idx);
+ const unsigned int input_height = input->dimension(height_idx);
- // Create kernel
- CLBuildOptions build_opts;
- build_opts.add_option(("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type)));
- build_opts.add_option("-DELEMENT_SIZE=" + support::cpp11::to_string(input->info()->element_size()));
- build_opts.add_option_if(has_bias, "-DHAS_BIAS");
- build_opts.add_option_if(is_data_type_fixed_point(data_type), "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
+ // Configure the execute window based on the selected optimal OpenCL kernel
+ bool window_changed = false;
+ Window win;
- int stride_x = 0;
- int stride_y = 0;
-
- std::tie(stride_x, stride_y) = conv_info.stride();
-
- const bool run_img2col_reduced = (output->info()->dimension(0) == (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2))) && (TensorShape::num_max_dimensions >= 4)
- && (std::equal(input->info()->tensor_shape().cbegin() + 3,
- input->info()->tensor_shape().cend(),
- output->info()->tensor_shape().cbegin() + 1))
- && ((stride_x == 1) && (stride_y == 1) && !conv_info.has_padding());
-
- bool is_optimized_path = false;
-
- _num_elems_processed_per_iteration = 1;
-
- std::string kernel_name;
- if(!run_img2col_reduced)
+ if(data_layout == DataLayout::NHWC)
{
- // Default kernel name
- kernel_name = "im2col_generic_dchw";
+ win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
- _convolved_dims = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1),
- kernel_dims.width, kernel_dims.height,
- conv_info, dilation);
+ const int xin_start = 0;
+ const int xin_end = input->dimension(0) < num_elems_processed_per_iteration ? ceil_to_multiple(input->dimension(0), num_elems_processed_per_iteration) : input->dimension(0);
+ const int yin_start = 0;
+ const int yin_end = input->dimension(1);
- build_opts.add_option("-DKERNEL_WIDTH=" + support::cpp11::to_string(kernel_dims.width));
- build_opts.add_option("-DKERNEL_HEIGHT=" + support::cpp11::to_string(kernel_dims.height));
- build_opts.add_option("-DKERNEL_DEPTH=" + support::cpp11::to_string(input->info()->dimension(2)));
- build_opts.add_option("-DCONVOLVED_WIDTH=" + support::cpp11::to_string(_convolved_dims.first));
- build_opts.add_option("-DCONVOLVED_HEIGHT=" + support::cpp11::to_string(_convolved_dims.second));
- build_opts.add_option("-DSTRIDE_X=" + support::cpp11::to_string(conv_info.stride().first));
- build_opts.add_option("-DSTRIDE_Y=" + support::cpp11::to_string(conv_info.stride().second));
- build_opts.add_option("-DPAD_LEFT=" + support::cpp11::to_string(conv_info.pad_left()));
- build_opts.add_option("-DPAD_TOP=" + support::cpp11::to_string(conv_info.pad_top()));
- build_opts.add_option("-DPAD_RIGHT=" + support::cpp11::to_string(conv_info.pad_right()));
- build_opts.add_option("-DPAD_BOTTOM=" + support::cpp11::to_string(conv_info.pad_bottom()));
- build_opts.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(input->info()->dimension(0)));
- build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(1)));
- build_opts.add_option("-DDILATION_X=" + support::cpp11::to_string(dilation.x()));
- build_opts.add_option("-DDILATION_Y=" + support::cpp11::to_string(dilation.y()));
- build_opts.add_option_if_else(is_data_type_quantized(data_type), "-DPAD_VALUE=" + support::cpp11::to_string(input->info()->quantization_info().offset), "-DPAD_VALUE=0");
+ const int xout_start = 0;
+ const int xout_end = input->dimension(0) < num_elems_processed_per_iteration ? ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration) : output->dimension(0);
+ const int yout_start = 0;
+ const int yout_end = output->dimension(1);
- const bool squared_im2col = kernel_dims.width == kernel_dims.height;
+ AccessWindowStatic input_access(input, xin_start, yin_start, xin_end, yin_end);
+ AccessWindowStatic output_access(output, xout_start, yout_start, xout_end, yout_end);
+ window_changed = window_changed || update_window_and_padding(win, input_access, output_access);
+ }
+ else
+ {
+ if(is_padding_required_nchw)
+ {
+ const BorderSize border(conv_info.pad_top(), conv_info.pad_right(), conv_info.pad_bottom(), conv_info.pad_left());
+ win = calculate_max_window(*input,
+ Steps(num_elems_processed_per_iteration * conv_info.stride().first, conv_info.stride().second));
+ AccessWindowStatic input_access(input,
+ -border.left,
+ -border.top,
+ ceil_to_multiple(input_width + border.right, kernel_dims.width * num_elems_processed_per_iteration),
+ input_height + border.bottom);
+ window_changed = window_changed || update_window_and_padding(win, input_access);
+ }
+ else
+ {
+ // For the generic case, CLIm2ColKernel doesn't need padding (we do not read out-of-bounds elements) so
+ // update_window_and_padding() can be skipped
+ win = calculate_max_window(*input, Steps());
+ }
+ }
+ output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
+ // set the Z dimension's step same size as the whole dimension so that one can't split across the Z dimension
+ win.set_dimension_step(Window::DimZ, win[Window::DimZ].end() - win[Window::DimZ].start());
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
+
+Im2ColConfiguration configure_opencl_kernel(const ITensorInfo *input, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation, unsigned int num_groups)
+{
+ const DataLayout data_layout = input->data_layout();
+ const DataType data_type = input->data_type();
+ const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+ const unsigned int input_width = input->dimension(width_idx);
+ const unsigned int input_height = input->dimension(height_idx);
+ const unsigned int input_channel = input->dimension(channel_idx);
+
+ const std::pair<unsigned int, unsigned int> convolved_dims = scaled_dimensions(input_width, input_height, kernel_dims.width, kernel_dims.height, conv_info, dilation);
+
+ // Im2Col configuration
+ std::string kernel_name = "im2col_generic_";
+ CLBuildOptions build_opts;
+ unsigned int num_elems_processed_per_iteration = 1;
+ bool is_padding_required_nchw = false;
+
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
+ build_opts.add_option("-DELEMENT_SIZE=" + support::cpp11::to_string(input->element_size()));
+ build_opts.add_option("-DKERNEL_WIDTH=" + support::cpp11::to_string(kernel_dims.width));
+ build_opts.add_option("-DKERNEL_HEIGHT=" + support::cpp11::to_string(kernel_dims.height));
+ build_opts.add_option("-DCONVOLVED_WIDTH=" + support::cpp11::to_string(convolved_dims.first));
+ build_opts.add_option("-DCONVOLVED_HEIGHT=" + support::cpp11::to_string(convolved_dims.second));
+ build_opts.add_option("-DSTRIDE_X=" + support::cpp11::to_string(conv_info.stride().first));
+ build_opts.add_option("-DSTRIDE_Y=" + support::cpp11::to_string(conv_info.stride().second));
+ build_opts.add_option("-DPAD_LEFT=" + support::cpp11::to_string(conv_info.pad_left()));
+ build_opts.add_option("-DPAD_TOP=" + support::cpp11::to_string(conv_info.pad_top()));
+ build_opts.add_option("-DPAD_RIGHT=" + support::cpp11::to_string(conv_info.pad_right()));
+ build_opts.add_option("-DPAD_BOTTOM=" + support::cpp11::to_string(conv_info.pad_bottom()));
+ build_opts.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(input_width));
+ build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(input_height));
+ build_opts.add_option("-DSRC_DEPTH=" + support::cpp11::to_string(input_channel));
+ build_opts.add_option("-DDILATION_X=" + support::cpp11::to_string(dilation.x()));
+ build_opts.add_option("-DDILATION_Y=" + support::cpp11::to_string(dilation.y()));
+ build_opts.add_option_if(num_groups > 1, "-DNUM_GROUPS=" + support::cpp11::to_string(num_groups));
+ build_opts.add_option_if_else(is_data_type_quantized(data_type), "-DPAD_VALUE=" + support::cpp11::to_string(input->quantization_info().offset), "-DPAD_VALUE=0");
+ build_opts.add_option_if(has_bias, "-DHAS_BIAS");
+
+ if(data_layout == DataLayout::NHWC)
+ {
+ num_elems_processed_per_iteration = 2;
+ is_padding_required_nchw = false;
+
+ // Only the 3x3 case is optimized for NHWC
+ if(kernel_dims == Size2D(3U, 3U))
+ {
+ kernel_name = "im2col3x3_";
+ }
+
+ build_opts.add_option("-DVECTOR_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
+ build_opts.add_option("-DLAST_ACCESSED=" + support::cpp11::to_string(std::max(static_cast<int>(input_channel - num_elems_processed_per_iteration), 0)));
+ }
+ else
+ {
if(dilation == Size2D(1U, 1U))
{
- if(squared_im2col && !is_data_type_fixed_point(data_type))
+ const bool squared_im2col = kernel_dims.width == kernel_dims.height;
+ if(squared_im2col)
{
- // Check if we can run an optimized im2col
+ // Check if we can run an optimized im2col for NCHW
switch(kernel_dims.width)
{
case 1:
// Optimized im2col1x1 if stride_x = 1 and conv_info.has_padding() = false
if(conv_info.stride().first == 1 && !conv_info.has_padding())
{
- // Set hint for LWS
- _lws_hint = cl::NDRange(1, 1, 8);
- _num_elems_processed_per_iteration = 4;
- is_optimized_path = true;
- kernel_name = "im2col1x1_stridex1_dchw";
+ kernel_name = "im2col1x1_stridex1_";
+ num_elems_processed_per_iteration = 4;
+ is_padding_required_nchw = true;
}
break;
case 3:
- _lws_hint = cl::NDRange(1, 1, 8);
- _num_elems_processed_per_iteration = 1;
- is_optimized_path = true;
- kernel_name = "im2col3x3_dchw";
+ kernel_name = "im2col3x3_";
+ num_elems_processed_per_iteration = 1;
+ is_padding_required_nchw = true;
break;
case 5:
- _num_elems_processed_per_iteration = 1;
- is_optimized_path = true;
- kernel_name = "im2col5x5_dchw";
+ kernel_name = "im2col5x5_";
+ num_elems_processed_per_iteration = 1;
+ is_padding_required_nchw = true;
break;
case 11:
// Optimized im2col11x11 if pad_x = pad_y = 0
if(!conv_info.has_padding())
{
- _num_elems_processed_per_iteration = 1;
- is_optimized_path = true;
- kernel_name = "im2col11x11_padx0_pady0_dchw";
+ kernel_name = "im2col11x11_padx0_pady0_";
+ num_elems_processed_per_iteration = 1;
+ is_padding_required_nchw = true;
}
break;
default:
- is_optimized_path = false;
+ kernel_name = "im2col_generic_";
+ num_elems_processed_per_iteration = 1;
+ is_padding_required_nchw = false;
break;
}
}
else if(kernel_dims.width > 1 && !conv_info.has_padding())
{
- _num_elems_processed_per_iteration = 1;
- kernel_name = "im2col_generic_padx0_pady0_dchw";
+ kernel_name = "im2col_generic_padx0_pady0_";
+ num_elems_processed_per_iteration = 1;
+ is_padding_required_nchw = false;
// Optimized im2col is performed using one or more vector operations with the specified vector size
// and a remainder. For example, for 5x5 convolutions, im2col is performed using vectors of size 4
// and scalars; for 7x7 convolutions, using vectors of size 4 and vectors of size 3.
// Using the vector size of 4 is always safe since OpenCL supports vectors of size 2 and 3.
// Using the vector size of 8, however, may be faster.
- size_t vector_size = 4;
// For 2x2 convolutions, use vectors of size 2. (For 3x3 convolutions, im2col_kernel3x3_padx0_pady0
// is used instead.)
- if(kernel_dims.width < vector_size)
- {
- vector_size = kernel_dims.width;
- }
- // Local work size and vector size optimized for the 11x11 AlexNet convolution on Bifrost.
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && kernel_dims.width == 11)
- {
- _lws_hint = cl::NDRange(1, 1, 1);
- vector_size = 8;
- }
+ const size_t vector_size = std::min(static_cast<size_t>(4), kernel_dims.width);
const size_t width_mod_vector_size = kernel_dims.width % vector_size;
build_opts.add_option("-DVECTOR_SIZE=" + support::cpp11::to_string(vector_size));
build_opts.add_option("-DWIDTH_MOD_VECTOR_SIZE=" + support::cpp11::to_string(width_mod_vector_size));
}
}
- _run_func = &CLIm2ColKernel::run_generic;
}
- else
- {
- _num_elems_processed_per_iteration = 1;
- kernel_name = "im2col_reduced_dchw";
- _run_func = &CLIm2ColKernel::run_reduced;
- }
+
+ // Append the data layout to the kernel_name
+ kernel_name += lower_string(string_from_data_layout(data_layout));
+
+ Im2ColConfiguration im2col_config;
+ im2col_config.kernel_name = kernel_name;
+ im2col_config.build_options = build_opts.options();
+ im2col_config.num_elems_processed_per_iteration = num_elems_processed_per_iteration;
+ im2col_config.is_padding_required_nchw = is_padding_required_nchw;
+
+ return im2col_config;
+}
+} // namespace
+
+CLIm2ColKernel::CLIm2ColKernel()
+ : _input(nullptr), _output(nullptr), _convolved_dims(), _num_elems_processed_per_iteration(1), _kernel_dims(), _conv_info(), _num_groups()
+{
+}
+
+void CLIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation,
+ unsigned int num_groups)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, dilation, num_groups));
+
+ const DataLayout data_layout = input->info()->data_layout();
+ const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const unsigned int input_width = input->info()->dimension(width_idx);
+ const unsigned int input_height = input->info()->dimension(height_idx);
+
+ // Select and configure the optimal OpenCL kernel to run.
+ // This function returns the OpenCL kernel's name, the arguments to pass at compile time, the number of elements processed per iteration
+ // and the padding requirement flag
+ Im2ColConfiguration im2col_config = configure_opencl_kernel(input->info(), kernel_dims, conv_info, has_bias, dilation, num_groups);
// Create kernel
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(im2col_config.kernel_name, im2col_config.build_options));
+
+ _input = input;
+ _output = output;
+ _convolved_dims = scaled_dimensions(input_width, input_height, kernel_dims.width, kernel_dims.height, conv_info, dilation);
+ _num_elems_processed_per_iteration = im2col_config.num_elems_processed_per_iteration;
+ _kernel_dims = kernel_dims; // Only needed by the Tuner
+ _conv_info = conv_info; // Only needed by the Tuner
+ _num_groups = num_groups;
// Configure kernel window
- Window win;
- if(is_optimized_path)
- {
- win = calculate_max_window(*input->info(),
- Steps(_num_elems_processed_per_iteration),
- false,
- BorderSize(conv_info.pad_top(), conv_info.pad_right(), conv_info.pad_bottom(), conv_info.pad_left()));
-
- const int x = -conv_info.pad_left();
- const int y = -conv_info.pad_top();
- const int w = kernel_dims.width * _num_elems_processed_per_iteration;
- const int h = kernel_dims.height;
-
- AccessWindowRectangle input_access(input->info(), x, y, w, h);
-
- update_window_and_padding(win, input_access);
- }
- else
- {
- // For the generic case, CLIm2ColKernel doesn't need padding (we do not read out-of-bounds elements) so
- // update_window_and_padding() can be skipped
- win = calculate_max_window(*input->info(), Steps());
- }
-
- output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
- if(!run_img2col_reduced)
- {
- // set the Z dimension's step same size as the whole dimension so that one can't split across the Z dimension
- win.set_dimension_step(Window::DimZ, win[Window::DimZ].end() - win[Window::DimZ].start());
- }
-
- ICLKernel::configure(win);
+ auto win_config = validate_and_configure_window(input->info(), output->info(), kernel_dims, conv_info, has_bias, dilation, im2col_config.num_elems_processed_per_iteration,
+ im2col_config.is_padding_required_nchw, num_groups);
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ ICLKernel::configure_internal(win_config.second);
// Set config_id for enabling LWS tuning
- _config_id = kernel_name;
+ _config_id = im2col_config.kernel_name;
_config_id += "_";
_config_id += lower_string(string_from_data_type(input->info()->data_type()));
_config_id += "_";
+ _config_id += support::cpp11::to_string(num_groups);
+ _config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(0));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(1));
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_layout(input->info()->data_layout()));
}
-Status CLIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation)
+Status CLIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation,
+ unsigned int num_groups)
{
- ARM_COMPUTE_UNUSED(kernel_dims);
- ARM_COMPUTE_UNUSED(conv_info);
- ARM_COMPUTE_UNUSED(has_bias);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, has_bias, dilation));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, dilation, num_groups));
+ Im2ColConfiguration im2col_config = configure_opencl_kernel(input, kernel_dims, conv_info, has_bias, dilation, num_groups);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), kernel_dims, conv_info, has_bias, dilation, im2col_config.num_elems_processed_per_iteration,
+ im2col_config.is_padding_required_nchw, num_groups)
+ .first);
return Status{};
}
void CLIm2ColKernel::run(const Window &window, cl::CommandQueue &queue)
{
- ARM_COMPUTE_ERROR_ON(_run_func == nullptr);
- (this->*_run_func)(window, queue);
-}
-
-void CLIm2ColKernel::run_generic(const Window &window, cl::CommandQueue &queue)
-{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
// Get initial windows
+ // Collapse in order to have (SRC_DEPTH * BATCH_SIZE) on the 3rd dimension
Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
- // Change the Z dimension's step back to 1
window_collapsed.set_dimension_step(Window::DimZ, 1);
- Window slice = window_collapsed.first_slice_window_3D();
- Window slice_in = window_collapsed.first_slice_window_3D();
- Window slice_out = window_collapsed.first_slice_window_3D();
+ Window window_output;
+ window_output.use_tensor_dimensions(_output->info()->tensor_shape());
- // Setup slice if stride_x != 0 or stride_y != 0
- if(_convolved_dims.first != _input->info()->dimension(0) || _convolved_dims.second != _input->info()->dimension(1))
+ const Window first_slice_3d = window_collapsed.first_slice_window_3D();
+
+ Window slice = first_slice_3d;
+ Window slice_in = first_slice_3d;
+ Window slice_out = window_output.first_slice_window_2D();
+
+ if(_input->info()->data_layout() == DataLayout::NHWC)
{
- // If the stride_x or stride_y are not 1, the output tensor of matrix multiply (Convolved tensor) will not
- // have the same shape of the im2col input tensor
- // In this case we need to re-compute the window using the shape of the tensor after matrix multiply (convolved_dims)
- slice.set(Window::DimX, Window::Dimension(0, static_cast<int>(_convolved_dims.first), 1));
- slice.set(Window::DimY, Window::Dimension(0, static_cast<int>(_convolved_dims.second), 1));
+ const Window tmp_win = window.collapse_if_possible(ICLKernel::window(), 3);
+ const int num_batches = tmp_win[3].end();
+
+ slice.set(1, Window::Dimension(0, static_cast<int>(_output->info()->tensor_shape()[1]), 1));
+ slice.set(2, Window::Dimension(0, static_cast<int>(num_batches), 1));
+ }
+ else
+ {
+ slice.set(0, Window::Dimension(0, static_cast<int>(ceil_to_multiple(_convolved_dims.first, _num_elems_processed_per_iteration)), _num_elems_processed_per_iteration));
+ slice.set(1, Window::Dimension(0, static_cast<int>(_convolved_dims.second), 1));
+ // Note: In case of NCHW the 3rd dimension is already set collapsing the input window
}
// Setup input slice
- // The first three dimensions of the input are increased by the inner loops
+ // The dimensions of the input are increased within the OpenCL kernel
slice_in.set(Window::DimX, Window::Dimension(0, 0, 0));
slice_in.set(Window::DimY, Window::Dimension(0, 0, 0));
slice_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
// Setup output slice
- slice_out.set(Window::DimX, Window::Dimension(0, _output->info()->dimension(0), _kernel_dims.area()));
- slice_out.set(Window::DimY, Window::Dimension(0, _output->info()->dimension(1), 1));
- slice_out.set(Window::DimZ, Window::Dimension(0, 1, 1));
+ // The dimensions of the output are increased within the OpenCL kernel
+ slice_out.set(Window::DimX, Window::Dimension(0, 0, 0));
+ slice_out.set(Window::DimY, Window::Dimension(0, 0, 0));
+ unsigned int idx = num_arguments_per_3D_tensor() + (_num_groups == 1 ? num_arguments_per_2D_tensor() : num_arguments_per_3D_tensor());
+ _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input->info()->strides_in_bytes()[3]));
+ _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_output->info()->strides_in_bytes()[((_num_groups == 1) ? 2 : 3)]));
do
{
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice_in);
- add_2D_tensor_argument(idx, _output, slice_out);
- _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input->info()->strides_in_bytes()[3]));
- _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_output->info()->strides_in_bytes()[3]));
- enqueue(queue, *this, slice, _lws_hint);
+ if(_num_groups == 1)
+ {
+ add_2D_tensor_argument(idx, _output, slice_out);
+ }
+ else
+ {
+ add_3D_tensor_argument(idx, _output, slice_out);
+ }
+ enqueue(queue, *this, slice, lws_hint());
}
- while(window_collapsed.slide_window_slice_3D(slice) && window_collapsed.slide_window_slice_3D(slice_out) && window_collapsed.slide_window_slice_3D(slice_in));
-}
-
-void CLIm2ColKernel::run_reduced(const Window &window, cl::CommandQueue &queue)
-{
- ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
-
- Window out_window;
- out_window.use_tensor_dimensions(_output->info()->tensor_shape());
-
- Window out_slice = out_window.first_slice_window_1D();
- Window in_slice = window.first_slice_window_3D();
-
- // Run kernel
- do
- {
- // Set arguments
- unsigned int idx = 0;
- add_3D_tensor_argument(idx, _input, in_slice);
- add_1D_tensor_argument(idx, _output, out_slice);
-
- _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(0));
- _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(1));
- enqueue(queue, *this, in_slice, _lws_hint);
- }
- while(window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_1D(out_slice));
+ while(window_collapsed.slide_window_slice_3D(slice) && window_output.slide_window_slice_2D(slice_out) && window_collapsed.slide_window_slice_3D(slice_in));
}
diff --git a/src/core/CL/kernels/CLIntegralImageKernel.cpp b/src/core/CL/kernels/CLIntegralImageKernel.cpp
index 69ede45..6fb39ff 100644
--- a/src/core/CL/kernels/CLIntegralImageKernel.cpp
+++ b/src/core/CL/kernels/CLIntegralImageKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -60,7 +60,7 @@
output_access.set_valid_region(win, input->info()->valid_region());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
CLIntegralImageVertKernel::CLIntegralImageVertKernel()
@@ -89,7 +89,7 @@
in_out_access.set_valid_region(win, in_out->info()->valid_region());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLIntegralImageVertKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLL2NormalizeLayerKernel.cpp b/src/core/CL/kernels/CLL2NormalizeLayerKernel.cpp
index 3d30350..54ed51e 100644
--- a/src/core/CL/kernels/CLL2NormalizeLayerKernel.cpp
+++ b/src/core/CL/kernels/CLL2NormalizeLayerKernel.cpp
@@ -26,7 +26,6 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
@@ -78,7 +77,7 @@
Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
// Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output, input->tensor_shape(), 1, input->data_type(), input->fixed_point_position());
+ auto_init_if_empty(*output, input->tensor_shape(), 1, input->data_type());
AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
@@ -121,7 +120,7 @@
auto win_config = validate_and_configure_window(_input->info(), _output->info());
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- ICLKernel::configure(std::get<1>(win_config));
+ ICLKernel::configure_internal(std::get<1>(win_config));
}
Status CLL2NormalizeLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, unsigned int axis, float epsilon)
diff --git a/src/core/CL/kernels/CLLKTrackerKernel.cpp b/src/core/CL/kernels/CLLKTrackerKernel.cpp
index 078d18e..40ed630 100644
--- a/src/core/CL/kernels/CLLKTrackerKernel.cpp
+++ b/src/core/CL/kernels/CLLKTrackerKernel.cpp
@@ -75,7 +75,7 @@
Window window;
window.set(Window::DimX, Window::Dimension(0, old_points->num_values(), 1));
window.set(Window::DimY, Window::Dimension(0, 1, 1));
- ICLKernel::configure(window);
+ ICLKernel::configure_internal(window);
}
void CLLKTrackerInitKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -104,7 +104,7 @@
Window window;
window.set(Window::DimX, Window::Dimension(0, new_points_internal->num_values(), 1));
window.set(Window::DimY, Window::Dimension(0, 1, 1));
- ICLKernel::configure(window);
+ ICLKernel::configure_internal(window);
}
void CLLKTrackerFinalizeKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -156,7 +156,7 @@
AccessWindowStatic(old_scharr_gy->info(), valid_region.start(0), valid_region.start(1),
valid_region.end(0), valid_region.end(1)));
- ICLKernel::configure(window);
+ ICLKernel::configure_internal(window);
// Initialize required variables
const int level0 = (level == 0) ? 1 : 0;
@@ -232,7 +232,7 @@
AccessWindowStatic(new_input->info(), valid_region.start(0), valid_region.start(1),
valid_region.end(0), valid_region.end(1)));
- ICLKernel::configure(window);
+ ICLKernel::configure_internal(window);
// Initialize required variables
const int level0 = (level == 0) ? 1 : 0;
diff --git a/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp
index 84f2e0c..ad2f3a4 100644
--- a/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp
@@ -26,13 +26,13 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include <set>
@@ -51,6 +51,7 @@
Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
@@ -89,13 +90,14 @@
_input1 = input1;
_output = output;
+ cl::NDRange lws_hint;
if(output->info()->dimension(1) == 196)
{
- _lws_hint = cl::NDRange(1, 7);
+ lws_hint = cl::NDRange(1, 7);
}
else
{
- _lws_hint = cl::NDRange(8, 8);
+ lws_hint = cl::NDRange(8, 8);
}
std::ostringstream mm_arguments;
@@ -113,7 +115,7 @@
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- ICLKernel::configure(std::get<1>(win_config));
+ ICLKernel::configure_internal(std::get<1>(win_config), lws_hint);
}
Status CLLocallyConnectedMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
@@ -141,7 +143,7 @@
add_2D_tensor_argument(idx, _input0, slice);
add_3D_tensor_argument(idx, _input1, slice_matrix_b);
add_2D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_2D(slice));
}
diff --git a/src/core/CL/kernels/CLMagnitudePhaseKernel.cpp b/src/core/CL/kernels/CLMagnitudePhaseKernel.cpp
index c504189..0b34c59 100644
--- a/src/core/CL/kernels/CLMagnitudePhaseKernel.cpp
+++ b/src/core/CL/kernels/CLMagnitudePhaseKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -137,7 +137,7 @@
output_magnitude_access.set_valid_region(win, valid_region);
output_phase_access.set_valid_region(win, valid_region);
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLMagnitudePhaseKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLMeanStdDevKernel.cpp b/src/core/CL/kernels/CLMeanStdDevKernel.cpp
index 1bf831b..0cde9c5 100644
--- a/src/core/CL/kernels/CLMeanStdDevKernel.cpp
+++ b/src/core/CL/kernels/CLMeanStdDevKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,14 +23,15 @@
*/
#include "arm_compute/core/CL/kernels/CLMeanStdDevKernel.h"
+#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include <cmath>
@@ -49,13 +50,24 @@
return _border_size;
}
+Status CLMeanStdDevKernel::validate(const ITensorInfo *input, float *mean, cl::Buffer *global_sum, float *stddev, cl::Buffer *global_sum_squared)
+{
+ ARM_COMPUTE_UNUSED(mean);
+ ARM_COMPUTE_UNUSED(stddev);
+ ARM_COMPUTE_UNUSED(global_sum);
+ ARM_COMPUTE_UNUSED(global_sum_squared);
+ ARM_COMPUTE_RETURN_ERROR_ON_INT64_BASE_ATOMICS_UNSUPPORTED();
+ ARM_COMPUTE_RETURN_ERROR_ON_TENSOR_NOT_2D(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
+
+ return Status{};
+}
+
void CLMeanStdDevKernel::configure(const ICLImage *input, float *mean, cl::Buffer *global_sum, float *stddev, cl::Buffer *global_sum_squared)
{
- ARM_COMPUTE_ERROR_ON_TENSOR_NOT_2D(input);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON(nullptr == mean);
- ARM_COMPUTE_ERROR_ON(nullptr == global_sum);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, global_sum);
ARM_COMPUTE_ERROR_ON(stddev && nullptr == global_sum_squared);
+ ARM_COMPUTE_ERROR_THROW_ON(CLMeanStdDevKernel::validate(input->info(), mean, global_sum, stddev, global_sum_squared));
_input = input;
_mean = mean;
@@ -94,7 +106,7 @@
AccessWindowRectangle input_access(input->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
update_window_and_padding(win, input_access);
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLMeanStdDevKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLMedian3x3Kernel.cpp b/src/core/CL/kernels/CLMedian3x3Kernel.cpp
index 3b9fb1f..b93179d 100644
--- a/src/core/CL/kernels/CLMedian3x3Kernel.cpp
+++ b/src/core/CL/kernels/CLMedian3x3Kernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -62,5 +62,5 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/kernels/CLMinMaxLayerKernel.cpp b/src/core/CL/kernels/CLMinMaxLayerKernel.cpp
index 60dd5e7..fa7b678 100644
--- a/src/core/CL/kernels/CLMinMaxLayerKernel.cpp
+++ b/src/core/CL/kernels/CLMinMaxLayerKernel.cpp
@@ -62,7 +62,7 @@
TensorShape output_shape = compute_min_max_shape(input);
// Output auto initialization if not yet initialized
- auto_init_if_empty(*output, output_shape, 1, input->data_type(), input->fixed_point_position());
+ auto_init_if_empty(*output, output_shape, 1, input->data_type());
const unsigned int num_elems_processed_per_iteration = 1;
@@ -105,7 +105,7 @@
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- ICLKernel::configure(std::get<1>(win_config));
+ ICLKernel::configure_internal(std::get<1>(win_config));
}
Status CLMinMaxLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
diff --git a/src/core/CL/kernels/CLMinMaxLocationKernel.cpp b/src/core/CL/kernels/CLMinMaxLocationKernel.cpp
index 5636592..0c7f3bc 100644
--- a/src/core/CL/kernels/CLMinMaxLocationKernel.cpp
+++ b/src/core/CL/kernels/CLMinMaxLocationKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -118,7 +118,7 @@
// Configure kernel window
Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
update_window_and_padding(win, AccessWindowHorizontal(input->info(), 0, ceil_to_multiple(num_elems_processed_per_iteration, 16)));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLMinMaxKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -209,7 +209,7 @@
constexpr unsigned int num_elems_processed_per_iteration = 1;
Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
update_window_and_padding(win, AccessWindowHorizontal(input->info(), 0, num_elems_processed_per_iteration));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLMinMaxLocationKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLNonLinearFilterKernel.cpp b/src/core/CL/kernels/CLNonLinearFilterKernel.cpp
index 6afa582..5e41974 100644
--- a/src/core/CL/kernels/CLNonLinearFilterKernel.cpp
+++ b/src/core/CL/kernels/CLNonLinearFilterKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -94,5 +94,5 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/kernels/CLNonMaximaSuppression3x3Kernel.cpp b/src/core/CL/kernels/CLNonMaximaSuppression3x3Kernel.cpp
index 6a96b0e..4e41f0d 100644
--- a/src/core/CL/kernels/CLNonMaximaSuppression3x3Kernel.cpp
+++ b/src/core/CL/kernels/CLNonMaximaSuppression3x3Kernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -68,5 +68,5 @@
output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/kernels/CLNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLNormalizationLayerKernel.cpp
index df2104a..eb1ad68 100644
--- a/src/core/CL/kernels/CLNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,12 +25,11 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
using namespace arm_compute;
@@ -39,24 +38,19 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, NormalizationLayerInfo norm_info)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC && norm_info.type() == NormType::IN_MAP_2D,
+ "Only Cross-map and 1D In-map normalization is supported for NHWC layout");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(norm_info.norm_size() % 2), "Normalization size should be odd");
- if(is_data_type_fixed_point(input->data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(norm_info.beta(), input);
- ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(norm_info.kappa(), input);
- ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(norm_info.scale_coeff(), input);
- }
-
// Checks performed when output is configured
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -67,14 +61,15 @@
// Output tensor auto initialization if not yet initialized
auto_init_if_empty(*output, *input->clone());
- const unsigned int norm_size = norm_info.norm_size();
- bool is_in_map = norm_info.is_in_map();
+ const unsigned int norm_idx = get_normalization_dimension_index(input->data_layout(), norm_info);
+ const unsigned int norm_size = norm_info.norm_size();
+ bool is_norm_accross_width = norm_idx == 0;
- const unsigned int border_width = is_in_map ? std::min(norm_size / 2, 3U) : 0;
+ const unsigned int border_width = is_norm_accross_width ? std::min(norm_size / 2, 3U) : 0;
const BorderSize border_size = BorderSize(0, border_width);
- const unsigned int num_elems_processed_per_iteration = (is_data_type_fixed_point(input->data_type())) ? 16 : 4;
- const unsigned int num_elems_read_per_iteration = is_in_map ? (num_elems_processed_per_iteration + 2 * (norm_size / 2)) : num_elems_processed_per_iteration;
+ const unsigned int num_elems_processed_per_iteration = 4;
+ const unsigned int num_elems_read_per_iteration = is_norm_accross_width ? (num_elems_processed_per_iteration + 2 * (norm_size / 2)) : num_elems_processed_per_iteration;
Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
@@ -92,7 +87,7 @@
} // namespace
CLNormalizationLayerKernel::CLNormalizationLayerKernel()
- : _input(nullptr), _output(nullptr), _border_size(0), _is_in_map(false)
+ : _input(nullptr), _output(nullptr), _border_size(0), _is_norm_across_width(false)
{
}
@@ -114,18 +109,17 @@
_input = input;
_output = output;
- _is_in_map = norm_info.is_in_map();
- const unsigned int border_width = _is_in_map ? std::min(norm_info.norm_size() / 2, 3U) : 0;
+ const unsigned int norm_idx = get_normalization_dimension_index(input->info()->data_layout(), norm_info);
+ _is_norm_across_width = norm_idx == 0;
+ const unsigned int border_width = _is_norm_across_width ? std::min(norm_info.norm_size() / 2, 3U) : 0;
_border_size = BorderSize(0, border_width);
- const unsigned int num_elems_processed_per_iteration = (is_data_type_fixed_point(input->info()->data_type())) ? 16 : 4;
+ const unsigned int num_elems_processed_per_iteration = 4;
const bool is_in_map_2D = (norm_info.type() == NormType::IN_MAP_2D);
// Set build options
CLBuildOptions build_opts;
build_opts.add_option(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())));
- build_opts.add_option_if(is_data_type_fixed_point(input->info()->data_type()),
- "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
build_opts.add_option(("-DCOEFF=" + float_to_string_with_full_precision(norm_info.scale_coeff())));
build_opts.add_option(("-DBETA=" + float_to_string_with_full_precision(norm_info.beta())));
build_opts.add_option(("-DKAPPA=" + float_to_string_with_full_precision(norm_info.kappa())));
@@ -135,13 +129,13 @@
build_opts.add_option_if(is_in_map_2D, "-DIN_MAP_2D");
// Create kernel
- std::string kernel_name = _is_in_map ? "normalization_layer_in_map" : "normalization_layer_cross_map";
+ std::string kernel_name = _is_norm_across_width ? "normalization_layer_in_map" : "normalization_layer_cross_map";
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), output->info(), norm_info);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Set config_id for enabling LWS tuning
_config_id = "normalization_layer_";
@@ -169,7 +163,7 @@
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
- const int collapsed_dimension = _is_in_map ? Window::DimZ : 4;
+ const int collapsed_dimension = _is_norm_across_width ? Window::DimZ : 4;
Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), collapsed_dimension);
Window slice = window_collapsed.first_slice_window_3D();
@@ -178,7 +172,7 @@
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice);
add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window_collapsed.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLPermuteKernel.cpp b/src/core/CL/kernels/CLPermuteKernel.cpp
index d20bee1..c6f0f4b 100644
--- a/src/core/CL/kernels/CLPermuteKernel.cpp
+++ b/src/core/CL/kernels/CLPermuteKernel.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
@@ -50,8 +51,9 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8,
- DataType::U16, DataType::S16, DataType::QS16,
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG((perm != PermutationVector{ 2U, 0U, 1U })
@@ -66,7 +68,6 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
}
@@ -119,7 +120,7 @@
coord.set_num_dimensions(output->info()->num_dimensions());
output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
Status CLPermuteKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
index f30ba61..4ca2ef8 100644
--- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
+++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
@@ -25,12 +25,12 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include <cmath>
@@ -50,34 +50,24 @@
ARM_COMPUTE_UNUSED(overflow_policy);
ARM_COMPUTE_UNUSED(rounding_policy);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative.");
const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2);
-
- if(is_data_type_fixed_point(input1->data_type()))
- {
- // All data types must be all QS8 or all QS16
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1.");
- }
// Validate in case of configured output
if(output->total_size() > 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
"Output can only be U8 if both inputs are U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output);
- if(is_data_type_fixed_point(input1->data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output);
- }
}
return Status{};
@@ -171,14 +161,6 @@
{
compute_type = "int";
}
- else if(input1->info()->data_type() == DataType::QS8)
- {
- compute_type = "qs8";
- }
- else if(input1->info()->data_type() == DataType::QS16)
- {
- compute_type = "qs16";
- }
else
{
compute_type = "ushort";
@@ -194,10 +176,6 @@
std::set<std::string> build_opts;
build_opts.emplace((overflow_policy == ConvertPolicy::WRAP || is_data_type_float(output->info()->data_type())) ? "-DWRAP" : "-DSATURATE");
build_opts.emplace((rounding_policy == RoundingPolicy::TO_ZERO) ? "-DROUND=_rtz" : "-DROUND=_rte");
- if(is_data_type_fixed_point(input1->info()->data_type()))
- {
- build_opts.emplace("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input1->info()->fixed_point_position()));
- }
build_opts.emplace("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(input1->info()->data_type()));
build_opts.emplace("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(input2->info()->data_type()));
build_opts.emplace("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->info()->data_type()));
@@ -219,7 +197,7 @@
_kernel.setArg(idx++, scale);
}
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale,
diff --git a/src/core/CL/kernels/CLPoolingLayerKernel.cpp b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
index 02fa283..df13068 100644
--- a/src/core/CL/kernels/CLPoolingLayerKernel.cpp
+++ b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
@@ -26,13 +26,13 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLKernel.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
@@ -58,10 +58,11 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
DataLayout data_layout = input->data_layout();
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
switch(data_layout)
{
case DataLayout::NCHW:
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
break;
case DataLayout::NHWC:
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
@@ -77,8 +78,7 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
- TensorInfo out_info(TensorInfo(compute_pool_shape(*input, pool_info), 1, output->data_type(), output->fixed_point_position()));
+ TensorInfo out_info(TensorInfo(compute_pool_shape(*input, pool_info), 1, output->data_type()));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &out_info);
}
@@ -154,7 +154,9 @@
num_elems_processed_per_iteration = 8;
win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
- AccessWindowRectangle input_access(input, 0, -pool_pad_left, num_elems_processed_per_iteration, pool_size_x);
+ AccessWindowStatic input_access(input,
+ 0, -1,
+ ceil_to_multiple(input->dimension(0), num_elems_processed_per_iteration), input->dimension(1));
AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
window_changed = update_window_and_padding(win, input_access, output_access);
output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
@@ -207,15 +209,12 @@
_output = output;
_pool_info = pool_info;
- const GPUTarget gpu_target = get_target();
- const DataType data_type = input->info()->data_type();
+ const DataType data_type = input->info()->data_type();
// Set build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
build_opts.add_option("-DPOOL_" + string_from_pooling_type(pool_type));
- build_opts.add_option_if(is_data_type_fixed_point(data_type),
- "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
build_opts.add_option("-DSTRIDE_X=" + support::cpp11::to_string(pool_stride_x));
build_opts.add_option("-DSTRIDE_Y=" + support::cpp11::to_string(pool_stride_y));
build_opts.add_option("-DPAD_X=" + support::cpp11::to_string(pool_pad_left));
@@ -240,7 +239,7 @@
{
// Check if we have pool3x3 with stride_x less equal than 3. In these cases, run an optimized OpenCL kernel where
// each thread computes 4 output elements
- const bool is_pool3x3_stride_le3 = (pool_size_x == 3) && (pool_size_y == 3) && (pool_stride_x <= 3) && !is_data_type_fixed_point(data_type);
+ const bool is_pool3x3_stride_le3 = (pool_size_x == 3) && (pool_size_y == 3) && (pool_stride_x <= 3);
std::string kernel_name = ((is_pool3x3_stride_le3) ? "pooling_layer_optimized_" : "pooling_layer_")
+ support::cpp11::to_string(pool_size_x);
@@ -270,22 +269,13 @@
auto win_config = validate_and_configure_window(input->info(), output->info(), pool_info);
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- ICLKernel::configure(std::get<1>(win_config));
+ ICLKernel::configure_internal(std::get<1>(win_config));
- // Configure the local work size (hint) from the first two dimensions of the global work size.
- // On Bifrost, this works for up to 35x35xC filters, for which the pooling_layer_3_optimized
- // kernel is launched with gws=(9, 33, C). In any case, the hint will be ignored if it is
- // invalid (e.g. exceeds the maximum workgroup size that the kernel can be launched with).
if(data_layout == DataLayout::NCHW)
{
CLPoolingConfig pooling_config = std::get<2>(win_config);
_num_elems_processed_per_iteration = pooling_config.first;
_border_size = pooling_config.second;
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
- {
- cl::NDRange gws = ICLKernel::gws_from_window(std::get<1>(win_config));
- _lws_hint = cl::NDRange(gws[0], gws[1], 1);
- }
}
else
{
@@ -304,6 +294,8 @@
_config_id += support::cpp11::to_string(output->info()->dimension(idx_height));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(idx_channel));
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_layout(input->info()->data_layout()));
}
Status CLPoolingLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info)
@@ -344,7 +336,7 @@
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, in_slice);
add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window_collapsed.slide_window_slice_3D(slice));
break;
@@ -363,7 +355,7 @@
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, in_slice);
add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice) && window.slide_window_slice_3D(in_slice));
break;
diff --git a/src/core/CL/kernels/CLQuantizationLayerKernel.cpp b/src/core/CL/kernels/CLQuantizationLayerKernel.cpp
index 028e508..9028b0f 100644
--- a/src/core/CL/kernels/CLQuantizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLQuantizationLayerKernel.cpp
@@ -54,7 +54,7 @@
std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *min_max)
{
// Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::U8, 0);
+ auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::U8);
constexpr unsigned int num_elems_processed_per_iteration = 4;
@@ -96,7 +96,7 @@
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- ICLKernel::configure(std::get<1>(win_config));
+ ICLKernel::configure_internal(std::get<1>(win_config));
}
Status CLQuantizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *min_max)
diff --git a/src/core/CL/kernels/CLROIPoolingLayerKernel.cpp b/src/core/CL/kernels/CLROIPoolingLayerKernel.cpp
index a07a424..2367694 100644
--- a/src/core/CL/kernels/CLROIPoolingLayerKernel.cpp
+++ b/src/core/CL/kernels/CLROIPoolingLayerKernel.cpp
@@ -26,13 +26,13 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLArray.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include <cmath>
@@ -49,13 +49,14 @@
void CLROIPoolingLayerKernel::configure(const ICLTensor *input, const ICLROIArray *rois, ICLTensor *output, const ROIPoolingLayerInfo &pool_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, rois, output);
+ ARM_COMPUTE_ERROR_ON_F16_UNSUPPORTED(input);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON((pool_info.pooled_width() == 0) || (pool_info.pooled_height() == 0));
ARM_COMPUTE_ERROR_ON(rois->num_values() == 0);
// Output auto inizialitation if not yet initialized
TensorShape output_shape(pool_info.pooled_width(), pool_info.pooled_height(), input->info()->dimension(2), rois->num_values());
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_ERROR_ON((output->info()->dimension(0) != pool_info.pooled_width()) || (output->info()->dimension(1) != pool_info.pooled_height()));
@@ -100,7 +101,7 @@
update_window_and_padding(window, input_access, output_access);
output_access.set_valid_region(window, ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(window);
+ ICLKernel::configure_internal(window);
}
void CLROIPoolingLayerKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLReductionOperationKernel.cpp b/src/core/CL/kernels/CLReductionOperationKernel.cpp
index 25b756b..bf36ae2 100644
--- a/src/core/CL/kernels/CLReductionOperationKernel.cpp
+++ b/src/core/CL/kernels/CLReductionOperationKernel.cpp
@@ -27,7 +27,6 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
@@ -40,12 +39,15 @@
namespace
{
+// OpenCL kernel requires input width to be a power of 2.
+constexpr unsigned int border_val = 64;
+
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
{
ARM_COMPUTE_UNUSED(op);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() != DataLayout::NCHW);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
@@ -65,12 +67,12 @@
// Output tensor auto initialization if not yet initialized
TensorShape output_shape{ input->tensor_shape() };
output_shape.set(axis, 1);
- auto_init_if_empty(*output, output_shape, 1, input->data_type(), input->fixed_point_position());
+ auto_init_if_empty(*output, output_shape, 1, input->data_type());
const unsigned int num_elems_processed_per_iteration = 16;
Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
- const unsigned int border_width = ((input->dimension(0) % 128) != 0) ? 128 - input->dimension(0) % 128 : 0;
+ const unsigned int border_width = ((input->dimension(0) % border_val) != 0) ? border_val - input->dimension(0) % border_val : 0;
AccessWindowStatic input_access(input, 0, 0, input->dimension(0) + border_width, 1);
AccessWindowHorizontal output_access(output, 0, 1);
@@ -101,23 +103,24 @@
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
const unsigned int num_elems_processed_per_iteration = 16;
- const unsigned int border_width = ((input->info()->dimension(0) % 128) != 0) ? 128 - input->info()->dimension(0) % 128 : 0;
+ const unsigned int width_leftover = input->info()->dimension(0) % border_val;
+ const unsigned int border_width = (width_leftover != 0) ? border_val - width_leftover : 0;
+ const unsigned int num_of_threads = ((input->info()->dimension(0) + border_width) / 16);
_input = input;
_output = output;
_reduction_axis = axis;
_op = op;
- _lws_hint = cl::NDRange(8);
- _border_size = BorderSize(0, border_width, 0, 0);
+
+ // Set the number of WG based on the input size. If input width is < 128
+ // we can use fewer threads than 8.
+ cl::NDRange lws_hint = cl::NDRange(std::min(8U, num_of_threads));
+ _border_size = BorderSize(0, border_width, 0, 0);
// Set build options
std::set<std::string> build_opts;
build_opts.emplace(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())));
build_opts.emplace(("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration)));
- if(is_data_type_fixed_point(input->info()->data_type()))
- {
- build_opts.emplace("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
- }
switch(op)
{
@@ -139,7 +142,7 @@
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- ICLKernel::configure(std::get<1>(win_config));
+ ICLKernel::configure_internal(std::get<1>(win_config), lws_hint);
}
Status CLReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
@@ -164,11 +167,11 @@
Window out_slice = out_window.first_slice_window_2D();
// Reshape window
- const unsigned int border_width = ((in_slice.x().end() % 128) != 0) ? 128 - in_slice.x().end() % 128 : 0;
+ const unsigned int border_width = ((in_slice.x().end() % border_val) != 0) ? border_val - in_slice.x().end() % border_val : 0;
in_slice.set(Window::DimX, Window::Dimension(in_slice.x().start(), in_slice.x().end() + border_width, in_slice.x().step()));
// Set local sums buffer
- unsigned int local_sum_size = _lws_hint[0] * _input->info()->element_size();
+ unsigned int local_sum_size = lws_hint()[0] * _input->info()->element_size();
_kernel.setArg(num_arguments_per_2D_tensor() * 2, local_sum_size, nullptr);
do
@@ -176,7 +179,7 @@
unsigned int idx = 0;
add_2D_tensor_argument(idx, _input, in_slice);
add_2D_tensor_argument(idx, _output, out_slice);
- enqueue(queue, *this, in_slice, _lws_hint);
+ enqueue(queue, *this, in_slice, lws_hint());
}
while(window.slide_window_slice_2D(in_slice) && window.slide_window_slice_2D(out_slice));
}
diff --git a/src/core/CL/kernels/CLRemapKernel.cpp b/src/core/CL/kernels/CLRemapKernel.cpp
index b46bb30..33c5f2d 100644
--- a/src/core/CL/kernels/CLRemapKernel.cpp
+++ b/src/core/CL/kernels/CLRemapKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -83,7 +83,7 @@
output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
// Set static arguments
unsigned int idx = 4 * num_arguments_per_2D_tensor(); //Skip the input and output parameters
diff --git a/src/core/CL/kernels/CLReshapeLayerKernel.cpp b/src/core/CL/kernels/CLReshapeLayerKernel.cpp
index 95f980f..c7efa9a 100644
--- a/src/core/CL/kernels/CLReshapeLayerKernel.cpp
+++ b/src/core/CL/kernels/CLReshapeLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,13 +26,13 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include <string>
@@ -46,12 +46,12 @@
void CLReshapeLayerKernel::configure(const ICLTensor *input, ICLTensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8,
- DataType::U16, DataType::S16, DataType::QS16,
+ ARM_COMPUTE_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
DataType::U32, DataType::S32, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
ARM_COMPUTE_ERROR_ON(input->info()->tensor_shape().total_size() != output->info()->tensor_shape().total_size());
@@ -92,7 +92,7 @@
output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLReshapeLayerKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLScaleKernel.cpp b/src/core/CL/kernels/CLScaleKernel.cpp
index 9b8a582..d56d6f7 100644
--- a/src/core/CL/kernels/CLScaleKernel.cpp
+++ b/src/core/CL/kernels/CLScaleKernel.cpp
@@ -26,94 +26,234 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLKernel.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Validate.h"
#include <set>
#include <string>
using namespace arm_compute;
+namespace
+{
+inline std::pair<float, float> calculate_scale_factors(const ITensorInfo &input, const ITensorInfo &output)
+{
+ DataLayout data_layout = input.data_layout();
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
+ // Compute the ratio between source width/height and destination width/height
+ const unsigned int input_width = input.dimension(idx_width);
+ const unsigned int input_height = input.dimension(idx_height);
+ const unsigned int output_width = output.dimension(idx_width);
+ const unsigned int output_height = output.dimension(idx_height);
+
+ float wr = static_cast<float>(input_width) / static_cast<float>(output_width);
+ float hr = static_cast<float>(input_height) / static_cast<float>(output_height);
+
+ return std::make_pair(wr, hr);
+}
+
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, InterpolationPolicy policy)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(output == input);
+
+ float wr = 0.f;
+ float hr = 0.f;
+ std::tie(wr, hr) = calculate_scale_factors(*input, *output);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(policy == InterpolationPolicy::AREA && (wr > 1.f || hr > 1.f));
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, BorderSize &border)
+{
+ Window win{};
+ bool window_changed{};
+ unsigned int num_elems_processed_per_iteration = 0;
+ DataLayout data_layout = input->data_layout();
+
+ switch(data_layout)
+ {
+ case DataLayout::NCHW:
+ {
+ if(border_mode == BorderMode::UNDEFINED)
+ {
+ border = BorderSize(0);
+ }
+
+ num_elems_processed_per_iteration = 4;
+ // Configure kernel window
+ win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
+ const ValidRegion &input_valid_region = input->valid_region();
+
+ // Reads can occur within the valid region of the input
+ AccessWindowStatic input_access(input,
+ input_valid_region.anchor[0] - border.left, input_valid_region.anchor[1] - border.top,
+ input_valid_region.anchor[0] + input_valid_region.shape[0] + border.right,
+ input_valid_region.anchor[1] + input_valid_region.shape[1] + border.bottom);
+ AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+
+ output_access.set_valid_region(win, calculate_valid_region_scale(*(input),
+ output->tensor_shape(),
+ policy,
+ sampling_policy,
+ border_mode == BorderMode::UNDEFINED));
+
+ window_changed = update_window_and_padding(win, input_access, output_access);
+ }
+ break;
+ case DataLayout::NHWC:
+ {
+ num_elems_processed_per_iteration = 1;
+ // Configure kernel window
+ win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
+ AccessWindowRectangle input_access(input, -border.left, -border.top, num_elems_processed_per_iteration, num_elems_processed_per_iteration);
+ AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+ window_changed = update_window_and_padding(win, input_access, output_access);
+ output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data layout not supported");
+ }
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
+} // namespace
+
BorderSize CLScaleKernel::border_size() const
{
return BorderSize(1);
}
-void CLScaleKernel::configure(const ICLTensor *input, ICLTensor *output, InterpolationPolicy policy, bool border_undefined, SamplingPolicy sampling_policy)
+Status CLScaleKernel::validate(const ITensorInfo *input, const ITensorInfo *output, InterpolationPolicy policy,
+ BorderMode border_mode, SamplingPolicy sampling_policy)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_NULLPTR(output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON(output == input);
+ BorderSize border = BorderSize(1);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, policy));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), policy, border_mode, sampling_policy, border).first);
- _input = input;
- _output = output;
+ return Status{};
+}
+
+const ICLTensor *CLScaleKernel::input() const
+{
+ return _input;
+}
+
+const ICLTensor *CLScaleKernel::output() const
+{
+ return _output;
+}
+
+void CLScaleKernel::configure(const ICLTensor *input, ICLTensor *output, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
+{
+ _input = input;
+ _output = output;
+ _interpolationPolicy = policy;
+
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), policy));
+
+ float wr = 0.f;
+ float hr = 0.f;
+ std::tie(wr, hr) = calculate_scale_factors(*input->info(), *output->info());
+
+ DataLayout data_layout = input->info()->data_layout();
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
// Compute the ratio between source width/height and destination width/height
- const auto wr = static_cast<float>(input->info()->dimension(0)) / static_cast<float>(output->info()->dimension(0));
- const auto hr = static_cast<float>(input->info()->dimension(1)) / static_cast<float>(output->info()->dimension(1));
+ const unsigned int input_width = input->info()->dimension(idx_width);
+ const unsigned int input_height = input->info()->dimension(idx_height);
+ const unsigned int output_width = output->info()->dimension(idx_width);
+ const unsigned int output_height = output->info()->dimension(idx_height);
// Compute actual border size
- BorderSize border = border_undefined ? BorderSize(0) : border_size();
+ BorderSize border = border_size();
// Area interpolation behaves as Nearest Neighbour in case of up-sampling
if(policy == InterpolationPolicy::AREA && wr <= 1.f && hr <= 1.f)
{
policy = InterpolationPolicy::NEAREST_NEIGHBOR;
}
- else
- {
- ARM_COMPUTE_ERROR_ON(policy == InterpolationPolicy::AREA);
- }
+
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(input->info(), output->info(), policy, border_mode, sampling_policy, border);
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ ICLKernel::configure_internal(win_config.second);
// Create kernel
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
build_opts.add_option("-DBORDER_SIZE=" + support::cpp11::to_string(border.right));
+ build_opts.add_option_if(border_mode == BorderMode::REPLICATE, "-DBORDER_MODE_REPLICATE");
build_opts.add_option_if_else(sampling_policy == SamplingPolicy::CENTER, "-DSAMPLING_POLICY_CENTER", "-DSAMPLING_POLICY_TOP_LEFT");
std::string interpolation_name = string_from_interpolation_policy(policy);
std::transform(interpolation_name.begin(), interpolation_name.end(), interpolation_name.begin(), ::tolower);
- std::string kernel_name = "scale_" + interpolation_name;
+ std::string kernel_name = "scale_" + interpolation_name + "_" + lower_string(string_from_data_layout(data_layout));
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
- // Configure kernel window
- constexpr unsigned int num_elems_processed_per_iteration = 4;
-
- Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration));
-
- const ValidRegion &input_valid_region = input->info()->valid_region();
-
- // Reads can occur within the valid region of the input
- AccessWindowStatic input_access(input->info(),
- input_valid_region.anchor[0] - border.left, input_valid_region.anchor[1] - border.top,
- input_valid_region.anchor[0] + input_valid_region.shape[0] + border.right,
- input_valid_region.anchor[1] + input_valid_region.shape[1] + border.bottom);
-
- AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
-
- update_window_and_padding(win, input_access, output_access);
-
- output_access.set_valid_region(win, calculate_valid_region_scale(*(input->info()),
- output->info()->tensor_shape(),
- policy,
- sampling_policy,
- border_undefined));
-
- ICLKernel::configure(win);
+ unsigned int idx = data_layout == DataLayout::NHWC ? 2 * num_arguments_per_3D_tensor() : 2 * num_arguments_per_2D_tensor(); //Skip the input and output parameters
// Set static kernel arguments
- const float scale_x = static_cast<float>(input->info()->dimension(0)) / output->info()->dimension(0);
- const float scale_y = static_cast<float>(input->info()->dimension(1)) / output->info()->dimension(1);
+ const float scale_x = static_cast<float>(input_width) / output_width;
+ const float scale_y = static_cast<float>(input_height) / output_height;
- unsigned int idx = 2 * num_arguments_per_2D_tensor(); //Skip the input and output parameters
- _kernel.setArg<float>(idx++, input->info()->dimension(0));
- _kernel.setArg<float>(idx++, input->info()->dimension(1));
+ _kernel.setArg<float>(idx++, input_width);
+ _kernel.setArg<float>(idx++, input_height);
_kernel.setArg<float>(idx++, scale_x);
_kernel.setArg<float>(idx++, scale_y);
}
+
+void CLScaleKernel::run(const Window &window, cl::CommandQueue &queue)
+{
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
+
+ switch(_input->info()->data_layout())
+ {
+ case DataLayout::NCHW:
+ {
+ Window slice = window.first_slice_window_2D();
+
+ do
+ {
+ unsigned int idx = 0;
+ add_2D_tensor_argument(idx, _input, slice);
+ add_2D_tensor_argument(idx, _output, slice);
+ enqueue(queue, *this, slice, lws_hint());
+ }
+ while(window.slide_window_slice_2D(slice));
+ break;
+ }
+ case DataLayout::NHWC:
+ {
+ Window slice = window.first_slice_window_3D();
+
+ do
+ {
+ unsigned int idx = 0;
+ add_3D_tensor_argument(idx, _input, slice);
+ add_3D_tensor_argument(idx, _output, slice);
+ enqueue(queue, *this, slice, lws_hint());
+ }
+ while(window.slide_window_slice_3D(slice));
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Data layout not supported");
+ }
+}
diff --git a/src/core/CL/kernels/CLScharr3x3Kernel.cpp b/src/core/CL/kernels/CLScharr3x3Kernel.cpp
index 913ef59..5182390 100644
--- a/src/core/CL/kernels/CLScharr3x3Kernel.cpp
+++ b/src/core/CL/kernels/CLScharr3x3Kernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -102,7 +102,7 @@
output_x_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
output_y_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLScharr3x3Kernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLSobel3x3Kernel.cpp b/src/core/CL/kernels/CLSobel3x3Kernel.cpp
index 436aaa4..b4bfe28 100644
--- a/src/core/CL/kernels/CLSobel3x3Kernel.cpp
+++ b/src/core/CL/kernels/CLSobel3x3Kernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -102,7 +102,7 @@
output_x_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
output_y_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLSobel3x3Kernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLSobel5x5Kernel.cpp b/src/core/CL/kernels/CLSobel5x5Kernel.cpp
index 4c0316f..46aa074 100644
--- a/src/core/CL/kernels/CLSobel5x5Kernel.cpp
+++ b/src/core/CL/kernels/CLSobel5x5Kernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -102,7 +102,7 @@
output_x_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
output_y_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLSobel5x5HorKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -201,7 +201,7 @@
output_x_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
output_y_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLSobel5x5VertKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLSobel7x7Kernel.cpp b/src/core/CL/kernels/CLSobel7x7Kernel.cpp
index a477953..0c94e88 100644
--- a/src/core/CL/kernels/CLSobel7x7Kernel.cpp
+++ b/src/core/CL/kernels/CLSobel7x7Kernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -105,7 +105,7 @@
output_x_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
output_y_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLSobel7x7HorKernel::run(const Window &window, cl::CommandQueue &queue)
@@ -204,7 +204,7 @@
output_x_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
output_y_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
void CLSobel7x7VertKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
index 447d6ee..403256b 100644
--- a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
+++ b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
@@ -26,12 +26,12 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
@@ -81,11 +81,11 @@
Status validate_arguments_1DMaxShiftExpSum(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(max, sum, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, max);
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->data_type());
@@ -101,7 +101,6 @@
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
}
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
}
// Checks performed when sum is configured
@@ -116,7 +115,6 @@
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(max, sum);
}
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(max, sum);
}
return Status{};
@@ -124,10 +122,10 @@
Status validate_arguments_1DNorm(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::S32, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(sum, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum);
// Note: output should always have a scale of 1/256 and offset 0
const QuantizationInfo allowed_quantization_info = QuantizationInfo(1.f / 256, 0);
@@ -137,7 +135,6 @@
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
if(!is_quantized_asymmetric)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
@@ -237,19 +234,15 @@
const DataType dt = input->info()->data_type();
const size_t reduction_dim_size = input->info()->dimension(0);
- auto beta_int = static_cast<int>(lround(beta * (1 << input->info()->fixed_point_position())));
// Set build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dt));
- build_opts.add_option_if(is_data_type_fixed_point(dt),
- "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
build_opts.add_option_if(dt == DataType::F16, "-DUSE_F16");
- build_opts.add_option_if(is_data_type_fixed_point(dt) && (beta != 1.0f), "-DBETA=" + support::cpp11::to_string(beta_int));
build_opts.add_option_if(is_data_type_float(dt) && (beta != 1.0f), "-DBETA=" + float_to_string_with_full_precision(beta));
build_opts.add_options_if(is_data_type_quantized_asymmetric(dt), prepare_quantized_softmax_build_options(input->info()->quantization_info().scale, beta).options());
- _lws_hint = cl::NullRange;
+ cl::NDRange lws_hint(cl::NullRange);
std::string kernel_name = is_data_type_quantized_asymmetric(dt) ? std::string("softmax_layer_max_shift_exp_sum_quantized_serial") :
std::string("softmax_layer_max_shift_exp_sum_serial");
ParallelReductionInfo parallel_reduction_info = is_parallel_reduction(reduction_dim_size);
@@ -271,7 +264,7 @@
build_opts.add_option_if((multiple_grid_size != 0) || ((reduction_dim_size % vector_size) != 0), "-DNON_MULTIPLE_OF_GRID_SIZE");
// Setting _lws_hint in this way can also communicate grid_size to CLLogits1DMaxShiftExpSumKernel::run().
// A single workgroup performs reduction in dimension 0 in the parallel case, hence lws[0]==gws[0].
- _lws_hint = cl::NDRange(_grid_size);
+ lws_hint = cl::NDRange(_grid_size);
}
// Create kernel.
@@ -284,7 +277,7 @@
// Configure window
auto win_config = validate_and_configure_window_1DMaxShiftExpSum(input->info(), max->info(), output->info(), sum->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second, lws_hint);
}
Status CLLogits1DMaxShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum)
@@ -329,7 +322,7 @@
add_3D_tensor_argument(idx, _max, slice);
add_3D_tensor_argument(idx, _output, slice);
add_3D_tensor_argument(idx, _sum, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window_collapsed.slide_window_slice_3D(slice));
}
@@ -362,8 +355,6 @@
// Set build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
- build_opts.add_option_if(is_data_type_fixed_point(input->info()->data_type()),
- "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
build_opts.add_options_if(is_quantized_asymmetric,
prepare_quantized_softmax_build_options(input->info()->quantization_info().scale, beta).options());
@@ -374,7 +365,7 @@
// Configure window
auto win_config = validate_and_configure_window_1DNorm(input->info(), output->info(), sum->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLLogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
@@ -403,7 +394,7 @@
add_3D_tensor_argument(idx, _input, slice);
add_3D_tensor_argument(idx, _sum, sum_slice);
add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window_collapsed.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLTransposeKernel.cpp b/src/core/CL/kernels/CLTransposeKernel.cpp
index b80a612..94e15f3 100644
--- a/src/core/CL/kernels/CLTransposeKernel.cpp
+++ b/src/core/CL/kernels/CLTransposeKernel.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
@@ -55,8 +56,9 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8,
- DataType::U16, DataType::S16, DataType::QS16,
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
@@ -66,7 +68,6 @@
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -115,9 +116,8 @@
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info()));
- _input = input;
- _output = output;
- _lws_hint = cl::NDRange(2, 8);
+ _input = input;
+ _output = output;
std::set<std::string> build_opts;
std::ostringstream data_type_in_bytes;
@@ -129,5 +129,5 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second, cl::NDRange(2, 8));
}
diff --git a/src/core/CL/kernels/CLWarpAffineKernel.cpp b/src/core/CL/kernels/CLWarpAffineKernel.cpp
index be095f2..1fae2b1 100644
--- a/src/core/CL/kernels/CLWarpAffineKernel.cpp
+++ b/src/core/CL/kernels/CLWarpAffineKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,9 +42,9 @@
namespace
{
-void options_add_matrix(std::set<std::string> &options, const float *matrix, size_t size)
+void options_add_matrix(std::set<std::string> &options, const std::array<float, 9> &matrix)
{
- for(size_t i = 0; i < size; ++i)
+ for(size_t i = 0; i < 6; ++i)
{
std::stringstream mat_str;
mat_str << "-DMAT" << i << "=" << matrix[i] << " ";
@@ -58,7 +58,7 @@
return BorderSize(1);
}
-void CLWarpAffineKernel::configure(const ICLTensor *input, ICLTensor *output, const float *matrix, InterpolationPolicy policy)
+void CLWarpAffineKernel::configure(const ICLTensor *input, ICLTensor *output, const std::array<float, 9> &matrix, InterpolationPolicy policy)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
@@ -69,7 +69,7 @@
// Create build options
std::set<std::string> options;
- options_add_matrix(options, matrix, 6);
+ options_add_matrix(options, matrix);
options.emplace(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())));
// Create kernel
@@ -98,5 +98,5 @@
output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/kernels/CLWarpPerspectiveKernel.cpp b/src/core/CL/kernels/CLWarpPerspectiveKernel.cpp
index a47952f..e537aec 100644
--- a/src/core/CL/kernels/CLWarpPerspectiveKernel.cpp
+++ b/src/core/CL/kernels/CLWarpPerspectiveKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,9 +42,9 @@
namespace
{
-inline void options_add_matrix(std::set<std::string> &options, const float *matrix, size_t size)
+inline void options_add_matrix(std::set<std::string> &options, const std::array<float, 9> &matrix)
{
- for(size_t i = 0; i < size; ++i)
+ for(size_t i = 0; i < 9; ++i)
{
std::stringstream mat_str;
mat_str << "-DMAT" << i << "=" << matrix[i] << " ";
@@ -58,7 +58,7 @@
return BorderSize(1);
}
-void CLWarpPerspectiveKernel::configure(const ICLTensor *input, ICLTensor *output, const float *matrix, InterpolationPolicy policy)
+void CLWarpPerspectiveKernel::configure(const ICLTensor *input, ICLTensor *output, const std::array<float, 9> &matrix, InterpolationPolicy policy)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
@@ -69,7 +69,7 @@
// Create build options
std::set<std::string> options;
- options_add_matrix(options, matrix, 9);
+ options_add_matrix(options, matrix);
options.emplace(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())));
// Create kernel
@@ -95,5 +95,5 @@
output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
diff --git a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
index f5eaa5a..7639a48 100644
--- a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
+++ b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
@@ -25,12 +25,12 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
using namespace arm_compute;
@@ -38,16 +38,20 @@
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output, unsigned int num_groups)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(num_groups == 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::NHWC && num_groups > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4 && num_groups > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(3) % num_groups) != 0);
if(biases != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases);
ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 4) && (biases->num_dimensions() != 1));
ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 5) && (biases->num_dimensions() != 2));
ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 4) && (biases->dimension(0) != input->tensor_shape()[3]));
@@ -57,9 +61,8 @@
// Checks performed when output is configured
if(output->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_weights_reshaped_shape(*input, biases != nullptr));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_weights_reshaped_shape(*input, biases != nullptr, num_groups));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
}
@@ -72,17 +75,17 @@
{
}
-void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor *biases, ICLTensor *output)
+void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor *biases, ICLTensor *output, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Output tensor auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_weights_reshaped_shape(*input->info(), (biases != nullptr))));
+ auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_weights_reshaped_shape(*input->info(), (biases != nullptr), num_groups)));
// Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(),
(biases != nullptr) ? biases->info() : nullptr,
- output->info()));
+ output->info(), num_groups));
const DataType data_type = input->info()->data_type();
@@ -93,30 +96,22 @@
// Create build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
+ build_opts.add_option("-DNUM_GROUPS=" + support::cpp11::to_string(num_groups));
build_opts.add_option_if(biases != nullptr, "-DHAS_BIAS");
- build_opts.add_option_if(is_data_type_fixed_point(data_type), "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("reshape_to_columns", build_opts.options()));
- // Set static arguments
- unsigned int idx = num_arguments_per_3D_tensor() + num_arguments_per_2D_tensor();
- idx += (biases != nullptr) ? num_arguments_per_1D_tensor() : 0;
- _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(0));
- _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(1));
- _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(2));
- _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(3));
-
// Configure window
Window win = calculate_max_window(*input->info(), Steps());
// The CLWeightsReshapeKernel doesn't need padding so update_window_and_padding() can be skipped
output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
- ICLKernel::configure(win);
+ ICLKernel::configure_internal(win);
}
-Status CLWeightsReshapeKernel::validate(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output)
+Status CLWeightsReshapeKernel::validate(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output, unsigned int num_groups)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, biases, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, biases, output, num_groups));
return Status{};
}
@@ -134,6 +129,14 @@
Window biases_window;
Window biases_slice;
+ unsigned int idx = num_arguments_per_3D_tensor() + num_arguments_per_2D_tensor();
+ idx += (_biases != nullptr) ? num_arguments_per_1D_tensor() : 0;
+ _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(0));
+ _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(1));
+ _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(2));
+ _kernel.setArg<cl_uint>(idx++, _input->info()->dimension(3));
+ _kernel.setArg<cl_uint>(idx++, _output->info()->strides_in_bytes().z());
+
if(_biases != nullptr)
{
biases_window.use_tensor_dimensions(_biases->info()->tensor_shape());
diff --git a/src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp b/src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp
index b8bce38..e5ab8d2 100644
--- a/src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp
+++ b/src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
@@ -32,7 +33,6 @@
#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
@@ -59,10 +59,10 @@
Status validate_arguments(const ITensorInfo *input, unsigned int width_offset, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8, DataType::U16, DataType::S16, DataType::QS16, DataType::F16, DataType::U32,
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16, DataType::F16, DataType::U32,
DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) + width_offset > output->dimension(0));
for(size_t i = 1; i < Coordinates::num_max_dimensions; ++i)
@@ -102,20 +102,16 @@
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_underlying_cl_type_from_data_type(input->info()->data_type()));
build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
+ build_opts.add_option("-DWIDTH_OFFSET=" + support::cpp11::to_string(_width_offset));
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("concatenate_width", build_opts.options()));
- const int offset_to_first_elements_in_bytes = _width_offset * _output->info()->strides_in_bytes()[0];
-
- unsigned int idx = 2 * num_arguments_per_3D_tensor(); // Skip the input and output parameters
- _kernel.setArg<cl_int>(idx, offset_to_first_elements_in_bytes);
-
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), width_offset, output->info());
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- ICLKernel::configure(std::get<1>(win_config));
+ ICLKernel::configure_internal(std::get<1>(win_config));
}
void CLWidthConcatenateLayerKernel::run(const Window &window, cl::CommandQueue &queue)
diff --git a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
index 41b3ac5..818638c 100644
--- a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
@@ -25,7 +25,6 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
-#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/Helpers.h"
@@ -47,7 +46,6 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() != DataLayout::NCHW);
const Size2D kernel_size = winograd_info.kernel_size;
const Size2D output_tile_size = winograd_info.output_tile_size;
@@ -55,11 +53,7 @@
const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Winograd filter transform only supports 3x3 and 5x5 kernels");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U)
- && output_tile_size != Size2D(4U, 4U),
- "Winograd filter transform only supports 2x2 or 4x4 output tile for 3x3 kernels");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size != Size2D(4U, 4U), "Winograd filter transform only supports 4x4 output tile for 5x5 kernels");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!cl_winograd_convolution_layer_supported(output_tile_size, kernel_size, input->data_layout()), "Winograd filter transform not supported");
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(idx_w) != kernel_size.width || input->dimension(idx_h) != kernel_size.height);
ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
@@ -79,10 +73,11 @@
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- const unsigned int num_elems_processed_per_iteration_x = input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH));
- const unsigned int num_elems_processed_per_iteration_y = input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT));
+ const unsigned int num_elems_processed_per_iteration_x = input->data_layout() == DataLayout::NCHW ? input->dimension(0) : 1;
+ const unsigned int num_elems_processed_per_iteration_y = input->dimension(1);
+ const unsigned int num_elems_read_per_iteration_z = input->data_layout() == DataLayout::NCHW ? 1 : input->dimension(2);
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y, num_elems_read_per_iteration_z));
bool window_changed = false;
AccessWindowRectangle input_access(input, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
@@ -111,17 +106,17 @@
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), winograd_info));
- const size_t idx_c = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::CHANNEL);
-
// Set build options
CLBuildOptions build_opts;
- build_opts.add_option("-DNUM_CHANNELS=" + support::cpp11::to_string(input->info()->dimension(idx_c)));
+ build_opts.add_option("-DSRC_DIM_Z=" + support::cpp11::to_string(input->info()->dimension(2)));
+ build_opts.add_option_if(winograd_info.kernel_size.height == 1, "-DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL");
+ build_opts.add_option_if(winograd_info.kernel_size.width == 1, "-DWINOGRAD_FILTER_TRANSFORM_VERTICAL");
const Size2D kernel_size = winograd_info.kernel_size;
const Size2D output_tile_size = winograd_info.output_tile_size;
// Create kernel
- std::string kernel_name = "winograd_filter_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string() + "_nchw";
+ std::string kernel_name = "winograd_filter_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string() + "_" + lower_string(string_from_data_layout(input->info()->data_layout()));
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
_input = input;
@@ -130,7 +125,7 @@
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
}
Status CLWinogradFilterTransformKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info)
diff --git a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
index febd22b..c4e472a 100644
--- a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/CL/kernels/CLWinogradInputTransformKernel.h"
+#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/ICLTensor.h"
@@ -30,6 +31,7 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "support/ToolchainSupport.h"
@@ -40,17 +42,13 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() != DataLayout::NCHW);
const PadStrideInfo conv_info = winograd_info.convolution_info;
const Size2D output_tile_size = winograd_info.output_tile_size;
const Size2D kernel_size = winograd_info.kernel_size;
ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd input transform only supports unit strides");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Winograd input transform only supports 3x3 and 5x5 kernels");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U)
- && output_tile_size != Size2D(4U, 4U),
- "Winograd input transform only supports 2x2 or 4x4 output tile for 3x3 kernels");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size != Size2D(4U, 4U), "Winograd input transform only supports 4x4 output tile for 5x5 kernels");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!cl_winograd_convolution_layer_supported(output_tile_size, kernel_size, input->data_layout()), "Winograd input transform not supported");
+
ARM_COMPUTE_UNUSED(conv_info);
ARM_COMPUTE_UNUSED(output_tile_size);
ARM_COMPUTE_UNUSED(kernel_size);
@@ -71,18 +69,27 @@
{
ARM_COMPUTE_UNUSED(output);
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- const PadStrideInfo conv_info = winograd_info.convolution_info;
- const Size2D output_tile_size = winograd_info.output_tile_size;
- const Size2D kernel_size = winograd_info.kernel_size;
- const unsigned int num_elems_read_per_iteration_x = output_tile_size.width + kernel_size.width - 1;
- const unsigned int num_elems_read_per_iteration_y = output_tile_size.height + kernel_size.height - 1;
+ bool window_changed = false;
+ Window win = calculate_max_window(*input, Steps(1, 1));
- Window win = calculate_max_window(*input, Steps(1, 1));
+ if(input->data_layout() == DataLayout::NCHW)
+ {
+ const PadStrideInfo conv_info = winograd_info.convolution_info;
+ const Size2D output_tile_size = winograd_info.output_tile_size;
+ const Size2D kernel_size = winograd_info.kernel_size;
- AccessWindowRectangle input_access(input, -conv_info.pad_left(), -conv_info.pad_top(), num_elems_read_per_iteration_x, num_elems_read_per_iteration_y);
+ unsigned int num_elems_read_per_iteration_x = output_tile_size.width + kernel_size.width - 1;
+ unsigned int num_elems_read_per_iteration_y = output_tile_size.height + kernel_size.height - 1;
- bool window_changed = update_window_and_padding(win, input_access);
+ AccessWindowRectangle input_access(input, -conv_info.pad_left(), -conv_info.pad_top(), num_elems_read_per_iteration_x, num_elems_read_per_iteration_y);
+ window_changed = update_window_and_padding(win, input_access);
+ }
+ else
+ {
+ AccessWindowStatic input_access(input, 0, -1, input->dimension(0), input->dimension(1) + 1);
+ window_changed = update_window_and_padding(win, input_access);
+ }
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
return std::make_pair(err, win);
@@ -108,19 +115,36 @@
const Size2D output_tile_size = winograd_info.output_tile_size;
const Size2D kernel_size = winograd_info.kernel_size;
- // Compute number of elements to process in the X and Y direction
- const int num_elements_x = input->info()->dimension(0) - (kernel_size.width - 1) + conv_info.pad_left() + conv_info.pad_right();
- const int num_elements_y = input->info()->dimension(1) - (kernel_size.height - 1) + conv_info.pad_top() + conv_info.pad_bottom();
+ const size_t idx_w = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT);
- // Check if we need to extend the right or bottom border
- const unsigned int extra_border_right = ((num_elements_x % output_tile_size.width) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.width - 1);
- const unsigned int extra_border_bottom = ((num_elements_y % output_tile_size.height) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.height - 1);
+ // Compute number of elements to process in the X and Y direction
+ const int num_elements_x = input->info()->dimension(idx_w) - (kernel_size.width - 1) + conv_info.pad_left() + conv_info.pad_right();
+ const int num_elements_y = input->info()->dimension(idx_h) - (kernel_size.height - 1) + conv_info.pad_top() + conv_info.pad_bottom();
+
+ if(input->info()->data_layout() == DataLayout::NCHW)
+ {
+ // Check if we need to extend the right or bottom border
+ const unsigned int extra_border_right = ((num_elements_x % output_tile_size.width) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.width - 1);
+ const unsigned int extra_border_bottom = ((num_elements_y % output_tile_size.height) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.height - 1);
+
+ _border_size = BorderSize(conv_info.pad_top(), conv_info.pad_right() + extra_border_right, conv_info.pad_bottom() + extra_border_bottom, conv_info.pad_left());
+ }
+ else
+ {
+ _border_size = BorderSize(1U, 0U, 1U, 0);
+ }
+
+ // Compute the number of output tiles along the x and y direction of size "output_tile_size"
+ const Size2D num_tiles = compute_winograd_convolution_tiles(Size2D(input->info()->dimension(idx_w), input->info()->dimension(idx_h)),
+ kernel_size,
+ output_tile_size,
+ conv_info);
_input = input;
_output = output;
- _border_size = BorderSize(conv_info.pad_top(), conv_info.pad_right() + extra_border_right, conv_info.pad_bottom() + extra_border_bottom, conv_info.pad_left());
- _num_tiles_x = std::ceil(num_elements_x / static_cast<float>(output_tile_size.width));
- _num_tiles_y = std::ceil(num_elements_y / static_cast<float>(output_tile_size.height));
+ _num_tiles_x = num_tiles.width;
+ _num_tiles_y = num_tiles.height;
const TensorShape output_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input->info(), winograd_info);
@@ -133,29 +157,40 @@
build_opts.add_option("-DNUM_TILES_X=" + support::cpp11::to_string(_num_tiles_x));
build_opts.add_option("-DPAD_LEFT=" + support::cpp11::to_string(conv_info.pad_left()));
build_opts.add_option("-DPAD_TOP=" + support::cpp11::to_string(conv_info.pad_top()));
+ build_opts.add_option("-DOUTPUT_TILE_W=" + support::cpp11::to_string(output_tile_size.width));
+ build_opts.add_option("-DOUTPUT_TILE_H=" + support::cpp11::to_string(output_tile_size.height));
+ build_opts.add_option_if(winograd_info.kernel_size.height == 1, "-DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL");
+ build_opts.add_option_if(winograd_info.kernel_size.width == 1, "-DWINOGRAD_INPUT_TRANSFORM_VERTICAL");
+
+ if(input->info()->data_layout() == DataLayout::NHWC)
+ {
+ build_opts.add_option("-DSRC_DIM_1=" + support::cpp11::to_string(_input->info()->dimension(1)));
+ build_opts.add_option("-DSRC_DIM_2=" + support::cpp11::to_string(_input->info()->dimension(2)));
+ }
// Create kernel
std::string kernel_name = "winograd_input_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string();
+ // Get the maximum dimension from the tile size
+ const unsigned int tile_max_dim = std::max(output_tile_size.width, output_tile_size.height);
+
// Check optimized kernel if output_dims == 2x2
- if(output_tile_size == Size2D(2U, 2U))
+ if((tile_max_dim == 2) && (input->info()->data_layout() == DataLayout::NCHW))
{
_step_z = (_input->info()->dimension(2) % 2) != 0 ? 1 : 2;
}
- _lws_hint = cl::NDRange(1, 1, 8);
-
// Append stepz and data layout
kernel_name += "_stepz";
kernel_name += support::cpp11::to_string(_step_z);
- kernel_name += "_nchw";
+ kernel_name += "_" + lower_string(string_from_data_layout(input->info()->data_layout()));
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Create window and update padding
auto win_config = validate_and_configure_window(input->info(), output->info(), winograd_info);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second, cl::NDRange(1, 1, 8));
_config_id = kernel_name;
_config_id += support::cpp11::to_string(input->info()->dimension(0));
@@ -167,6 +202,8 @@
_config_id += support::cpp11::to_string(conv_info.pad_left());
_config_id += "_";
_config_id += support::cpp11::to_string(conv_info.pad_top());
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_layout(input->info()->data_layout()));
}
Status CLWinogradInputTransformKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info)
@@ -183,12 +220,16 @@
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
- Window slice = window.first_slice_window_3D();
- slice.set(Window::DimX, Window::Dimension(0, _num_tiles_x, 1));
- slice.set(Window::DimY, Window::Dimension(0, _num_tiles_y, 1));
+ const size_t idx_w = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::HEIGHT);
+ const size_t idx_c = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::CHANNEL);
- ARM_COMPUTE_ERROR_ON(((slice.z().end() - slice.z().start()) % _step_z) != 0);
- slice.set(Window::DimZ, Window::Dimension(slice.z().start(), slice.z().end(), _step_z));
+ Window slice = window.first_slice_window_3D();
+ slice.set(idx_w, Window::Dimension(0, _num_tiles_x, 1));
+ slice.set(idx_h, Window::Dimension(0, _num_tiles_y, 1));
+
+ ARM_COMPUTE_ERROR_ON(((slice[idx_c].end() - slice[idx_c].start()) % _step_z) != 0);
+ slice.set(idx_c, Window::Dimension(slice[idx_c].start(), slice[idx_c].end(), _step_z));
do
{
@@ -196,7 +237,7 @@
add_3D_tensor_argument(idx, _input, slice);
add_3D_tensor_argument(idx, _output, slice);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
index 5c0a735..fa42596 100644
--- a/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
@@ -48,25 +48,26 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const WinogradInfo &winograd_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(winograd_info.output_data_layout != DataLayout::NCHW);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(output->data_layout() != winograd_info.output_data_layout);
const PadStrideInfo conv_info = winograd_info.convolution_info;
const Size2D output_tile_size = winograd_info.output_tile_size;
const Size2D kernel_size = winograd_info.kernel_size;
const Size2D input_dimensions = winograd_info.input_dimensions;
+ const unsigned int num_channels = (winograd_info.kernel_size.width + winograd_info.output_tile_size.width - 1) * (winograd_info.kernel_size.height + winograd_info.output_tile_size.height - 1);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Only 3x3 and 5x5 kernels are supported");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size == Size2D(2U, 2U) && input->dimension(2) != 16, "Wrong number of batches");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size == Size2D(4U, 4U) && input->dimension(2) != 36, "Wrong number of batches");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size == Size2D(4U, 4U) && input->dimension(2) != 64, "Wrong number of batches");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!cl_winograd_convolution_layer_supported(output_tile_size, kernel_size, winograd_info.output_data_layout), "Winograd output transform not supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->dimension(2) != num_channels, "Wrong number of channels");
// Compute number of elements to process in the X and Y direction
- const int num_elements_x = input_dimensions.width - (kernel_size.width - 1) + conv_info.pad_left() + conv_info.pad_right();
- const int num_elements_y = input_dimensions.height - (kernel_size.height - 1) + conv_info.pad_top() + conv_info.pad_bottom();
- const int num_tiles_x = std::ceil(num_elements_x / static_cast<float>(output_tile_size.width));
- const int num_tiles_y = std::ceil(num_elements_y / static_cast<float>(output_tile_size.height));
+ // Compute the number of output tiles along the x and y direction of size "output_tile_size"
+ const Size2D num_tiles = compute_winograd_convolution_tiles(input_dimensions,
+ kernel_size,
+ output_tile_size,
+ conv_info);
- ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != static_cast<unsigned int>((num_tiles_x * num_tiles_y)));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != static_cast<unsigned int>((num_tiles.area())));
if(bias != nullptr)
{
@@ -95,19 +96,30 @@
Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
bool window_changed = false;
+ int output_static_window_end_x = 0;
+ int output_static_window_end_y = 0;
+
+ if(output->data_layout() == DataLayout::NCHW)
+ {
+ output_static_window_end_x = ceil_to_multiple(output->dimension(0), output_tile_size.width);
+ output_static_window_end_y = ceil_to_multiple(output->dimension(1), output_tile_size.height);
+ }
+ else
+ {
+ output_static_window_end_x = output->dimension(0);
+ output_static_window_end_y = std::max(ceil_to_multiple(output->dimension(1), output_tile_size.width), output->dimension(1) + 1 /* For out of bound reads towards the z axis */);
+ }
+
AccessWindowRectangle input_access(input, 0, 0, num_elems_processed_per_iteration, num_elems_processed_per_iteration);
- AccessWindowStatic output_access(output, 0, 0, ceil_to_multiple(output->dimension(0), output_tile_size.width), ceil_to_multiple(output->dimension(1), output_tile_size.height));
+ AccessWindowStatic output_access(output, 0, 0, output_static_window_end_x, output_static_window_end_y);
+ window_changed = update_window_and_padding(win, input_access, output_access);
+ output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
if(bias != nullptr)
{
AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
- window_changed = update_window_and_padding(win, input_access, bias_access, output_access);
+ window_changed = window_changed || update_window_and_padding(win, bias_access);
}
- else
- {
- window_changed = update_window_and_padding(win, input_access, output_access);
- }
- output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
return std::make_pair(err, win);
@@ -137,22 +149,30 @@
const Size2D kernel_size = winograd_info.kernel_size;
const Size2D output_tile_size = winograd_info.output_tile_size;
const PadStrideInfo conv_info = winograd_info.convolution_info;
- const int num_elements_x = input_dimensions.width - (kernel_size.width - 1) + conv_info.pad_left() + conv_info.pad_right();
- const int num_tiles_x = std::ceil(num_elements_x / static_cast<float>(output_tile_size.width));
+
+ // Compute the number of output tiles along the x and y direction of size "output_tile_size"
+ const Size2D num_tiles = compute_winograd_convolution_tiles(input_dimensions,
+ kernel_size,
+ output_tile_size,
+ conv_info);
// Set build options
CLBuildOptions build_opts;
build_opts.add_option_if(_bias != nullptr, std::string("-DHAS_BIAS"));
- build_opts.add_option("-DNUM_TILES_X=" + support::cpp11::to_string(num_tiles_x));
+ build_opts.add_option("-DNUM_TILES_X=" + support::cpp11::to_string(num_tiles.width));
+ build_opts.add_option("-DOUTPUT_TILE_W=" + support::cpp11::to_string(output_tile_size.width));
+ build_opts.add_option("-DOUTPUT_TILE_H=" + support::cpp11::to_string(output_tile_size.height));
+ build_opts.add_option_if(winograd_info.kernel_size.height == 1, "-DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL");
+ build_opts.add_option_if(winograd_info.kernel_size.width == 1, "-DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL");
// Create kernel
- std::string kernel_name = "winograd_output_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string() + "_nchw";
+ std::string kernel_name = "winograd_output_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string() + "_" + lower_string(string_from_data_layout(winograd_info.output_data_layout));
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), (bias != nullptr ? bias->info() : nullptr), output->info(), winograd_info.output_tile_size);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure(win_config.second);
+ ICLKernel::configure_internal(win_config.second);
// Set config_id for enabling LWS tuning
_config_id = kernel_name;
@@ -166,6 +186,8 @@
_config_id += support::cpp11::to_string(output->info()->dimension(0));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(1));
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_layout(winograd_info.output_data_layout));
}
Status CLWinogradOutputTransformKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const WinogradInfo &winograd_info)
@@ -198,12 +220,18 @@
add_1D_tensor_argument(idx1, _bias, slice_biases);
}
+ if(_output->info()->data_layout() == DataLayout::NHWC)
+ {
+ unsigned int idx2 = 2 * num_arguments_per_3D_tensor() + ((_bias != nullptr) ? num_arguments_per_1D_tensor() : 0);
+ _kernel.setArg(idx2, static_cast<int>(_output->info()->total_size() - _output->info()->strides_in_bytes().y()));
+ }
+
do
{
unsigned int idx = 0;
add_3D_tensor_argument(idx, _input, slice);
add_3D_tensor_argument(idx, _output, slice_out);
- enqueue(queue, *this, slice, _lws_hint);
+ enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice) && window.slide_window_slice_3D(slice_out));
}
diff --git a/src/core/CPP/CPPTypes.cpp b/src/core/CPP/CPPTypes.cpp
index 9c2b41b..e4c3b77 100644
--- a/src/core/CPP/CPPTypes.cpp
+++ b/src/core/CPP/CPPTypes.cpp
@@ -51,6 +51,10 @@
}
}
+unsigned int CPUInfo::get_cpu_num() const
+{
+ return _percpu.size();
+}
bool CPUInfo::has_fp16() const
{
return _fp16;
diff --git a/src/core/CPP/kernels/CPPPermuteKernel.cpp b/src/core/CPP/kernels/CPPPermuteKernel.cpp
index 5c93f3e..17eaec2 100644
--- a/src/core/CPP/kernels/CPPPermuteKernel.cpp
+++ b/src/core/CPP/kernels/CPPPermuteKernel.cpp
@@ -40,8 +40,8 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8,
- DataType::U16, DataType::S16, DataType::QS16,
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(perm.num_dimensions() > 4, "Only up to 4D permutation vectors are supported");
@@ -53,7 +53,6 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
diff --git a/src/core/GLES_COMPUTE/kernels/GCActivationLayerKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCActivationLayerKernel.cpp
index 8287823..874c336 100644
--- a/src/core/GLES_COMPUTE/kernels/GCActivationLayerKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCActivationLayerKernel.cpp
@@ -55,11 +55,10 @@
if(output != nullptr)
{
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
_output = output;
}
diff --git a/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
index 9a592df..c745f3f 100644
--- a/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
@@ -48,27 +48,23 @@
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var);
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(mean, var);
if(output->total_size() != 0)
{
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
if(beta != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, beta);
}
if(gamma != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma);
}
if(act_info.enabled())
{
@@ -86,7 +82,7 @@
ITensorInfo *beta, ITensorInfo *gamma)
{
// Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output, input->tensor_shape(), 1, input->data_type(), input->fixed_point_position());
+ auto_init_if_empty(*output, input->tensor_shape(), 1, input->data_type());
unsigned int num_elems_processed_per_iteration = 1;
if(input->data_type() == DataType::F16)
diff --git a/src/core/GLES_COMPUTE/kernels/GCDepthwiseConvolutionLayer3x3Kernel.cpp b/src/core/GLES_COMPUTE/kernels/GCDepthwiseConvolutionLayer3x3Kernel.cpp
index c237409..a0d1876 100644
--- a/src/core/GLES_COMPUTE/kernels/GCDepthwiseConvolutionLayer3x3Kernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCDepthwiseConvolutionLayer3x3Kernel.cpp
@@ -69,8 +69,7 @@
auto_init_if_empty(*output->info(),
output_shape,
1,
- input->info()->data_type(),
- input->info()->fixed_point_position());
+ input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
ARM_COMPUTE_ERROR_ON(output->info()->dimension(2) != weights->info()->dimension(2));
diff --git a/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp
index 6b16def..8b0d41f 100644
--- a/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp
@@ -78,12 +78,11 @@
output_shape.set(2, weights->info()->dimension(3));
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_ERROR_ON(!conv_info.padding_is_symmetric());
_conv_stride_x = std::get<0>(conv_info.stride());
diff --git a/src/core/GLES_COMPUTE/kernels/GCGEMMInterleave4x4Kernel.cpp b/src/core/GLES_COMPUTE/kernels/GCGEMMInterleave4x4Kernel.cpp
index 171fbad..efd5747 100644
--- a/src/core/GLES_COMPUTE/kernels/GCGEMMInterleave4x4Kernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCGEMMInterleave4x4Kernel.cpp
@@ -51,7 +51,7 @@
output_shape.set(1, std::ceil(input->info()->dimension(1) / 4.0f));
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
diff --git a/src/core/GLES_COMPUTE/kernels/GCGEMMMatrixMultiplyKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCGEMMMatrixMultiplyKernel.cpp
index d576c30..8ead05a 100644
--- a/src/core/GLES_COMPUTE/kernels/GCGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCGEMMMatrixMultiplyKernel.cpp
@@ -97,7 +97,6 @@
ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
}
}
diff --git a/src/core/GLES_COMPUTE/kernels/GCGEMMTranspose1xWKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCGEMMTranspose1xWKernel.cpp
index 5d9f9c2..dfbd021 100644
--- a/src/core/GLES_COMPUTE/kernels/GCGEMMTranspose1xWKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCGEMMTranspose1xWKernel.cpp
@@ -49,7 +49,7 @@
output_shape.set(1, static_cast<size_t>(std::ceil((input->info()->dimension(0) / static_cast<float>(transpose_w)))));
// Output tensor auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
diff --git a/src/core/GLES_COMPUTE/kernels/GCIm2ColKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCIm2ColKernel.cpp
index 6c89616..2197190 100644
--- a/src/core/GLES_COMPUTE/kernels/GCIm2ColKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCIm2ColKernel.cpp
@@ -53,7 +53,6 @@
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -253,7 +252,7 @@
if(_input->info()->data_type() == DataType::F16)
{
(dynamic_cast<TensorInfo *>(_input->info()))->init(_input->info()->tensor_shape(), _input->info()->num_channels(), _input->info()->data_type(), _input->info()->strides_in_bytes(), 0,
- _input->info()->total_size(), _input->info()->fixed_point_position());
+ _input->info()->total_size());
}
_kernel.use();
diff --git a/src/core/GLES_COMPUTE/kernels/GCPoolingLayerKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCPoolingLayerKernel.cpp
index 3a0944c..f225ebd 100644
--- a/src/core/GLES_COMPUTE/kernels/GCPoolingLayerKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCPoolingLayerKernel.cpp
@@ -75,7 +75,6 @@
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
unsigned int pooled_w = 0;
unsigned int pooled_h = 0;
@@ -118,8 +117,7 @@
auto_init(input, output, pooled_w, pooled_h);
- BorderSize border_size = BorderSize(pool_pad_y, pool_pad_x);
- const DataType data_type = input->data_type();
+ BorderSize border_size = BorderSize(pool_pad_y, pool_pad_x);
const int input_width = input->dimension(0);
const int input_height = input->dimension(1);
@@ -131,7 +129,7 @@
{
// Check if we have pool3x3 with stride_x less equal than 3. In these cases, run an optimized OpenGLES kernel where
// each thread computes 4 output elements
- const bool is_pool3x3_stride_le3 = (pool_size == 3) && (pool_stride_x <= 3) && !is_data_type_fixed_point(data_type);
+ const bool is_pool3x3_stride_le3 = (pool_size == 3) && (pool_stride_x <= 3);
int num_elems_read_per_iteration = pool_size;
@@ -261,8 +259,6 @@
_output = output;
_pool_info = pool_info;
- const DataType data_type = input->info()->data_type();
-
// Set build options
std::set<std::string> build_opts;
build_opts.emplace("#define LOCAL_SIZE_X " + support::cpp11::to_string(1));
@@ -293,7 +289,7 @@
{
// Check if we have pool3x3 with stride_x less equal than 3. In these cases, run an optimized OpenGLES kernel where
// each thread computes 4 output elements
- const bool is_pool3x3_stride_le3 = (pool_size == 3) && (pool_stride_x <= 3) && !is_data_type_fixed_point(data_type);
+ const bool is_pool3x3_stride_le3 = (pool_size == 3) && (pool_stride_x <= 3);
std::string kernel_name = "pooling_layer_" + support::cpp11::to_string(pool_size);
if(is_pool3x3_stride_le3)
diff --git a/src/core/GLES_COMPUTE/kernels/GCSoftmaxLayerKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCSoftmaxLayerKernel.cpp
index 040a663..7ae2fc9 100644
--- a/src/core/GLES_COMPUTE/kernels/GCSoftmaxLayerKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCSoftmaxLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -49,7 +49,7 @@
output_shape.set(0, 1);
// Output auto initialization if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
@@ -110,8 +110,8 @@
ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output);
// Output auto initialization if not yet initialized
- auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
- auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type());
+ auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, max, sum);
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
@@ -204,10 +204,9 @@
ARM_COMPUTE_ERROR_ON_NULLPTR(sum, output);
// Output auto initialization if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
_input = input;
diff --git a/src/core/GLES_COMPUTE/kernels/GCTransposeKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCTransposeKernel.cpp
index bda08e4..7248891 100644
--- a/src/core/GLES_COMPUTE/kernels/GCTransposeKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCTransposeKernel.cpp
@@ -49,7 +49,7 @@
output_shape.set(1, h_out);
// Output tensor auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
diff --git a/src/core/GPUTarget.cpp b/src/core/GPUTarget.cpp
index 575d858..a14a9c9 100644
--- a/src/core/GPUTarget.cpp
+++ b/src/core/GPUTarget.cpp
@@ -51,9 +51,9 @@
{
return arm_compute::GPUTarget::G51LIT;
}
- else if(version == "TNOX")
+ else if(version == "G76")
{
- return arm_compute::GPUTarget::TNOX;
+ return arm_compute::GPUTarget::G76;
}
else if(version == "TTRX")
{
@@ -106,7 +106,7 @@
{ GPUTarget::G51, "g51" },
{ GPUTarget::G51BIG, "g51big" },
{ GPUTarget::G51LIT, "g51lit" },
- { GPUTarget::TNOX, "tnox" },
+ { GPUTarget::G76, "g76" },
{ GPUTarget::TTRX, "ttrx" },
{ GPUTarget::TBOX, "tbox" }
};
@@ -122,8 +122,8 @@
if(!found_mali)
{
- ARM_COMPUTE_LOG_INFO_MSG_CORE("Can't find valid Mali GPU. Target is set to UNKNOWN.");
- return GPUTarget::UNKNOWN;
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Can't find valid Mali GPU. Target is set to default.");
+ return GPUTarget::MIDGARD;
}
const char target = name_parts.str(1)[0];
diff --git a/src/core/Helpers.cpp b/src/core/Helpers.cpp
index e336331..c0af3bb 100644
--- a/src/core/Helpers.cpp
+++ b/src/core/Helpers.cpp
@@ -59,6 +59,13 @@
++n;
}
+ if(anchor.num_dimensions() > 2)
+ {
+ window.set(2, Window::Dimension(anchor[2], std::max<size_t>(1, shape[2]), steps[2]));
+
+ ++n;
+ }
+
for(; n < anchor.num_dimensions(); ++n)
{
window.set(n, Window::Dimension(anchor[n], std::max<size_t>(1, shape[n])));
diff --git a/src/core/ITensor.cpp b/src/core/ITensor.cpp
index eb5f072..3dffcd0 100644
--- a/src/core/ITensor.cpp
+++ b/src/core/ITensor.cpp
@@ -62,7 +62,7 @@
Iterator src_it(&src, win_src);
Iterator dst_it(this, win_dst);
- const size_t line_size = src_info->num_channels() * src_info->element_size() * src_info->dimension(0);
+ const size_t line_size = src_info->element_size() * src_info->dimension(0);
execute_window_loop(win_src, [&](const Coordinates & id)
{
diff --git a/src/core/NEON/kernels/NEActivationLayerKernel.cpp b/src/core/NEON/kernels/NEActivationLayerKernel.cpp
index ec12515..7a92c6b 100644
--- a/src/core/NEON/kernels/NEActivationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEActivationLayerKernel.cpp
@@ -23,7 +23,7 @@
*/
#include "arm_compute/core/NEON/kernels/NEActivationLayerKernel.h"
-#include "arm_compute/core/FixedPoint.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/NEAsymm.h"
@@ -45,15 +45,14 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QASYMM8, DataType::F16, DataType::F32);
// Checks performed when output is configured
if((output != nullptr) && (output->total_size() != 0))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -139,6 +138,7 @@
{ ActivationFunction::RELU, &NEActivationLayerKernel::activation<ActivationFunction::RELU, float16_t> },
{ ActivationFunction::BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::BOUNDED_RELU, float16_t> },
{ ActivationFunction::LU_BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::LU_BOUNDED_RELU, float16_t> },
+ { ActivationFunction::LEAKY_RELU, &NEActivationLayerKernel::activation<ActivationFunction::LEAKY_RELU, float16_t> },
{ ActivationFunction::SOFT_RELU, &NEActivationLayerKernel::activation<ActivationFunction::SOFT_RELU, float16_t> },
{ ActivationFunction::SQRT, &NEActivationLayerKernel::activation<ActivationFunction::SQRT, float16_t> },
{ ActivationFunction::SQUARE, &NEActivationLayerKernel::activation<ActivationFunction::SQUARE, float16_t> },
@@ -146,36 +146,6 @@
};
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC*/
- // Activation functions : QS8
- static std::map<ActivationFunction, ActivationFunctionExecutorPtr> act_map_qs8 =
- {
- { ActivationFunction::ABS, &NEActivationLayerKernel::activation<ActivationFunction::ABS, qint8_t> },
- { ActivationFunction::LINEAR, &NEActivationLayerKernel::activation<ActivationFunction::LINEAR, qint8_t> },
- { ActivationFunction::LOGISTIC, &NEActivationLayerKernel::activation<ActivationFunction::LOGISTIC, qint8_t> },
- { ActivationFunction::RELU, &NEActivationLayerKernel::activation<ActivationFunction::RELU, qint8_t> },
- { ActivationFunction::BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::BOUNDED_RELU, qint8_t> },
- { ActivationFunction::LU_BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::LU_BOUNDED_RELU, qint8_t> },
- { ActivationFunction::LEAKY_RELU, &NEActivationLayerKernel::activation<ActivationFunction::LEAKY_RELU, qint8_t> },
- { ActivationFunction::SOFT_RELU, &NEActivationLayerKernel::activation<ActivationFunction::SOFT_RELU, qint8_t> },
- { ActivationFunction::SQRT, &NEActivationLayerKernel::activation<ActivationFunction::SQRT, qint8_t> },
- { ActivationFunction::SQUARE, &NEActivationLayerKernel::activation<ActivationFunction::SQUARE, qint8_t> },
- { ActivationFunction::TANH, &NEActivationLayerKernel::activation<ActivationFunction::TANH, qint8_t> },
- };
- // Activation functions : QS16
- static std::map<ActivationFunction, ActivationFunctionExecutorPtr> act_map_qs16 =
- {
- { ActivationFunction::ABS, &NEActivationLayerKernel::activation<ActivationFunction::ABS, qint16_t> },
- { ActivationFunction::LINEAR, &NEActivationLayerKernel::activation<ActivationFunction::LINEAR, qint16_t> },
- { ActivationFunction::LOGISTIC, &NEActivationLayerKernel::activation<ActivationFunction::LOGISTIC, qint16_t> },
- { ActivationFunction::RELU, &NEActivationLayerKernel::activation<ActivationFunction::RELU, qint16_t> },
- { ActivationFunction::BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::BOUNDED_RELU, qint16_t> },
- { ActivationFunction::LU_BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::LU_BOUNDED_RELU, qint16_t> },
- { ActivationFunction::LEAKY_RELU, &NEActivationLayerKernel::activation<ActivationFunction::LEAKY_RELU, qint16_t> },
- { ActivationFunction::SOFT_RELU, &NEActivationLayerKernel::activation<ActivationFunction::SOFT_RELU, qint16_t> },
- { ActivationFunction::SQRT, &NEActivationLayerKernel::activation<ActivationFunction::SQRT, qint16_t> },
- { ActivationFunction::SQUARE, &NEActivationLayerKernel::activation<ActivationFunction::SQUARE, qint16_t> },
- { ActivationFunction::TANH, &NEActivationLayerKernel::activation<ActivationFunction::TANH, qint16_t> },
- };
// Activation functions : QASYMM8
static std::map<ActivationFunction, ActivationFunctionExecutorPtr> act_map_qasymm8 =
{
@@ -188,12 +158,6 @@
case DataType::QASYMM8:
_func = act_map_qasymm8[activation_info.activation()];
break;
- case DataType::QS8:
- _func = act_map_qs8[activation_info.activation()];
- break;
- case DataType::QS16:
- _func = act_map_qs16[activation_info.activation()];
- break;
case DataType::F32:
_func = act_map_f32[activation_info.activation()];
break;
@@ -219,11 +183,14 @@
Iterator input(_input, window);
Iterator output(_output, window);
- static const float16x8_t CONST_0 = vdupq_n_f16(0.f);
- static const float16x8_t CONST_1 = vdupq_n_f16(1.f);
+ static const float16x8_t CONST_0 = vdupq_n_f16(0.f);
+ static const float16x4_t CONST_1_H = vdup_n_f16(1.f);
- const float16x8_t a = vdupq_n_f16(_act_info.a());
- const float16x8_t b = vdupq_n_f16(_act_info.b());
+ static const float32x4_t CONST_1_F32 = vdupq_n_f32(1.f);
+
+ const float16x8_t a = vdupq_n_f16(_act_info.a());
+ const float16x4_t a_h = vdup_n_f16(_act_info.a());
+ const float16x8_t b = vdupq_n_f16(_act_info.b());
execute_window_loop(window, [&](const Coordinates &)
{
@@ -272,14 +239,28 @@
};
break;
case ActivationFunction::LOGISTIC:
+ {
+ const float16x4x2_t in0 =
+ {
+ vinv_f16(vadd_f16(CONST_1_H, vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vneg_f16(vget_low_f16(in.val[0]))))))),
+ vinv_f16(vadd_f16(CONST_1_H, vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vneg_f16(vget_high_f16(in.val[0]))))))),
+ };
+
+ const float16x4x2_t in1 =
+ {
+ vinv_f16(vadd_f16(CONST_1_H, vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vneg_f16(vget_low_f16(in.val[1]))))))),
+ vinv_f16(vadd_f16(CONST_1_H, vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vneg_f16(vget_high_f16(in.val[1]))))))),
+ };
+
tmp =
{
{
- vinvq_f16(vaddq_f16(CONST_1, vexpq_f16(vnegq_f16(in.val[0])))),
- vinvq_f16(vaddq_f16(CONST_1, vexpq_f16(vnegq_f16(in.val[1])))),
+ vcombine_f16(in0.val[0], in0.val[1]),
+ vcombine_f16(in1.val[0], in1.val[1]),
}
};
- break;
+ }
+ break;
case ActivationFunction::RELU:
tmp =
{
@@ -299,14 +280,28 @@
};
break;
case ActivationFunction::SOFT_RELU:
+ {
+ const float16x4x2_t in0 =
+ {
+ vcvt_f16_f32(vlogq_f32(vaddq_f32(CONST_1_F32, vexpq_f32(vcvt_f32_f16(vget_low_f16(in.val[0])))))),
+ vcvt_f16_f32(vlogq_f32(vaddq_f32(CONST_1_F32, vexpq_f32(vcvt_f32_f16(vget_high_f16(in.val[0])))))),
+ };
+
+ const float16x4x2_t in1 =
+ {
+ vcvt_f16_f32(vlogq_f32(vaddq_f32(CONST_1_F32, vexpq_f32(vcvt_f32_f16(vget_low_f16(in.val[1])))))),
+ vcvt_f16_f32(vlogq_f32(vaddq_f32(CONST_1_F32, vexpq_f32(vcvt_f32_f16(vget_high_f16(in.val[1])))))),
+ };
+
tmp =
{
{
- vlogq_f16(vaddq_f16(CONST_1, vexpq_f16(in.val[0]))),
- vlogq_f16(vaddq_f16(CONST_1, vexpq_f16(in.val[1]))),
+ vcombine_f16(in0.val[0], in0.val[1]),
+ vcombine_f16(in1.val[0], in1.val[1]),
}
};
- break;
+ }
+ break;
case ActivationFunction::SQRT:
tmp =
{
@@ -326,14 +321,33 @@
};
break;
case ActivationFunction::TANH:
+ {
+ const float16x8x2_t mul =
+ {
+ vmulq_f16(b, in.val[0]),
+ vmulq_f16(b, in.val[1])
+ };
+ const float16x4x2_t in0 =
+ {
+ vmul_f16(a_h, vcvt_f16_f32(vtanhq_f32(vcvt_f32_f16(vget_low_f16(mul.val[0]))))),
+ vmul_f16(a_h, vcvt_f16_f32(vtanhq_f32(vcvt_f32_f16(vget_high_f16(mul.val[0]))))),
+ };
+
+ const float16x4x2_t in1 =
+ {
+ vmul_f16(a_h, vcvt_f16_f32(vtanhq_f32(vcvt_f32_f16(vget_low_f16(mul.val[1]))))),
+ vmul_f16(a_h, vcvt_f16_f32(vtanhq_f32(vcvt_f32_f16(vget_high_f16(mul.val[1]))))),
+ };
+
tmp =
{
{
- vmulq_f16(a, vtanhq_f16(vmulq_f16(b, in.val[0]))),
- vmulq_f16(a, vtanhq_f16(vmulq_f16(b, in.val[1]))),
+ vcombine_f16(in0.val[0], in0.val[1]),
+ vcombine_f16(in1.val[0], in1.val[1]),
}
};
- break;
+ }
+ break;
default:
ARM_COMPUTE_ERROR("Not implemented");
break;
@@ -508,70 +522,6 @@
}
template <ActivationLayerInfo::ActivationFunction F, typename T>
-typename std::enable_if<std::is_same<T, int8_t>::value, void>::type NEActivationLayerKernel::activation(const Window &window)
-{
- Iterator input(_input, window);
- Iterator output(_output, window);
- const int fixed_point_position = _input->info()->fixed_point_position();
-
- static const qint8x16_t CONST_0 = vdupq_n_qs8(0);
- const qint8x16_t CONST_1 = vdupq_n_qs8(sqcvt_qs8_f32(1.f, fixed_point_position));
- const qint8x16_t a = vdupq_n_qs8(sqcvt_qs8_f32(_act_info.a(), fixed_point_position));
- const qint8x16_t b = vdupq_n_qs8(sqcvt_qs8_f32(_act_info.b(), fixed_point_position));
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const auto input_ptr = reinterpret_cast<const int8_t *>(input.ptr());
- const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
-
- const qint8x16_t in = vld1q_qs8(input_ptr);
- qint8x16_t tmp = {};
-
- switch(F)
- {
- case ActivationFunction::ABS:
- tmp = vqabsq_qs8(in);
- break;
- case ActivationFunction::LINEAR:
- tmp = vqmlaq_qs8(b, a, in, fixed_point_position);
- break;
- case ActivationFunction::LOGISTIC:
- tmp = vqrecipq_qs8(vqaddq_qs8(CONST_1, vqexpq_qs8(vnegq_s8(in), fixed_point_position)), fixed_point_position);
- break;
- case ActivationFunction::RELU:
- tmp = vmaxq_qs8(CONST_0, in);
- break;
- case ActivationFunction::BOUNDED_RELU:
- tmp = vminq_qs8(a, vmaxq_qs8(CONST_0, in));
- break;
- case ActivationFunction::LU_BOUNDED_RELU:
- tmp = vminq_qs8(a, vmaxq_qs8(b, in));
- break;
- case ActivationFunction::LEAKY_RELU:
- tmp = vbslq_s8(vcgtq_s8(in, CONST_0), in, vmulq_qs8(a, in, fixed_point_position));
- break;
- case ActivationFunction::SOFT_RELU:
- tmp = vlogq_qs8(vqaddq_qs8(CONST_1, vqexpq_qs8(in, fixed_point_position)), fixed_point_position);
- break;
- case ActivationFunction::SQRT:
- tmp = vqrecipq_qs8(vqinvsqrtq_qs8(in, fixed_point_position), fixed_point_position);
- break;
- case ActivationFunction::SQUARE:
- tmp = vqmulq_qs8(in, in, fixed_point_position);
- break;
- case ActivationFunction::TANH:
- tmp = vqmulq_qs8(a, vqtanhq_qs8(vqmulq_qs8(b, in, fixed_point_position), fixed_point_position), fixed_point_position);
- break;
- default:
- break;
- }
-
- vst1q_qs8(output_ptr, tmp);
- },
- input, output);
-}
-
-template <ActivationLayerInfo::ActivationFunction F, typename T>
typename std::enable_if<std::is_same<T, qasymm8_t>::value, void>::type NEActivationLayerKernel::activation(const Window &window)
{
Iterator input(_input, window);
@@ -620,137 +570,6 @@
input, output);
}
-template <ActivationLayerInfo::ActivationFunction F, typename T>
-typename std::enable_if<std::is_same<T, qint16_t>::value, void>::type NEActivationLayerKernel::activation(const Window &window)
-{
- Iterator input(_input, window);
- Iterator output(_output, window);
- const int fixed_point_position = _input->info()->fixed_point_position();
-
- static const qint16x8_t CONST_0 = vdupq_n_qs16(0);
- const qint16x8_t CONST_1 = vdupq_n_qs16(sqcvt_qs16_f32(1.f, fixed_point_position));
- const qint16x8_t a = vdupq_n_qs16(sqcvt_qs16_f32(_act_info.a(), fixed_point_position));
- const qint16x8_t b = vdupq_n_qs16(sqcvt_qs16_f32(_act_info.b(), fixed_point_position));
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const auto input_ptr = reinterpret_cast<const int16_t *>(input.ptr());
- const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
-
- const qint16x8x2_t in = vld2q_s16(input_ptr);
- qint16x8x2_t tmp = { {} };
-
- switch(F)
- {
- case ActivationFunction::ABS:
- tmp =
- {
- {
- vqabsq_qs16(in.val[0]),
- vqabsq_qs16(in.val[1]),
- }
- };
- break;
- case ActivationFunction::LINEAR:
- tmp =
- {
- {
- vqmlaq_qs16(b, a, in.val[0], fixed_point_position),
- vqmlaq_qs16(b, a, in.val[1], fixed_point_position),
- }
- };
- break;
- case ActivationFunction::LOGISTIC:
- tmp =
- {
- {
- vqrecipq_qs16(vqaddq_qs16(CONST_1, vqexpq_qs16(vnegq_s16(in.val[0]), fixed_point_position)), fixed_point_position),
- vqrecipq_qs16(vqaddq_qs16(CONST_1, vqexpq_qs16(vnegq_s16(in.val[1]), fixed_point_position)), fixed_point_position),
- }
- };
- break;
- case ActivationFunction::RELU:
- tmp =
- {
- {
- vmaxq_qs16(CONST_0, in.val[0]),
- vmaxq_qs16(CONST_0, in.val[1]),
- }
- };
- break;
- case ActivationFunction::BOUNDED_RELU:
- tmp =
- {
- {
- vminq_qs16(a, vmaxq_qs16(CONST_0, in.val[0])),
- vminq_qs16(a, vmaxq_qs16(CONST_0, in.val[1])),
- }
- };
- break;
- case ActivationFunction::LU_BOUNDED_RELU:
- tmp =
- {
- {
- vminq_qs16(a, vmaxq_qs16(b, in.val[0])),
- vminq_qs16(a, vmaxq_qs16(b, in.val[1])),
- }
- };
- break;
- case ActivationFunction::LEAKY_RELU:
- tmp =
- {
- {
- vbslq_s16(vcgtq_s16(in.val[0], CONST_0), in.val[0], vmulq_qs16(a, in.val[0], fixed_point_position)),
- vbslq_s16(vcgtq_s16(in.val[1], CONST_0), in.val[1], vmulq_qs16(a, in.val[1], fixed_point_position)),
- }
- };
- break;
- case ActivationFunction::SOFT_RELU:
- tmp =
- {
- {
- vlogq_qs16(vqaddq_qs16(CONST_1, vqexpq_qs16(in.val[0], fixed_point_position)), fixed_point_position),
- vlogq_qs16(vqaddq_qs16(CONST_1, vqexpq_qs16(in.val[1], fixed_point_position)), fixed_point_position),
- }
- };
- break;
- case ActivationFunction::SQRT:
- tmp =
- {
- {
- vqrecipq_qs16(vqinvsqrtq_qs16(in.val[0], fixed_point_position), fixed_point_position),
- vqrecipq_qs16(vqinvsqrtq_qs16(in.val[1], fixed_point_position), fixed_point_position),
- }
- };
- break;
- case ActivationFunction::SQUARE:
- tmp =
- {
- {
- vqmulq_qs16(in.val[0], in.val[0], fixed_point_position),
- vqmulq_qs16(in.val[1], in.val[1], fixed_point_position),
- }
- };
- break;
- case ActivationFunction::TANH:
- tmp =
- {
- {
- vqmulq_qs16(a, vqtanhq_qs16(vqmulq_qs16(b, in.val[0], fixed_point_position), fixed_point_position), fixed_point_position),
- vqmulq_qs16(a, vqtanhq_qs16(vqmulq_qs16(b, in.val[1], fixed_point_position), fixed_point_position), fixed_point_position),
- }
- };
- break;
- default:
- ARM_COMPUTE_ERROR("Function not implemented");
- break;
- }
-
- vst2q_qs16(output_ptr, tmp);
- },
- input, output);
-}
-
Status NEActivationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ActivationLayerInfo &act_info)
{
ARM_COMPUTE_UNUSED(act_info);
diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
index a487090..a6102b1 100644
--- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/NEON/kernels/NEArithmeticAdditionKernel.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
@@ -48,38 +49,6 @@
{
constexpr unsigned int num_elems_processed_per_iteration = 16;
-void add_wrap_QS8_QS8_QS8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
-{
- Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
- Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
- Iterator output(out, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const qint8x16_t a = vld1q_qs8(reinterpret_cast<const qint8_t *>(input1.ptr()));
- const qint8x16_t b = vld1q_qs8(reinterpret_cast<const qint8_t *>(input2.ptr()));
-
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), vaddq_qs8(a, b));
- },
- input1, input2, output);
-}
-
-void add_saturate_QS8_QS8_QS8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
-{
- Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
- Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
- Iterator output(out, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const qint8x16_t a = vld1q_qs8(reinterpret_cast<const qint8_t *>(input1.ptr()));
- const qint8x16_t b = vld1q_qs8(reinterpret_cast<const qint8_t *>(input2.ptr()));
-
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), vqaddq_qs8(a, b));
- },
- input1, input2, output);
-}
-
void add_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
@@ -362,28 +331,22 @@
{
ARM_COMPUTE_UNUSED(policy);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
- if(is_data_type_fixed_point(input1.data_type()) || is_data_type_fixed_point(input2.data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(&input1, &input2);
- }
-
// Validate in case of configured output
if(output.total_size() > 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(input1.data_type() == DataType::QS8 && input2.data_type() == DataType::QS8 && output.data_type() == DataType::QS8)
- && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8)
+ !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8)
&& !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
&& !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
- && !(input1.data_type() == DataType::QS16 && input2.data_type() == DataType::QS16 && output.data_type() == DataType::QS16)
&& !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
&& !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32)
&& !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16),
@@ -391,11 +354,6 @@
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
"Wrong shape for output");
-
- if(is_data_type_fixed_point(input1.data_type()) || is_data_type_fixed_point(output.data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(&input1, &output);
- }
}
return Status{};
@@ -460,8 +418,6 @@
static std::map<std::string, AddFunction *> map_function =
{
- { "add_wrap_QS8_QS8_QS8", &add_wrap_QS8_QS8_QS8 },
- { "add_saturate_QS8_QS8_QS8", &add_saturate_QS8_QS8_QS8 },
{ "add_wrap_U8_U8_U8", &add_wrap_U8_U8_U8 },
{ "add_saturate_U8_U8_U8", &add_saturate_U8_U8_U8 },
{ "add_wrap_S16_U8_S16", &add_wrap_S16_U8_S16 },
@@ -470,8 +426,6 @@
{ "add_saturate_U8_S16_S16", &add_saturate_U8_S16_S16 },
{ "add_wrap_U8_U8_S16", &add_wrap_U8_U8_S16 },
{ "add_saturate_U8_U8_S16", &add_saturate_U8_U8_S16 },
- { "add_wrap_QS16_QS16_QS16", &add_wrap_S16_S16_S16 },
- { "add_saturate_QS16_QS16_QS16", &add_saturate_S16_S16_S16 },
{ "add_wrap_S16_S16_S16", &add_wrap_S16_S16_S16 },
{ "add_saturate_S16_S16_S16", &add_saturate_S16_S16_S16 },
{ "add_wrap_F32_F32_F32", &add_F32_F32_F32 },
diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
index 3db8028..3c76548 100644
--- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
@@ -45,38 +46,6 @@
namespace
{
-void sub_wrap_QS8_QS8_QS8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
-{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
- Iterator output(out, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const qint8x16_t a = vld1q_qs8(reinterpret_cast<const qint8_t *>(input1.ptr()));
- const qint8x16_t b = vld1q_qs8(reinterpret_cast<const qint8_t *>(input2.ptr()));
-
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), vsubq_qs8(a, b));
- },
- input1, input2, output);
-}
-
-void sub_saturate_QS8_QS8_QS8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
-{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
- Iterator output(out, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const qint8x16_t a = vld1q_qs8(reinterpret_cast<const qint8_t *>(input1.ptr()));
- const qint8x16_t b = vld1q_qs8(reinterpret_cast<const qint8_t *>(input2.ptr()));
-
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), vqsubq_qs8(a, b));
- },
- input1, input2, output);
-}
-
void sub_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
Iterator input1(in1, window);
@@ -352,24 +321,17 @@
inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
{
ARM_COMPUTE_UNUSED(policy);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
-
- if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type()))
- {
- // Check that all data types are the same and all fixed-point positions are the same
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
- }
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(input1->data_type() == DataType::QS8 && input2->data_type() == DataType::QS8 && output->data_type() == DataType::QS8)
- && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8)
+ !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8)
&& !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
&& !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
&& !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
- && !(input1->data_type() == DataType::QS16 && input2->data_type() == DataType::QS16 && output->data_type() == DataType::QS16)
&& !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
&& !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32)
&& !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16),
@@ -432,8 +394,6 @@
static std::map<std::string, NEArithmeticSubtractionKernel::SubFunction *> map_function =
{
- { "sub_wrap_QS8_QS8_QS8", &sub_wrap_QS8_QS8_QS8 },
- { "sub_saturate_QS8_QS8_QS8", &sub_saturate_QS8_QS8_QS8 },
{ "sub_wrap_U8_U8_U8", &sub_wrap_U8_U8_U8 },
{ "sub_wrap_U8_U8_S16", &sub_wrap_U8_U8_S16 },
{ "sub_saturate_U8_U8_U8", &sub_saturate_U8_U8_U8 },
@@ -442,8 +402,6 @@
{ "sub_wrap_S16_U8_S16", &sub_wrap_S16_U8_S16 },
{ "sub_saturate_U8_S16_S16", &sub_saturate_U8_S16_S16 },
{ "sub_saturate_S16_U8_S16", &sub_saturate_S16_U8_S16 },
- { "sub_wrap_QS16_QS16_QS16", &sub_wrap_S16_S16_S16 },
- { "sub_saturate_QS16_QS16_QS16", &sub_saturate_S16_S16_S16 },
{ "sub_wrap_S16_S16_S16", &sub_wrap_S16_S16_S16 },
{ "sub_saturate_S16_S16_S16", &sub_saturate_S16_S16_S16 },
{ "sub_wrap_F32_F32_F32", &sub_F32_F32_F32 },
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
index 6be50fd..ac1fc39 100644
--- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/NEON/NEFixedPoint.h"
#include "arm_compute/core/NEON/NEMath.h"
@@ -43,14 +44,16 @@
const ITensorInfo *beta, const ITensorInfo *gamma, float epsilon, ActivationLayerInfo act_info)
{
ARM_COMPUTE_UNUSED(epsilon);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16,
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16,
DataType::F32);
if(act_info.enabled())
{
ActivationLayerInfo::ActivationFunction act = act_info.activation();
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU
+ ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU
+ && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU
&& act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
ARM_COMPUTE_RETURN_ERROR_ON(act_info.b() > act_info.a());
}
@@ -60,22 +63,18 @@
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var);
if(beta != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, beta);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
}
if(gamma != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
}
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) != mean->dimension(0));
@@ -104,112 +103,6 @@
} //namespace
template <bool fused_activation>
-void NEBatchNormalizationLayerKernel::batch_normalization_qs8(const Window &window)
-{
- static_assert(!fused_activation, "Activation is not supported for QS8");
-
- Iterator input(_input, window);
- Iterator output(_output, window);
-
- // Hold information about the current feature map we are iterating.
- // Only compute denominator and NEON vectors once per feature map.
- int slice = -1;
-
- const int fixed_point_position = _input->info()->fixed_point_position();
- const auto input_mean = reinterpret_cast<const qint8_t *>(_mean->ptr_to_element(Coordinates(0, 0)));
- const auto input_var = reinterpret_cast<const qint8_t *>(_var->ptr_to_element(Coordinates(0, 0)));
- const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const qint8_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
- const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const qint8_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
-
- qint8x16_t mean_vec = vdupq_n_qs8(0);
- qint8x16_t var_vec = vdupq_n_qs8(0);
- qint8x16_t gamma_vec = vdupq_n_qs8(sqcvt_qs8_f32(1, fixed_point_position));
- qint8x16_t beta_vec = vdupq_n_qs8(sqcvt_qs8_f32(0, fixed_point_position));
- qint8x16_t denominator = vdupq_n_qs8(0);
- const qint8x16_t epsilon_vec = vdupq_n_qs8(sqcvt_qs8_f32(_epsilon, fixed_point_position));
- execute_window_loop(window, [&](const Coordinates & id)
- {
- if(slice != id.z())
- {
- // Conctruct vectors
- mean_vec = vdupq_n_qs8(*(input_mean + id.z()));
- var_vec = vdupq_n_qs8(*(input_var + id.z()));
- if(input_gamma != nullptr)
- {
- gamma_vec = vdupq_n_qs8(*(input_gamma + id.z()));
- }
- if(input_beta != nullptr)
- {
- beta_vec = vdupq_n_qs8(*(input_beta + id.z()));
- }
-
- // Calculate denominator
- denominator = vqinvsqrtq_qs8(vqaddq_qs8(var_vec, epsilon_vec), fixed_point_position);
- slice = id.z();
- }
-
- // Calculate x bar and store results
- const qint8x16_t numerator = vqsubq_qs8(vld1q_qs8(reinterpret_cast<const qint8_t *>(input.ptr())), mean_vec);
- const qint8x16_t x_bar = vqmulq_qs8(numerator, denominator, fixed_point_position);
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), vqmlaq_qs8(beta_vec, x_bar, gamma_vec, fixed_point_position));
- },
- input, output);
-}
-
-template <bool fused_activation>
-void NEBatchNormalizationLayerKernel::batch_normalization_qs16(const Window &window)
-{
- static_assert(!fused_activation, "Activation is not supported for QS16");
-
- Iterator input(_input, window);
- Iterator output(_output, window);
-
- // Hold information about the current feature map we are iterating.
- // Only compute denominator and NEON vectors once per feature map.
- int slice = -1;
-
- const int fixed_point_position = _input->info()->fixed_point_position();
- const auto input_mean = reinterpret_cast<const qint16_t *>(_mean->ptr_to_element(Coordinates(0, 0)));
- const auto input_var = reinterpret_cast<const qint16_t *>(_var->ptr_to_element(Coordinates(0, 0)));
- const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const qint16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
- const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const qint16_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
-
- qint16x8_t mean_vec = vdupq_n_qs16(0);
- qint16x8_t var_vec = vdupq_n_qs16(0);
- qint16x8_t gamma_vec = vdupq_n_qs16(sqcvt_qs16_f32(1, fixed_point_position));
- qint16x8_t beta_vec = vdupq_n_qs16(sqcvt_qs16_f32(0, fixed_point_position));
- qint16x8_t denominator = vdupq_n_qs16(0);
- const qint16x8_t epsilon_vec = vdupq_n_qs16(sqcvt_qs16_f32(_epsilon, fixed_point_position));
- execute_window_loop(window, [&](const Coordinates & id)
- {
- if(slice != id.z())
- {
- // Conctruct vectors
- mean_vec = vdupq_n_qs16(*(input_mean + id.z()));
- var_vec = vdupq_n_qs16(*(input_var + id.z()));
- if(input_gamma != nullptr)
- {
- gamma_vec = vdupq_n_qs16(*(input_gamma + id.z()));
- }
- if(input_beta != nullptr)
- {
- beta_vec = vdupq_n_qs16(*(input_beta + id.z()));
- }
-
- // Calculate denominator
- denominator = vqinvsqrtq_qs16(vqaddq_qs16(var_vec, epsilon_vec), fixed_point_position);
- slice = id.z();
- }
-
- // Calculate x bar and store results
- const qint16x8_t numerator = vqsubq_qs16(vld1q_qs16(reinterpret_cast<const qint16_t *>(input.ptr())), mean_vec);
- const qint16x8_t x_bar = vqmulq_qs16(numerator, denominator, fixed_point_position);
- vst1q_qs16(reinterpret_cast<qint16_t *>(output.ptr()), vqmlaq_qs16(beta_vec, x_bar, gamma_vec, fixed_point_position));
- },
- input, output);
-}
-
-template <bool fused_activation>
void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window &window)
{
static_assert(!fused_activation, "Activation is not supported for FP16");
@@ -406,12 +299,6 @@
const bool is_nhwc = _input->info()->data_layout() == DataLayout::NHWC;
switch(_input->info()->data_type())
{
- case DataType::QS8:
- _func = &NEBatchNormalizationLayerKernel::batch_normalization_qs8<false>;
- break;
- case DataType::QS16:
- _func = &NEBatchNormalizationLayerKernel::batch_normalization_qs16<false>;
- break;
case DataType::F16:
_func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<false> : &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<false>;
break;
diff --git a/src/core/NEON/kernels/NECannyEdgeKernel.cpp b/src/core/NEON/kernels/NECannyEdgeKernel.cpp
index 9dfd580..dc37452 100644
--- a/src/core/NEON/kernels/NECannyEdgeKernel.cpp
+++ b/src/core/NEON/kernels/NECannyEdgeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -567,29 +567,29 @@
const uint32x4_t mk0_0 = vld1q_u32(in - 1);
const uint32x4_t mk0_1 = vld1q_u32(in + 1);
uint32x4_t mask0 = vceqq_u32(pc32, vdupq_n_u32(0));
- mask0 = vandq_u32(mask0, vcgeq_u32(mc, mk0_0));
- mask0 = vandq_u32(mask0, vcgeq_u32(mc, mk0_1));
+ mask0 = vandq_u32(mask0, vcgtq_u32(mc, mk0_0));
+ mask0 = vandq_u32(mask0, vcgtq_u32(mc, mk0_1));
// 45 degree
const uint32x4_t mk45_0 = vld1q_u32(in - stride_mag - 1);
const uint32x4_t mk45_1 = vld1q_u32(in + stride_mag + 1);
uint32x4_t mask1 = vceqq_u32(pc32, vdupq_n_u32(1));
- mask1 = vandq_u32(mask1, vcgeq_u32(mc, mk45_0));
- mask1 = vandq_u32(mask1, vcgeq_u32(mc, mk45_1));
+ mask1 = vandq_u32(mask1, vcgtq_u32(mc, mk45_0));
+ mask1 = vandq_u32(mask1, vcgtq_u32(mc, mk45_1));
// 90 degree
const uint32x4_t mk90_0 = vld1q_u32(in - stride_mag);
const uint32x4_t mk90_1 = vld1q_u32(in + stride_mag);
uint32x4_t mask2 = vceqq_u32(pc32, vdupq_n_u32(2));
- mask2 = vandq_u32(mask2, vcgeq_u32(mc, mk90_0));
- mask2 = vandq_u32(mask2, vcgeq_u32(mc, mk90_1));
+ mask2 = vandq_u32(mask2, vcgtq_u32(mc, mk90_0));
+ mask2 = vandq_u32(mask2, vcgtq_u32(mc, mk90_1));
// 135 degree
const uint32x4_t mk135_0 = vld1q_u32(in - stride_mag + 1);
const uint32x4_t mk135_1 = vld1q_u32(in + stride_mag - 1);
uint32x4_t mask3 = vceqq_u32(pc32, vdupq_n_u32(3));
- mask3 = vandq_u32(mask3, vcgeq_u32(mc, mk135_0));
- mask3 = vandq_u32(mask3, vcgeq_u32(mc, mk135_1));
+ mask3 = vandq_u32(mask3, vcgtq_u32(mc, mk135_0));
+ mask3 = vandq_u32(mask3, vcgtq_u32(mc, mk135_1));
// Merge masks
mask0 = vorrq_u32(mask0, mask1);
@@ -1338,29 +1338,29 @@
const uint16x8_t mk0_0 = vld1q_u16(magnitude - 1);
const uint16x8_t mk0_1 = vld1q_u16(magnitude + 1);
uint16x8_t mask0 = vceqq_u16(pc16, vdupq_n_u16(0));
- mask0 = vandq_u16(mask0, vcgeq_u16(mc, mk0_0));
- mask0 = vandq_u16(mask0, vcgeq_u16(mc, mk0_1));
+ mask0 = vandq_u16(mask0, vcgtq_u16(mc, mk0_0));
+ mask0 = vandq_u16(mask0, vcgtq_u16(mc, mk0_1));
// 45 degree
const uint16x8_t mk45_0 = vld1q_u16(magnitude - stride_mag - 1);
const uint16x8_t mk45_1 = vld1q_u16(magnitude + stride_mag + 1);
uint16x8_t mask1 = vceqq_u16(pc16, vdupq_n_u16(1));
- mask1 = vandq_u16(mask1, vcgeq_u16(mc, mk45_0));
- mask1 = vandq_u16(mask1, vcgeq_u16(mc, mk45_1));
+ mask1 = vandq_u16(mask1, vcgtq_u16(mc, mk45_0));
+ mask1 = vandq_u16(mask1, vcgtq_u16(mc, mk45_1));
// 90 degree
const uint16x8_t mk90_0 = vld1q_u16(magnitude - stride_mag);
const uint16x8_t mk90_1 = vld1q_u16(magnitude + stride_mag);
uint16x8_t mask2 = vceqq_u16(pc16, vdupq_n_u16(2));
- mask2 = vandq_u16(mask2, vcgeq_u16(mc, mk90_0));
- mask2 = vandq_u16(mask2, vcgeq_u16(mc, mk90_1));
+ mask2 = vandq_u16(mask2, vcgtq_u16(mc, mk90_0));
+ mask2 = vandq_u16(mask2, vcgtq_u16(mc, mk90_1));
// 135 degree
const uint16x8_t mk135_0 = vld1q_u16(magnitude - stride_mag + 1);
const uint16x8_t mk135_1 = vld1q_u16(magnitude + stride_mag - 1);
uint16x8_t mask3 = vceqq_u16(pc16, vdupq_n_u16(3));
- mask3 = vandq_u16(mask3, vcgeq_u16(mc, mk135_0));
- mask3 = vandq_u16(mask3, vcgeq_u16(mc, mk135_1));
+ mask3 = vandq_u16(mask3, vcgtq_u16(mc, mk135_0));
+ mask3 = vandq_u16(mask3, vcgtq_u16(mc, mk135_1));
// Merge masks
mask0 = vorrq_u16(mask0, mask1);
@@ -1399,29 +1399,29 @@
const uint32x4_t mk0_0 = vld1q_u32(input - 1);
const uint32x4_t mk0_1 = vld1q_u32(input + 1);
uint32x4_t mask0 = vceqq_u32(pc32, vdupq_n_u32(0));
- mask0 = vandq_u32(mask0, vcgeq_u32(mc, mk0_0));
- mask0 = vandq_u32(mask0, vcgeq_u32(mc, mk0_1));
+ mask0 = vandq_u32(mask0, vcgtq_u32(mc, mk0_0));
+ mask0 = vandq_u32(mask0, vcgtq_u32(mc, mk0_1));
// 45 degree
const uint32x4_t mk45_0 = vld1q_u32(input - stride_mag - 1);
const uint32x4_t mk45_1 = vld1q_u32(input + stride_mag + 1);
uint32x4_t mask1 = vceqq_u32(pc32, vdupq_n_u32(1));
- mask1 = vandq_u32(mask1, vcgeq_u32(mc, mk45_0));
- mask1 = vandq_u32(mask1, vcgeq_u32(mc, mk45_1));
+ mask1 = vandq_u32(mask1, vcgtq_u32(mc, mk45_0));
+ mask1 = vandq_u32(mask1, vcgtq_u32(mc, mk45_1));
// 90 degree
const uint32x4_t mk90_0 = vld1q_u32(input - stride_mag);
const uint32x4_t mk90_1 = vld1q_u32(input + stride_mag);
uint32x4_t mask2 = vceqq_u32(pc32, vdupq_n_u32(2));
- mask2 = vandq_u32(mask2, vcgeq_u32(mc, mk90_0));
- mask2 = vandq_u32(mask2, vcgeq_u32(mc, mk90_1));
+ mask2 = vandq_u32(mask2, vcgtq_u32(mc, mk90_0));
+ mask2 = vandq_u32(mask2, vcgtq_u32(mc, mk90_1));
// 135 degree
const uint32x4_t mk135_0 = vld1q_u32(input - stride_mag + 1);
const uint32x4_t mk135_1 = vld1q_u32(input + stride_mag - 1);
uint32x4_t mask3 = vceqq_u32(pc32, vdupq_n_u32(3));
- mask3 = vandq_u32(mask3, vcgeq_u32(mc, mk135_0));
- mask3 = vandq_u32(mask3, vcgeq_u32(mc, mk135_1));
+ mask3 = vandq_u32(mask3, vcgtq_u32(mc, mk135_0));
+ mask3 = vandq_u32(mask3, vcgtq_u32(mc, mk135_1));
// Merge masks
mask0 = vorrq_u32(mask0, mask1);
diff --git a/src/core/NEON/kernels/NECol2ImKernel.cpp b/src/core/NEON/kernels/NECol2ImKernel.cpp
index 9fda65f..bb8e758 100644
--- a/src/core/NEON/kernels/NECol2ImKernel.cpp
+++ b/src/core/NEON/kernels/NECol2ImKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,14 +44,16 @@
output_shape.set(0, convolved_dims.width);
output_shape.set(1, convolved_dims.height);
output_shape.set(2, input->tensor_shape()[0]);
+ output_shape.set(3, input->tensor_shape()[3]); // For NEON the batch size is on the fourth dimension of the input tensor
return output_shape;
}
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &convolved_dims)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8,
- DataType::U16, DataType::S16, DataType::QS16,
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
@@ -60,7 +62,6 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), get_output_shape(input, convolved_dims));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
diff --git a/src/core/NEON/kernels/NEColorConvertKernel.cpp b/src/core/NEON/kernels/NEColorConvertKernel.cpp
index 347aeae..4582c88 100644
--- a/src/core/NEON/kernels/NEColorConvertKernel.cpp
+++ b/src/core/NEON/kernels/NEColorConvertKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
diff --git a/src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp b/src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp
index b3746bd..b6d166d 100644
--- a/src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp
+++ b/src/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.cpp
@@ -37,25 +37,26 @@
DataLayout data_layout)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+
+ // Output tensor auto initialisation if not yet initialized
+ auto_init_if_empty(*output->info(), *input->info()->clone());
+
ARM_COMPUTE_ERROR_THROW_ON(NEConvertFullyConnectedWeightsKernel::validate(input->info(), output->info(), original_input_shape, data_layout));
_input = input;
_output = output;
- const unsigned int num_elems_per_input_plane = original_input_shape.x() * original_input_shape.y();
- const unsigned int num_channels = original_input_shape.z();
+ const DataLayout input_data_layout = (data_layout == DataLayout::NCHW) ? DataLayout::NHWC : DataLayout::NCHW;
- // Set build options
- if(data_layout == DataLayout::NCHW)
- {
- _factor1 = num_elems_per_input_plane;
- _factor2 = num_channels;
- }
- else
- {
- _factor1 = num_channels;
- _factor2 = num_elems_per_input_plane;
- }
+ const int width_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::CHANNEL);
+
+ const unsigned int num_elems_per_input_plane = original_input_shape[width_idx] * original_input_shape[height_idx];
+ const unsigned int num_channels = original_input_shape[channel_idx];
+
+ _factor1 = (data_layout == DataLayout::NCHW) ? num_elems_per_input_plane : num_channels;
+ _factor2 = (data_layout == DataLayout::NCHW) ? num_channels : num_elems_per_input_plane;
// Configure kernel window
Window win = calculate_max_window(*input->info(), Steps());
@@ -65,14 +66,23 @@
Status NEConvertFullyConnectedWeightsKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape,
DataLayout data_layout)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8, DataType::U16, DataType::S16, DataType::QS16, DataType::U32, DataType::S32,
- DataType::QS32, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1,
+ DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
+ DataType::U32, DataType::S32,
+ DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2);
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != original_input_shape.total_size_lower(3));
ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::UNKNOWN);
+ // Checks performed when output is configured
+ if((output != nullptr) && (output->total_size() != 0))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ }
+
return Status{};
}
diff --git a/src/core/NEON/kernels/NECopyKernel.cpp b/src/core/NEON/kernels/NECopyKernel.cpp
new file mode 100644
index 0000000..20496ad
--- /dev/null
+++ b/src/core/NEON/kernels/NECopyKernel.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NECopyKernel.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
+
+using namespace arm_compute;
+
+NECopyKernel::NECopyKernel()
+ : _input(nullptr), _output(nullptr)
+{
+}
+
+void NECopyKernel::configure(const ITensor *input, ITensor *output)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+
+ _input = input;
+ _output = output;
+
+ INEKernel::configure(calculate_max_window(*output->info()));
+}
+
+Status NECopyKernel::validate(const arm_compute::ITensorInfo *input, const arm_compute::ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ return Status{};
+}
+
+void NECopyKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+
+ Window output_window{ window };
+ output_window.set(Window::DimX, Window::Dimension(output_window.x().start(), output_window.x().end(), _input->info()->dimension(0)));
+
+ Window out_slice = output_window.first_slice_window_1D();
+
+ do
+ {
+ Iterator input_it(_input, out_slice);
+ Iterator output_it(_output, out_slice);
+
+ execute_window_loop(out_slice, [&](const Coordinates & id)
+ {
+ memcpy(output_it.ptr(), input_it.ptr(), _output->info()->dimension(0) * _output->info()->element_size());
+ },
+ input_it, output_it);
+
+ }
+ while(output_window.slide_window_slice_1D(out_slice));
+}
diff --git a/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp b/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp
index 891a03c..8c875cd 100644
--- a/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,45 +28,18 @@
#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/NEFixedPoint.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
-#include <arm_neon.h>
#include <cstdint>
using namespace arm_compute;
namespace
{
-// Overloads of 128-bit vector loads
-uint8x16_t loadq(const uint8_t *ptr)
-{
- return vld1q_u8(ptr);
-}
-uint16x8_t loadq(const uint16_t *ptr)
-{
- return vld1q_u16(ptr);
-}
-uint32x4_t loadq(const uint32_t *ptr)
-{
- return vld1q_u32(ptr);
-}
-// Overloads of 128-bit vector stores
-void storeq(uint8_t *ptr, uint8x16_t val)
-{
- return vst1q_u8(ptr, val);
-}
-void storeq(uint16_t *ptr, uint16x8_t val)
-{
- return vst1q_u16(ptr, val);
-}
-void storeq(uint32_t *ptr, uint32x4_t val)
-{
- return vst1q_u32(ptr, val);
-}
-
template <typename T>
void depth_concat(const ITensor *in, ITensor *out, std::pair<int, int> start_xy, int depth_offset, const Window &window)
{
@@ -89,10 +62,55 @@
const auto in_ptr = reinterpret_cast<const T *>(input_ptr + input.offset());
const auto out_ptr = reinterpret_cast<T *>(output_ptr + output.offset());
- storeq(out_ptr, loadq(in_ptr));
+ wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr));
},
input, output);
}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, unsigned int depth_offset, ITensorInfo *output)
+{
+ ARM_COMPUTE_UNUSED(depth_offset);
+
+ // Configure kernel window
+ const int left_right = (output->dimension(0) - input->dimension(0)) / 2;
+ const int top_bottom = (output->dimension(1) - input->dimension(1)) / 2;
+
+ const unsigned int num_elems_processed_per_iteration = 16 / input->element_size();
+ const unsigned int num_elems_read_per_iteration = 16 / input->element_size();
+ const unsigned int num_rows_read_per_iteration = 1;
+
+ // The window needs to be based on input as we copy all the depths of input
+ Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
+ win.set(Window::DimZ, Window::Dimension(0, input->tensor_shape().z(), 1));
+
+ AccessWindowRectangle input_access(input, -left_right, -top_bottom, num_elems_read_per_iteration, num_rows_read_per_iteration);
+ AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+ bool window_changed = update_window_and_padding(win, input_access, output_access);
+ output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
+
+Status validate_arguments(const ITensorInfo *input, unsigned int depth_offset, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) + depth_offset > output->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) > output->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) > output->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(3, input, output);
+
+ // The gaps between the two lowest dimensions of input and output need to be divisible by 2
+ // Otherwise it is not clear how the padding should be added onto the input tensor
+ ARM_COMPUTE_RETURN_ERROR_ON((output->dimension(0) - input->dimension(0)) % 2);
+ ARM_COMPUTE_RETURN_ERROR_ON((output->dimension(1) - input->dimension(1)) % 2);
+
+ return Status{};
+}
} // namespace
NEDepthConcatenateLayerKernel::NEDepthConcatenateLayerKernel()
@@ -107,18 +125,8 @@
void NEDepthConcatenateLayerKernel::configure(const ITensor *input, unsigned int depth_offset, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) + depth_offset > output->info()->dimension(2));
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) > output->info()->dimension(0));
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(1) > output->info()->dimension(1));
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(3, input, output);
-
- // The gaps between the two lowest dimensions of input and output need to be divisible by 2
- // Otherwise it is not clear how the padding should be added onto the input tensor
- ARM_COMPUTE_ERROR_ON((output->info()->dimension(0) - input->info()->dimension(0)) % 2);
- ARM_COMPUTE_ERROR_ON((output->info()->dimension(1) - input->info()->dimension(1)) % 2);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), depth_offset, output->info()));
_func = nullptr;
_input = input;
@@ -129,10 +137,9 @@
switch(input->info()->data_type())
{
- case DataType::QS8:
+ case DataType::QASYMM8:
_func = &depth_concat<uint8_t>;
break;
- case DataType::QS16:
case DataType::F16:
_func = &depth_concat<uint16_t>;
break;
@@ -143,20 +150,20 @@
ARM_COMPUTE_ERROR("Unsupported data type.");
}
- const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
- const unsigned int num_elems_read_per_iteration = 16 / input->info()->element_size();
- const unsigned int num_rows_read_per_iteration = 1;
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(input->info(), depth_offset, output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
- // The window needs to be based on input as we copy all the depths of input
- Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration));
- win.set(Window::DimZ, Window::Dimension(0, input->info()->tensor_shape().z(), 1));
+ INEKernel::configure(std::get<1>(win_config));
+}
- AccessWindowRectangle input_access(input->info(), -_left_right, -_top_bottom, num_elems_read_per_iteration, num_rows_read_per_iteration);
- AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
- update_window_and_padding(win, input_access, output_access);
- output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
-
- INEKernel::configure(win);
+Status NEDepthConcatenateLayerKernel::validate(const arm_compute::ITensorInfo *input,
+ unsigned int depth_offset,
+ const arm_compute::ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, depth_offset, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), depth_offset, output->clone().get()).first);
+ return Status{};
}
void NEDepthConcatenateLayerKernel::run(const Window &window, const ThreadInfo &info)
diff --git a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
index c29cb57..8280b52 100644
--- a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,13 +40,13 @@
} // namespace arm_compute
NEDepthConvertLayerKernel::NEDepthConvertLayerKernel()
- : _input(nullptr), _output(nullptr), _policy(), _shift(0), _fixed_point_position_input(0), _fixed_point_position_output(0)
+ : _input(nullptr), _output(nullptr), _policy(), _shift(0)
{
}
void NEDepthConvertLayerKernel::configure(ITensor *input, ITensor *output, ConvertPolicy policy, uint32_t shift)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::U16, DataType::QS16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S16, DataType::U16);
_input = input;
_output = input;
@@ -58,48 +58,26 @@
// Auto initialize output shape if not initialized (We can only auto-configure the shape, datatype must be given)
set_shape_if_empty(*output->info(), input->info()->tensor_shape());
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::U16, DataType::QS16, DataType::U32, DataType::S32, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
// Set output
_output = output;
}
- // Set initial fixed point position of input and output
- _fixed_point_position_input = input->info()->fixed_point_position();
- _fixed_point_position_output = _output->info()->fixed_point_position();
-
- // Set the fixed point position to the output tensor if needed
- if(is_data_type_fixed_point(input->info()->data_type()) && is_data_type_fixed_point(_output->info()->data_type()))
- {
- // If in-place set the fixed point position of the output tensor to be equal to shift
- _fixed_point_position_output = (_input == _output) ? static_cast<int>(_shift) : _fixed_point_position_output;
- // Set fixed point position to output tensor
- _output->info()->set_fixed_point_position(_fixed_point_position_output);
- }
-
- ARM_COMPUTE_ERROR_ON(shift >= 8 && (!is_data_type_fixed_point(input->info()->data_type()) && !is_data_type_fixed_point(output->info()->data_type())));
+ ARM_COMPUTE_ERROR_ON(shift >= 8);
ARM_COMPUTE_ERROR_ON(input == output && (data_size_from_type(input->info()->data_type()) != data_size_from_type(output->info()->data_type())));
ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::U8 && (output->info()->data_type() != DataType::S16 && output->info()->data_type() != DataType::U16
&& output->info()->data_type() != DataType::S32),
"Only data_types supported [in] U8 -> [out] U16, S16, S32");
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::QS8 && (output->info()->data_type() != DataType::QS8 && output->info()->data_type() != DataType::F32),
- "Only data_types supported [in] QS8 -> [out] QS8, F32");
-
ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::U16 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::U32),
"Only data_types supported [in] U16 -> [out] U8, U32");
ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::S16 && (output->info()->data_type() != DataType::U8 && output->info()->data_type() != DataType::S32),
"Only data_types supported [in] S16 -> [out] U8, S32");
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::QS16 && (output->info()->data_type() != DataType::QS16 && output->info()->data_type() != DataType::F32),
- "Only data_types supported [in] QS16 -> [out] QS16, F32");
-
- ARM_COMPUTE_ERROR_ON_MSG(input->info()->data_type() == DataType::F32 && (output->info()->data_type() != DataType::QS8 && output->info()->data_type() != DataType::QS16),
- "Only data_types supported [in] F32 -> [out] QS8, QS16");
-
constexpr unsigned int num_elems_processed_per_iteration = 16;
// Configure kernel window
@@ -132,8 +110,6 @@
Iterator input(_input, window);
Iterator output(_output, window);
- bool in_place = (_input == _output);
-
switch(_input->info()->data_type())
{
case DataType::U8:
@@ -212,49 +188,6 @@
}
break;
}
- case DataType::QS8:
- {
- switch(_output->info()->data_type())
- {
- case DataType::QS8:
- {
- const int relative_shift = _fixed_point_position_output - _fixed_point_position_input;
- /* Fixed point position conversion QS8 -> QS8 */
- if(relative_shift != 0 || !in_place)
- {
- const auto relative_shift_vec = vdupq_n_qs8(relative_shift);
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const qint8x16_t texels_qs8 = vld1q_qs8(reinterpret_cast<const qint8_t *>(input.ptr()));
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), vqrshlq_s8(texels_qs8, relative_shift_vec));
- },
- input, output);
- }
- break;
- }
- case DataType::F32:
- {
- /* Up-conversion QS8 -> F32 */
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const qint8x16_t texels_qs8 = vld1q_qs8(reinterpret_cast<const qint8_t *>(input.ptr()));
-
- float32x4x2_t texels_low = vcvt_f32_qs8(vget_low_s8(texels_qs8), _fixed_point_position_input);
- float32x4x2_t texels_high = vcvt_f32_qs8(vget_high_s8(texels_qs8), _fixed_point_position_input);
-
- vst1q_f32(reinterpret_cast<float *>(output.ptr()), texels_low.val[0]);
- vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 4, texels_low.val[1]);
- vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 8, texels_high.val[0]);
- vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 12, texels_high.val[1]);
- },
- input, output);
- break;
- }
- default:
- ARM_COMPUTE_ERROR("Output data type not supported");
- }
- break;
- }
case DataType::S16:
{
switch(_output->info()->data_type())
@@ -408,116 +341,6 @@
}
break;
}
- case DataType::QS16:
- {
- switch(_output->info()->data_type())
- {
- case DataType::QS16:
- {
- const int relative_shift = _fixed_point_position_output - _fixed_point_position_input;
- /* Fixed point position conversion QS16 -> QS16 */
- if(relative_shift != 0 || !in_place)
- {
- const auto relative_shift_vec = vdupq_n_qs16(relative_shift);
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const qint16x8x2_t texels_qs16 =
- {
- {
- vld1q_qs16(reinterpret_cast<qint16_t *>(input.ptr())),
- vld1q_qs16(reinterpret_cast<qint16_t *>(input.ptr()) + 8)
- }
- };
- vst1q_qs16(reinterpret_cast<qint16_t *>(output.ptr()), vqrshlq_s16(texels_qs16.val[0], relative_shift_vec));
- vst1q_qs16(reinterpret_cast<qint16_t *>(output.ptr()) + 8, vqrshlq_s16(texels_qs16.val[1], relative_shift_vec));
- },
- input, output);
- }
- break;
- }
- case DataType::F32:
- {
- /* Up-conversion QS16 -> F32 */
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const int16x8x2_t texels_qs16 =
- {
- {
- vld1q_s16(reinterpret_cast<qint16_t *>(input.ptr())),
- vld1q_s16(reinterpret_cast<qint16_t *>(input.ptr()) + 8)
- }
- };
-
- vst1q_f32(reinterpret_cast<float *>(output.ptr()), vcvt_f32_qs16(vget_low_s16(texels_qs16.val[0]), _fixed_point_position_input));
- vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 4, vcvt_f32_qs16(vget_high_s16(texels_qs16.val[0]), _fixed_point_position_input));
- vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 8, vcvt_f32_qs16(vget_low_s16(texels_qs16.val[1]), _fixed_point_position_input));
- vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 12, vcvt_f32_qs16(vget_high_s16(texels_qs16.val[1]), _fixed_point_position_input));
- },
- input, output);
- break;
- }
- default:
- ARM_COMPUTE_ERROR("Output data type not supported");
- }
- break;
- }
- case DataType::F32:
- {
- switch(_output->info()->data_type())
- {
- case DataType::QS8:
- {
- /* Down-conversion F32 -> QS8 */
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const float32x4x4_t texels_f32 =
- {
- {
- vld1q_f32(reinterpret_cast<const float *>(input.ptr())),
- vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 4),
- vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 8),
- vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 12)
- }
- };
-
- const qint8x16_t texels_s8 = vqcvtq_qs8_f32(texels_f32, _fixed_point_position_output);
-
- vst1q_s8(reinterpret_cast<int8_t *>(output.ptr()), texels_s8);
- },
- input, output);
- break;
- }
- case DataType::QS16:
- {
- /* Down-conversion F32 -> QS16 */
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const float32x4x2_t texels_f32_1 =
- {
- {
- vld1q_f32(reinterpret_cast<const float *>(input.ptr())),
- vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 4),
- }
- };
- const float32x4x2_t texels_f32_2 =
- {
- {
- vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 8),
- vld1q_f32(reinterpret_cast<const float *>(input.ptr()) + 12)
- }
- };
-
- vst1q_s16(reinterpret_cast<qint16_t *>(output.ptr()), vqcvtq_qs16_f32(texels_f32_1, _fixed_point_position_output));
- vst1q_s16(reinterpret_cast<qint16_t *>(output.ptr()) + 8, vqcvtq_qs16_f32(texels_f32_2, _fixed_point_position_output));
- },
- input, output);
- break;
- }
- default:
- ARM_COMPUTE_ERROR("Output data type not supported");
- }
- break;
- }
default:
ARM_COMPUTE_ERROR("Not supported");
}
diff --git a/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp b/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp
index 8cdf175..09e4acd 100644
--- a/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp
@@ -115,7 +115,7 @@
in_top += delta_input, in_mid += delta_input, in_low += delta_input,
p_out += num_elems_written_per_iteration)
{
- auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vw_r0, vw_r1, vw_r2, 0, input_offset);
+ auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vw_r0, vw_r1, vw_r2, input_offset);
store_results<stridex>(p_out, vres);
}
}
@@ -144,6 +144,113 @@
ARM_COMPUTE_ERROR("Not implemented");
}
}
+
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier, bool is_optimized)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
+
+ const DataLayout data_layout = input->data_layout();
+ const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != 3 || weights->dimension(height_idx) != 3);
+
+ if(!is_optimized)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(conv_info.stride().first < 1 || conv_info.stride().first > 3);
+ }
+
+ if(output->total_size() != 0)
+ {
+ const TensorShape output_shape = compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()) && (output->data_type() != DataType::S32));
+ ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_float(input->data_type()) && (output->data_type() != DataType::F32));
+ }
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier, bool is_optimized,
+ IDepthwiseConvolution *convolver = nullptr)
+{
+ Window win;
+ bool window_changed = false;
+
+ if(is_optimized)
+ {
+ if(convolver != nullptr)
+ {
+ auto win_last = convolver->get_window();
+ win.set(Window::DimX, Window::Dimension(0, win_last, 1));
+
+ // Auto-configure output
+ bool same_padding = conv_info.has_padding();
+ TensorShape output_shape{ input->tensor_shape() };
+
+ output_shape.set(1, convolver->output_size(output_shape.y(), same_padding)); // Set width
+ output_shape.set(2, convolver->output_size(output_shape.z(), same_padding)); // Set height
+
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output, input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape));
+
+ // Configure window (optimised)
+ // Set padding in channels
+ const int num_channels = weights->dimension(0);
+ if((num_channels >= 128) && (num_channels % 16 == 0))
+ {
+ input->extend_padding(PaddingSize(0, 4, 0, 0));
+ weights->extend_padding(PaddingSize(0, 4, 0, 0));
+ output->extend_padding(PaddingSize(0, 4, 0, 0));
+ }
+ }
+ }
+ else
+ {
+ // Get convolved dimensions
+ const TensorShape output_shape = compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier);
+ const DataType output_dt = (input->data_type() == DataType::QASYMM8) ? DataType::S32 : input->data_type();
+
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output, input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape).set_data_type(output_dt));
+
+ // Configure kernel window (generic)
+ const unsigned int conv_stride_x = conv_info.stride().first;
+ const unsigned int conv_stride_y = conv_info.stride().second;
+ const unsigned int conv_pad_top = conv_info.pad_top();
+ const unsigned int conv_pad_left = conv_info.pad_left();
+
+ unsigned int num_elems_written_per_iteration = 16 >> conv_stride_x;
+ unsigned int num_elems_read_per_iteration = 0;
+
+ switch(input->data_type())
+ {
+ case DataType::QASYMM8:
+ num_elems_read_per_iteration = 16;
+ break;
+ case DataType::F32:
+ num_elems_read_per_iteration = 12;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported.");
+ }
+
+ // Configure kernel window
+ win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
+
+ AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top, num_elems_read_per_iteration, 3, conv_stride_x, conv_stride_y);
+ AccessWindowStatic weights_access(weights, 0, 0, 3, 3);
+ AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
+
+ window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
+ output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+ }
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
} // namespace
NEDepthwiseConvolutionLayer3x3Kernel::NEDepthwiseConvolutionLayer3x3Kernel()
@@ -159,8 +266,7 @@
void NEDepthwiseConvolutionLayer3x3Kernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier,
DataLayout data_layout)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
_input = input;
_output = output;
@@ -177,6 +283,17 @@
(_run_optimized) ? configure_optimized() : configure_generic();
}
+Status NEDepthwiseConvolutionLayer3x3Kernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
+
+ bool is_optimized = NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(input->tensor_shape(), conv_info, input->data_type(), depth_multiplier, input->data_layout());
+
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info, depth_multiplier, is_optimized));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), weights->clone().get(), output->clone().get(), conv_info, depth_multiplier, is_optimized).first);
+ return Status{};
+}
+
void NEDepthwiseConvolutionLayer3x3Kernel::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
@@ -227,90 +344,26 @@
void NEDepthwiseConvolutionLayer3x3Kernel::configure_generic()
{
- ARM_COMPUTE_ERROR_ON(_weights->info()->dimension(0) != 3 || _weights->info()->dimension(1) != 3);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _weights->info(), _output->info(), _conv_info, _depth_multiplier, _run_optimized));
- // Get convolved dimensions
- const TensorShape output_shape = compute_depthwise_convolution_shape(*_input->info(), *_weights->info(), _conv_info, _depth_multiplier);
- const DataType output_dt = (_input->info()->data_type() == DataType::QASYMM8) ? DataType::S32 : _input->info()->data_type();
+ _num_elems_written_per_iteration = 16 >> _conv_info.stride().first;
+ _border_size = BorderSize(_conv_info.pad_top(), _conv_info.pad_right(), _conv_info.pad_bottom(), _conv_info.pad_left());
- // Output auto inizialitation if not yet initialized
- auto_init_if_empty(*_output->info(),
- _input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape).set_data_type(output_dt));
-
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(_output->info()->tensor_shape(), output_shape);
-
- const unsigned int conv_stride_x = _conv_info.stride().first;
- const unsigned int conv_stride_y = _conv_info.stride().second;
- const unsigned int conv_pad_top = _conv_info.pad_top();
- const unsigned int conv_pad_right = _conv_info.pad_right();
- const unsigned int conv_pad_bottom = _conv_info.pad_bottom();
- const unsigned int conv_pad_left = _conv_info.pad_left();
-
- ARM_COMPUTE_ERROR_ON(conv_stride_x < 1 || conv_stride_x > 3);
-
- unsigned int num_elems_read_per_iteration = 0;
- switch(_input->info()->data_type())
- {
- case DataType::QASYMM8:
- num_elems_read_per_iteration = 16;
- _num_elems_written_per_iteration = 16 >> conv_stride_x;
- break;
- case DataType::F32:
- num_elems_read_per_iteration = 12;
- _num_elems_written_per_iteration = 16 >> conv_stride_x;
- break;
- default:
- ARM_COMPUTE_ERROR("Data type not supported.");
- }
- _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
-
- // Configure kernel window
- Window win = calculate_max_window(*_output->info(), Steps(_num_elems_written_per_iteration));
-
- AccessWindowRectangle input_access(_input->info(), -conv_pad_left, -conv_pad_top,
- num_elems_read_per_iteration, 3,
- conv_stride_x, conv_stride_y);
- AccessWindowStatic weights_access(_weights->info(), 0, 0, 3, 3);
- AccessWindowHorizontal output_access(_output->info(), 0, _num_elems_written_per_iteration);
-
- update_window_and_padding(win, input_access, weights_access, output_access);
- output_access.set_valid_region(win, ValidRegion(Coordinates(), _output->info()->tensor_shape()));
-
- INEKernel::configure(win);
+ auto win_config = validate_and_configure_window(_input->info(), _weights->info(), _output->info(), _conv_info, _depth_multiplier, false);
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ INEKernel::configure(win_config.second);
}
void NEDepthwiseConvolutionLayer3x3Kernel::configure_optimized()
{
- ARM_COMPUTE_ERROR_ON(_weights->info()->dimension(1) != 3 || _weights->info()->dimension(2) != 3);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _weights->info(), _output->info(), _conv_info, _depth_multiplier, _run_optimized));
_border_size = BorderSize(0, 0);
_convolver = create_convolver_object(_conv_info, _weights, _input, _output);
- // Auto-configure output
- bool same_padding = _conv_info.has_padding();
- TensorShape output_shape{ _input->info()->tensor_shape() };
-
- output_shape.set(1, _convolver->output_size(output_shape.y(), same_padding)); // Set width
- output_shape.set(2, _convolver->output_size(output_shape.z(), same_padding)); // Set height
-
- // Output auto inizialitation if not yet initialized
- auto_init_if_empty(*_output->info(),
- _input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape));
-
- // Set padding in channels
- const int num_channels = _weights->info()->dimension(0);
- if((num_channels >= 128) && (num_channels % 16 == 0))
- {
- _input->info()->extend_padding(PaddingSize(0, 4, 0, 0));
- _weights->info()->extend_padding(PaddingSize(0, 4, 0, 0));
- _output->info()->extend_padding(PaddingSize(0, 4, 0, 0));
- }
-
- // Configure window
- Window win;
- auto win_last = _convolver->get_window();
- win.set(Window::DimX, Window::Dimension(0, win_last, 1));
- INEKernel::configure(win);
+ auto win_config = validate_and_configure_window(_input->info(), _weights->info(), _output->info(), _conv_info, _depth_multiplier, true, _convolver.get());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ INEKernel::configure(win_config.second);
}
void NEDepthwiseConvolutionLayer3x3Kernel::run_generic(const Window &window, const ThreadInfo &info)
diff --git a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
index cfd8eac..92ee8d5 100644
--- a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
@@ -37,6 +37,22 @@
using namespace arm_compute;
+namespace
+{
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier)
+{
+ ARM_COMPUTE_UNUSED(conv_info);
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()) && has_bias);
+ ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(2) * depth_multiplier) != output->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0)));
+
+ return Status{};
+}
+} // namespace
+
template <typename T>
void NEDepthwiseIm2ColKernel::run_generic(const Window &window)
{
@@ -120,12 +136,9 @@
void NEDepthwiseIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
- ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input->info()->data_type()) && has_bias);
- ARM_COMPUTE_ERROR_ON((input->info()->dimension(2) * depth_multiplier) != output->info()->dimension(2));
- ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0)));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, depth_multiplier));
_input = input;
_output = output;
@@ -159,6 +172,13 @@
INEKernel::configure(win);
}
+Status NEDepthwiseIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, depth_multiplier));
+ return Status{};
+}
+
void NEDepthwiseIm2ColKernel::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
diff --git a/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp b/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp
index 8960d8a..2d17c23 100644
--- a/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.h"
#include "arm_compute/core/AccessWindowTranspose.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Coordinates.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
@@ -34,8 +35,28 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
using namespace arm_compute;
+using namespace arm_compute::misc::shape_calculator;
+
+namespace
+{
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
+
+ if(output->total_size() != 0)
+ {
+ TensorShape output_shape = compute_vector_to_tensor_output_shape(input->tensor_shape(), conv_w, conv_h, output->data_layout());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+
+ return Status{};
+}
+} // namespace
template <typename T>
void NEDepthwiseVectorToTensorKernel::vector_to_tensor(const Window &window)
@@ -76,20 +97,13 @@
void NEDepthwiseVectorToTensorKernel::configure(const ITensor *input, ITensor *output, size_t conv_w, size_t conv_h)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_NULLPTR(output);
-
- TensorShape output_shape = input->info()->tensor_shape();
- output_shape.set(0, conv_w);
- output_shape.set(1, conv_h);
- output_shape.set(2, input->info()->tensor_shape()[0] / (conv_w * conv_h));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Output auto inizialitation if not yet initialized
+ TensorShape output_shape = compute_vector_to_tensor_output_shape(input->info()->tensor_shape(), conv_w, conv_h, output->info()->data_layout());
auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), conv_w, conv_h));
_input = input;
_output = output;
@@ -122,6 +136,13 @@
INEKernel::configure(win);
}
+Status NEDepthwiseVectorToTensorKernel::validate(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, conv_w, conv_h));
+ return Status{};
+}
+
void NEDepthwiseVectorToTensorKernel::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
diff --git a/src/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.cpp
index 36b17bf..22a2cf8 100644
--- a/src/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.cpp
@@ -77,6 +77,25 @@
},
in, out);
}
+
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *biases)
+{
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()) && (biases != nullptr));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != output->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != (input->dimension(0) * input->dimension(1) + ((biases != nullptr) ? 1 : 0)));
+
+ if(biases != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != input->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+ }
+
+ return Status{};
+}
} // namespace
NEDepthwiseWeightsReshapeKernel::NEDepthwiseWeightsReshapeKernel()
@@ -86,20 +105,9 @@
void NEDepthwiseWeightsReshapeKernel::configure(const ITensor *input, ITensor *output, const ITensor *biases)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
- ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input->info()->data_type()) && (biases != nullptr));
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(1));
- ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (input->info()->dimension(0) * input->info()->dimension(1) + ((biases != nullptr) ? 1 : 0)));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- if(biases != nullptr)
- {
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases);
- ARM_COMPUTE_ERROR_ON(biases->info()->dimension(0) != input->info()->dimension(2));
- ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() > 1);
- }
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), (biases != nullptr) ? biases->info() : nullptr));
_input = input;
_output = output;
@@ -137,6 +145,13 @@
INEKernel::configure(win);
}
+Status NEDepthwiseWeightsReshapeKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *biases)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, biases));
+ return Status{};
+}
+
void NEDepthwiseWeightsReshapeKernel::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
diff --git a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp
index 4120e5f..47c895c 100644
--- a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp
@@ -54,7 +54,7 @@
std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *min_max)
{
// Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::F32, 0);
+ auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::F32);
constexpr unsigned int num_elems_processed_per_iteration = 8;
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
index 5eafdf0..f525d93 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
@@ -43,34 +44,6 @@
namespace
{
-template <unsigned int stridex>
-qint16x8_t internal_vld1q(const qint16_t *in);
-
-template <>
-qint16x8_t internal_vld1q<1>(const qint16_t *in)
-{
- return vld1q_qs16(in);
-}
-
-template <>
-qint16x8_t internal_vld1q<2>(const qint16_t *in)
-{
- const int16x8x2_t tmp = vld2q_s16(in);
- return tmp.val[0];
-}
-
-template <>
-qint16x8_t internal_vld1q<3>(const qint16_t *in)
-{
- const int16x8x3_t tmp = vld3q_s16(in);
- return tmp.val[0];
-}
-
-inline qint16x8_t internal_vdupq_n(qint16_t v)
-{
- return vdupq_n_qs16(v);
-}
-
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <unsigned int stridex>
float16x8_t internal_vld1q(const float16_t *in);
@@ -105,15 +78,13 @@
vst1q_f16(p, v);
}
-float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
+float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y)
{
- ARM_COMPUTE_UNUSED(fixed_point_position);
return vmulq_f16(x, y);
}
-inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
+inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z)
{
- ARM_COMPUTE_UNUSED(fixed_point_position);
return vaddq_f16(x, vmulq_f16(y, z));
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -151,107 +122,16 @@
vst1q_f32(p, v);
}
-float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
+float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y)
{
- ARM_COMPUTE_UNUSED(fixed_point_position);
return vmulq_f32(x, y);
}
-inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
+inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z)
{
- ARM_COMPUTE_UNUSED(fixed_point_position);
return vmlaq_f32(x, y, z);
}
-template <unsigned int stridex>
-qint8x8_t internal_vld1q(const qint8_t *in);
-
-template <>
-qint8x8_t internal_vld1q<1>(const qint8_t *in)
-{
- return vld1_qs8(in);
-}
-
-template <>
-qint8x8_t internal_vld1q<2>(const qint8_t *in)
-{
- const qint8x8x2_t tmp = vld2_s8(in);
- return tmp.val[0];
-}
-
-template <>
-qint8x8_t internal_vld1q<3>(const qint8_t *in)
-{
- const qint8x8x3_t tmp = vld3_s8(in);
- return tmp.val[0];
-}
-
-inline qint8x8_t internal_vdupq_n(qint8_t v)
-{
- return vdup_n_qs8(v);
-}
-
-inline qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
-{
- return vmull_qs8(x, y, fixed_point_position);
-}
-
-inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
-{
- return vqmlal_qs8(x, y, z, fixed_point_position);
-}
-
-inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
-{
- vst1q_qs16(p, v);
-}
-
-inline void internal_vst1q(int32_t *p, const qint32x4x2_t &v)
-{
- vst1q_s32(p, v.val[0]);
- vst1q_s32(p + 4, v.val[1]);
-}
-
-template <unsigned int stridex>
-qint32x4x2_t internal_vld1q(const qint32_t *in);
-
-template <>
-qint32x4x2_t internal_vld1q<1>(const qint32_t *in)
-{
- const qint32x4x2_t r =
- {
- {
- vld1q_s32(in),
- vld1q_s32(in + 4)
- }
- };
- return r;
-}
-
-inline qint32x4x2_t internal_vmull(const qint16x8_t &x, const qint16x8_t &y, int fixed_point_position)
-{
- const qint32x4x2_t r =
- {
- {
- vmull_qs16(vget_low_s16(x), vget_low_s16(y), fixed_point_position),
- vmull_qs16(vget_high_s16(x), vget_high_s16(y), fixed_point_position),
- }
- };
- return r;
-}
-
-inline qint32x4x2_t internal_vmlal(const qint32x4x2_t &x, const qint16x8_t &y, const qint16x8_t &z, int fixed_point_position)
-{
- const qint32x4x2_t r =
- {
- {
- vqmlal_qs16(x.val[0], vget_low_s16(y), vget_low_s16(z), fixed_point_position),
- vqmlal_qs16(x.val[1], vget_high_s16(y), vget_high_s16(z), fixed_point_position)
- }
- };
- return r;
-}
-
constexpr int small_tensor_size_optim = 8;
inline bool run_optim_small_tensor_info(const ITensorInfo *t)
{
@@ -355,21 +235,20 @@
static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
{
- const int input_stride_x = input->info()->strides_in_bytes().x();
- const int input_stride_y = input->info()->strides_in_bytes().y();
- const int input_stride_z = input->info()->strides_in_bytes().z();
- const int output_stride_y = output->info()->strides_in_bytes().y();
- const int output_stride_z = output->info()->strides_in_bytes().z();
- const int kernel_stride_z = weights->info()->strides_in_bytes().z();
- const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
- const int output_w = output->info()->dimension(0);
- const int output_h = output->info()->dimension(1);
- const int range_z = window.z().end() - window.z().start();
- const int kernel_depth = weights->info()->dimension(Window::DimZ);
- const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
- const unsigned int conv_pad_left = conv_info.pad_left();
- const unsigned int conv_pad_top = conv_info.pad_top();
- const int fixed_point_position = input->info()->fixed_point_position();
+ const int input_stride_x = input->info()->strides_in_bytes().x();
+ const int input_stride_y = input->info()->strides_in_bytes().y();
+ const int input_stride_z = input->info()->strides_in_bytes().z();
+ const int output_stride_y = output->info()->strides_in_bytes().y();
+ const int output_stride_z = output->info()->strides_in_bytes().z();
+ const int kernel_stride_z = weights->info()->strides_in_bytes().z();
+ const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
+ const int output_w = output->info()->dimension(0);
+ const int output_h = output->info()->dimension(1);
+ const int range_z = window.z().end() - window.z().start();
+ const int kernel_depth = weights->info()->dimension(Window::DimZ);
+ const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
+ const unsigned int conv_pad_left = conv_info.pad_left();
+ const unsigned int conv_pad_top = conv_info.pad_top();
// setup output window for the iterator
Window window_out = window;
@@ -414,7 +293,7 @@
auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
{
- internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
+ internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val)));
}
}
}
@@ -431,7 +310,7 @@
auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
{
- internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
+ internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val)));
}
}
}
@@ -469,7 +348,7 @@
template <unsigned int stridex>
float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
- const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position);
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
{
@@ -511,9 +390,8 @@
template <>
inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
- const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
{
- ARM_COMPUTE_UNUSED(fixed_point_position);
const float32x4x3_t vin0 = load_input(in_0);
const float32x4x3_t vin1 = load_input(in_1);
const float32x4x3_t vin2 = load_input(in_2);
@@ -601,10 +479,9 @@
template <>
inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
- const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
{
- ARM_COMPUTE_UNUSED(fixed_point_position);
- float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
+ float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
@@ -613,9 +490,9 @@
template <>
inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
- const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
{
- float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
+ float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
return out;
}
@@ -642,28 +519,6 @@
vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
}
-template <unsigned int stridex>
-void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
-
-template <>
-void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
-{
- vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
- vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
-}
-
-template <>
-void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
-{
- vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
-}
-
-template <>
-void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
-{
- vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
-}
-
template <typename T1>
class convolver_nhwc
{
@@ -745,7 +600,7 @@
const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
const auto we_values = internal_vld1q<1>(we_addr);
- out_values = internal_vmlal(out_values, in_values, we_values, 0);
+ out_values = internal_vmlal(out_values, in_values, we_values);
}
out_val += out_values[0];
@@ -784,24 +639,23 @@
const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
{
ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
- const int input_stride_x = input->info()->strides_in_bytes().x();
- const int input_stride_y = input->info()->strides_in_bytes().y();
- const int input_stride_z = input->info()->strides_in_bytes().z();
- const int output_stride_y = output->info()->strides_in_bytes().y();
- const int output_stride_z = output->info()->strides_in_bytes().z();
- const int kernel_stride_x = weights->info()->strides_in_bytes().x();
- const int kernel_stride_y = weights->info()->strides_in_bytes().y();
- const int kernel_stride_z = weights->info()->strides_in_bytes().z();
- const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
- const int output_w = output->info()->dimension(0);
- const int output_h = output->info()->dimension(1);
- const int num_planes_z = window.z().end() - window.z().start();
- const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
- const int kernel_depth = weights->info()->dimension(Window::DimZ);
- const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
- const unsigned int conv_pad_left = conv_info.pad_left();
- const unsigned int conv_pad_top = conv_info.pad_top();
- const int fixed_point_position = input->info()->fixed_point_position();
+ const int input_stride_x = input->info()->strides_in_bytes().x();
+ const int input_stride_y = input->info()->strides_in_bytes().y();
+ const int input_stride_z = input->info()->strides_in_bytes().z();
+ const int output_stride_y = output->info()->strides_in_bytes().y();
+ const int output_stride_z = output->info()->strides_in_bytes().z();
+ const int kernel_stride_x = weights->info()->strides_in_bytes().x();
+ const int kernel_stride_y = weights->info()->strides_in_bytes().y();
+ const int kernel_stride_z = weights->info()->strides_in_bytes().z();
+ const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
+ const int output_w = output->info()->dimension(0);
+ const int output_h = output->info()->dimension(1);
+ const int num_planes_z = window.z().end() - window.z().start();
+ const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
+ const int kernel_depth = weights->info()->dimension(Window::DimZ);
+ const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
+ const unsigned int conv_pad_left = conv_info.pad_left();
+ const unsigned int conv_pad_top = conv_info.pad_top();
// setup output window for the iterator
Window window_out = window;
@@ -864,7 +718,7 @@
for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
{
- auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
+ auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2);
store_results<stridex>(p_out, vres);
}
}
@@ -889,7 +743,7 @@
for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
{
- auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
+ auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2);
accumulate_results<stridex>(p_out, vres);
}
}
@@ -908,24 +762,23 @@
const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
{
ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
- const int input_stride_x = input->info()->strides_in_bytes().x();
- const int input_stride_y = input->info()->strides_in_bytes().y();
- const int input_stride_z = input->info()->strides_in_bytes().z();
- const int output_stride_y = output->info()->strides_in_bytes().y();
- const int output_stride_z = output->info()->strides_in_bytes().z();
- const int kernel_stride_x = weights->info()->strides_in_bytes().x();
- const int kernel_stride_y = weights->info()->strides_in_bytes().y();
- const int kernel_stride_z = weights->info()->strides_in_bytes().z();
- const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
- const int output_w = output->info()->dimension(0);
- const int output_h = output->info()->dimension(1);
- const int num_planes_z = window.z().end() - window.z().start();
- const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
- const int kernel_depth = weights->info()->dimension(Window::DimZ);
- const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
- const unsigned int conv_pad_left = conv_info.pad_left();
- const unsigned int conv_pad_top = conv_info.pad_top();
- const int fixed_point_position = input->info()->fixed_point_position();
+ const int input_stride_x = input->info()->strides_in_bytes().x();
+ const int input_stride_y = input->info()->strides_in_bytes().y();
+ const int input_stride_z = input->info()->strides_in_bytes().z();
+ const int output_stride_y = output->info()->strides_in_bytes().y();
+ const int output_stride_z = output->info()->strides_in_bytes().z();
+ const int kernel_stride_x = weights->info()->strides_in_bytes().x();
+ const int kernel_stride_y = weights->info()->strides_in_bytes().y();
+ const int kernel_stride_z = weights->info()->strides_in_bytes().z();
+ const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
+ const int output_w = output->info()->dimension(0);
+ const int output_h = output->info()->dimension(1);
+ const int num_planes_z = window.z().end() - window.z().start();
+ const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
+ const int kernel_depth = weights->info()->dimension(Window::DimZ);
+ const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
+ const unsigned int conv_pad_left = conv_info.pad_left();
+ const unsigned int conv_pad_top = conv_info.pad_top();
// setup output window for the iterator
Window window_out = window;
@@ -976,7 +829,7 @@
for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
{
- auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
+ auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
store_results<stridex>(p_out, vres);
}
}
@@ -1001,7 +854,7 @@
for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
{
- auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
+ auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
accumulate_results<stridex>(p_out, vres);
}
}
@@ -1120,7 +973,8 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
const DataLayout data_layout = input->data_layout();
@@ -1133,6 +987,7 @@
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (input->data_type() == DataType::F16));
// Checks performed when output is configured
if(output->total_size() != 0)
@@ -1140,11 +995,6 @@
TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
DataType data_type = input->data_type();
- if(is_data_type_fixed_point(data_type))
- {
- // Promote data type in case of fixed point
- data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
- }
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type);
@@ -1180,11 +1030,9 @@
{
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- case DataType::QS8:
- case DataType::QS16:
num_elems_written_per_iteration = 8;
break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
if(run_optim_small_tensor_info(input))
{
@@ -1204,8 +1052,6 @@
break;
}
case 3:
- case 5:
- {
switch(input->data_type())
{
case DataType::F32:
@@ -1215,13 +1061,25 @@
break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- case DataType::QS8:
- case DataType::QS16:
num_weight_elems_read_per_row = 8 + kernel_size - 1;
num_elems_read_per_iteration = 24;
num_elems_written_per_iteration = 32 >> conv_stride_x;
break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported.");
+ break;
+ }
+ break;
+ case 5:
+ {
+ switch(input->data_type())
+ {
+ case DataType::F32:
+ num_weight_elems_read_per_row = 4 + kernel_size - 1;
+ num_elems_read_per_iteration = 12;
+ num_elems_written_per_iteration = 16 >> conv_stride_x;
+ break;
default:
ARM_COMPUTE_ERROR("Data type not supported.");
break;
@@ -1315,14 +1173,8 @@
DataType data_type = input->info()->data_type();
- if(is_data_type_fixed_point(data_type))
- {
- // Promote data type in case of fixed point
- data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
- }
-
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, data_type, input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, data_type);
// Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info));
@@ -1371,12 +1223,6 @@
{
switch(_input->info()->data_type())
{
- case DataType::QS8:
- convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
- case DataType::QS16:
- convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
case DataType::F32:
convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
break;
@@ -1395,9 +1241,6 @@
{
switch(_input->info()->data_type())
{
- case DataType::QS8:
- convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
case DataType::F32:
convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
break;
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
index edda2cd..eefbd98 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h"
#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
@@ -43,24 +44,17 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8,
- DataType::QS16, DataType::F16,
- DataType::QS32, DataType::S32, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8,
+ DataType::F16,
+ DataType::S32, DataType::F32);
if(bias != nullptr)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::QS32, DataType::S32, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::F16, DataType::S32, DataType::F32);
- if(is_data_type_fixed_point(input->data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS8 && bias->data_type() != DataType::QS8, "Wrong data type for bias");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS16 && bias->data_type() != DataType::QS8, "Wrong data type for bias");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS32 && bias->data_type() != DataType::QS16, "Wrong data type for bias");
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, bias);
- }
- else if(is_data_type_quantized_asymmetric(input->data_type()))
+ if(is_data_type_quantized_asymmetric(input->data_type()))
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32);
}
@@ -80,17 +74,10 @@
// Checks performed when output is configured
if((output != nullptr) && (output->total_size() != 0))
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
- if(is_data_type_fixed_point(input->data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS8 && output->data_type() != DataType::QS8, "Wrong data type for output");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS16 && output->data_type() != DataType::QS8, "Wrong data type for output");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS32 && output->data_type() != DataType::QS16, "Wrong data type for output");
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
- }
- else if(is_data_type_quantized_asymmetric(output->data_type()))
+ if(is_data_type_quantized_asymmetric(output->data_type()))
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && output->data_type() != DataType::QASYMM8, "Wrong data type for bias");
}
@@ -168,81 +155,24 @@
{
return vld1q_f32(in);
}
-inline qint8x16_t internal_vld1q(const qint8_t *in)
-{
- return vld1q_qs8(in);
-}
-inline qint16x8_t internal_vld1q(const qint16_t *in)
-{
- return vld1q_qs16(in);
-}
-inline qint32x4_t internal_vld1q(const qint32_t *in)
-{
- return vld1q_s32(in);
-}
// Internal store
inline void internal_vst1q(float *p, const float32x4_t &v)
{
vst1q_f32(p, v);
}
-inline void internal_vst1q(qint8_t *p, const qint8x16_t &v)
-{
- vst1q_qs8(p, v);
-}
-inline void internal_vst1q(qint8_t *p, const qint16x8_t &v)
-{
- vst1_qs8(p, vqmovn_s16(v));
-}
-inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
-{
- vst1q_qs16(p, v);
-}
-inline void internal_vst1q(qint32_t *p, const qint32x4_t &v)
-{
- vst1q_s32(p, v);
-}
-
-inline void internal_vst1q(qint16_t *p, const qint32x4_t &v)
-{
- vst1_qs16(p, vqmovn_qs32(v));
-}
// Internal vdup
inline float32x4_t internal_vdupq_n(float v)
{
return vdupq_n_f32(v);
}
-inline qint8x16_t internal_vdupq_n(qint8_t v)
-{
- return vdupq_n_qs8(v);
-}
-inline qint16x8_t internal_vdupq_n(qint16_t v)
-{
- return vdupq_n_qs16(v);
-}
-inline qint32x4_t internal_vdupq_n(qint32_t v)
-{
- return vdupq_n_qs32(v);
-}
// Internal vadd
inline float32x4_t internal_vqaddq(const float32x4_t &x, const float32x4_t &y)
{
return vaddq_f32(x, y);
}
-inline qint8x16_t internal_vqaddq(const qint8x16_t &x, const qint8x16_t &y)
-{
- return vqaddq_qs8(x, y);
-}
-inline qint16x8_t internal_vqaddq(const qint16x8_t &x, const qint16x8_t &y)
-{
- return vqaddq_qs16(x, y);
-}
-inline qint32x4_t internal_vqaddq(const qint32x4_t &x, const qint32x4_t &y)
-{
- return vqaddq_qs32(x, y);
-}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
inline float16x8_t internal_vld1q(const float16_t *in)
@@ -494,39 +424,6 @@
{
switch(input->info()->data_type())
{
- case DataType::QS8:
- {
- if(bias == nullptr)
- {
- _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, false> : &output_stage<qint8_t, qint8_t, false, false>;
- }
- else
- {
- _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, true> : &output_stage<qint8_t, qint8_t, false, true>;
- }
- break;
- }
- case DataType::QS16:
- {
- if(bias != nullptr && bias->info()->data_type() == DataType::QS8)
- {
- _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, true> : &output_stage<qint16_t, qint8_t, false, true>;
- }
- else if(bias == nullptr)
- {
- _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, false> : &output_stage<qint16_t, qint8_t, false, false>;
- }
- else
- {
- ARM_COMPUTE_ERROR("Not implemented");
- }
- break;
- }
- case DataType::QS32:
- {
- _func = (output == nullptr) ? &output_stage<qint32_t, qint16_t, true, true> : &output_stage<qint32_t, qint16_t, false, true>;
- break;
- }
case DataType::S32:
{
_func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>;
@@ -570,7 +467,7 @@
Status NEDirectConvolutionLayerOutputStageKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), bias->clone().get(), output == nullptr ? nullptr : output->clone().get()).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), bias == nullptr ? nullptr : bias->clone().get(), output == nullptr ? nullptr : output->clone().get()).first);
return Status{};
}
diff --git a/src/core/NEON/kernels/NEFillBorderKernel.cpp b/src/core/NEON/kernels/NEFillBorderKernel.cpp
index 747b8b1..aef4d48 100644
--- a/src/core/NEON/kernels/NEFillBorderKernel.cpp
+++ b/src/core/NEON/kernels/NEFillBorderKernel.cpp
@@ -105,8 +105,9 @@
void NEFillBorderKernel::configure(ITensor *tensor, BorderSize border_size, BorderMode border_mode, const PixelValue &constant_border_value)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(tensor, 1, DataType::U8, DataType::QS8, DataType::QASYMM8,
- DataType::QS16, DataType::U16, DataType::S16,
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(tensor, 1, DataType::U8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
@@ -147,7 +148,6 @@
case DataType::U8:
fill_constant_value_single_channel<uint8_t>(window);
break;
- case DataType::QS8:
case DataType::S8:
fill_constant_value_single_channel<int8_t>(window);
break;
@@ -155,7 +155,6 @@
fill_constant_value_single_channel<uint16_t>(window);
break;
case DataType::S16:
- case DataType::QS16:
fill_constant_value_single_channel<int16_t>(window);
break;
case DataType::U32:
@@ -192,7 +191,6 @@
case DataType::U8:
fill_replicate_single_channel<uint8_t>(window);
break;
- case DataType::QS8:
case DataType::S8:
fill_replicate_single_channel<int8_t>(window);
break;
@@ -200,7 +198,6 @@
fill_replicate_single_channel<uint16_t>(window);
break;
case DataType::S16:
- case DataType::QS16:
fill_replicate_single_channel<int16_t>(window);
break;
case DataType::U32:
diff --git a/src/core/NEON/kernels/NEFloorKernel.cpp b/src/core/NEON/kernels/NEFloorKernel.cpp
index 72b652d..872ac26 100644
--- a/src/core/NEON/kernels/NEFloorKernel.cpp
+++ b/src/core/NEON/kernels/NEFloorKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,7 +40,7 @@
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Auto initialize output
- auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
diff --git a/src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.cpp b/src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.cpp
index 12755a4..5483602 100644
--- a/src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.cpp
@@ -44,11 +44,11 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::U8, DataType::S8,
- DataType::QS16, DataType::U16, DataType::S16, DataType::U32, DataType::S32,
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::U8, DataType::S8,
+ DataType::U16, DataType::S16, DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
if(output->total_size() != 0)
{
@@ -57,7 +57,6 @@
output_shape.set(1, std::ceil(input->dimension(1) / 4.0f));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
diff --git a/src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.cpp
index ee334df..af84d02 100644
--- a/src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.cpp
@@ -193,11 +193,14 @@
Window win_vector_sum_row(collapsed_window);
win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
Iterator vector_sum_col(_vector_sum_col, win_vector_sum_col);
Iterator vector_sum_row(_vector_sum_row, win_vector_sum_row);
Iterator mm_result(_mm_result, window);
+ const size_t sum_row_stride_y = _vector_sum_row->info()->strides_in_bytes().y();
+
execute_window_loop(collapsed_window, [&](const Coordinates & id)
{
// Compute the leftover term due to a_offset.
@@ -217,7 +220,7 @@
a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], _a_offset);
// Compute the leftover term due to b_offset.
- int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast<const int32_t *>(vector_sum_row.ptr()) + id.y());
+ int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast<const int32_t *>(vector_sum_row.ptr() + id.z() * sum_row_stride_y) + id.y());
b_offset_term_s32 = vmulq_n_s32(b_offset_term_s32, _b_offset);
// Add a_offset_term_s32 and b_offset_term_s32
@@ -266,14 +269,17 @@
Window win_vector_sum_row(collapsed_window);
win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
Iterator vector_sum_row(_vector_sum_row, win_vector_sum_row);
Iterator mm_result(_mm_result, window);
+ const size_t sum_row_stride_y = _vector_sum_row->info()->strides_in_bytes().y();
+
execute_window_loop(window, [&](const Coordinates & id)
{
// Compute the leftover term due to b_offset.
- int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast<const int32_t *>(vector_sum_row.ptr()) + id.y());
+ int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast<const int32_t *>(vector_sum_row.ptr() + id.z() * sum_row_stride_y) + id.y());
b_offset_term_s32 = vmulq_n_s32(b_offset_term_s32, _b_offset);
int32x4x4_t in_s32 =
diff --git a/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp
index cab3c7a..42353ed 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.h"
#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
@@ -43,9 +44,9 @@
{
inline Status validate_arguments(const ITensorInfo *accum, const ITensorInfo *biases)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(accum);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(biases, accum);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(biases, accum);
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != accum->dimension(0));
@@ -161,33 +162,6 @@
break;
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- case DataType::QS8:
- {
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const qint8x16_t accum = vld1q_qs8(reinterpret_cast<const qint8_t *>(in0_out.ptr()));
- const qint8x16_t biases = vld1q_qs8(reinterpret_cast<const qint8_t *>(in1.ptr()));
-
- vst1q_qs8(reinterpret_cast<qint8_t *>(in0_out.ptr()), vqaddq_qs8(accum, biases));
- },
- in0_out, in1);
- break;
- }
- case DataType::QS16:
- {
- execute_window_loop(window, [&](const Coordinates & id)
- {
- qint16x8x2_t accum = vld2q_s16(reinterpret_cast<const qint16_t *>(in0_out.ptr()));
- const qint16x8x2_t biases = vld2q_s16(reinterpret_cast<const qint16_t *>(in1.ptr()));
-
- accum.val[0] = vqaddq_qs16(accum.val[0], biases.val[0]);
- accum.val[1] = vqaddq_qs16(accum.val[1], biases.val[1]);
-
- vst2q_s16(reinterpret_cast<qint16_t *>(in0_out.ptr()), accum);
- },
- in0_out, in1);
- break;
- }
default:
ARM_COMPUTE_ERROR("Data type not supported");
break;
diff --git a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
index dfba743..cd6aa55 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/NEON/NEFixedPoint.h"
@@ -91,54 +92,6 @@
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-void matrix_addition_qs8(const ITensor *input, ITensor *output, const Window &window, float beta)
-{
- const int fixed_point_position = input->info()->fixed_point_position();
- const qint8x16_t beta_qs8 = vdupq_n_qs8(sqcvt_qs8_f32(beta, fixed_point_position));
-
- Iterator in(input, window);
- Iterator out(output, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const auto in_ptr = reinterpret_cast<const qint8_t *>(in.ptr());
- const auto out_ptr = reinterpret_cast<qint8_t *>(out.ptr());
-
- qint8x16_t alpha_ab = vld1q_qs8(out_ptr);
- const qint8x16_t c = vld1q_qs8(in_ptr);
-
- // Multiply matrix C by its weight and accumulate
- alpha_ab = vqmlaq_qs8(alpha_ab, c, beta_qs8, fixed_point_position);
-
- vst1q_qs8(out_ptr, alpha_ab);
- },
- in, out);
-}
-
-void matrix_addition_qs16(const ITensor *input, ITensor *output, const Window &window, float beta)
-{
- const int fixed_point_position = input->info()->fixed_point_position();
- const qint16x8_t beta_qs16 = vdupq_n_qs16(sqcvt_qs16_f32(beta, fixed_point_position));
-
- Iterator in(input, window);
- Iterator out(output, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const auto in_ptr = reinterpret_cast<const qint16_t *>(in.ptr());
- const auto out_ptr = reinterpret_cast<qint16_t *>(out.ptr());
-
- qint16x8x2_t alpha_ab = vld2q_s16(out_ptr);
- const qint16x8x2_t c = vld2q_s16(in_ptr);
-
- // Multiply matrix C by its weight and accumulate
- alpha_ab.val[0] = vqmlaq_qs16(alpha_ab.val[0], c.val[0], beta_qs16, fixed_point_position);
- alpha_ab.val[1] = vqmlaq_qs16(alpha_ab.val[1], c.val[1], beta_qs16, fixed_point_position);
-
- vst2q_s16(out_ptr, alpha_ab);
- },
- in, out);
-}
} // namespace
NEGEMMMatrixAdditionKernel::NEGEMMMatrixAdditionKernel()
@@ -148,10 +101,10 @@
void NEGEMMMatrixAdditionKernel::configure(const ITensor *input, ITensor *output, float beta)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != output->info()->dimension(0));
ARM_COMPUTE_ERROR_ON(input->info()->dimension(1) != output->info()->dimension(1));
@@ -160,12 +113,6 @@
case DataType::F32:
_func = &matrix_addition_f32;
break;
- case DataType::QS8:
- _func = &matrix_addition_qs8;
- break;
- case DataType::QS16:
- _func = &matrix_addition_qs16;
- break;
case DataType::F16:
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
_func = &matrix_addition_f16;
diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
index 69b052a..0ca2474 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/AccessWindowTranspose.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
@@ -356,263 +357,6 @@
}
template <bool multiply_alpha>
-void vector_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
-{
- const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
- const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
- const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
- const int fixed_point_position = input0->info()->fixed_point_position();
-
- // The implementation computes 32 elements per iteration
- const int window_start_x = 32 * info.thread_id;
- const int window_step_x = 32 * info.num_threads;
- // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
- const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
-
- Window win_out(window);
- win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
- win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
-
- Window win_a(window);
- win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
- win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
-
- Window win_b;
- // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
- // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
- if(input1->info()->num_dimensions() >= 3)
- {
- win_b = window;
- }
- win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
- win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
-
- Iterator ina(input0, win_a);
- Iterator inb(input1, win_b);
- Iterator out(output, win_out);
-
- execute_window_loop(win_out, [&](const Coordinates & id)
- {
- if(id.x() > width_matrix_b)
- {
- return;
- }
-
- // Reset accumulators
- qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
-
- auto vec_a = reinterpret_cast<const qint8_t *>(ina.ptr());
- auto matrix_b = reinterpret_cast<const qint8_t *>(inb.ptr());
-
- auto vec_a_end_addr = vec_a + num_elems_vec_a;
- for(; vec_a <= (vec_a_end_addr - 2);)
- {
- const qint8x8_t a0 = vld1_dup_qs8(vec_a + 0);
- const qint8x8_t a1 = vld1_dup_qs8(vec_a + 1);
-
- const qint8x8_t b00 = vld1_qs8(matrix_b + 0 + 0 * in_b_stride);
- const qint8x8_t b01 = vld1_qs8(matrix_b + 8 + 0 * in_b_stride);
- const qint8x8_t b02 = vld1_qs8(matrix_b + 16 + 0 * in_b_stride);
- const qint8x8_t b03 = vld1_qs8(matrix_b + 24 + 0 * in_b_stride);
- const qint8x8_t b10 = vld1_qs8(matrix_b + 0 + 1 * in_b_stride);
- const qint8x8_t b11 = vld1_qs8(matrix_b + 8 + 1 * in_b_stride);
- const qint8x8_t b12 = vld1_qs8(matrix_b + 16 + 1 * in_b_stride);
- const qint8x8_t b13 = vld1_qs8(matrix_b + 24 + 1 * in_b_stride);
-
- // First accumulation
- acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
- acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
- acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
- acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
-
- // Second accumulation
- acc00_qs16 = vqmlal_qs8(acc00_qs16, b10, a1, fixed_point_position);
- acc01_qs16 = vqmlal_qs8(acc01_qs16, b11, a1, fixed_point_position);
- acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a1, fixed_point_position);
- acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a1, fixed_point_position);
-
- vec_a += 2;
- matrix_b += 2 * in_b_stride;
- }
-
- for(; vec_a < vec_a_end_addr;)
- {
- const qint8x8_t a0 = vld1_dup_qs8(vec_a);
-
- const qint8x8_t b00 = vld1_qs8(matrix_b + 0);
- const qint8x8_t b01 = vld1_qs8(matrix_b + 8);
- const qint8x8_t b02 = vld1_qs8(matrix_b + 16);
- const qint8x8_t b03 = vld1_qs8(matrix_b + 24);
-
- acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
- acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
- acc02_qs16 = vqmlal_qs8(acc02_qs16, b02, a0, fixed_point_position);
- acc03_qs16 = vqmlal_qs8(acc03_qs16, b03, a0, fixed_point_position);
-
- vec_a += 1;
- matrix_b += in_b_stride;
- }
-
- // Convert back to qint8x8_t and saturate
- qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
- qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
- qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
- qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
-
- // Multiply by the weight of the matrix product (alpha)
- if(multiply_alpha)
- {
- const qint8x8_t alpha_qs8 = vdup_n_qs8(sqcvt_qs8_f32(alpha, fixed_point_position));
- acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
- acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
- acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
- acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
- }
-
- const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
-
- // Store 8x4 output elements
- vst1_qs8(mtx_out0 + 0, acc00_qs8);
- vst1_qs8(mtx_out0 + 8, acc01_qs8);
- vst1_qs8(mtx_out0 + 16, acc02_qs8);
- vst1_qs8(mtx_out0 + 24, acc03_qs8);
- },
- ina, inb, out);
-}
-
-template <bool multiply_alpha>
-void vector_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha)
-{
- const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
- const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
- const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
- const int fixed_point_position = input0->info()->fixed_point_position();
-
- // The implementation computes 16 elements per iteration
- const int window_start_x = 16 * info.thread_id;
- const int window_step_x = 16 * info.num_threads;
- // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
- const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
- ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
-
- Window win_out(window);
- win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
- win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
-
- Window win_a(window);
- win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
- win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
-
- Window win_b;
- // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
- // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
- if(input1->info()->num_dimensions() >= 3)
- {
- win_b = window;
- }
- win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
- win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
-
- Iterator ina(input0, win_a);
- Iterator inb(input1, win_b);
- Iterator out(output, win_out);
-
- execute_window_loop(win_out, [&](const Coordinates & id)
- {
- if(id.x() > width_matrix_b)
- {
- return;
- }
-
- // Reset accumulators
- qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
- qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
- qint32x4_t acc02_qs32 = vdupq_n_qs32(0);
- qint32x4_t acc03_qs32 = vdupq_n_qs32(0);
-
- auto vec_a = reinterpret_cast<const qint16_t *>(ina.ptr());
- auto matrix_b = reinterpret_cast<const qint16_t *>(inb.ptr());
-
- auto vec_a_end_addr = vec_a + num_elems_vec_a;
- for(; vec_a <= (vec_a_end_addr - 2);)
- {
- const qint16x4_t a0 = vld1_dup_qs16(vec_a + 0);
- const qint16x4_t a1 = vld1_dup_qs16(vec_a + 1);
-
- const qint16x4_t b00 = vld1_qs16(matrix_b + 0 + 0 * in_b_stride);
- const qint16x4_t b01 = vld1_qs16(matrix_b + 4 + 0 * in_b_stride);
- const qint16x4_t b02 = vld1_qs16(matrix_b + 8 + 0 * in_b_stride);
- const qint16x4_t b03 = vld1_qs16(matrix_b + 12 + 0 * in_b_stride);
- const qint16x4_t b10 = vld1_qs16(matrix_b + 0 + 1 * in_b_stride);
- const qint16x4_t b11 = vld1_qs16(matrix_b + 4 + 1 * in_b_stride);
- const qint16x4_t b12 = vld1_qs16(matrix_b + 8 + 1 * in_b_stride);
- const qint16x4_t b13 = vld1_qs16(matrix_b + 12 + 1 * in_b_stride);
-
- // First accumulation
- acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
- acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
- acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
- acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
-
- // Second accumulation
- acc00_qs32 = vqmlal_qs16(acc00_qs32, b10, a1, fixed_point_position);
- acc01_qs32 = vqmlal_qs16(acc01_qs32, b11, a1, fixed_point_position);
- acc02_qs32 = vqmlal_qs16(acc02_qs32, b12, a1, fixed_point_position);
- acc03_qs32 = vqmlal_qs16(acc03_qs32, b13, a1, fixed_point_position);
-
- vec_a += 2;
- matrix_b += 2 * in_b_stride;
- }
-
- for(; vec_a < vec_a_end_addr;)
- {
- const qint16x4_t a0 = vld1_dup_qs16(vec_a);
-
- const qint16x4_t b00 = vld1_qs16(matrix_b + 0);
- const qint16x4_t b01 = vld1_qs16(matrix_b + 4);
- const qint16x4_t b02 = vld1_qs16(matrix_b + 8);
- const qint16x4_t b03 = vld1_qs16(matrix_b + 12);
-
- acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
- acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
- acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
- acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
-
- vec_a += 1;
- matrix_b += in_b_stride;
- }
-
- // Convert back to qint16x4_t and saturate
- qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
- qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
- qint16x4_t acc02_qs16 = vqmovn_qs32(acc02_qs32);
- qint16x4_t acc03_qs16 = vqmovn_qs32(acc03_qs32);
-
- // Multiply by the weight of the matrix product (alpha)
- if(multiply_alpha)
- {
- const qint16x4_t alpha_qs16 = vdup_n_qs16(sqcvt_qs16_f32(alpha, fixed_point_position));
- acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
- acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
- acc02_qs16 = vqmul_qs16(acc02_qs16, alpha_qs16, fixed_point_position);
- acc03_qs16 = vqmul_qs16(acc03_qs16, alpha_qs16, fixed_point_position);
- }
-
- const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
-
- // Store 16x4 output elements
- vst1_qs16(mtx_out0 + 0, acc00_qs16);
- vst1_qs16(mtx_out0 + 4, acc01_qs16);
- vst1_qs16(mtx_out0 + 8, acc02_qs16);
- vst1_qs16(mtx_out0 + 12, acc03_qs16);
- },
- ina, inb, out);
-}
-
-template <bool multiply_alpha>
void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
{
const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
@@ -1063,361 +807,13 @@
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
}
-template <bool multiply_alpha>
-void matrix_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
-{
- const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
- const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
- const size_t out_stride2 = out_stride1 * 2;
- const size_t out_stride3 = out_stride1 * 3;
- const int num_elems_matrix_b_x = input1->info()->dimension(0);
- const int fixed_point_position = input0->info()->fixed_point_position();
- const qint8x8_t alpha_qs8 = vdup_n_qs8(sqcvt_qs8_f32(alpha, fixed_point_position));
- ARM_COMPUTE_UNUSED(alpha_qs8);
-
- // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
- Window win_a(window);
- win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
- win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
-
- Window win_b;
- // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
- // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
- if(input1->info()->num_dimensions() >= 3)
- {
- win_b = window;
- }
- // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
- // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 16x4
- win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, 2 * in_b_stride));
- win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
-
- Iterator ina(input0, win_a);
- Iterator inb(input1, win_b);
- Iterator out(output, window);
-
- // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
- // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
- // All the values needed for computing a single 32x4 block will be read from consecutive memory positions
- execute_window_loop(window, [&](const Coordinates & id)
- {
- auto mtx_a0 = reinterpret_cast<const qint8_t *>(ina.ptr());
- auto mtx_b0 = reinterpret_cast<const qint8_t *>(inb.ptr());
- auto mtx_b1 = mtx_b0 + in_b_stride;
-
- qint16x8_t acc00_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc10_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc20_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc30_qs16 = vdupq_n_qs16(0);
-
- qint16x8_t acc01_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc11_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc21_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc31_qs16 = vdupq_n_qs16(0);
-
- qint16x8_t acc02_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc12_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc22_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc32_qs16 = vdupq_n_qs16(0);
-
- qint16x8_t acc03_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc13_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc23_qs16 = vdupq_n_qs16(0);
- qint16x8_t acc33_qs16 = vdupq_n_qs16(0);
-
- int k = 0;
- // This for loop performs 2 accumulations
- for(; k <= (num_elems_matrix_b_x - 32); k += 32)
- {
- const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
- const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
- const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
- const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
- const qint8x8_t a4 = vld1_dup_qs8(mtx_a0 + 4);
- const qint8x8_t a5 = vld1_dup_qs8(mtx_a0 + 5);
- const qint8x8_t a6 = vld1_dup_qs8(mtx_a0 + 6);
- const qint8x8_t a7 = vld1_dup_qs8(mtx_a0 + 7);
-
- const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
- const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
- const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
- const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
-
- // First accumulation
- acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
- acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
- acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
- acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
- acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
- acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
- acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
- acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
-
- const qint8x8_t b02 = vld1_qs8(mtx_b0 + 16);
- const qint8x8_t b03 = vld1_qs8(mtx_b0 + 24);
- const qint8x8_t b12 = vld1_qs8(mtx_b1 + 16);
- const qint8x8_t b13 = vld1_qs8(mtx_b1 + 24);
-
- acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
- acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
- acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
- acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
- acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
- acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
- acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
- acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
-
-#if __arm__
- asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
- asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
- asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
-#endif /* __arm__ */
-
- // Second accumulation
- acc00_qs16 = vqmlal_qs8(acc00_qs16, b02, a4, fixed_point_position);
- acc10_qs16 = vqmlal_qs8(acc10_qs16, b02, a5, fixed_point_position);
- acc20_qs16 = vqmlal_qs8(acc20_qs16, b02, a6, fixed_point_position);
- acc30_qs16 = vqmlal_qs8(acc30_qs16, b02, a7, fixed_point_position);
- acc01_qs16 = vqmlal_qs8(acc01_qs16, b03, a4, fixed_point_position);
- acc11_qs16 = vqmlal_qs8(acc11_qs16, b03, a5, fixed_point_position);
- acc21_qs16 = vqmlal_qs8(acc21_qs16, b03, a6, fixed_point_position);
- acc31_qs16 = vqmlal_qs8(acc31_qs16, b03, a7, fixed_point_position);
- acc02_qs16 = vqmlal_qs8(acc02_qs16, b12, a4, fixed_point_position);
- acc12_qs16 = vqmlal_qs8(acc12_qs16, b12, a5, fixed_point_position);
- acc22_qs16 = vqmlal_qs8(acc22_qs16, b12, a6, fixed_point_position);
- acc32_qs16 = vqmlal_qs8(acc32_qs16, b12, a7, fixed_point_position);
- acc03_qs16 = vqmlal_qs8(acc03_qs16, b13, a4, fixed_point_position);
- acc13_qs16 = vqmlal_qs8(acc13_qs16, b13, a5, fixed_point_position);
- acc23_qs16 = vqmlal_qs8(acc23_qs16, b13, a6, fixed_point_position);
- acc33_qs16 = vqmlal_qs8(acc33_qs16, b13, a7, fixed_point_position);
-
- mtx_a0 += 8;
- mtx_b0 += 32;
- mtx_b1 += 32;
- }
-
- // This for loop performs the left over accumulations
- for(; k < num_elems_matrix_b_x; k += 16)
- {
- const qint8x8_t a0 = vld1_dup_qs8(mtx_a0 + 0);
- const qint8x8_t a1 = vld1_dup_qs8(mtx_a0 + 1);
- const qint8x8_t a2 = vld1_dup_qs8(mtx_a0 + 2);
- const qint8x8_t a3 = vld1_dup_qs8(mtx_a0 + 3);
-
- const qint8x8_t b00 = vld1_qs8(mtx_b0 + 0);
- const qint8x8_t b01 = vld1_qs8(mtx_b0 + 8);
- const qint8x8_t b10 = vld1_qs8(mtx_b1 + 0);
- const qint8x8_t b11 = vld1_qs8(mtx_b1 + 8);
-
- acc00_qs16 = vqmlal_qs8(acc00_qs16, b00, a0, fixed_point_position);
- acc10_qs16 = vqmlal_qs8(acc10_qs16, b00, a1, fixed_point_position);
- acc20_qs16 = vqmlal_qs8(acc20_qs16, b00, a2, fixed_point_position);
- acc30_qs16 = vqmlal_qs8(acc30_qs16, b00, a3, fixed_point_position);
- acc01_qs16 = vqmlal_qs8(acc01_qs16, b01, a0, fixed_point_position);
- acc11_qs16 = vqmlal_qs8(acc11_qs16, b01, a1, fixed_point_position);
- acc21_qs16 = vqmlal_qs8(acc21_qs16, b01, a2, fixed_point_position);
- acc31_qs16 = vqmlal_qs8(acc31_qs16, b01, a3, fixed_point_position);
- acc02_qs16 = vqmlal_qs8(acc02_qs16, b10, a0, fixed_point_position);
- acc12_qs16 = vqmlal_qs8(acc12_qs16, b10, a1, fixed_point_position);
- acc22_qs16 = vqmlal_qs8(acc22_qs16, b10, a2, fixed_point_position);
- acc32_qs16 = vqmlal_qs8(acc32_qs16, b10, a3, fixed_point_position);
- acc03_qs16 = vqmlal_qs8(acc03_qs16, b11, a0, fixed_point_position);
- acc13_qs16 = vqmlal_qs8(acc13_qs16, b11, a1, fixed_point_position);
- acc23_qs16 = vqmlal_qs8(acc23_qs16, b11, a2, fixed_point_position);
- acc33_qs16 = vqmlal_qs8(acc33_qs16, b11, a3, fixed_point_position);
-
- mtx_a0 += 4;
- mtx_b0 += 16;
- mtx_b1 += 16;
- }
-
- // Convert back to qint8x8_t and saturate
- qint8x8_t acc00_qs8 = vqmovn_qs16(acc00_qs16);
- qint8x8_t acc10_qs8 = vqmovn_qs16(acc10_qs16);
- qint8x8_t acc20_qs8 = vqmovn_qs16(acc20_qs16);
- qint8x8_t acc30_qs8 = vqmovn_qs16(acc30_qs16);
-
- qint8x8_t acc01_qs8 = vqmovn_qs16(acc01_qs16);
- qint8x8_t acc11_qs8 = vqmovn_qs16(acc11_qs16);
- qint8x8_t acc21_qs8 = vqmovn_qs16(acc21_qs16);
- qint8x8_t acc31_qs8 = vqmovn_qs16(acc31_qs16);
-
- qint8x8_t acc02_qs8 = vqmovn_qs16(acc02_qs16);
- qint8x8_t acc12_qs8 = vqmovn_qs16(acc12_qs16);
- qint8x8_t acc22_qs8 = vqmovn_qs16(acc22_qs16);
- qint8x8_t acc32_qs8 = vqmovn_qs16(acc32_qs16);
-
- qint8x8_t acc03_qs8 = vqmovn_qs16(acc03_qs16);
- qint8x8_t acc13_qs8 = vqmovn_qs16(acc13_qs16);
- qint8x8_t acc23_qs8 = vqmovn_qs16(acc23_qs16);
- qint8x8_t acc33_qs8 = vqmovn_qs16(acc33_qs16);
-
- // Multiply by the weight of the matrix product (alpha)
- if(multiply_alpha)
- {
- acc00_qs8 = vqmul_qs8(acc00_qs8, alpha_qs8, fixed_point_position);
- acc10_qs8 = vqmul_qs8(acc10_qs8, alpha_qs8, fixed_point_position);
- acc20_qs8 = vqmul_qs8(acc20_qs8, alpha_qs8, fixed_point_position);
- acc30_qs8 = vqmul_qs8(acc30_qs8, alpha_qs8, fixed_point_position);
- acc01_qs8 = vqmul_qs8(acc01_qs8, alpha_qs8, fixed_point_position);
- acc11_qs8 = vqmul_qs8(acc11_qs8, alpha_qs8, fixed_point_position);
- acc21_qs8 = vqmul_qs8(acc21_qs8, alpha_qs8, fixed_point_position);
- acc31_qs8 = vqmul_qs8(acc31_qs8, alpha_qs8, fixed_point_position);
- acc02_qs8 = vqmul_qs8(acc02_qs8, alpha_qs8, fixed_point_position);
- acc12_qs8 = vqmul_qs8(acc12_qs8, alpha_qs8, fixed_point_position);
- acc22_qs8 = vqmul_qs8(acc22_qs8, alpha_qs8, fixed_point_position);
- acc32_qs8 = vqmul_qs8(acc32_qs8, alpha_qs8, fixed_point_position);
- acc03_qs8 = vqmul_qs8(acc03_qs8, alpha_qs8, fixed_point_position);
- acc13_qs8 = vqmul_qs8(acc13_qs8, alpha_qs8, fixed_point_position);
- acc23_qs8 = vqmul_qs8(acc23_qs8, alpha_qs8, fixed_point_position);
- acc33_qs8 = vqmul_qs8(acc33_qs8, alpha_qs8, fixed_point_position);
- }
-
- const auto mtx_out0 = reinterpret_cast<qint8_t *>(out.ptr());
-
- // Store 32x4 output elements
- vst1_qs8(mtx_out0 + 0, acc00_qs8);
- vst1_qs8(mtx_out0 + 8, acc01_qs8);
- vst1_qs8(mtx_out0 + 16, acc02_qs8);
- vst1_qs8(mtx_out0 + 24, acc03_qs8);
- vst1_qs8(mtx_out0 + out_stride1 + 0, acc10_qs8);
- vst1_qs8(mtx_out0 + out_stride1 + 8, acc11_qs8);
- vst1_qs8(mtx_out0 + out_stride1 + 16, acc12_qs8);
- vst1_qs8(mtx_out0 + out_stride1 + 24, acc13_qs8);
- vst1_qs8(mtx_out0 + out_stride2 + 0, acc20_qs8);
- vst1_qs8(mtx_out0 + out_stride2 + 8, acc21_qs8);
- vst1_qs8(mtx_out0 + out_stride2 + 16, acc22_qs8);
- vst1_qs8(mtx_out0 + out_stride2 + 24, acc23_qs8);
- vst1_qs8(mtx_out0 + out_stride3 + 0, acc30_qs8);
- vst1_qs8(mtx_out0 + out_stride3 + 8, acc31_qs8);
- vst1_qs8(mtx_out0 + out_stride3 + 16, acc32_qs8);
- vst1_qs8(mtx_out0 + out_stride3 + 24, acc33_qs8);
- },
- ina, inb, out);
-}
-
-template <bool multiply_alpha>
-void matrix_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
-{
- const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
- const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
- const size_t out_stride2 = out_stride1 * 2;
- const size_t out_stride3 = out_stride1 * 3;
- const int num_elems_matrix_b_x = input1->info()->dimension(0);
- const int fixed_point_position = input0->info()->fixed_point_position();
- const qint16x4_t alpha_qs16 = vdup_n_qs16(sqcvt_qs16_f32(alpha, fixed_point_position));
- ARM_COMPUTE_UNUSED(alpha_qs16);
-
- // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
- Window win_a(window);
- win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
- win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
-
- Window win_b;
- // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
- // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
- if(input1->info()->num_dimensions() >= 3)
- {
- win_b = window;
- }
- // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
- win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
- win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
-
- Iterator ina(input0, win_a);
- Iterator inb(input1, win_b);
- Iterator out(output, window);
-
- // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
- // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 8x4 elements per iteration
- // All the values needed for computing a single 8x4 block will be read from consecutive memory positions
- execute_window_loop(window, [&](const Coordinates & id)
- {
- auto mtx_a0 = reinterpret_cast<const qint16_t *>(ina.ptr());
- auto mtx_b0 = reinterpret_cast<const qint16_t *>(inb.ptr());
- auto mtx_b1 = mtx_b0 + in_b_stride;
-
- qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
- qint32x4_t acc10_qs32 = vdupq_n_qs32(0);
- qint32x4_t acc20_qs32 = vdupq_n_qs32(0);
- qint32x4_t acc30_qs32 = vdupq_n_qs32(0);
-
- qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
- qint32x4_t acc11_qs32 = vdupq_n_qs32(0);
- qint32x4_t acc21_qs32 = vdupq_n_qs32(0);
- qint32x4_t acc31_qs32 = vdupq_n_qs32(0);
-
- // This for loop performs 1 accumulation
- for(int k = 0; k <= (num_elems_matrix_b_x - 8); k += 8)
- {
- const qint16x4_t a0 = vld1_dup_qs16(mtx_a0 + 0);
- const qint16x4_t a1 = vld1_dup_qs16(mtx_a0 + 1);
- const qint16x4_t a2 = vld1_dup_qs16(mtx_a0 + 2);
- const qint16x4_t a3 = vld1_dup_qs16(mtx_a0 + 3);
-
- const qint16x4_t b00 = vld1_qs16(mtx_b0 + 0);
- const qint16x4_t b01 = vld1_qs16(mtx_b0 + 4);
-
- acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
- acc10_qs32 = vqmlal_qs16(acc10_qs32, b00, a1, fixed_point_position);
- acc20_qs32 = vqmlal_qs16(acc20_qs32, b00, a2, fixed_point_position);
- acc30_qs32 = vqmlal_qs16(acc30_qs32, b00, a3, fixed_point_position);
- acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
- acc11_qs32 = vqmlal_qs16(acc11_qs32, b01, a1, fixed_point_position);
- acc21_qs32 = vqmlal_qs16(acc21_qs32, b01, a2, fixed_point_position);
- acc31_qs32 = vqmlal_qs16(acc31_qs32, b01, a3, fixed_point_position);
-
- mtx_a0 += 4;
- mtx_b0 += 8;
- mtx_b1 += 8;
- }
-
- // Convert back to qint16x4_t and saturate
- qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
- qint16x4_t acc10_qs16 = vqmovn_qs32(acc10_qs32);
- qint16x4_t acc20_qs16 = vqmovn_qs32(acc20_qs32);
- qint16x4_t acc30_qs16 = vqmovn_qs32(acc30_qs32);
-
- qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
- qint16x4_t acc11_qs16 = vqmovn_qs32(acc11_qs32);
- qint16x4_t acc21_qs16 = vqmovn_qs32(acc21_qs32);
- qint16x4_t acc31_qs16 = vqmovn_qs32(acc31_qs32);
-
- // Multiply by the weight of the matrix product (alpha)
- if(multiply_alpha)
- {
- acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
- acc10_qs16 = vqmul_qs16(acc10_qs16, alpha_qs16, fixed_point_position);
- acc20_qs16 = vqmul_qs16(acc20_qs16, alpha_qs16, fixed_point_position);
- acc30_qs16 = vqmul_qs16(acc30_qs16, alpha_qs16, fixed_point_position);
- acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
- acc11_qs16 = vqmul_qs16(acc11_qs16, alpha_qs16, fixed_point_position);
- acc21_qs16 = vqmul_qs16(acc21_qs16, alpha_qs16, fixed_point_position);
- acc31_qs16 = vqmul_qs16(acc31_qs16, alpha_qs16, fixed_point_position);
- }
-
- const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
-
- // Store 8x4 output elements
- vst1_qs16(mtx_out0 + 0, acc00_qs16);
- vst1_qs16(mtx_out0 + 4, acc01_qs16);
- vst1_qs16(mtx_out0 + out_stride1 + 0, acc10_qs16);
- vst1_qs16(mtx_out0 + out_stride1 + 4, acc11_qs16);
- vst1_qs16(mtx_out0 + out_stride2 + 0, acc20_qs16);
- vst1_qs16(mtx_out0 + out_stride2 + 4, acc21_qs16);
- vst1_qs16(mtx_out0 + out_stride3 + 0, acc30_qs16);
- vst1_qs16(mtx_out0 + out_stride3 + 4, acc31_qs16);
- },
- ina, inb, out);
-}
-
inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
{
ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input0);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
if(!is_interleaved)
{
@@ -1428,7 +824,6 @@
ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
}
}
else
@@ -1467,7 +862,6 @@
}
ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
}
}
@@ -1492,16 +886,6 @@
num_elems_processed_per_iteration_x = 16;
break;
}
- case DataType::QS8:
- {
- num_elems_processed_per_iteration_x = 32;
- break;
- }
- case DataType::QS16:
- {
- num_elems_processed_per_iteration_x = 16;
- break;
- }
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
@@ -1539,16 +923,6 @@
num_elems_processed_per_iteration_x = 8;
break;
}
- case DataType::QS8:
- {
- num_elems_processed_per_iteration_x = 32;
- break;
- }
- case DataType::QS16:
- {
- num_elems_processed_per_iteration_x = 8;
- break;
- }
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
@@ -1638,18 +1012,6 @@
vector_matrix_multiply_f32<false>(_input0, _input1, _output, window, info, _alpha);
break;
}
- case DataType::QS8:
- {
- multiply_alpha ? vector_matrix_multiply_qs8<true>(_input0, _input1, _output, window, info, _alpha) :
- vector_matrix_multiply_qs8<false>(_input0, _input1, _output, window, info, _alpha);
- break;
- }
- case DataType::QS16:
- {
- multiply_alpha ? vector_matrix_multiply_qs16<true>(_input0, _input1, _output, window, info, _alpha) :
- vector_matrix_multiply_qs16<false>(_input0, _input1, _output, window, info, _alpha);
- break;
- }
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
@@ -1675,18 +1037,6 @@
matrix_matrix_multiply_f32<false>(_input0, _input1, _output, window, _alpha);
break;
}
- case DataType::QS8:
- {
- multiply_alpha ? matrix_matrix_multiply_qs8<true>(_input0, _input1, _output, window, _alpha) :
- matrix_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
- break;
- }
- case DataType::QS16:
- {
- multiply_alpha ? matrix_matrix_multiply_qs16<true>(_input0, _input1, _output, window, _alpha) :
- matrix_matrix_multiply_qs16<false>(_input0, _input1, _output, window, _alpha);
- break;
- }
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
diff --git a/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
index c1e975e..2387869 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
@@ -39,6 +39,43 @@
using namespace arm_compute;
+namespace
+{
+Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(output, DataType::S32, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input0->data_type()) && (output->data_type() != DataType::S32));
+ ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_float(input0->data_type()) && (output->data_type() != DataType::F32));
+
+ ARM_COMPUTE_RETURN_ERROR_ON(input0->num_dimensions() == input1->num_dimensions());
+ ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(2) != input1->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(DataLayoutDimension::HEIGHT) != output->dimension(DataLayoutDimension::HEIGHT));
+ ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(DataLayoutDimension::WIDTH) != output->dimension(DataLayoutDimension::WIDTH));
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
+{
+ const unsigned int num_elems_read_per_iteration = 16 / input0->element_size();
+
+ Window win = calculate_max_window(*input0, Steps(num_elems_read_per_iteration));
+
+ AccessWindowHorizontal input0_access(input0, 0, num_elems_read_per_iteration);
+ AccessWindowHorizontal input1_access(input1, 0, num_elems_read_per_iteration);
+ AccessWindowStatic output_access(output, 0, 0, output->dimension(0), output->dimension(1));
+
+ bool window_changed = update_window_and_padding(win, input0_access, input1_access, output_access);
+
+ output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
+} // namespace
+
template <typename I0, typename I1, typename O>
void NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply(const Window &window_in, const Window &window_w, const Window &window_out)
{
@@ -175,11 +212,9 @@
void NEGEMMMatrixVectorMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
- ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input0->info()->data_type()) && (output->info()->data_type() != DataType::S32));
- ARM_COMPUTE_ERROR_ON(input0->info()->dimension(2) != input1->info()->dimension(1));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
+
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info()));
_input0 = input0;
_input1 = input1;
@@ -204,17 +239,17 @@
const unsigned int border_x = ceil_to_multiple(input0->info()->dimension(0), num_elems_read_per_iteration) - input0->info()->dimension(0);
_border_size = BorderSize(0, border_x);
- Window win = calculate_max_window(*input0->info(), Steps(num_elems_read_per_iteration));
+ auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ INEKernel::configure(win_config.second);
+}
- AccessWindowHorizontal input0_access(input0->info(), 0, num_elems_read_per_iteration);
- AccessWindowHorizontal input1_access(input1->info(), 0, num_elems_read_per_iteration);
- AccessWindowStatic output_access(output->info(), 0, 0, output->info()->dimension(0), output->info()->dimension(1));
-
- update_window_and_padding(win, input0_access, input1_access, output_access);
-
- _output->info()->set_valid_region(ValidRegion(Coordinates(), _output->info()->tensor_shape()));
-
- INEKernel::configure(win);
+Status NEGEMMMatrixVectorMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first);
+ return Status{};
}
void NEGEMMMatrixVectorMultiplyKernel::run(const Window &window, const ThreadInfo &info)
diff --git a/src/core/NEON/kernels/NEGEMMTranspose1xWKernel.cpp b/src/core/NEON/kernels/NEGEMMTranspose1xWKernel.cpp
index 5d6163d..2e14e7a 100644
--- a/src/core/NEON/kernels/NEGEMMTranspose1xWKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMTranspose1xWKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2018 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,17 +54,16 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::U8, DataType::S8,
- DataType::QS16, DataType::U16, DataType::S16, DataType::U32, DataType::S32,
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::U8, DataType::S8,
+ DataType::U16, DataType::S16, DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), get_output_shape(input));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -102,7 +101,7 @@
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Output tensor auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), get_output_shape(input->info()), 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), get_output_shape(input->info()), 1, input->info()->data_type());
// Perform validate step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info()));
diff --git a/src/core/NEON/kernels/NEHarrisCornersKernel.cpp b/src/core/NEON/kernels/NEHarrisCornersKernel.cpp
index 14fa1b4..5e1c216 100644
--- a/src/core/NEON/kernels/NEHarrisCornersKernel.cpp
+++ b/src/core/NEON/kernels/NEHarrisCornersKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,10 +41,6 @@
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template class arm_compute::NEHarrisScoreFP16Kernel<3>;
-template class arm_compute::NEHarrisScoreFP16Kernel<5>;
-template class arm_compute::NEHarrisScoreFP16Kernel<7>;
-
namespace fp16
{
inline float16x8_t harris_score(float16x8_t gx2, float16x8_t gy2, float16x8_t gxgy, float sensitivity, float strength_thresh)
@@ -361,6 +357,10 @@
INEKernel::configure(win);
}
+template class arm_compute::NEHarrisScoreFP16Kernel<3>;
+template class arm_compute::NEHarrisScoreFP16Kernel<5>;
+template class arm_compute::NEHarrisScoreFP16Kernel<7>;
+
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
template class arm_compute::NEHarrisScoreKernel<3>;
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp
index 86e3fd7..98b1488 100644
--- a/src/core/NEON/kernels/NEIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp
@@ -23,8 +23,8 @@
*/
#include "arm_compute/core/NEON/kernels/NEIm2ColKernel.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/Size2D.h"
@@ -45,33 +45,31 @@
namespace
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
- bool has_bias, bool is_fully_connected, bool is_flatten, const Size2D &dilation)
+ bool has_bias, const Size2D &dilation, unsigned int num_groups, bool is_fully_connected, bool is_flatten)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::QASYMM8 && has_bias);
ARM_COMPUTE_RETURN_ERROR_ON((dilation.x() < 1) || (dilation.y() < 1));
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Number of groups greater than one are not supported on NEON");
- TensorShape expected_output_shape;
- if(is_flatten) /* Called by FlattenLayer */
+ if(output->total_size() > 0)
{
- expected_output_shape = misc::shape_calculator::compute_im2col_flatten_shape(input);
- }
- else if(!is_fully_connected) /* Called by ConvolutionLayer */
- {
- expected_output_shape = misc::shape_calculator::compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation);
- }
- else /* Called by FullyConnectedLayer */
- {
- const int num_batch_dimensions = std::max(0, static_cast<int>(output->tensor_shape().num_dimensions()) - 1);
- const int num_input_dimensions = input->tensor_shape().num_dimensions() - num_batch_dimensions;
+ TensorShape expected_output_shape;
- expected_output_shape = misc::shape_calculator::compute_im2col_fc_shape(input, num_input_dimensions);
- }
+ if(is_flatten || is_fully_connected)
+ {
+ expected_output_shape = misc::shape_calculator::compute_flatten_shape(input);
+ }
+ else
+ {
+ expected_output_shape = misc::shape_calculator::compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, false);
+ }
- TensorInfo expected_output = output->clone()->set_tensor_shape(expected_output_shape);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output, output);
+ TensorInfo expected_output = output->clone()->set_tensor_shape(expected_output_shape);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
return Status{};
}
@@ -90,7 +88,6 @@
int input_stride_x,
int input_stride_y,
int input_stride_z,
- int fixed_point_position,
int pad_value,
int dilation_x,
int dilation_y)
@@ -171,18 +168,7 @@
// Append 1 if the convolution layer has biases
if(has_bias)
{
- if(std::is_same<T, qint8_t>::value)
- {
- *out_ptr = sqcvt_qs8_f32(1.0f, fixed_point_position);
- }
- else if(std::is_same<T, qint16_t>::value)
- {
- *out_ptr = sqcvt_qs16_f32(1.0f, fixed_point_position);
- }
- else
- {
- *out_ptr = static_cast<T>(1);
- }
+ *out_ptr = static_cast<T>(1);
}
}
} // namespace
@@ -251,7 +237,6 @@
input_stride_x,
input_stride_y,
input_stride_z,
- _input->info()->fixed_point_position(),
offset,
_dilation.x(),
_dilation.y());
@@ -294,18 +279,7 @@
// Add bias
if(_has_bias)
{
- if(std::is_same<T, qint8_t>::value)
- {
- *(reinterpret_cast<T *>(out_ptr) + out_width - 1) = sqcvt_qs8_f32(1.0f, _input->info()->fixed_point_position());
- }
- else if(std::is_same<T, qint16_t>::value)
- {
- *(reinterpret_cast<T *>(out_ptr) + out_width - 1) = sqcvt_qs16_f32(1.0f, _input->info()->fixed_point_position());
- }
- else
- {
- *(reinterpret_cast<T *>(out_ptr) + out_width - 1) = static_cast<T>(1);
- }
+ *(reinterpret_cast<T *>(out_ptr) + out_width - 1) = static_cast<T>(1);
}
}
while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_1D(out_slice));
@@ -317,13 +291,14 @@
}
void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
- bool has_bias, bool is_fully_connected, bool is_flatten, const Size2D &dilation)
+ bool has_bias, const Size2D &dilation, unsigned int num_groups, bool is_fully_connected, bool is_flatten)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Perform validation step
ARM_COMPUTE_UNUSED(is_fully_connected, is_flatten);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten, dilation));
+ ARM_COMPUTE_UNUSED(num_groups);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, dilation, num_groups, is_fully_connected, is_flatten));
const DataLayout data_layout = input->info()->data_layout();
const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
@@ -366,12 +341,6 @@
_func = &NEIm2ColKernel::run_reduced<float16_t>;
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- case DataType::QS8:
- _func = &NEIm2ColKernel::run_reduced<qint8_t>;
- break;
- case DataType::QS16:
- _func = &NEIm2ColKernel::run_reduced<qint16_t>;
- break;
case DataType::QASYMM8:
_func = &NEIm2ColKernel::run_reduced<qasymm8_t>;
break;
@@ -392,12 +361,6 @@
_func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_generic<float16_t, false> : &NEIm2ColKernel::run_generic<float16_t, true>;
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- case DataType::QS8:
- _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_generic<qint8_t, false> : &NEIm2ColKernel::run_generic<qint8_t, true>;
- break;
- case DataType::QS16:
- _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_generic<qint16_t, false> : &NEIm2ColKernel::run_generic<qint16_t, true>;
- break;
case DataType::QASYMM8:
_func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_generic<qasymm8_t, false> : &NEIm2ColKernel::run_generic<qasymm8_t, true>;
break;
@@ -417,9 +380,9 @@
}
Status NEIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
- bool has_bias, bool is_fully_connected, bool is_flatten, const Size2D &dilation)
+ bool has_bias, const Size2D &dilation, unsigned int num_groups, bool is_fully_connected, bool is_flatten)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten, dilation));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, dilation, num_groups, is_fully_connected, is_flatten));
return Status{};
}
diff --git a/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp b/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
index 91776d8..ed03783 100644
--- a/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
@@ -103,7 +103,7 @@
Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
// Output auto initialization if not yet initialized
- auto_init_if_empty(*output, input->tensor_shape(), 1, input->data_type(), input->fixed_point_position());
+ auto_init_if_empty(*output, input->tensor_shape(), 1, input->data_type());
AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
AccessWindowHorizontal sum_access(sum, 0, num_elems_processed_per_iteration_sum);
diff --git a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
index 099626d..4d3ec46 100644
--- a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.h"
#include "arm_compute/core/AccessWindowTranspose.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
@@ -305,7 +306,7 @@
Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input0);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
diff --git a/src/core/NEON/kernels/NEMagnitudePhaseKernel.cpp b/src/core/NEON/kernels/NEMagnitudePhaseKernel.cpp
index 2d7c29d..4a318f0 100644
--- a/src/core/NEON/kernels/NEMagnitudePhaseKernel.cpp
+++ b/src/core/NEON/kernels/NEMagnitudePhaseKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,386 +51,6 @@
constexpr float COEFF2 = 0.2447f;
} // namespace
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-namespace fp16
-{
-inline float16x8_t inv(float16x8_t x)
-{
- const float16x8_t estimate = vrecpeq_f16(x);
- return vmulq_f16(estimate, vrecpsq_f16(x, estimate));
-}
-
-inline float16x8_t atan2_fast(float16x8_t gx, float16x8_t gy, float16x8_t scale)
-{
- static const float16x8_t one = vdupq_n_f16(1.0f);
- static const float16x8_t ninety = vdupq_n_f16(90.f * SCALE_FACTOR);
- static const float16x8_t epsilon = vdupq_n_f16(1e-9f);
- static const float16x8_t piover4 = vdupq_n_f16(PI_4);
- static const float16x8_t coeff1 = vdupq_n_f16(COEFF1);
- static const float16x8_t coeff2 = vdupq_n_f16(COEFF2);
-
- const float16x8_t abs_gx = vabsq_f16(gx);
- const float16x8_t abs_gy = vabsq_f16(gy);
- const float16x8_t tmin = vminq_f16(abs_gx, abs_gy);
- const float16x8_t tmax = vmaxq_f16(abs_gx, abs_gy);
-
- // z = min(x, y) / max(x, y)
- const float16x8_t z = vmulq_f16(tmin, inv(vaddq_f16(tmax, epsilon)));
- const float16x8_t absz = vabsq_f16(z);
-
- // = x * [pi/4 + (1 - |x|) * (0.2447 + 0.0663 * |x|)]
- float16x8_t arctan = vmulq_f16(z, vfmaq_f16(piover4,
- vsubq_f16(one, absz),
- vfmaq_f16(coeff2, coeff1, absz)));
-
- // Radians to degrees conversion with applied a scale factor in order to have the result [0, 255]
- arctan = vmulq_f16(arctan, scale);
-
- /* If z > 1, result = 90 - result */
- return vbslq_f16(vcgeq_f16(abs_gx, abs_gy), arctan, vsubq_f16(ninety, arctan));
-}
-
-inline float16x8_t atan2_0_360(float16x8_t gx, float16x8_t gy)
-{
- static const float16x8_t scale = vdupq_n_f16(SCALE_360);
- static const float16x8_t threesixty = vdupq_n_f16(360.0f * SCALE_FACTOR);
- static const float16x8_t zero = vdupq_n_f16(0.0f);
- static const float16x8_t oneeighty = vdupq_n_f16(180.0f * SCALE_FACTOR);
-
- float16x8_t arctan = atan2_fast(gx, gy, scale);
-
- // Choose correct quadrant
- arctan = vbslq_f16(vcltq_f16(gx, zero), vsubq_f16(oneeighty, arctan), arctan);
- arctan = vbslq_f16(vcltq_f16(gy, zero), vsubq_f16(threesixty, arctan), arctan);
-
- return arctan;
-}
-
-inline float16x8_t atan2_0_180(float16x8_t gx, float16x8_t gy)
-{
- static const float16x8_t scale = vdupq_n_f16(SCALE_180);
- static const float16x8_t threesixty = vdupq_n_f16(360.0f * SCALE_FACTOR);
- static const float16x8_t oneeighty = vdupq_n_f16(180.0f * SCALE_FACTOR);
- static const float16x8_t zero = vdupq_n_f16(0.0f);
-
- float16x8_t arctan = atan2_fast(gx, gy, scale);
-
- // Choose correct quadrant
- arctan = vbslq_f16(vcltq_f16(gx, zero), vsubq_f16(oneeighty, arctan), arctan);
- arctan = vbslq_f16(vcltq_f16(gy, zero), vsubq_f16(threesixty, arctan), arctan);
- arctan = vbslq_f16(vcgtq_f16(arctan, oneeighty), vsubq_f16(arctan, oneeighty), arctan);
-
- return arctan;
-}
-
-inline float32x4_t invsqrtv(float32x4_t x)
-{
- float32x4_t sqrt_reciprocal = vrsqrteq_f32(x);
-
- sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
- sqrt_reciprocal);
- sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
- sqrt_reciprocal);
-
- return sqrt_reciprocal;
-}
-
-inline float32x4_t sqrtv(float32x4_t x)
-{
- float32x4_t res = vdupq_n_f32(0.5f);
- return vmlaq_f32(res, x, invsqrtv(x));
-}
-
-inline int16x8_t magnitude_l1(int16x8_t input1, int16x8_t input2)
-{
- return vqaddq_s16(vqabsq_s16(input1), vqabsq_s16(input2));
-}
-
-inline int16x8_t magnitude_l2(int16x8_t input1, int16x8_t input2)
-{
- const int32x4x2_t square_x =
- {
- vmull_s16(vget_low_s16(input1), vget_low_s16(input1)),
- vmull_s16(vget_high_s16(input1), vget_high_s16(input1))
- };
-
- const int32x4x2_t square_y =
- {
- vmull_s16(vget_low_s16(input2), vget_low_s16(input2)),
- vmull_s16(vget_high_s16(input2), vget_high_s16(input2))
- };
-
- const uint32x4x2_t sum =
- {
- vaddq_u32(vreinterpretq_u32_s32(square_x.val[0]),
- vreinterpretq_u32_s32(square_y.val[0])),
- vaddq_u32(vreinterpretq_u32_s32(square_x.val[1]),
- vreinterpretq_u32_s32(square_y.val[1]))
- };
-
- const float32x4x2_t res =
- {
- sqrtv(vcvtq_f32_u32(sum.val[0])),
- sqrtv(vcvtq_f32_u32(sum.val[1]))
- };
-
- return vcombine_s16(vqmovn_s32(vcvtq_s32_f32(res.val[0])),
- vqmovn_s32(vcvtq_s32_f32(res.val[1])));
-}
-
-inline uint8x8_t phase_signed(int16x8_t input1, int16x8_t input2)
-{
- static const float16x8_t zeropointfive = vdupq_n_f16(0.5f);
-
- const float16x8_t inputx_f16 = vcvtq_f16_s16(input1);
- const float16x8_t inputy_f16 = vcvtq_f16_s16(input2);
-
- // Compute fast atan2
- const float16x8_t angle = atan2_0_360(inputx_f16, inputy_f16);
-
- return vqmovun_s16(vcvtq_s16_f16(vaddq_f16(angle, zeropointfive)));
-}
-
-inline uint8x8_t phase_unsigned(int16x8_t input1, int16x8_t input2)
-{
- static const float16x8_t zeropointfive = vdupq_n_f16(0.5f);
-
- const float16x8_t inputx_f16 = vcvtq_f16_s16(input1);
- const float16x8_t inputy_f16 = vcvtq_f16_s16(input2);
-
- // Compute fast atan2
- const float16x8_t angle = atan2_0_180(inputx_f16, inputy_f16);
-
- return vqmovun_s16(vcvtq_s16_f16(vaddq_f16(angle, zeropointfive)));
-}
-
-template <MagnitudeType mag_type>
-inline int16x8x2_t compute_magnitude(const int16x8x2_t &in0, const int16x8x2_t &gx);
-
-template <>
-inline int16x8x2_t compute_magnitude<MagnitudeType::L2NORM>(const int16x8x2_t &in0, const int16x8x2_t &gx)
-{
- const int16x8x2_t mag =
- {
- magnitude_l2(in0.val[0], gx.val[0]),
- magnitude_l2(in0.val[1], gx.val[1])
- };
-
- return mag;
-}
-
-template <>
-inline int16x8x2_t compute_magnitude<MagnitudeType::L1NORM>(const int16x8x2_t &in0, const int16x8x2_t &gx)
-{
- const int16x8x2_t mag =
- {
- magnitude_l1(in0.val[0], gx.val[0]),
- magnitude_l1(in0.val[1], gx.val[1])
- };
-
- return mag;
-}
-
-template <PhaseType phase_type>
-inline uint8x16_t compute_phase(const int16x8x2_t &in0, const int16x8x2_t &gx);
-
-template <>
-inline uint8x16_t compute_phase<PhaseType::SIGNED>(const int16x8x2_t &in0, const int16x8x2_t &gx)
-{
- return vcombine_u8(phase_signed(in0.val[0], gx.val[0]),
- phase_signed(in0.val[1], gx.val[1]));
-}
-
-template <>
-inline uint8x16_t compute_phase<PhaseType::UNSIGNED>(const int16x8x2_t &in0, const int16x8x2_t &gx)
-{
- return vcombine_u8(phase_unsigned(in0.val[0], gx.val[0]),
- phase_unsigned(in0.val[1], gx.val[1]));
-}
-} // namespace fp16
-
-template <MagnitudeType mag_type, PhaseType phase_type>
-NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::NEMagnitudePhaseFP16Kernel()
- : _func(nullptr), _gx(nullptr), _gy(nullptr), _magnitude(nullptr), _phase(nullptr)
-{
-}
-
-template <MagnitudeType mag_type, PhaseType phase_type>
-void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase)
-{
- ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(gx, Format::S16);
- ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(gy, Format::S16);
- ARM_COMPUTE_ERROR_ON((nullptr == magnitude) && (nullptr == phase));
-
- const bool run_mag = magnitude != nullptr;
- const bool run_phase = phase != nullptr;
-
- if(run_mag)
- {
- ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(magnitude, Format::S16);
- }
-
- if(run_phase)
- {
- ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(phase, Format::U8);
- }
-
- _gx = gx;
- _gy = gy;
- _magnitude = magnitude;
- _phase = phase;
-
- if(run_mag && run_phase)
- {
- /* Run magnitude and phase */
- _func = &NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude_phase;
- }
- else if(run_mag)
- {
- /* Run magnitude */
- _func = &NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude;
- }
- else if(run_phase)
- {
- /* Run phase */
- _func = &NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::phase;
- }
- else
- {
- ARM_COMPUTE_ERROR("At least one output must be NOT NULL");
- }
-
- const unsigned int num_elems_processed_per_iteration = 16;
-
- // Configure kernel window
- Window win = calculate_max_window(*gx->info(), Steps(num_elems_processed_per_iteration));
- AccessWindowHorizontal magnitude_access(magnitude == nullptr ? nullptr : magnitude->info(), 0, num_elems_processed_per_iteration);
- AccessWindowHorizontal phase_access(phase == nullptr ? nullptr : phase->info(), 0, num_elems_processed_per_iteration);
-
- update_window_and_padding(win,
- AccessWindowHorizontal(gx->info(), 0, num_elems_processed_per_iteration),
- AccessWindowHorizontal(gy->info(), 0, num_elems_processed_per_iteration),
- magnitude_access,
- phase_access);
-
- ValidRegion valid_region = intersect_valid_regions(gx->info()->valid_region(),
- gy->info()->valid_region());
-
- magnitude_access.set_valid_region(win, valid_region);
- phase_access.set_valid_region(win, valid_region);
-
- INEKernel::configure(win);
-}
-
-template <MagnitudeType mag_type, PhaseType phase_type>
-void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude(const Window &window)
-{
- Iterator gx(_gx, window);
- Iterator gy(_gy, window);
- Iterator magnitude(_magnitude, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const int16x8x2_t input1 =
- {
- vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
- vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
- };
-
- const int16x8x2_t input2 =
- {
- vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
- vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
- };
-
- // Compute and store magnitude
- const int16x8x2_t mag = fp16::compute_magnitude<mag_type>(input1, input2);
-
- /* Store magnitude */
- vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
- vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
- },
- gx, gy, magnitude);
-}
-
-template <MagnitudeType mag_type, PhaseType phase_type>
-void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::phase(const Window &window)
-{
- Iterator gx(_gx, window);
- Iterator gy(_gy, window);
- Iterator phase(_phase, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const int16x8x2_t input1 =
- {
- vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
- vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
- };
-
- const int16x8x2_t input2 =
- {
- vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
- vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
- };
-
- // Compute and store phase
- vst1q_u8(phase.ptr(), fp16::compute_phase<phase_type>(input1, input2));
- },
- gx, gy, phase);
-}
-
-template <MagnitudeType mag_type, PhaseType phase_type>
-void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude_phase(const Window &window)
-{
- Iterator gx(_gx, window);
- Iterator gy(_gy, window);
- Iterator magnitude(_magnitude, window);
- Iterator phase(_phase, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const int16x8x2_t input1 =
- {
- vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
- vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
- };
-
- const int16x8x2_t input2 =
- {
- vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
- vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
- };
-
- // Compute and store magnitude
- const int16x8x2_t mag = fp16::compute_magnitude<mag_type>(input1, input2);
-
- vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
- vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
-
- // Compute and store phase
- vst1q_u8(phase.ptr(), fp16::compute_phase<phase_type>(input1, input2));
- },
- gx, gy, magnitude, phase);
-}
-
-template <MagnitudeType mag_type, PhaseType phase_type>
-void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::run(const Window &window, const ThreadInfo &info)
-{
- ARM_COMPUTE_UNUSED(info);
- ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
- ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- ARM_COMPUTE_ERROR_ON(_func == nullptr);
-
- (this->*_func)(window);
-}
-
-template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L1NORM, PhaseType::SIGNED>;
-template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L2NORM, PhaseType::SIGNED>;
-template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L1NORM, PhaseType::UNSIGNED>;
-template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L2NORM, PhaseType::UNSIGNED>;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
namespace
{
inline float32x4_t inv(float32x4_t x)
diff --git a/src/core/NEON/kernels/NEMinMaxLayerKernel.cpp b/src/core/NEON/kernels/NEMinMaxLayerKernel.cpp
index 434f4eb..5d1b4b3 100644
--- a/src/core/NEON/kernels/NEMinMaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEMinMaxLayerKernel.cpp
@@ -68,7 +68,7 @@
TensorShape output_shape = compute_min_max_shape(input);
// Output auto initialization if not yet initialized
- auto_init_if_empty(*output, output_shape, 1, input->data_type(), input->fixed_point_position());
+ auto_init_if_empty(*output, output_shape, 1, input->data_type());
constexpr unsigned int num_elems_processed_per_iteration = 1;
@@ -147,7 +147,7 @@
execute_window_loop(window_input, [&](const Coordinates & id)
{
int x = x_start;
- const auto in_ptr = reinterpret_cast<const float *const>(input.ptr() + id_batch[1] * _input->info()->strides_in_bytes()[3]);
+ const auto in_ptr = reinterpret_cast<const float *>(input.ptr() + id_batch[1] * _input->info()->strides_in_bytes()[3]);
// Vector loop
for(; x <= x_end - 8; x += 8)
@@ -181,7 +181,7 @@
const float min_i = std::min(vget_lane_f32(carry_min, 0), carry_min_scalar);
const float max_i = std::max(vget_lane_f32(carry_max, 0), carry_max_scalar);
- auto out_ptr = reinterpret_cast<float *const>(output.ptr());
+ auto out_ptr = reinterpret_cast<float *>(output.ptr());
// Perform reduction of local min/max values
update_min_max(out_ptr, min_i, max_i);
@@ -205,7 +205,7 @@
execute_window_loop(window_output, [&](const Coordinates & id)
{
- vst1_f32(reinterpret_cast<float *const>(output.ptr()), reset_values);
+ vst1_f32(reinterpret_cast<float *>(output.ptr()), reset_values);
},
output);
}
diff --git a/src/core/NEON/kernels/NEMinMaxLocationKernel.cpp b/src/core/NEON/kernels/NEMinMaxLocationKernel.cpp
index b90e813..befece2 100644
--- a/src/core/NEON/kernels/NEMinMaxLocationKernel.cpp
+++ b/src/core/NEON/kernels/NEMinMaxLocationKernel.cpp
@@ -212,7 +212,7 @@
execute_window_loop(win, [&](const Coordinates & id)
{
int x = x_start;
- const auto in_ptr = reinterpret_cast<const int16_t *const>(input.ptr());
+ const auto in_ptr = reinterpret_cast<const int16_t *>(input.ptr());
// Vector loop
for(; x <= x_end - 16; x += 16)
@@ -271,7 +271,7 @@
execute_window_loop(win, [&](const Coordinates & id)
{
int x = x_start;
- const auto in_ptr = reinterpret_cast<const float *const>(input.ptr());
+ const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
// Vector loop
for(; x <= x_end - 8; x += 8)
diff --git a/src/core/NEON/kernels/NENormalizationLayerKernel.cpp b/src/core/NEON/kernels/NENormalizationLayerKernel.cpp
index 776cb27..fe6b69c 100644
--- a/src/core/NEON/kernels/NENormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NENormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/NEON/kernels/NENormalizationLayerKernel.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/NEON/NEFixedPoint.h"
@@ -39,26 +40,20 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *input_squared, const ITensorInfo *output, const NormalizationLayerInfo &norm_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_squared, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC && norm_info.type() == NormType::IN_MAP_2D,
+ "Only Cross-map and 1D In-map normalization is supported for NHWC layout");
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, input_squared);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, input_squared);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(norm_info.norm_size() % 2), "Normalization size should be odd");
- if(is_data_type_fixed_point(input->data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, input_squared);
- ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(norm_info.beta(), input);
- ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(norm_info.kappa(), input);
- ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(norm_info.scale_coeff(), input);
- }
-
// Checks performed when output is configured
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
@@ -68,8 +63,9 @@
{
unsigned int num_elems_processed_per_iteration = 16 / input->element_size();
const unsigned int num_elems_read_per_iteration = num_elems_processed_per_iteration + 2 * (norm_info.norm_size() / 2);
+ const unsigned int norm_idx = get_normalization_dimension_index(input->data_layout(), norm_info);
const unsigned int num_rows = (norm_info.type() == NormType::IN_MAP_2D) ? norm_info.norm_size() : 1;
- const unsigned int border_width = (norm_info.is_cross_map()) ? 0 : std::min<unsigned int>(norm_info.norm_size() / 2, 3U);
+ const unsigned int border_width = (norm_idx == 2) ? 0 : std::min<unsigned int>(norm_info.norm_size() / 2, 3U);
BorderSize border_size = BorderSize(0, border_width);
bool window_changed = false;
@@ -114,7 +110,8 @@
// Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), input_squared->info(), output->info(), norm_info));
- const unsigned int border_width = (norm_info.is_cross_map()) ? 0 : std::min<unsigned int>(norm_info.norm_size() / 2, 3U);
+ const unsigned int norm_idx = get_normalization_dimension_index(input->info()->data_layout(), norm_info);
+ const unsigned int border_width = (norm_idx == 2) ? 0 : std::min<unsigned int>(norm_info.norm_size() / 2, 3U);
_input = input;
_input_squared = input_squared;
@@ -126,16 +123,21 @@
{
case DataType::F32:
{
- switch(norm_info.type())
+ switch(norm_idx)
{
- case NormType::IN_MAP_1D:
- _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 0, false>;
+ case 0:
+ {
+ if(norm_info.type() == NormType::IN_MAP_2D)
+ {
+ _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 0, true>;
+ }
+ else
+ {
+ _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 0, false>;
+ }
break;
- case NormType::IN_MAP_2D:
- // Normalize over X and Y
- _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 0, true>;
- break;
- case NormType::CROSS_MAP:
+ }
+ case 2:
_func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 2, false>;
break;
default:
@@ -143,18 +145,24 @@
}
break;
}
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- switch(norm_info.type())
+ switch(norm_idx)
{
- case NormType::IN_MAP_1D:
- _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 0, false>;
+ case 0:
+ {
+ if(norm_info.type() == NormType::IN_MAP_2D)
+ {
+ _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 0, true>;
+ }
+ else
+ {
+ _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 0, false>;
+ }
break;
- case NormType::IN_MAP_2D:
- // Normalize over X and Y
- _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 0, true>;
- break;
- case NormType::CROSS_MAP:
+ }
+ case 2:
_func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 2, false>;
break;
default:
@@ -162,44 +170,7 @@
}
break;
}
- case DataType::QS8:
- {
- switch(norm_info.type())
- {
- case NormType::IN_MAP_1D:
- _func = &NENormalizationLayerKernel::normalize_fixed_point<DataType::QS8, 0, false>;
- break;
- case NormType::IN_MAP_2D:
- // Normalize over X and Y
- _func = &NENormalizationLayerKernel::normalize_fixed_point<DataType::QS8, 0, true>;
- break;
- case NormType::CROSS_MAP:
- _func = &NENormalizationLayerKernel::normalize_fixed_point<DataType::QS8, 2, false>;
- break;
- default:
- break;
- }
- break;
- }
- case DataType::QS16:
- {
- switch(norm_info.type())
- {
- case NormType::IN_MAP_1D:
- _func = &NENormalizationLayerKernel::normalize_fixed_point<DataType::QS16, 0, false>;
- break;
- case NormType::IN_MAP_2D:
- // Normalize over X and Y
- _func = &NENormalizationLayerKernel::normalize_fixed_point<DataType::QS16, 0, true>;
- break;
- case NormType::CROSS_MAP:
- _func = &NENormalizationLayerKernel::normalize_fixed_point<DataType::QS16, 2, false>;
- break;
- default:
- break;
- }
- break;
- }
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default:
ARM_COMPUTE_ERROR("NOT SUPPORTED!");
}
@@ -306,105 +277,6 @@
}
}
-template <DataType dt, unsigned int dim, bool do_2D_norm>
-void NENormalizationLayerKernel::normalize_fixed_point(const Window &window)
-{
- Iterator input(_input, window);
- Iterator input_squared(_input_squared, window);
- Iterator output(_output, window);
-
- const int dim_y = 1;
- const int radius = _norm_info.norm_size() / 2;
- const int total_size = _input->info()->dimension(dim) - 1;
- const int input_squared_stride = _input_squared->info()->strides_in_bytes()[dim];
- // We account padding across X only and we iterate over rows
- const int min_left = (dim == 2) ? 0 : -static_cast<int>(border_size().left);
- const int max_right = (dim == 2) ? total_size : total_size + border_size().left;
- const int min_top = 0;
- const int max_bottom = _input->info()->dimension(dim_y) - 1;
-
- const int fixed_point_position = _input->info()->fixed_point_position();
-
- if(dt == DataType::QS8)
- {
- const qint8x16_t coeff_vec = vdupq_n_qs8_f32(_norm_info.scale_coeff(), fixed_point_position);
- const qint8x16_t beta_vec = vdupq_n_qs8_f32(_norm_info.beta(), fixed_point_position);
- const qint8x16_t kappa_vec = vdupq_n_qs8_f32(_norm_info.kappa(), fixed_point_position);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- // Get range to normalize
- const int current_row = do_2D_norm ? id[dim_y] : 0;
- const int current_slice = id[dim];
- const int first_row = do_2D_norm ? std::max(current_row - radius, min_top) : 0;
- const int last_row = do_2D_norm ? std::min(current_row + radius, max_bottom) : 0;
- const int first_slice = std::max(current_slice - radius, min_left);
- const int last_slice = std::min(current_slice + radius, max_right);
-
- // Accumulate 2D In-Map values
- qint8x16_t accu = vdupq_n_qs8(0);
- for(int j = first_row; j <= last_row; ++j)
- {
- // Compute row displacement
- const int row = (j - current_row) * _input_squared->info()->strides_in_bytes()[dim_y];
- const uint8_t *const input_squared_ptr = input_squared.ptr() + row - (current_slice * input_squared_stride);
- for(int i = first_slice; i <= last_slice; ++i)
- {
- accu = vqaddq_qs8(accu, vld1q_qs8(reinterpret_cast<const qint8_t *>(input_squared_ptr + i * input_squared_stride)));
- }
- }
-
- // Normalize
- const qint8x16_t accu_scale = vqmlaq_qs8(kappa_vec, coeff_vec, accu, fixed_point_position);
- const qint8x16_t normalized = vqpowq_qs8(accu_scale, beta_vec, fixed_point_position);
- const qint8x16_t normalized_pixel = vdivq_qs8(vld1q_qs8(reinterpret_cast<const qint8_t *>(input.ptr())), normalized, fixed_point_position);
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), normalized_pixel);
- },
- input, input_squared, output);
- }
- else if(dt == DataType::QS16)
- {
- const qint16x8_t coeff_vec = vdupq_n_qs16_f32(_norm_info.scale_coeff(), fixed_point_position);
- const qint16x8_t beta_vec = vdupq_n_qs16_f32(_norm_info.beta(), fixed_point_position);
- const qint16x8_t kappa_vec = vdupq_n_qs16_f32(_norm_info.kappa(), fixed_point_position);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- // Get range to normalize
- const int current_row = do_2D_norm ? id[dim_y] : 0;
- const int current_slice = id[dim];
- const int first_row = do_2D_norm ? std::max(current_row - radius, min_top) : 0;
- const int last_row = do_2D_norm ? std::min(current_row + radius, max_bottom) : 0;
- const int first_slice = std::max(current_slice - radius, min_left);
- const int last_slice = std::min(current_slice + radius, max_right);
-
- // Accumulate 2D In-Map values
- qint16x8_t accu = vdupq_n_qs16(0);
- for(int j = first_row; j <= last_row; ++j)
- {
- // Compute row displacement
- const int row = (j - current_row) * _input_squared->info()->strides_in_bytes()[dim_y];
- const uint8_t *const input_squared_ptr = input_squared.ptr() + row - (current_slice * input_squared_stride);
- for(int i = first_slice; i <= last_slice; ++i)
- {
- accu = vqaddq_qs16(accu, vld1q_qs16(reinterpret_cast<const qint16_t *>(input_squared_ptr + i * input_squared_stride)));
- }
- }
-
- // Normalize
- const qint16x8_t accu_scale = vqmlaq_qs16(kappa_vec, coeff_vec, accu, fixed_point_position);
- const qint16x8_t normalized = vqpowq_qs16(accu_scale, beta_vec, fixed_point_position);
- const qint16x8_t normalized_pixel = vdivq_qs16(vld1q_qs16(reinterpret_cast<const qint16_t *>(input.ptr())), normalized, fixed_point_position);
- vst1q_qs16(reinterpret_cast<qint16_t *>(output.ptr()), normalized_pixel);
- },
- input, input_squared, output);
- }
- else
- {
- ARM_COMPUTE_ERROR("Not supported");
- }
-}
-
Status NENormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *input_squared, const ITensorInfo *output, const NormalizationLayerInfo norm_info)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, input_squared, output, norm_info));
diff --git a/src/core/NEON/kernels/NEPermuteKernel.cpp b/src/core/NEON/kernels/NEPermuteKernel.cpp
index ae1d48c..8d3fd88 100644
--- a/src/core/NEON/kernels/NEPermuteKernel.cpp
+++ b/src/core/NEON/kernels/NEPermuteKernel.cpp
@@ -45,8 +45,9 @@
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8,
- DataType::U16, DataType::S16, DataType::QS16,
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16,
DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG((perm.num_dimensions() == 3 && !(perm[0] == 2 && perm[1] == 0 && perm[2] == 1) && !(perm[0] == 1 && perm[1] == 2 && perm[2] == 0)),
@@ -59,7 +60,6 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index 193ca37..a4f5143 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
@@ -61,9 +62,10 @@
ARM_COMPUTE_UNUSED(overflow_policy);
ARM_COMPUTE_UNUSED(rounding_policy);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
"Output can only be U8 if both inputs are U8");
@@ -71,14 +73,6 @@
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
- if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type()))
- {
- // Check that all data types are the same and all fixed-point positions are the same
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
- // Check if scale is representable in fixed-point with the provided settings
- ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(scale, input1);
- }
-
if(std::abs(scale - scale255_constant) < 0.00001f)
{
ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
@@ -120,11 +114,6 @@
{
set_format_if_unknown(*output, Format::F16);
}
- else if(input1->data_type() == DataType::QS8 && input2->data_type() == DataType::QS8)
- {
- set_data_type_if_unknown(*output, DataType::QS8);
- set_fixed_point_position_if_zero(*output, input1->fixed_point_position());
- }
}
// Configure kernel window
@@ -220,105 +209,6 @@
}
template <bool is_scale255, bool is_sat>
-void mul_QS8_QS8_QS8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n, int fixed_point_position)
-{
- const auto output = static_cast<qint8_t *__restrict>(output_ptr);
-
- const qint8x16_t ta1 = vld1q_qs8(static_cast<const qint8_t *__restrict>(input1_ptr));
- const qint8x16_t ta2 = vld1q_qs8(static_cast<const qint8_t *__restrict>(input2_ptr));
-
- if(is_scale255)
- {
- qint16x8_t tmp1_high = vmovl_s8(vget_high_s8(ta1));
- qint16x8_t tmp1_low = vmovl_s8(vget_low_s8(ta1));
- const qint16x8_t tmp2_high = vmovl_s8(vget_high_s8(ta2));
- const qint16x8_t tmp2_low = vmovl_s8(vget_low_s8(ta2));
-
- const float32x4x2_t scale255_f32 =
- {
- {
- scale255_constant_f32q,
- scale255_constant_f32q
- }
- };
- const qint16x8_t scale255 = vqcvtq_qs16_f32(scale255_f32, fixed_point_position);
-
- tmp1_high = vmulq_qs16(tmp1_high, tmp2_high, fixed_point_position);
- tmp1_low = vmulq_qs16(tmp1_low, tmp2_low, fixed_point_position);
- tmp1_high = vmulq_qs16(tmp1_high, scale255, fixed_point_position);
- tmp1_low = vmulq_qs16(tmp1_low, scale255, fixed_point_position);
-
- if(is_sat)
- {
- vst1q_qs8(output, vcombine_s8(vqmovn_s16(tmp1_low), vqmovn_s16(tmp1_high)));
- }
- else
- {
- vst1q_qs8(output, vcombine_s8(vmovn_s16(tmp1_low), vmovn_s16(tmp1_high)));
- }
- }
- else
- {
- const qint8x16_t vn = vdupq_n_s8(-n);
- qint8x16_t res = ta2;
-
- if(is_sat)
- {
- res = vqshlq_s8(vqmulq_qs8(ta1, res, fixed_point_position), vn);
- }
- else
- {
- res = vshlq_s8(vmulq_qs8(ta1, res, fixed_point_position), vn);
- }
- vst1q_qs8(output, res);
- }
-}
-
-template <bool is_scale255, bool is_sat>
-void mul_QS16_QS16_QS16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n, int fixed_point_position)
-{
- const qint16x8x2_t ta1 = vld2q_qs16(static_cast<const qint16_t *__restrict>(input1_ptr));
- qint16x8x2_t res = vld2q_qs16(static_cast<const qint16_t *__restrict>(input2_ptr));
-
- if(is_scale255)
- {
- const float32x4x2_t scale255_f32 =
- {
- {
- scale255_constant_f32q,
- scale255_constant_f32q
- }
- };
- const qint16x8_t scale255 = vqcvtq_qs16_f32(scale255_f32, fixed_point_position);
- if(is_sat)
- {
- res.val[0] = vqmulq_qs16(vqmulq_qs16(ta1.val[0], res.val[0], fixed_point_position), scale255, fixed_point_position);
- res.val[1] = vqmulq_qs16(vqmulq_qs16(ta1.val[1], res.val[1], fixed_point_position), scale255, fixed_point_position);
- }
- else
- {
- res.val[0] = vmulq_qs16(vmulq_qs16(ta1.val[0], res.val[0], fixed_point_position), scale255, fixed_point_position);
- res.val[1] = vmulq_qs16(vmulq_qs16(ta1.val[1], res.val[1], fixed_point_position), scale255, fixed_point_position);
- }
- }
- else
- {
- const qint16x8_t vn = vdupq_n_s16(-n);
- if(is_sat)
- {
- res.val[0] = vqshlq_s16(vqmulq_qs16(ta1.val[0], res.val[0], fixed_point_position), vn);
- res.val[1] = vqshlq_s16(vqmulq_qs16(ta1.val[1], res.val[1], fixed_point_position), vn);
- }
- else
- {
- res.val[0] = vshlq_s16(vmulq_qs16(ta1.val[0], res.val[0], fixed_point_position), vn);
- res.val[1] = vshlq_s16(vmulq_qs16(ta1.val[1], res.val[1], fixed_point_position), vn);
- }
- }
- vst2q_s16(static_cast<qint16_t *__restrict>(output_ptr), res);
-}
-
-template <bool is_scale255, bool is_sat>
inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &input1, const int16x8_t &input2, int n)
{
int32x4_t tmp1_high = vmovl_s16(vget_high_s16(input1));
@@ -529,7 +419,7 @@
} // namespace
NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
- : _func_float(nullptr), _func_int(nullptr), _func_q_int(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
+ : _func_float(nullptr), _func_int(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
{
}
@@ -550,7 +440,6 @@
_scale = scale;
_scale_exponent = 0;
_func_int = nullptr;
- _func_q_int = nullptr;
_func_float = nullptr;
bool is_scale_255 = false;
@@ -630,28 +519,6 @@
_func_int = is_sat ? &mul_U8_U8_S16_n<false, true> : &mul_U8_U8_S16_n<false, false>;
}
}
- else if(DataType::QS8 == dt_input1 && DataType::QS8 == dt_input2 && DataType::QS8 == dt_output)
- {
- if(is_scale_255)
- {
- _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<true, true> : &mul_QS8_QS8_QS8_n<true, false>;
- }
- else
- {
- _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<false, true> : &mul_QS8_QS8_QS8_n<false, false>;
- }
- }
- else if(DataType::QS16 == dt_input1 && DataType::QS16 == dt_input2 && DataType::QS16 == dt_output)
- {
- if(is_scale_255)
- {
- _func_q_int = is_sat ? &mul_QS16_QS16_QS16_n<true, true> : &mul_QS16_QS16_QS16_n<true, false>;
- }
- else
- {
- _func_q_int = is_sat ? &mul_QS16_QS16_QS16_n<false, true> : &mul_QS16_QS16_QS16_n<false, false>;
- }
- }
else if(DataType::F16 == dt_input1 && DataType::F16 == dt_input2 && DataType::F16 == dt_output)
{
_func_float = &mul_F16_F16_F16_n<false, false>;
@@ -724,17 +591,6 @@
},
input1, input2, output);
}
- else if(_func_q_int != nullptr)
- {
- int fixed_point_position = _input1->info()->fixed_point_position();
- execute_window_loop(collapsed, [&](const Coordinates & id)
- {
- (*_func_q_int)(input1.ptr(), input2.ptr(), output.ptr(), _scale_exponent, fixed_point_position);
- collapsed.slide_window_slice_3D(slice_input1);
- collapsed.slide_window_slice_3D(slice_input2);
- },
- input1, input2, output);
- }
else
{
ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
@@ -753,4 +609,4 @@
const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
const unsigned int border = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
return BorderSize(0, border, 0, 0);
-}
\ No newline at end of file
+}
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index 7877cf5..ad4b8f7 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -24,8 +24,8 @@
#include "arm_compute/core/NEON/kernels/NEPoolingLayerKernel.h"
#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
-#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/NEAsymm.h"
@@ -79,32 +79,6 @@
return 1.f / ((end_y - start_y) * (end_x - start_x));
}
-inline qint8_t calculate_avg_scale_q8(const Coordinates &id, int pool_size, int upper_bound_w, int upper_bound_h,
- int pad_x, int pad_y, int stride_x, int stride_y, int fixed_point_position)
-{
- static const std::array<qint8_t, 10> scale_values_q8 =
- { { 0x0, 0x0, 0x40, 0x2A, 0x20, 0x19, 0x15, 0x12, 0x10, 0xE } };
- const int start_x = id.x() * stride_x - pad_x;
- const int start_y = id.y() * stride_y - pad_y;
- const int end_x = std::min(start_x + pool_size, upper_bound_w);
- const int end_y = std::min(start_y + pool_size, upper_bound_h);
- const int val = ((end_y - start_y) * (end_x - start_x));
- return sshr_qs8(scale_values_q8[val], (7 - fixed_point_position));
-}
-
-inline qint16_t calculate_avg_scale_q16(const Coordinates &id, int pool_size, int upper_bound_w, int upper_bound_h,
- int pad_x, int pad_y, int stride_x, int stride_y, int fixed_point_position)
-{
- static std::array<qint16_t, 10> scale_values_q16 =
- { { 0x0, 0x0, 0x4000, 0x2AAB, 0x2000, 0x199A, 0x1555, 0x1249, 0x1000, 0xE38 } };
- const int start_x = id.x() * stride_x - pad_x;
- const int start_y = id.y() * stride_y - pad_y;
- const int end_x = std::min(start_x + pool_size, upper_bound_w);
- const int end_y = std::min(start_y + pool_size, upper_bound_h);
- const int val = ((end_y - start_y) * (end_x - start_x));
- return sshr_qs16(scale_values_q16[val], (15 - fixed_point_position));
-}
-
template <bool exclude_padding>
inline void scale_vector_s16x8(uint16x8_t &v, const Coordinates &id, int id_offset, int step,
const int pool_size, const int upper_bound_w, const int upper_bound_h,
@@ -155,7 +129,7 @@
v = vsetq_lane_u16(elems[7], v, 7);
}
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info, unsigned int &pooled_w, unsigned int pooled_h, int pool_size_x)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info, unsigned int &pooled_w, unsigned int pooled_h)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
@@ -163,22 +137,15 @@
int pool_stride_y = 0;
PoolingType pool_type = pool_info.pool_type();
const PadStrideInfo pad_stride_info = pool_info.pad_stride_info();
- const bool exclude_padding = pool_info.exclude_padding();
std::tie(pool_stride_x, pool_stride_y) = pad_stride_info.stride();
- static const std::set<int> supported_pool_sizes = { 2, 3 };
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(pool_type == PoolingType::L2 && is_data_type_quantized(input->data_type()));
- ARM_COMPUTE_RETURN_ERROR_ON((supported_pool_sizes.find(pool_size_x) == supported_pool_sizes.end()) && ((input->data_type() != DataType::F32) && (input->data_type() != DataType::QASYMM8))
- && (pool_type != PoolingType::MAX));
- ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_fixed_point(input->data_type()) && pool_stride_x > 2);
- ARM_COMPUTE_RETURN_ERROR_ON(exclude_padding && is_data_type_fixed_point(input->data_type()));
-
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON((output->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH)) != pooled_w)
|| (output->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT)) != pooled_h));
@@ -236,22 +203,6 @@
{
switch(input->data_type())
{
- case DataType::QS8:
- num_elems_read_per_iteration = 16;
- switch(pool_size_x)
- {
- case 2:
- num_elems_horizontal_window = (pool_stride_x == 2) ? 8 : 16;
- num_elems_processed_per_iteration = (pool_stride_x == 2) ? 8 : 15;
- break;
- case 3:
- num_elems_horizontal_window = (pool_stride_x == 2) ? 8 : 16;
- num_elems_processed_per_iteration = (pool_stride_x == 2) ? 7 : 14;
- break;
- default:
- break;
- }
- break;
case DataType::QASYMM8:
if(is_nhwc)
{
@@ -274,22 +225,6 @@
break;
}
break;
- case DataType::QS16:
- num_elems_read_per_iteration = 8;
- switch(pool_size_x)
- {
- case 2:
- num_elems_horizontal_window = (pool_stride_x == 2) ? 4 : 8;
- num_elems_processed_per_iteration = (pool_stride_x == 2) ? 4 : 7;
- break;
- case 3:
- num_elems_horizontal_window = (pool_stride_x == 2) ? 4 : 8;
- num_elems_processed_per_iteration = (pool_stride_x == 2) ? 3 : 6;
- break;
- default:
- break;
- }
- break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
if(is_nhwc)
@@ -300,10 +235,6 @@
switch(pool_size_x)
{
case 2:
- num_elems_read_per_iteration = 16;
- num_elems_processed_per_iteration = 8;
- num_elems_horizontal_window = 8;
- break;
case 3:
num_elems_read_per_iteration = 4;
num_elems_processed_per_iteration = 1;
@@ -346,14 +277,8 @@
{
if(is_nhwc)
{
- if(DataType::QASYMM8 == input->data_type())
- {
- num_elems_processed_per_iteration = 8;
- }
- else
- {
- num_elems_processed_per_iteration = 4;
- }
+ const unsigned int vector_size = 16 / input->element_size();
+ num_elems_processed_per_iteration = (input->data_type() == DataType::QASYMM8) ? 8 : vector_size;
}
}
@@ -450,7 +375,7 @@
auto_init(input->info(), output->info(), pooled_w, pooled_h);
// Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), pool_info, pooled_w, pooled_h, pool_size_x));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), pool_info, pooled_w, pooled_h));
// Set instance variables
_input = input;
@@ -462,64 +387,7 @@
const DataType data_type = input->info()->data_type();
const bool is_nchw = data_layout == DataLayout::NCHW;
- // Select appropriate function
- if(data_type == DataType::QS8)
- {
- if(_is_square)
- {
- switch(pool_size_x)
- {
- case 2:
- switch(pool_type)
- {
- case PoolingType::AVG:
- _func = &NEPoolingLayerKernel::pooling2_q8_nchw<PoolingType::AVG>;
- break;
- case PoolingType::MAX:
- _func = &NEPoolingLayerKernel::pooling2_q8_nchw<PoolingType::MAX>;
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported pooling type!");
- }
- break;
- case 3:
- switch(pool_type)
- {
- case PoolingType::AVG:
- _func = &NEPoolingLayerKernel::pooling3_q8_nchw<PoolingType::AVG>;
- break;
- case PoolingType::MAX:
- _func = &NEPoolingLayerKernel::pooling3_q8_nchw<PoolingType::MAX>;
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported pooling type!");
- }
- break;
- default:
- switch(pool_type)
- {
- case PoolingType::MAX:
- _func = &NEPoolingLayerKernel::poolingMxN_q8_nchw<PoolingType::MAX>;
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported pooling type!");
- }
- break;
- }
- }
- else
- {
- switch(pool_type)
- {
- case PoolingType::MAX:
- _func = &NEPoolingLayerKernel::poolingMxN_q8_nchw<PoolingType::MAX>;
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported pooling type!");
- }
- }
- }
- else if(data_type == DataType::QASYMM8)
+ if(data_type == DataType::QASYMM8)
{
if(pool_size_x == 2 && pool_stride_x < 3 && _is_square)
{
@@ -606,62 +474,6 @@
}
}
}
- else if(data_type == DataType::QS16)
- {
- if(_is_square)
- {
- switch(pool_size_x)
- {
- case 2:
- switch(pool_type)
- {
- case PoolingType::AVG:
- _func = &NEPoolingLayerKernel::pooling2_q16_nchw<PoolingType::AVG>;
- break;
- case PoolingType::MAX:
- _func = &NEPoolingLayerKernel::pooling2_q16_nchw<PoolingType::MAX>;
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported pooling type!");
- }
- break;
- case 3:
- switch(pool_type)
- {
- case PoolingType::AVG:
- _func = &NEPoolingLayerKernel::pooling3_q16_nchw<PoolingType::AVG>;
- break;
- case PoolingType::MAX:
- _func = &NEPoolingLayerKernel::pooling3_q16_nchw<PoolingType::MAX>;
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported pooling type!");
- }
- break;
- default:
- switch(pool_type)
- {
- case PoolingType::MAX:
- _func = &NEPoolingLayerKernel::poolingMxN_q16_nchw<PoolingType::MAX>;
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported pooling type!");
- }
- break;
- }
- }
- else
- {
- switch(pool_type)
- {
- case PoolingType::MAX:
- _func = &NEPoolingLayerKernel::poolingMxN_q16_nchw<PoolingType::MAX>;
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported pooling type!");
- }
- }
- }
else if(data_type == DataType::F16)
{
if(_is_square)
@@ -1022,71 +834,6 @@
INEKernel::configure(win_config.second);
}
-template <PoolingType pooling_type>
-void NEPoolingLayerKernel::pooling2_q8_nchw(const Window &window_input, const Window &window)
-{
- Iterator input(_input, window_input);
- Iterator output(_output, window);
-
- const int fixed_point_position = _input->info()->fixed_point_position();
- constexpr int pool_size = 2;
- int pool_stride_x = 0;
- int pool_stride_y = 0;
- const int pool_pad_right = _pool_info.pad_stride_info().pad_right();
- const int pool_pad_top = _pool_info.pad_stride_info().pad_top();
- const int pool_pad_left = _pool_info.pad_stride_info().pad_left();
- const int pool_pad_bottom = _pool_info.pad_stride_info().pad_bottom();
- std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info().stride();
- const int upper_bound_w = _input->info()->dimension(0) + pool_pad_right;
- const int upper_bound_h = _input->info()->dimension(1) + pool_pad_bottom;
-
- const uint8_t *const input_top_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top)));
- const uint8_t *const input_bottom_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top) + 1));
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const auto top_data = vld1q_qs8(reinterpret_cast<const qint8_t *>(input_top_ptr + input.offset()));
- const auto bottom_data = vld1q_qs8(reinterpret_cast<const qint8_t *>(input_bottom_ptr + input.offset()));
- qint8x8_t lower_res = {};
- qint8x8_t upper_res = {};
- if(pooling_type == PoolingType::AVG)
- {
- // Calculate scale
- const qint8_t scale = calculate_avg_scale_q8(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x, pool_stride_y, fixed_point_position);
- const qint8x8_t scale_vec = vdup_n_qs8(scale);
-
- // Perform pooling
- const qint8x16_t sum_data = vqaddq_qs8(top_data, bottom_data);
- lower_res = vqmul_qs8(vpadd_s8(vget_low_s8(sum_data), vget_high_s8(sum_data)), scale_vec, fixed_point_position);
- if(pool_stride_x == 1)
- {
- const qint8x16_t sum_data_shifted = vextq_s8(sum_data, sum_data, 1);
- upper_res = vqmul_qs8(vpadd_s8(vget_low_s8(sum_data_shifted), vget_high_s8(sum_data_shifted)), scale_vec, fixed_point_position);
- }
- }
- else
- {
- const qint8x16_t max_data = vmaxq_s8(top_data, bottom_data);
- lower_res = vpmax_s8(vget_low_s8(max_data), vget_high_s8(max_data));
- if(pool_stride_x == 1)
- {
- const qint8x16_t max_data_shifted = vextq_s8(max_data, max_data, 1);
- upper_res = vpmax_s8(vget_low_s8(max_data_shifted), vget_high_s8(max_data_shifted));
- }
- }
- if(pool_stride_x == 1)
- {
- const qint8x8x2_t res = { { lower_res, upper_res } };
- vst2_s8(reinterpret_cast<qint8_t *>(output.ptr()), res);
- }
- else
- {
- vst1_qs8(reinterpret_cast<qint8_t *>(output.ptr()), lower_res);
- }
- },
- input, output);
-}
-
template <PoolingType pooling_type, bool exclude_padding>
void NEPoolingLayerKernel::pooling2_qasymm8_nchw(const Window &window_input, const Window &window)
{
@@ -1201,71 +948,6 @@
input, output);
}
-template <PoolingType pooling_type>
-void NEPoolingLayerKernel::pooling2_q16_nchw(const Window &window_input, const Window &window)
-{
- Iterator input(_input, window_input);
- Iterator output(_output, window);
-
- const int fixed_point_position = _input->info()->fixed_point_position();
- constexpr int pool_size = 2;
- const int pool_pad_right = _pool_info.pad_stride_info().pad_right();
- const int pool_pad_top = _pool_info.pad_stride_info().pad_top();
- const int pool_pad_left = _pool_info.pad_stride_info().pad_left();
- const int pool_pad_bottom = _pool_info.pad_stride_info().pad_bottom();
- int pool_stride_x = 0;
- int pool_stride_y = 0;
- std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info().stride();
- const int upper_bound_w = _input->info()->dimension(0) + pool_pad_right;
- const int upper_bound_h = _input->info()->dimension(1) + pool_pad_bottom;
-
- const unsigned char *const input_top_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top)));
- const unsigned char *const input_bottom_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top) + 1));
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const auto top_data = vld1q_qs16(reinterpret_cast<const qint16_t *>(input_top_ptr + input.offset()));
- const auto bottom_data = vld1q_qs16(reinterpret_cast<const qint16_t *>(input_bottom_ptr + input.offset()));
- qint16x4_t lower_res = {};
- qint16x4_t upper_res = {};
- if(pooling_type == PoolingType::AVG)
- {
- // Calculate scale
- const qint16_t scale = calculate_avg_scale_q16(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x, pool_stride_y, fixed_point_position);
- const qint16x4_t scale_vec = vdup_n_qs16(scale);
-
- // Perform pooling
- const qint16x8_t sum_data = vqaddq_qs16(top_data, bottom_data);
- lower_res = vqmul_qs16(vpadd_s16(vget_low_s16(sum_data), vget_high_s16(sum_data)), scale_vec, fixed_point_position);
- if(pool_stride_x == 1)
- {
- const qint16x8_t sum_data_shifted = vextq_s16(sum_data, sum_data, 1);
- upper_res = vqmul_qs16(vpadd_s16(vget_low_s16(sum_data_shifted), vget_high_s16(sum_data_shifted)), scale_vec, fixed_point_position);
- }
- }
- else
- {
- const qint16x8_t max_data = vmaxq_s16(top_data, bottom_data);
- lower_res = vpmax_s16(vget_low_s16(max_data), vget_high_s16(max_data));
- if(pool_stride_x == 1)
- {
- const qint16x8_t max_data_shifted = vextq_s16(max_data, max_data, 1);
- upper_res = vpmax_s16(vget_low_s16(max_data_shifted), vget_high_s16(max_data_shifted));
- }
- }
- if(pool_stride_x == 1)
- {
- const qint16x4x2_t res = { { lower_res, upper_res } };
- vst2_s16(reinterpret_cast<qint16_t *>(output.ptr()), res);
- }
- else
- {
- vst1_qs16(reinterpret_cast<qint16_t *>(output.ptr()), lower_res);
- }
- },
- input, output);
-}
-
template <PoolingType pooling_type, bool exclude_padding>
void NEPoolingLayerKernel::pooling3_f16_nchw(const Window &window_input, const Window &window)
{
@@ -1357,38 +1039,39 @@
execute_window_loop(window, [&](const Coordinates & id)
{
- auto top_data = vld2q_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
- auto bottom_data = vld2q_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
- float16x8_t res = {};
+ float16x4_t top_data = vld1_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
+ float16x4_t bottom_data = vld1_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
+ float16x4_t res = {};
// Get power of 2 in case of l2 pooling
if(pooling_type == PoolingType::L2)
{
- top_data.val[0] = vmulq_f16(top_data.val[0], top_data.val[0]);
- top_data.val[1] = vmulq_f16(top_data.val[1], top_data.val[1]);
- bottom_data.val[0] = vmulq_f16(bottom_data.val[0], bottom_data.val[0]);
- bottom_data.val[1] = vmulq_f16(bottom_data.val[1], bottom_data.val[1]);
+ top_data = vmul_f16(top_data, top_data);
+ bottom_data = vmul_f16(bottom_data, bottom_data);
}
if(pooling_type != PoolingType::MAX)
{
const float scale = calculate_avg_scale<exclude_padding, DataLayout::NCHW>(id, pool_size, pool_size, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x, pool_stride_y);
- const float16x8_t scale_v = vdupq_n_f16(scale);
- res = vmulq_f16(scale_v, vaddq_f16(bottom_data.val[1], vaddq_f16(bottom_data.val[0], vaddq_f16(top_data.val[0], top_data.val[1]))));
+ const float16x4_t scale_v = vdup_n_f16(scale);
+
+ const float16x4_t sum_data = vadd_f16(top_data, bottom_data);
+ res = vmul_f16(vpadd_f16(sum_data, sum_data), scale_v);
}
else
{
- res = vmaxq_f16(bottom_data.val[1], vmaxq_f16(bottom_data.val[0], vmaxq_f16(top_data.val[0], top_data.val[1])));
+ const float16x4_t max_data = vmax_f16(top_data, bottom_data);
+ res = vpmax_f16(max_data, max_data);
}
// Calculate square-root in case of l2 pooling
if(pooling_type == PoolingType::L2)
{
- res = vinvq_f16(vinvsqrtq_f16(res));
+ res = vinv_f16(vinvsqrt_f16(res));
}
// Store result
- vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), res);
+ *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(res, 0);
},
input, output);
#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -1461,82 +1144,6 @@
input, output);
}
-template <PoolingType pooling_type>
-void NEPoolingLayerKernel::pooling3_q8_nchw(const Window &window_input, const Window &window)
-{
- Iterator input(_input, window_input);
- Iterator output(_output, window);
-
- const int fixed_point_position = _input->info()->fixed_point_position();
- constexpr int pool_size = 3;
- const int pool_pad_right = _pool_info.pad_stride_info().pad_right();
- const int pool_pad_top = _pool_info.pad_stride_info().pad_top();
- const int pool_pad_left = _pool_info.pad_stride_info().pad_left();
- const int pool_pad_bottom = _pool_info.pad_stride_info().pad_bottom();
- int pool_stride_x = 0;
- int pool_stride_y = 0;
- std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info().stride();
- const int upper_bound_w = _input->info()->dimension(0) + pool_pad_right;
- const int upper_bound_h = _input->info()->dimension(1) + pool_pad_bottom;
-
- const uint8_t *const input_top_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top)));
- const uint8_t *const input_middle_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top) + 1));
- const uint8_t *const input_bottom_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top) + 2));
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const auto top_data = vld1q_qs8(reinterpret_cast<const qint8_t *>(input_top_ptr + input.offset()));
- const auto middle_data = vld1q_qs8(reinterpret_cast<const qint8_t *>(input_middle_ptr + input.offset()));
- const auto bottom_data = vld1q_qs8(reinterpret_cast<const qint8_t *>(input_bottom_ptr + input.offset()));
- qint8x8_t res = {};
- if(pooling_type == PoolingType::AVG)
- {
- // Calculate scale
- const qint8_t scale = calculate_avg_scale_q8(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x, pool_stride_y, fixed_point_position);
-
- // Perform pooling for stride 2
- const qint8x16_t sum_data = vqaddq_qs8(vqaddq_qs8(top_data, bottom_data), middle_data);
- const qint8x16_t sum_data2 = vextq_s8(sum_data, sum_data, 1);
- const qint8x16_t sum_data3 = vextq_s8(sum_data, sum_data, 2);
- const qint8x16_t final_sum = vqaddq_qs8(vqaddq_qs8(sum_data, sum_data2), sum_data3);
- if(pool_stride_x == 2)
- {
- const qint8x8x2_t table = { { vget_low_s8(final_sum), vget_high_s8(final_sum) } };
- static const qint8x8_t lookup_val = { 0, 2, 4, 6, 8, 10, 12, 14 };
- const qint8x8_t scale_vec = vdup_n_qs8(scale);
- res = vtbl2_s8(table, lookup_val);
- res = vqmul_qs8(res, scale_vec, fixed_point_position);
- vst1_qs8(reinterpret_cast<qint8_t *>(output.ptr()), res);
- }
- else
- {
- const qint8x16_t scale_vec = vdupq_n_qs8(scale);
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), vqmulq_qs8(final_sum, scale_vec, fixed_point_position));
- }
- }
- else
- {
- const qint8x16_t max_data = vmaxq_s8(vmaxq_s8(top_data, bottom_data), middle_data);
- const qint8x16_t max_data2 = vextq_s8(max_data, max_data, 1);
- const qint8x16_t max_data3 = vextq_s8(max_data, max_data, 2);
- const qint8x16_t final_max = vmaxq_s8(vmaxq_s8(max_data, max_data2), max_data3);
-
- if(pool_stride_x == 2)
- {
- const qint8x8x2_t table = { { vget_low_s8(final_max), vget_high_s8(final_max) } };
- static const qint8x8_t lookup_val = { 0, 2, 4, 6, 8, 10, 12, 14 };
- res = vtbl2_s8(table, lookup_val);
- vst1_qs8(reinterpret_cast<qint8_t *>(output.ptr()), res);
- }
- else
- {
- vst1q_qs8(reinterpret_cast<qint8_t *>(output.ptr()), final_max);
- }
- }
- },
- input, output);
-}
-
template <PoolingType pooling_type, bool exclude_padding>
void NEPoolingLayerKernel::pooling3_qasymm8_nchw(const Window &window_input, const Window &window)
{
@@ -1657,77 +1264,6 @@
input, output);
}
-template <PoolingType pooling_type>
-void NEPoolingLayerKernel::pooling3_q16_nchw(const Window &window_input, const Window &window)
-{
- Iterator input(_input, window_input);
- Iterator output(_output, window);
-
- const int fixed_point_position = _input->info()->fixed_point_position();
- constexpr int pool_size = 3;
- const int pool_pad_right = _pool_info.pad_stride_info().pad_right();
- const int pool_pad_top = _pool_info.pad_stride_info().pad_top();
- const int pool_pad_left = _pool_info.pad_stride_info().pad_left();
- const int pool_pad_bottom = _pool_info.pad_stride_info().pad_bottom();
- int pool_stride_x = 0;
- int pool_stride_y = 0;
- std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info().stride();
- const int upper_bound_w = _input->info()->dimension(0) + pool_pad_right;
- const int upper_bound_h = _input->info()->dimension(1) + pool_pad_bottom;
-
- const unsigned char *const input_top_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top)));
- const unsigned char *const input_middle_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top) + 1));
- const unsigned char *const input_bottom_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_left), -static_cast<int>(pool_pad_top) + 2));
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const auto top_data = vld1q_qs16(reinterpret_cast<const qint16_t *>(input_top_ptr + input.offset()));
- const auto middle_data = vld1q_qs16(reinterpret_cast<const qint16_t *>(input_middle_ptr + input.offset()));
- const auto bottom_data = vld1q_qs16(reinterpret_cast<const qint16_t *>(input_bottom_ptr + input.offset()));
-
- if(pooling_type == PoolingType::AVG)
- {
- // Calculate scale
- const qint16_t scale = calculate_avg_scale_q16(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x, pool_stride_y, fixed_point_position);
-
- // Perform pooling for stride 2
- const qint16x8_t sum_data = vqaddq_qs16(vqaddq_qs16(top_data, bottom_data), middle_data);
- const qint16x8_t sum_data2 = vextq_s16(sum_data, sum_data, 1);
- const qint16x8_t sum_data3 = vextq_s16(sum_data, sum_data, 2);
- const qint16x8_t final_sum = vqaddq_qs16(vqaddq_qs16(sum_data, sum_data2), sum_data3);
- if(pool_stride_x == 2)
- {
- const qint16x4_t tmp = { vgetq_lane_s16(final_sum, 0), vgetq_lane_s16(final_sum, 2), vgetq_lane_s16(final_sum, 4), vgetq_lane_s16(final_sum, 6) };
- const qint16x4_t scale_vec = vdup_n_qs16(scale);
- vst1_qs16(reinterpret_cast<qint16_t *>(output.ptr()), vqmul_qs16(tmp, scale_vec, fixed_point_position));
- }
- else
- {
- const qint16x8_t scale_vec = vdupq_n_qs16(scale);
- vst1q_qs16(reinterpret_cast<qint16_t *>(output.ptr()), vqmulq_qs16(final_sum, scale_vec, fixed_point_position));
- }
- }
- else
- {
- const qint16x8_t max_data = vmaxq_s16(vmaxq_s16(top_data, bottom_data), middle_data);
- const qint16x8_t max_data2 = vextq_s16(max_data, max_data, 1);
- const qint16x8_t max_data3 = vextq_s16(max_data, max_data, 2);
- const qint16x8_t final_max = vmaxq_s16(vmaxq_s16(max_data, max_data2), max_data3);
-
- if(pool_stride_x == 2)
- {
- const qint16x4_t tmp = { vgetq_lane_s16(final_max, 0), vgetq_lane_s16(final_max, 2), vgetq_lane_s16(final_max, 4), vgetq_lane_s16(final_max, 6) };
- vst1_qs16(reinterpret_cast<qint16_t *>(output.ptr()), tmp);
- }
- else
- {
- vst1q_qs16(reinterpret_cast<qint16_t *>(output.ptr()), final_max);
- }
- }
- },
- input, output);
-}
-
template <PoolingType pooling_type, bool exclude_padding>
void NEPoolingLayerKernel::pooling3_f32_nchw(const Window &window_input, const Window &window)
{
@@ -1879,110 +1415,6 @@
input, output);
}
-template <PoolingType pooling_type>
-void NEPoolingLayerKernel::poolingMxN_q8_nchw(const Window &window_input, const Window &window)
-{
- Iterator input(_input, window_input);
- Iterator output(_output, window);
-
- const int pool_size_x = _pool_info.is_global_pooling() ? _input->info()->tensor_shape().x() : _pool_info.pool_size().width;
- const int pool_size_y = _pool_info.is_global_pooling() ? _input->info()->tensor_shape().y() : _pool_info.pool_size().height;
- const int pool_pad_top = _pool_info.pad_stride_info().pad_top();
- const int pool_pad_left = _pool_info.pad_stride_info().pad_left();
- int pool_stride_x = 0;
- int pool_stride_y = 0;
- std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info().stride();
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- qint8x16_t vres = {};
- qint8_t res = {};
-
- //PoolingType::MAX
- for(int y = 0; y < pool_size_y; ++y)
- {
- int x = 0;
- for(; x <= (pool_size_x - 16); x += 16)
- {
- const qint8x16_t data = vld1q_qs8(reinterpret_cast<const qint8_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().x() +
- (y - pool_pad_top) * _input->info()->strides_in_bytes().y()));
- vres = vmaxq_s8(vres, data);
- }
-
- // Leftover for loop
- for(; x < pool_size_x; ++x)
- {
- qint8_t data = *(reinterpret_cast<const qint8_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().x() + (y - pool_pad_top) * _input->info()->strides_in_bytes().y()));
- res = std::max(res, data);
- }
- }
- //Reduce
- const qint8x8_t half_vres = vpmax_s8(vget_low_s8(vres), vget_high_s8(vres));
- res = std::max(res, vget_lane_s8(half_vres, 0));
- res = std::max(res, vget_lane_s8(half_vres, 1));
- res = std::max(res, vget_lane_s8(half_vres, 2));
- res = std::max(res, vget_lane_s8(half_vres, 3));
- res = std::max(res, vget_lane_s8(half_vres, 4));
- res = std::max(res, vget_lane_s8(half_vres, 5));
- res = std::max(res, vget_lane_s8(half_vres, 6));
- res = std::max(res, vget_lane_s8(half_vres, 7));
-
- // Store result
- *(reinterpret_cast<qint8_t *>(output.ptr())) = res;
- },
- input, output);
-}
-
-template <PoolingType pooling_type>
-void NEPoolingLayerKernel::poolingMxN_q16_nchw(const Window &window_input, const Window &window)
-{
- Iterator input(_input, window_input);
- Iterator output(_output, window);
-
- const int pool_size_x = _pool_info.is_global_pooling() ? _input->info()->tensor_shape().x() : _pool_info.pool_size().width;
- const int pool_size_y = _pool_info.is_global_pooling() ? _input->info()->tensor_shape().y() : _pool_info.pool_size().height;
- const int pool_pad_top = _pool_info.pad_stride_info().pad_top();
- const int pool_pad_left = _pool_info.pad_stride_info().pad_left();
- int pool_stride_x = 0;
- int pool_stride_y = 0;
- std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info().stride();
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- qint16x8_t vres = {};
- qint16_t res = {};
-
- //PoolingType::MAX
- for(int y = 0; y < pool_size_y; ++y)
- {
- int x = 0;
- for(; x <= (pool_size_x - 8); x += 8)
- {
- const qint16x8_t data = vld1q_qs16(reinterpret_cast<const qint16_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().x() +
- (y - pool_pad_top) * _input->info()->strides_in_bytes().y()));
- vres = vmaxq_s16(vres, data);
- }
-
- // Leftover for loop
- for(; x < pool_size_x; ++x)
- {
- qint16_t data = *(reinterpret_cast<const qint16_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().x() + (y - pool_pad_top) * _input->info()->strides_in_bytes().y()));
- res = std::max(res, data);
- }
- }
- //Reduce
- const qint16x4_t half_vres = vpmax_s16(vget_low_s16(vres), vget_high_s16(vres));
- res = std::max(res, vget_lane_s16(half_vres, 0));
- res = std::max(res, vget_lane_s16(half_vres, 1));
- res = std::max(res, vget_lane_s16(half_vres, 2));
- res = std::max(res, vget_lane_s16(half_vres, 3));
-
- // Store result
- *(reinterpret_cast<qint16_t *>(output.ptr())) = res;
- },
- input, output);
-}
-
template <PoolingType pooling_type, bool exclude_padding>
void NEPoolingLayerKernel::poolingMxN_f16_nchw(const Window &window_input, const Window &window)
{
@@ -2662,7 +2094,7 @@
pool_size_y,
pool_info.pad_stride_info());
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, pool_info, pooled_w, pooled_h, pool_size_x));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, pool_info, pooled_w, pooled_h));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), pool_info, num_elems_processed_per_iteration, border_size, pooled_w, pooled_h,
pool_size_x, pool_size_y)
.first);
@@ -2688,13 +2120,6 @@
unsigned int window_x_inc = 0;
switch(_input->info()->data_type())
{
- case DataType::QS8:
- case DataType::QS16:
- case DataType::F16:
- {
- window_x_inc = (pool_stride_x == 2) ? _num_elems_processed_per_iteration * 2 : _num_elems_processed_per_iteration;
- break;
- }
case DataType::QASYMM8:
{
window_x_inc = pool_stride_x;
@@ -2704,6 +2129,7 @@
}
break;
}
+ case DataType::F16:
case DataType::F32:
{
window_x_inc = pool_stride_x;
diff --git a/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp b/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
index ee23e76..b49400a 100644
--- a/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
@@ -54,7 +54,7 @@
std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *min_max)
{
// Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::U8, 0);
+ auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::U8);
constexpr unsigned int num_elems_processed_per_iteration = 8;
diff --git a/src/core/NEON/kernels/NEROIPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEROIPoolingLayerKernel.cpp
index a209a52..4d908db 100644
--- a/src/core/NEON/kernels/NEROIPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEROIPoolingLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,7 +51,7 @@
// Output auto inizialitation if not yet initialized
TensorShape output_shape(pool_info.pooled_width(), pool_info.pooled_height(), input->info()->dimension(2), rois->num_values());
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_ERROR_ON((output->info()->dimension(0) != pool_info.pooled_width()) || (output->info()->dimension(1) != pool_info.pooled_height()));
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index 30d42fa..30f21bb 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -134,7 +134,7 @@
const TensorShape output_shape = calculate_output_shape(input->tensor_shape(), axis);
// Output auto initialization if not yet initialized
- auto_init_if_empty(*output, output_shape, 1, input->data_type(), input->fixed_point_position());
+ auto_init_if_empty(*output, output_shape, 1, input->data_type());
unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
diff --git a/src/core/NEON/kernels/NEReshapeLayerKernel.cpp b/src/core/NEON/kernels/NEReshapeLayerKernel.cpp
index 45ba68d..8043e8b 100644
--- a/src/core/NEON/kernels/NEReshapeLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEReshapeLayerKernel.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/core/NEON/kernels/NEReshapeLayerKernel.h"
#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
@@ -32,7 +33,6 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Validate.h"
-#include <arm_neon.h>
#include <cstdint>
using namespace arm_compute;
@@ -59,11 +59,10 @@
void NEReshapeLayerKernel::configure(const ITensor *input, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::QS8, DataType::U16, DataType::S16, DataType::QS16,
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16,
DataType::U32, DataType::S32, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_ERROR_ON(input->info()->tensor_shape().total_size() != output->info()->tensor_shape().total_size());
_input = input;
@@ -94,12 +93,10 @@
case DataType::U8:
case DataType::S8:
case DataType::QASYMM8:
- case DataType::QS8:
reshape_tensor<uint8_t>(window, _input, _output);
break;
case DataType::U16:
case DataType::S16:
- case DataType::QS16:
case DataType::F16:
reshape_tensor<uint16_t>(window, _input, _output);
break;
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index d91efd2..3d19c1d 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
@@ -194,56 +195,7 @@
template <typename T>
T sqsub(T a, T b);
template <typename T>
-T sqmul(T a, T b, int fixed_point_position);
-
-#define DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(TYPET, TYPEU, TAGT, TAGU) \
- inline vec_8_byte_t<TYPET> vqsub(vec_8_byte_t<TYPET> a, vec_8_byte_t<TYPET> b) \
- { \
- return vqsub_##TAGT(a, b); \
- } \
- inline vec_8_byte_t<TYPEU> vqadd(vec_8_byte_t<TYPEU> a, vec_8_byte_t<TYPEU> b) \
- { \
- return vqadd_##TAGU(a, b); \
- } \
- inline vec_16_byte_t<TYPEU> vqadd(vec_16_byte_t<TYPEU> a, vec_16_byte_t<TYPEU> b) \
- { \
- return vqaddq_##TAGU(a, b); \
- } \
- inline vec_8_byte_t<TYPET> vqexp(vec_8_byte_t<TYPET> vec, int fixed_point_position) \
- { \
- return vqexp_q##TAGT(vec, fixed_point_position); \
- } \
- inline auto vmovl(vec_8_byte_t<TYPET> vec)->decltype(vmovl_##TAGT(vec)) \
- { \
- return vmovl_##TAGT(vec); \
- } \
- inline vec_16_byte_t<TYPET> vqrecip(vec_16_byte_t<TYPET> vec, int fixed_point_position) \
- { \
- return vqrecipq_q##TAGT(vec, fixed_point_position); \
- } \
- inline vec_16_byte_t<TYPET> vqmul(vec_16_byte_t<TYPET> a, vec_16_byte_t<TYPET> b, int fixed_point_position) \
- { \
- return vqmulq_q##TAGT(a, b, fixed_point_position); \
- } \
- template <> \
- inline TYPEU sqadd<TYPEU>(TYPEU a, TYPEU b) \
- { \
- return sqadd_q##TAGU(a, b); \
- } \
- inline TYPET sqexp(TYPET val, int fixed_point_position) \
- { \
- return sqexp_q##TAGT(val, fixed_point_position); \
- } \
- template <> \
- inline TYPET sqsub<TYPET>(TYPET a, TYPET b) \
- { \
- return sqsub_q##TAGT(a, b); \
- } \
- template <> \
- inline TYPET sqmul<TYPET>(TYPET a, TYPET b, int fixed_point_position) \
- { \
- return sqmul_q##TAGT(a, b, fixed_point_position); \
- }
+T sqmul(T a, T b);
#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG) \
inline vec_8_byte_t<TYPE> vadd(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
@@ -258,10 +210,6 @@
{ \
return vsubq_##TAG(a, b); \
} \
- inline vec_16_byte_t<TYPE> vexp(vec_16_byte_t<TYPE> vec) \
- { \
- return vexpq_##TAG(vec); \
- } \
inline vec_16_byte_t<TYPE> vmul_n(vec_16_byte_t<TYPE> vec, TYPE val) \
{ \
return vmulq_n_##TAG(vec, val); \
@@ -278,9 +226,6 @@
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32)
-DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(int8_t, int16_t, s8, s16)
-DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(int16_t, int32_t, s16, s32)
-
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16)
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -331,6 +276,25 @@
return res;
}
+float32x4_t vexp(const float32x4_t &vec)
+{
+ return vexpq_f32(vec);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+float16x8_t vexp(const float16x8_t &vec)
+{
+ float16x4x2_t res =
+ {
+ {
+ vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_low_f16(vec)))),
+ vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_high_f16(vec))))
+ }
+ };
+ return vcombine_f16(res.val[0], res.val[1]);
+}
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+
template <>
float32x4x4_t vdup_n<float32x4x4_t>(float val)
{
@@ -372,17 +336,13 @@
{
Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
{
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
-#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F32);
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
// Validate in case of configured output
if(output.total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
}
@@ -395,7 +355,7 @@
// Softmax across the x dimension
const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1);
// Output auto initialization if not yet initialized
- auto_init_if_empty(output, output_shape, 1, input.data_type(), input.fixed_point_position(), input.quantization_info());
+ auto_init_if_empty(output, output_shape, 1, input.data_type(), input.quantization_info());
// Configure kernel window
const int input_width = input.valid_region().shape.x();
@@ -447,7 +407,7 @@
const auto out_ptr = reinterpret_cast<T *>(output.ptr());
// Init max value
- auto vec_max = vdup_n<vec_16_byte_t<T>>(std::numeric_limits<T>::lowest());
+ auto vec_max = vdup_n<vec_16_byte_t<T>>(support::cpp11::lowest<T>());
// Loop over input row
for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max))
@@ -488,12 +448,6 @@
case DataType::QASYMM8:
_func = &logits_1d_max<qasymm8_t>;
break;
- case DataType::QS8:
- _func = &logits_1d_max<qint8_t>;
- break;
- case DataType::QS16:
- _func = &logits_1d_max<qint16_t>;
- break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
_func = &logits_1d_max<float16_t>;
@@ -543,19 +497,16 @@
Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
const ITensorInfo &output, const float beta, const ITensorInfo &tmp)
{
+ ARM_COMPUTE_UNUSED(beta);
// Check input
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
-#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F32);
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
// Check max
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &max);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
// Check output if configured
@@ -564,19 +515,14 @@
const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &output);
ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
}
- // Check beta
- ARM_COMPUTE_RETURN_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input.data_type()));
-
// Check tmp if configured
if(tmp.total_size() != 0)
{
const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &tmp);
// We could potentially reduce tmp memory if we could predict or make an assumption
// on the maximum number of threads that will run in parallel.
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
@@ -727,88 +673,6 @@
in_it, max_it, out_it);
}
-template <typename T, typename U>
-void logits_1d_softmax_fixed_point(const ITensor &in, const ITensor &max, void *const tmp,
- ITensor &out, const float /*beta*/, const Window &window)
-{
- const int start_x = in.info()->valid_region().anchor.x();
- const int input_width = in.info()->valid_region().shape.x();
-
- const int fixed_point_position = in.info()->fixed_point_position();
-
- Iterator in_it(&in, window);
- Iterator max_it(&max, window);
- Iterator out_it(&out, window);
-
- execute_window_loop(window, [&](const Coordinates &)
- {
- /* Get pointers */
- const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
- const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
- const auto tmp_ptr = reinterpret_cast<T *>(tmp);
-
- vec_16_byte_t<T> vec_sum_inversed;
-
- /* Compute exponentials and sum */
- {
- /* Get max value */
- const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
- const auto vec_max = vdup_n<vec_8_byte_t<T>>(max_val);
-
- /* Init sum to zero */
- auto vec_sum = vdup_n<vec_16_byte_t<U>>(0);
-
- /* Loop over row and compute exponentials and sum */
- int i = 0;
- constexpr int vec_size = vec_size_of(vec_sum);
- for(; i <= (input_width - vec_size); i += vec_size)
- {
- auto vec_elements = vld<vec_8_byte_t<T>>(in_ptr + i);
- vec_elements = vqsub(vec_elements, vec_max);
- vec_elements = vqexp(vec_elements, fixed_point_position);
- vec_sum = vqadd(vec_sum, vmovl(vec_elements));
- vst(tmp_ptr + i, vec_elements);
- }
- /* Reduce sum */
- const vec_8_byte_t<U> sum_8_byte = vqadd(vget_high(vec_sum), vget_low(vec_sum));
- U sum = reduce_add(sqadd<U>, sum_8_byte);
-
- /* Run remaining elements */
- for(; i < input_width; ++i)
- {
- T element = sqexp(sqsub(in_ptr[i], max_val), fixed_point_position);
- sum = sqadd<U>(sum, element);
- tmp_ptr[i] = element;
- }
-
- const auto qsum = utility::saturate_cast<T>(sum);
- vec_sum_inversed = vqrecip(vdup_n<vec_16_byte_t<T>>(qsum), fixed_point_position);
- }
-
- /* Normalize exponentials */
- {
- /* Loop over row and compute softmax */
- int i = 0;
- constexpr int vec_size = vec_size_of(vec_sum_inversed);
- for(; i <= (input_width - vec_size); i += vec_size)
- {
- const auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
- const vec_16_byte_t<T> normalized_value = vqmul(vec_in, vec_sum_inversed, fixed_point_position);
- vst(out_ptr + i, normalized_value);
- }
-
- const T sum_inversed = vget_lane<0>(vec_sum_inversed);
-
- /* Run remaining elements */
- for(; i < input_width; ++i)
- {
- out_ptr[i] = sqmul(tmp_ptr[i], sum_inversed, fixed_point_position);
- }
- }
- },
- in_it, max_it, out_it);
-}
-
template <typename T>
void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
ITensor &out, const float beta, const Window &window)
@@ -845,7 +709,7 @@
{
auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
vec_elements = vsub(vec_elements, vec_max);
- vec_elements = vexp(vmul_n(vec_elements, beta));
+ vec_elements = vexp(vmul_n(vec_elements, static_cast<T>(beta)));
vec_sum = vadd(vec_sum, vec_elements);
vst(tmp_ptr + i, vec_elements);
}
@@ -908,12 +772,6 @@
case DataType::QASYMM8:
_func = &logits_1d_softmax_qasymm8;
break;
- case DataType::QS8:
- _func = &logits_1d_softmax_fixed_point<qint8_t, qint16_t>;
- break;
- case DataType::QS16:
- _func = &logits_1d_softmax_fixed_point<qint16_t, qint32_t>;
- break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
_func = &logits_1d_softmax_float<float16_t>;
diff --git a/src/core/NEON/kernels/NETransposeKernel.cpp b/src/core/NEON/kernels/NETransposeKernel.cpp
index e6f3acc..7ac6cdb 100644
--- a/src/core/NEON/kernels/NETransposeKernel.cpp
+++ b/src/core/NEON/kernels/NETransposeKernel.cpp
@@ -74,7 +74,8 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8, DataType::U16, DataType::S16, DataType::QS16, DataType::U32, DataType::S32,
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16, DataType::U32, DataType::S32,
DataType::F16,
DataType::F32);
@@ -84,7 +85,6 @@
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
diff --git a/src/core/NEON/kernels/NEWarpKernel.cpp b/src/core/NEON/kernels/NEWarpKernel.cpp
index 0fa8278..d04bc07 100644
--- a/src/core/NEON/kernels/NEWarpKernel.cpp
+++ b/src/core/NEON/kernels/NEWarpKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,7 +45,7 @@
} // namespace
INEWarpKernel::INEWarpKernel()
- : _func(nullptr), _input(nullptr), _output(nullptr), _constant_border_value(0), _matrix(nullptr)
+ : _func(nullptr), _input(nullptr), _output(nullptr), _constant_border_value(0), _matrix()
{
}
@@ -64,11 +64,10 @@
(this->*_func)(window);
}
-void INEWarpKernel::configure(const ITensor *input, ITensor *output, const float *matrix, BorderMode border_mode, uint8_t constant_border_value)
+void INEWarpKernel::configure(const ITensor *input, ITensor *output, const std::array<float, 9> &matrix, BorderMode border_mode, uint8_t constant_border_value)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON(nullptr == matrix);
_matrix = matrix;
_constant_border_value = constant_border_value;
diff --git a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
index 3031a87..2c9ad92 100644
--- a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
+++ b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
@@ -105,14 +105,14 @@
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *biases, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
if(biases != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases);
ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 4) && (biases->num_dimensions() != 1));
ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 5) && (biases->num_dimensions() != 2));
ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() == 4) && (biases->dimension(0) != input->tensor_shape()[3]));
@@ -124,7 +124,6 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), get_output_shape(input, biases != nullptr));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
return Status{};
diff --git a/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp b/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp
new file mode 100644
index 0000000..1b38677
--- /dev/null
+++ b/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp
@@ -0,0 +1,125 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NEWidthConcatenateLayerKernel.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/IAccessWindow.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
+
+#include <cstdint>
+
+using namespace arm_compute;
+
+namespace
+{
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, unsigned int width_offset, ITensorInfo *output)
+{
+ const unsigned int num_elems_processed_per_iteration = 16 / output->element_size();
+
+ // The window needs to be based on input as we copy all the widths of input
+ Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
+ AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal output_access(output, width_offset, num_elems_processed_per_iteration);
+ bool window_changed = update_window_and_padding(win, input_access, output_access);
+
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
+
+Status validate_arguments(const ITensorInfo *input, unsigned int width_offset, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1,
+ DataType::U8, DataType::S8, DataType::QASYMM8,
+ DataType::U16, DataType::S16, DataType::F16,
+ DataType::U32, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) + width_offset > output->dimension(0));
+
+ for(size_t i = 1; i < Coordinates::num_max_dimensions; ++i)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(i) != output->dimension(i));
+ }
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 3);
+
+ return Status{};
+}
+} // namespace
+
+NEWidthConcatenateLayerKernel::NEWidthConcatenateLayerKernel()
+ : _input(nullptr), _output(nullptr), _width_offset(0)
+{
+}
+
+void NEWidthConcatenateLayerKernel::configure(const ITensor *input, unsigned int width_offset, ITensor *output)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), width_offset, output->info()));
+
+ _input = input;
+ _output = output;
+ _width_offset = width_offset;
+
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(input->info(), width_offset, output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
+
+ INEKernel::configure(std::get<1>(win_config));
+}
+
+Status NEWidthConcatenateLayerKernel::validate(const ITensorInfo *input, unsigned int width_offset, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, width_offset, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), width_offset, output->clone().get()).first);
+ return Status{};
+}
+
+void NEWidthConcatenateLayerKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+
+ // Offset output pointer to the correct position
+ uint8_t *output_ptr = _output->buffer() + _output->info()->offset_first_element_in_bytes() + _width_offset * _output->info()->strides_in_bytes()[0];
+
+ // Create iterators
+ Iterator input(_input, window);
+ Iterator output(_output, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ const auto in_ptr = input.ptr();
+ const auto out_ptr = output_ptr + output.offset();
+
+ wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr));
+ },
+ input, output);
+}
diff --git a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
index 672684d..3d7a16d 100644
--- a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
@@ -40,38 +40,6 @@
namespace
{
-Status validate_arguments_winograd_gemm(const ITensorInfo *a, const ITensorInfo *b, const ITensor *c, const ITensorInfo *output, const float alpha, const float beta,
- const GEMMInfo &gemm_info = GEMMInfo())
-{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(b);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
-
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
-
- if(c != nullptr)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c->info());
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->info()->dimension(1), "The matrix C must have the same number of rows as the matrix A");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->info()->dimension(0), "The matrix C must have the same number of columns as the matrix B");
- }
-
- if(output->total_size() != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != output->dimension(0), "The output matrix must have the same number of columns as the matrix B");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != output->dimension(1), "The output matrix must have the same number of rows as the matrix A");
- ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() != a->num_dimensions());
- }
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
- ARM_COMPUTE_UNUSED(alpha, beta);
- return Status{};
-}
-
Status validate_arguments_winograd_weight_trans(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
@@ -182,7 +150,6 @@
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(winograd_info.output_data_layout != DataLayout::NCHW);
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != num_tiles.area());
ARM_COMPUTE_RETURN_ERROR_ON_MSG((kernel_dims.width != 3U && kernel_dims.width != 5U), "Winograd output transform only supports 3x3 and 5x5 kernels");
ARM_COMPUTE_RETURN_ERROR_ON_MSG((kernel_dims.width != kernel_dims.height), "Winograd output transform only supports 3x3 and 5x5 kernels");
@@ -233,85 +200,13 @@
return std::make_pair(err, win);
}
} // namespace
-template <typename TIn, typename TOut, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-NEWinogradLayerBatchedGEMMKernel<TIn, TOut, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::NEWinogradLayerBatchedGEMMKernel()
- : _gemms()
-{
-}
-
-template <typename TIn, typename TOut, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-void NEWinogradLayerBatchedGEMMKernel<TIn, TOut, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::configure(
- const unsigned int n_gemms,
- const int M, const int K, const int N,
- const int a_matrix_stride,
- const int a_row_stride,
- const int b_matrix_stride,
- const int b_row_stride,
- const int c_matrix_stride,
- const int c_row_stride,
- const TIn *const a_ptr,
- const TIn *const b_ptr,
- TOut *const c_ptr)
-{
- _gemms = support::cpp14::make_unique<MultiGEMM>(n_gemms, M, K, N, a_matrix_stride, a_row_stride, b_matrix_stride, b_row_stride, c_matrix_stride, c_row_stride, a_ptr, b_ptr, c_ptr);
- Window win;
- auto win_last = _gemms->get_window();
- win.set(Window::DimX, Window::Dimension(0, win_last, 1));
- INEKernel::configure(win);
-}
-
-template <typename TIn, typename TOut, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-void NEWinogradLayerBatchedGEMMKernel<TIn, TOut, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::run(const Window &window, const ThreadInfo &info)
-{
- ARM_COMPUTE_UNUSED(info);
- ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
- const size_t first_gemm = window.x().start();
- const size_t last_gemm = window.x().end();
- _gemms->run(first_gemm, last_gemm);
-}
-
-template <typename TIn, typename TOut, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-unsigned int NEWinogradLayerBatchedGEMMKernel<TIn, TOut, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_number_gemms() const
-{
- return WinogradBase::N_GEMMS;
-}
-
-template <typename TIn, typename TOut, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-int NEWinogradLayerBatchedGEMMKernel<TIn, TOut, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_output_tile_rows() const
-{
- return _output_tile_rows;
-}
-
-template <typename TIn, typename TOut, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-int NEWinogradLayerBatchedGEMMKernel<TIn, TOut, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_output_tile_cols() const
-{
- return _output_tile_cols;
-}
-
-template <typename TIn, typename TOut, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-int NEWinogradLayerBatchedGEMMKernel<TIn, TOut, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_number_blocks() const
-{
- return WinogradConv::N_BLOCK;
-}
-
-template <typename TIn, typename TOut, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-Status NEWinogradLayerBatchedGEMMKernel<TIn, TOut, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensor *c,
- const ITensorInfo *output, const float alpha, const float beta, const GEMMInfo &gemm_info)
-{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_winograd_gemm(a, b, c, output, alpha, beta, gemm_info));
- return Status{};
-}
-
-template class NEWinogradLayerBatchedGEMMKernel<float, float, 2, 2, 3, 3>;
-template class NEWinogradLayerBatchedGEMMKernel<float, float, 4, 4, 3, 3>;
-template class NEWinogradLayerBatchedGEMMKernel<float, float, 2, 2, 5, 5>;
// Weights transform
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-unsigned int NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_weight_storage_size(int n_output_channels, int n_input_channels) const
+unsigned int NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_weight_storage_size(int num_output_channels, int num_input_channels) const
{
- const KernelShape shape(n_output_channels, KernelRows, KernelCols, n_input_channels);
+ const KernelShape shape(num_output_channels, KernelRows, KernelCols, num_input_channels);
return static_cast<unsigned int>(
// WinogradConv returns the size in bytes, we divide by `sizeof(T)` to express that in units of T
WinogradConv::get_kernel_storage_size(shape) / sizeof(T));
@@ -319,7 +214,8 @@
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::NEWinogradLayerTransformWeightsKernel()
- : _transform()
+ : _weights_hwio(nullptr), _output(nullptr), _matrix_stride(0), _num_output_channels(0), _num_input_channels(0)
+
{
}
@@ -332,16 +228,21 @@
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
void NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::configure(
const ITensor *weights_hwio,
- T *const output,
- const int matrix_stride, /** Stride across matrices in the output. */
- const int n_output_channels, /** Number of filters. */
- const int n_input_channels) /** Number of channels in each filter. */
+ ITensor *output,
+ const int matrix_stride, /** Stride across matrices in the output. */
+ const int num_output_channels, /** Number of filters. */
+ const int num_input_channels) /** Number of channels in each filter. */
{
- const int matrix_row_stride = roundup(n_output_channels, WinogradConv::N_BLOCK);
- _transform = support::cpp14::make_unique<WeightsTransform>(reinterpret_cast<T *>(weights_hwio->buffer()), output, matrix_stride, matrix_row_stride, n_output_channels,
- n_input_channels);
- Window win;
- auto win_last = _transform->get_window();
+ _weights_hwio = weights_hwio;
+ _output = output;
+ _matrix_stride = matrix_stride;
+ _num_output_channels = num_output_channels;
+ _num_input_channels = num_input_channels;
+
+ const int matrix_row_stride = roundup(num_output_channels, WinogradConv::N_BLOCK);
+ WeightsTransform transform(nullptr, nullptr, matrix_stride, matrix_row_stride, num_output_channels, num_input_channels);
+ Window win;
+ auto win_last = transform.get_window();
win.set(Window::DimX, Window::Dimension(0, win_last, 1));
INEKernel::configure(win);
}
@@ -351,9 +252,12 @@
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
- const size_t fst = window.x().start();
- const size_t lst = window.x().end();
- _transform->run(fst, lst);
+
+ const int matrix_row_stride = roundup(_num_output_channels, WinogradConv::N_BLOCK);
+ WeightsTransform transform(reinterpret_cast<T *>(_weights_hwio->buffer()), reinterpret_cast<T *>(_output->buffer()), _matrix_stride, matrix_row_stride, _num_output_channels, _num_input_channels);
+ const size_t fst = window.x().start();
+ const size_t lst = window.x().end();
+ transform.run(fst, lst);
}
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
@@ -379,16 +283,16 @@
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
unsigned int NEWinogradLayerTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_input_storage_size(
- int n_batches, /** Number of batches in the input tensor. */
- int n_channels, /** Number of feature maps in the input tensor. */
- int n_rows, /** Number of rows in each feature map. */
- int n_cols, /** Number of columns in each feature map. */
- bool same_padding /** Use "SAME" padding, otherwise use "VALID". */
+ int num_batches, /* Number of batches in the input tensor. */
+ int num_channels, /* Number of feature maps in the input tensor. */
+ int num_rows, /* Number of rows in each feature map. */
+ int num_cols, /* Number of columns in each feature map. */
+ bool same_padding /* Use "SAME" padding, otherwise use "VALID". */
) const
{
// Construct shapes for the input and kernel tensors.
- const Tensor4DShape input_shape(n_batches, n_rows, n_cols, n_channels);
- const KernelShape kern_shape(1, KernelRows, KernelCols, n_channels);
+ const Tensor4DShape input_shape(num_batches, num_rows, num_cols, num_channels);
+ const KernelShape kern_shape(1, KernelRows, KernelCols, num_channels);
const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID;
// Return the size, converted into units of TIn
return static_cast<unsigned int>(WinogradConv::get_input_storage_size(kern_shape, input_shape, padding) / sizeof(T));
@@ -403,25 +307,32 @@
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
NEWinogradLayerTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::NEWinogradLayerTransformInputKernel()
- : _transform()
+ : _input_nhwc(), _num_batches(0), _num_rows(0), _num_cols(0), _num_channels(0), _padding(), _output(nullptr), _matrix_stride(0)
{
}
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
void NEWinogradLayerTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::configure(
- const T *const input, /** Input tensor data */
- const int n_batches, /** Number of batches in input tensor. */
- const int n_rows, /** Number of rows in input tensor. */
- const int n_cols, /** Number of columns in input tensor. */
- const int n_channels, /** Number of channels in input tensor. */
- const PaddingType padding, /** Padding type. */
- T *const output, /** Base of output matrices. */
- const int matrix_stride) /** Stride between output matrices. */
+ const ITensor *input_nhwc,
+ const int num_batches, /* Number of batches in input tensor. */
+ const int num_rows, /* Number of rows in input tensor. */
+ const int num_cols, /* Number of columns in input tensor. */
+ const int num_channels, /* Number of channels in input tensor. */
+ const PaddingType padding, /* Padding type. */
+ ITensor *output, /* Base of output matrices. */
+ const int matrix_stride) /* Stride between output matrices. */
{
- // _input_matrix_row_stride(n_input_channels),
- _transform = support::cpp14::make_unique<InputTransform>(input, n_batches, n_rows, n_cols, n_channels, padding, output, matrix_stride, n_channels);
- Window win;
- auto win_last = _transform->get_window();
+ _input_nhwc = input_nhwc;
+ _num_batches = num_batches;
+ _num_rows = num_rows;
+ _num_cols = num_cols;
+ _num_channels = num_channels;
+ _padding = padding;
+ _output = output;
+ _matrix_stride = matrix_stride;
+ InputTransform transform(nullptr, num_batches, num_rows, num_cols, num_channels, padding, nullptr, matrix_stride, num_channels);
+ Window win;
+ auto win_last = transform.get_window();
win.set(Window::DimX, Window::Dimension(0, win_last, 1));
INEKernel::configure(win);
}
@@ -431,9 +342,21 @@
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+
+ const int element_size_in_bytes = _input_nhwc->info()->element_size();
+ const int input_col_stride = _input_nhwc->info()->strides_in_bytes().y() / element_size_in_bytes;
+ const int input_row_stride = _input_nhwc->info()->strides_in_bytes().z() / element_size_in_bytes;
+ const int input_batch_stride = _input_nhwc->info()->strides_in_bytes()[3] / element_size_in_bytes;
+
+ InputTransform input_transform(reinterpret_cast<const T *>(_input_nhwc->buffer() + _input_nhwc->info()->offset_first_element_in_bytes()),
+ _num_batches, _num_rows, _num_cols, _num_channels, _padding,
+ reinterpret_cast<T *>(_output->buffer() + _output->info()->offset_first_element_in_bytes()),
+ _matrix_stride, _num_channels, input_batch_stride, input_row_stride, input_col_stride);
+
+ // The code below cannot be moved to configure because biases hasn't been allocated at that point
const size_t fst = window.x().start();
const size_t lst = window.x().end();
- _transform->run(fst, lst);
+ input_transform.run(fst, lst);
}
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
@@ -453,16 +376,16 @@
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
unsigned int NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_output_storage_size(
- int n_batches, /** Number of batches in the output tensor. */
- int n_rows, /** Number of rows in each feature map of the input tensor. */
- int n_cols, /** Number of columns in each feature map of the input tensor. */
- int n_output_channels, /** Number of feature maps in the output tensor. */
- bool same_padding /** Use "SAME" padding, otherwise use "VALID". */
+ int num_batches, /* Number of batches in the output tensor. */
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ int num_output_channels, /* Number of feature maps in the output tensor. */
+ bool same_padding /* Use "SAME" padding, otherwise use "VALID". */
) const
{
// Construct shapes for the input and kernel tensors.
- const Tensor4DShape input_shape(n_batches, n_rows, n_cols, 1);
- const KernelShape kern_shape(n_output_channels, KernelRows, KernelCols, 1);
+ const Tensor4DShape input_shape(num_batches, num_rows, num_cols, 1);
+ const KernelShape kern_shape(num_output_channels, KernelRows, KernelCols, 1);
const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID;
// Return the size, converted into units of TOut
@@ -472,7 +395,7 @@
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::NEWinogradLayerTransformOutputKernel()
- : _biases(nullptr), _output_workspace(nullptr), _matrix_stride(0), _matrix_row_stride(0), _output(nullptr), _n_batches(0), _n_rows(0), _n_cols(0), _n_channels(0)
+ : _biases(nullptr), _output_workspace(nullptr), _matrix_stride(0), _matrix_row_stride(0), _output_nhwc(nullptr), _num_batches(0), _num_rows(0), _num_cols(0), _num_channels(0)
{
}
@@ -492,29 +415,32 @@
template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
void NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::configure(
const ITensor *biases,
- const T *const output_workingspace,
+ const ITensor *output_workingspace,
const int matrix_stride,
- T *const output,
- const int n_batches,
- const int n_rows,
- const int n_cols,
- const int n_channels)
+ ITensor *output_nhwc,
+ const int num_batches,
+ const int num_rows,
+ const int num_cols,
+ const int num_channels)
{
_biases = biases;
_output_workspace = output_workingspace;
_matrix_stride = matrix_stride;
- _matrix_row_stride = roundup(n_channels, WinogradConv::N_BLOCK);
- _output = output;
- _n_batches = n_batches;
- _n_rows = n_rows;
- _n_cols = n_cols;
- _n_channels = n_channels;
-
+ _matrix_row_stride = roundup(num_channels, WinogradConv::N_BLOCK);
+ _output_nhwc = output_nhwc;
+ _num_batches = num_batches;
+ _num_rows = num_rows;
+ _num_cols = num_cols;
+ _num_channels = num_channels;
// We don't have the biases buffer at this stage as it hasn't been allocated, we pass in nullptr OutputTransform is only used here to compute the window
- OutputTransform output_transform(_output_workspace, _matrix_stride, _matrix_row_stride, nullptr, _output, _n_batches, _n_rows, _n_cols, _n_channels);
- Window win;
- auto win_last = output_transform.get_window();
+ OutputTransform output_transform(nullptr, _matrix_stride, _matrix_row_stride, nullptr, nullptr, _num_batches, _num_rows, _num_cols, _num_channels);
+
+ Window win;
+ auto win_last = output_transform.get_window();
win.set(Window::DimX, Window::Dimension(0, win_last, 1));
+
+ _output_nhwc->info()->set_valid_region(ValidRegion(Coordinates(), _output_nhwc->info()->tensor_shape()));
+
INEKernel::configure(win);
}
@@ -524,11 +450,12 @@
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_NULLPTR(_output_workspace);
- ARM_COMPUTE_ERROR_ON_NULLPTR(_output);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_output_nhwc);
- OutputTransform output_transform(_output_workspace, _matrix_stride, _matrix_row_stride,
- (_biases ? reinterpret_cast<T *>(_biases->buffer()) : nullptr), _output,
- _n_batches, _n_rows, _n_cols, _n_channels);
+ OutputTransform output_transform(reinterpret_cast<T *>(_output_workspace->buffer()), _matrix_stride, _matrix_row_stride,
+ (_biases ? reinterpret_cast<T *>(_biases->buffer() + _biases->info()->offset_first_element_in_bytes()) : nullptr),
+ reinterpret_cast<T *>(_output_nhwc->buffer() + _output_nhwc->info()->offset_first_element_in_bytes()),
+ _num_batches, _num_rows, _num_cols, _num_channels, 0, _output_nhwc->info()->strides_in_bytes()[2] / sizeof(T), _output_nhwc->info()->strides_in_bytes()[1] / sizeof(T));
// The code below cannot be moved to configure because biases hasn't been allocated at that point
const size_t fst = window.x().start();
diff --git a/src/core/NEON/kernels/arm_gemm/asmlib.hpp b/src/core/NEON/kernels/arm_gemm/asmlib.hpp
index b3fcb33..38f51ae 100644
--- a/src/core/NEON/kernels/arm_gemm/asmlib.hpp
+++ b/src/core/NEON/kernels/arm_gemm/asmlib.hpp
@@ -31,21 +31,21 @@
// used by the workaround.
// "Correct" version
-#define ASM_PREFETCH(address) "PRFM PLDL1KEEP, " address "\n"
-#define ASM_PREFETCHL2(address) "PRFM PLDL2KEEP, " address "\n"
-#define ASM_PREFETCHW(address) "PRFM PSTL1KEEP, " address "\n"
+#define ASM_PREFETCH(address) "PRFM PLDL1KEEP, " address "\n"
+#define ASM_PREFETCHL2(address) "PRFM PLDL2KEEP, " address "\n"
+#define ASM_PREFETCHW(address) "PRFM PSTL1KEEP, " address "\n"
#define ASM_PREFETCHWL2(address) "PRFM PSTL2KEEP, " address "\n"
// Lee's uarchsim hack
-//#define ASM_PREFETCH(address) "LDNP x20, x21, " address "\n"
+//#define ASM_PREFETCH(address) "LDNP x20, x21, " address "\n"
// No preload at all
//#define ASM_PREFETCH(address) ""
#else
// "Correct" versions for AArch32
-#define ASM_PREFETCH(address) "PLD " address "\n"
-#define ASM_PREFETCHW(address) "PLDW " address "\n"
+#define ASM_PREFETCH(address) "PLD " address "\n"
+#define ASM_PREFETCHW(address) "PLDW " address "\n"
#endif
@@ -53,76 +53,77 @@
* Do some prefetches.
*/
template <typename T>
-static inline void prefetch_6x(const T *pfp)
-{
- __asm __volatile(
+static inline void prefetch_6x(const T *pfp) {
+ __asm __volatile (
ASM_PREFETCH("[%[pfp]]")
ASM_PREFETCH("[%[pfp], #64]")
ASM_PREFETCH("[%[pfp], #128]")
ASM_PREFETCH("[%[pfp], #192]")
ASM_PREFETCH("[%[pfp], #256]")
ASM_PREFETCH("[%[pfp], #320]")
- :
- : [pfp] "r"(pfp)
- : "memory");
+ :
+ : [pfp] "r" (pfp)
+ : "memory"
+ );
}
template <typename T>
-static inline void prefetch_5x(const T *pfp)
-{
- __asm __volatile(
+static inline void prefetch_5x(const T *pfp) {
+ __asm __volatile (
ASM_PREFETCH("[%[pfp]]")
ASM_PREFETCH("[%[pfp], #64]")
ASM_PREFETCH("[%[pfp], #128]")
ASM_PREFETCH("[%[pfp], #192]")
ASM_PREFETCH("[%[pfp], #256]")
- :
- : [pfp] "r"(pfp)
- : "memory");
+ :
+ : [pfp] "r" (pfp)
+ : "memory"
+ );
}
template <typename T>
-static inline void prefetch_4x(const T *pfp)
-{
- __asm __volatile(
+static inline void prefetch_4x(const T *pfp) {
+ __asm __volatile (
ASM_PREFETCH("[%[pfp]]")
ASM_PREFETCH("[%[pfp], #64]")
ASM_PREFETCH("[%[pfp], #128]")
ASM_PREFETCH("[%[pfp], #192]")
- :
- : [pfp] "r"(pfp)
- : "memory");
+ :
+ : [pfp] "r" (pfp)
+ : "memory"
+ );
}
template <typename T>
-static inline void prefetch_3x(const T *pfp)
-{
- __asm __volatile(
+static inline void prefetch_3x(const T *pfp) {
+ __asm __volatile (
ASM_PREFETCH("[%[pfp]]")
ASM_PREFETCH("[%[pfp], #64]")
ASM_PREFETCH("[%[pfp], #128]")
- :
- : [pfp] "r"(pfp)
- : "memory");
+ :
+ : [pfp] "r" (pfp)
+ : "memory"
+ );
}
template <typename T>
-static inline void prefetch_2x(const T *pfp)
-{
- __asm __volatile(
+static inline void prefetch_2x(const T *pfp) {
+ __asm __volatile (
ASM_PREFETCH("[%[pfp]]")
ASM_PREFETCH("[%[pfp], #64]")
- :
- : [pfp] "r"(pfp)
- : "memory");
+ :
+ : [pfp] "r" (pfp)
+ : "memory"
+ );
}
template <typename T>
-static inline void prefetch_1x(const T *pfp)
-{
- __asm __volatile(
+static inline void prefetch_1x(const T *pfp) {
+ __asm __volatile (
ASM_PREFETCH("[%[pfp]]")
- :
- : [pfp] "r"(pfp)
- : "memory");
+ :
+ : [pfp] "r" (pfp)
+ : "memory"
+ );
}
+
diff --git a/src/core/NEON/kernels/arm_gemm/buffer_manager.hpp b/src/core/NEON/kernels/arm_gemm/buffer_manager.hpp
index dd74744..03f099d 100644
--- a/src/core/NEON/kernels/arm_gemm/buffer_manager.hpp
+++ b/src/core/NEON/kernels/arm_gemm/buffer_manager.hpp
@@ -38,36 +38,33 @@
#endif
-namespace arm_gemm
-{
+namespace arm_gemm {
+
#ifndef NO_MULTI_THREADING
-enum class BufferStatus
-{
+enum class BufferStatus {
IDLE,
POPULATING,
BUSY
};
-class Buffer
-{
+class Buffer {
private:
- const int _maxusers; // Maximum permissible threads.
- void *const _storage; // Storage for buffer content.
+ const int _maxusers; // Maximum permissible threads.
+ void * const _storage; // Storage for buffer content.
- int _numusers; // Actual number of threads (might be lower).
+ int _numusers; // Actual number of threads (might be lower).
- volatile BufferStatus _status = BufferStatus::IDLE; // Status
- std::atomic_int _users = {}; // How many users are still using the buffer.
- volatile int _index = 0; // Which block of data currently resides in the buffer.
+ volatile BufferStatus _status = BufferStatus::IDLE; // Status
+ std::atomic_int _users = { }; // How many users are still using the buffer.
+ volatile int _index = 0; // Which block of data currently resides in the buffer.
- std::mutex _lock = {};
+ std::mutex _lock = { };
#ifdef USE_SEMAPHORE
- std::condition_variable _cv = {};
+ std::condition_variable _cv = { };
#endif
template <typename T>
- void populate_buffer(T func)
- {
+ void populate_buffer(T func) {
func(_storage);
/* Now mark it as ready. */
@@ -78,17 +75,15 @@
_cv.notify_all();
}
#else
- _status = BufferStatus::BUSY;
+ _status = BufferStatus::BUSY;
#endif
}
public:
Buffer(Buffer &) = delete;
- Buffer &operator=(Buffer &) = delete;
+ Buffer &operator= (Buffer &) = delete;
- Buffer(void *storage, int maxusers)
- : _maxusers(maxusers), _storage(storage), _numusers(maxusers)
- {
+ Buffer(void *storage, int maxusers) : _maxusers(maxusers), _storage(storage), _numusers(maxusers) {
_status = BufferStatus::IDLE;
}
@@ -99,38 +94,32 @@
* If it's already being populated by another thread or is ready, return.
*/
template <typename T>
- void try_populate(const int index, T func)
- {
- for(;;)
- {
+ void try_populate(const int index, T func) {
+ for (;;) {
#ifdef USE_SEMAPHORE
/* If it's busy with a previous index, wait on the semaphore. */
- if((_status == BufferStatus::BUSY) && (_index != index))
- {
+ if ((_status == BufferStatus::BUSY) && (_index != index)) {
std::unique_lock<std::mutex> ul(_lock);
- if((_status == BufferStatus::BUSY) && (_index != index))
- {
+ if ((_status == BufferStatus::BUSY) && (_index != index)) {
_cv.wait(ul);
}
}
#endif
/* Return if another thread is populating it already. */
- if((_index == index) && ((_status == BufferStatus::POPULATING) || (_status == BufferStatus::BUSY)))
- {
+ if ((_index == index) &&
+ ((_status == BufferStatus::POPULATING) || (_status == BufferStatus::BUSY))) {
return;
}
- if(_status == BufferStatus::IDLE)
- {
+ if (_status == BufferStatus::IDLE) {
std::lock_guard<std::mutex> guard(_lock);
/* If the buffer is still idle, we can grab it and populate it. */
- if(_status == BufferStatus::IDLE)
- {
+ if (_status == BufferStatus::IDLE) {
_status = BufferStatus::POPULATING;
- _index = index;
- _users = _numusers;
+ _index = index;
+ _users = _numusers;
break;
}
}
@@ -141,26 +130,26 @@
}
template <typename T>
- void *get(const int index, T func)
- {
+ void *get(const int index, T func) {
// Loop until we achieve something.
- for(;;)
- {
+ for (;;) {
// If the index is correct and the buffer status is busy then we can
// just return the content. No locking is needed here as the index
// cannot change (and status cannot change from BUSY) until all
// users have finished.
- if((_index == index) && (_status == BufferStatus::BUSY))
- {
+ if ((_index == index) && (_status == BufferStatus::BUSY)) {
return _storage;
}
+
+ /* If the buffer still has some previous content, or is being
+ * populated, we can wait with the semaphore. */
#ifdef USE_SEMAPHORE
- if(((_status == BufferStatus::BUSY) && (_index != index)) || (_status == BufferStatus::POPULATING))
- {
+ if (((_status == BufferStatus::BUSY) && (_index != index)) ||
+ (_status == BufferStatus::POPULATING)) {
std::unique_lock<std::mutex> ul(_lock);
- if(((_status == BufferStatus::BUSY) && (_index != index)) || (_status == BufferStatus::POPULATING))
- {
+ if (((_status == BufferStatus::BUSY) && (_index != index)) ||
+ (_status == BufferStatus::POPULATING)) {
_cv.wait(ul);
}
}
@@ -168,17 +157,15 @@
// If it's idle, we need to populate it. The IDLE->POPULATING
// transition requires the lock.
- if(_status == BufferStatus::IDLE)
- {
+ if (_status == BufferStatus::IDLE) {
std::lock_guard<std::mutex> guard(_lock);
/* If it's still idle, grab it. Otherwise drop through and
* we'll do something else next time through the loop. */
- if(_status == BufferStatus::IDLE)
- {
+ if (_status == BufferStatus::IDLE) {
_status = BufferStatus::POPULATING;
- _index = index;
- _users = _numusers;
+ _index = index;
+ _users = _numusers;
break;
}
}
@@ -194,10 +181,8 @@
* simply (atomically) decrement the user count, and if it's hit zero we
* flag the buffer as idle.
*/
- void release(void)
- {
- if(--_users == 0)
- {
+ void release(void) {
+ if (--_users == 0) {
#ifdef USE_SEMAPHORE
std::unique_lock<std::mutex> ul(_lock);
_status = BufferStatus::IDLE;
@@ -211,110 +196,91 @@
}
/* This is called to change the number of users. */
- void set_numusers(int numusers)
- {
+ void set_numusers(int numusers) {
_numusers = std::min(numusers, _maxusers);
}
};
-class BufferManager
-{
+
+class BufferManager {
private:
/* This has to be a vector of Buffer *, because a Buffer cannot be moved
* or copied due to atomic members. */
- std::vector<Buffer *> _buffers = {};
- const int _maxthreads;
- void *const _storage;
+ std::vector<Buffer *> _buffers = { };
+ const int _maxthreads;
+ void * const _storage;
public:
BufferManager(BufferManager &) = delete;
- BufferManager &operator=(BufferManager &) = delete;
+ BufferManager & operator=(BufferManager &) = delete;
// Say how much storage is needed.
- static inline size_t get_storage_requirement(const int maxthreads, const size_t buffersize)
- {
+ static inline size_t get_storage_requirement(const int maxthreads, const size_t buffersize) {
return buffersize * ((maxthreads == 1) ? 1 : 3);
}
- BufferManager(const int maxthreads, const size_t buffersize, void *storage)
- : _maxthreads(maxthreads), _storage(storage)
- {
+ BufferManager(const int maxthreads, const size_t buffersize, void *storage) : _maxthreads(maxthreads), _storage(storage) {
const int numbuffers = (maxthreads == 1) ? 1 : 3;
/* We don't need any Buffer objects in single thread mode. */
- if(_maxthreads == 1)
- {
+ if (_maxthreads == 1) {
return;
}
/* Use intptr_t to avoid performing arithmetic on a void * */
intptr_t storage_int = reinterpret_cast<intptr_t>(_storage);
- for(int i = 0; i < numbuffers; i++)
- {
+ for (int i=0; i<numbuffers; i++) {
_buffers.push_back(new Buffer(reinterpret_cast<void *>(storage_int), _maxthreads));
storage_int += buffersize;
}
}
- ~BufferManager()
- {
- while(_buffers.size())
- {
+ ~BufferManager() {
+ while (_buffers.size()) {
delete _buffers.back();
_buffers.pop_back();
}
}
template <typename T>
- void *get(const int index, T func)
- {
+ void *get(const int index, T func) {
/* In single thread mode, we just directly call the populating
* function on the (single) buffer, otherwise forward to the
* relevant Buffer. */
- if(_maxthreads == 1)
- {
+ if (_maxthreads==1) {
func(_storage);
return _storage;
- }
- else
- {
+ } else {
return _buffers[index % _buffers.size()]->get(index, func);
}
}
template <typename T>
- void try_populate(const int index, T func)
- {
+ void try_populate(const int index, T func) {
/* No need for this in single thread mode. */
- if(_maxthreads == 1)
- {
+ if (_maxthreads==1) {
return;
}
_buffers[index % _buffers.size()]->try_populate(index, func);
}
- void release(const int index)
- {
+ void release(const int index) {
/* No need for this in single thread mode. */
- if(_maxthreads == 1)
- {
+ if (_maxthreads==1) {
return;
}
_buffers[index % _buffers.size()]->release();
}
- void set_nthreads(int threads)
- {
- if(_maxthreads == 1)
- {
+ void set_nthreads(int threads) {
+ if (_maxthreads==1) {
return;
}
- for(unsigned int i = 0; i < _buffers.size(); i++)
- {
+ for(unsigned int i=0; i<_buffers.size(); i++) {
_buffers[i]->set_numusers(threads);
}
}
@@ -329,49 +295,35 @@
* All the other methods do nothing.
*/
-class BufferManager
-{
+class BufferManager {
private:
- void *const _storage;
+ void * const _storage;
public:
BufferManager(BufferManager &) = delete;
- BufferManager &operator=(BufferManager &) = delete;
+ BufferManager & operator=(BufferManager &) = delete;
- BufferManager(const int maxthreads, const size_t buffersize, void *storage)
- : _storage(storage)
- {
- }
+ BufferManager(const int maxthreads, const size_t buffersize, void *storage) : _storage(storage) { }
- ~BufferManager()
- {
- }
+ ~BufferManager() { }
// Say how much storage is needed.
- static inline size_t get_storage_requirement(const int maxthreads, const size_t buffersize)
- {
+ static inline size_t get_storage_requirement(const int maxthreads, const size_t buffersize) {
return buffersize;
}
template <typename T>
- void try_populate(const int index, T func)
- {
- }
+ void try_populate(const int index, T func) { }
- void release(const int index)
- {
- }
+ void release(const int index) { }
template <typename T>
- void *get(const int index, T func)
- {
+ void *get(const int index, T func) {
func(_storage);
return _storage;
}
- void set_nthreads(int)
- {
- }
+ void set_nthreads(int) { }
};
#endif
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index d1180b1..4579ebd 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -28,52 +28,72 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
-#include "kernels/a32_sgemm_8x6.hpp"
#include "kernels/a64_hgemm_24x8.hpp"
#include "kernels/a64_sgemm_12x8.hpp"
+#include "kernels/a32_sgemm_8x6.hpp"
-namespace arm_gemm
-{
-template <>
-UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const __fp16 alpha, const __fp16 beta,
- const int maxthreads, const bool pretransposed_hint)
-{
+namespace arm_gemm {
+
#ifdef __aarch64__
- // Only consider the native FP16 kernel if it will get built.
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS)
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- // If the compiler is configured to enable this feature always, then assume it is available at runtime too.
- const bool use_fp16 = true;
-#else
- // Otherwise, detect at runtime via CPUInfo.
- const bool use_fp16 = ci.has_fp16();
-#endif
-
- // If FP16 is supported, use it.
- if(use_fp16)
- {
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+class GemmImpl_gemm_fp16_interleaved_fp16 : public GemmImplementation<__fp16, __fp16> {
+public:
+#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ bool is_supported(const GemmArgs<__fp16> &args) override {
+ return args._ci->has_fp16();
}
#endif
- // Fallback to using the blocked SGEMM kernel.
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
-#else
- // For AArch32, only support the SGEMM route for now.
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(args));
+ }
+
+ GemmImpl_gemm_fp16_interleaved_fp16() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED_FP16) { }
+};
#endif
+
+#endif // __aarch64__
+
+class GemmImpl_gemm_fp16_interleaved : public GemmImplementation<__fp16, __fp16> {
+public:
+ UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
+#ifdef __aarch64__
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(args));
+#elif defined(__arm__)
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args));
+#else
+# error Unknown Architecture
+#endif
+ }
+
+ GemmImpl_gemm_fp16_interleaved() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+#if defined(__aarch64__) && (defined(__ARM_FEATURE_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
+static GemmImpl_gemm_fp16_interleaved_fp16 gemm_fp16_interleaved_fp16_impl{};
+#endif
+static GemmImpl_gemm_fp16_interleaved gemm_fp16_interleaved_impl{};
+
+static std::vector<GemmImplementation<__fp16, __fp16> *> gemm_fp16_methods = {
+#if defined(__aarch64__) && (defined(__ARM_FEATURE_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
+ &gemm_fp16_interleaved_fp16_impl,
+#endif
+ &gemm_fp16_interleaved_impl
+};
+
+template<>
+std::vector<GemmImplementation<__fp16, __fp16> *> &gemm_implementation_list<__fp16, __fp16>() {
+ return gemm_fp16_methods;
}
-// Instantiate static class members if necessary.
-#if defined(__aarch64__) && (defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
-const int hgemm_24x8::out_width;
-const int hgemm_24x8::out_height;
-#endif
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16>(GemmArgs<__fp16> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<__fp16, __fp16>(GemmArgs<__fp16> &args);
+template bool method_is_compatible<__fp16, __fp16>(GemmMethod method, GemmArgs<__fp16> &args);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 43df1aa..e840e90 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -22,71 +22,116 @@
* SOFTWARE.
*/
#include "arm_gemm.hpp"
-#include "gemm_batched.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
+#include "gemv_batched.hpp"
#include "gemv_native_transposed.hpp"
#include "gemv_pretransposed.hpp"
-#include "kernels/a32_sgemm_8x6.hpp"
#include "kernels/a64_sgemm_12x8.hpp"
-#include "kernels/a64_sgemm_native_16x4.hpp"
-#include "kernels/a64_sgemv_pretransposed.hpp"
+#include "kernels/a32_sgemm_8x6.hpp"
#include "kernels/a64_sgemv_trans.hpp"
+#include "kernels/a64_sgemv_pretransposed.hpp"
+#include "kernels/a64_sgemm_native_16x4.hpp"
-namespace arm_gemm
-{
-template <>
-UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const float alpha, const float beta,
- const int maxthreads, const bool pretransposed_hint)
-{
- /* Handle "batched GEMM" */
- if(M == 1 && nbatches > 1)
- {
- return UniqueGemmCommon<float, float>(new GemmBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
- }
+namespace arm_gemm {
+
#ifdef __aarch64__
- /* Cases in priority order */
- /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set. nbatches must be 1 or we would have returned above so don't test. */
- if(M == 1 && alpha == 1.0f && pretransposed_hint)
- {
- return UniqueGemmCommon<float, float>(new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
+// SGEMM implementations for AArch64
+
+// Pretransposed GEMV
+class GemmImpl_sgemm_gemv_pretransposed : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Msize==1 && args._alpha==1.0f && args._pretransposed_hint && args._nbatches==1);
}
- /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
- if(M == 1 && alpha == 1.0f && !trA && !trB)
- {
- return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._trB, args._beta));
}
- /* Native GEMM: requires M to be a multiple of 4, K at least 4, N a
- * multiple of 16, doesn't handle alpha and only makes sense for small
- * sizes. */
- if(N <= 128 && K <= 128 && ((M % 4) == 0) && (K >= 4) && ((N % 16) == 0) && alpha == 1.0f)
- {
- return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
+ GemmImpl_sgemm_gemv_pretransposed() : GemmImplementation<float, float>(GemmMethod::GEMV_PRETRANSPOSED) { }
+};
+
+// Native GEMV
+class GemmImpl_sgemm_gemv_native_transposed : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Msize==1 && args._alpha==1.0f && !args._trA && !args._trB && args._nbatches==1);
}
- /* Blocked GEMM, handles all cases. */
- return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._beta));
+ }
+
+ GemmImpl_sgemm_gemv_native_transposed() : GemmImplementation<float, float>(GemmMethod::GEMV_NATIVE_TRANSPOSED) { }
+};
+
+// Native GEMM
+class GemmImpl_sgemm_gemm_native : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Ksize>4 && (args._Nsize % 16)==0 && args._alpha==1.0f && !args._trA && !args._trB);
+ }
+
+ bool is_recommended(const GemmArgs<float> &args) override {
+ return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8));
+ }
+
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._beta));
+ }
+
+ GemmImpl_sgemm_gemm_native() : GemmImplementation<float, float>(GemmMethod::GEMM_NATIVE) { }
+};
+#endif // __aarch64__
+
+// Interleaved GEMM
+class GemmImpl_sgemm_gemm_interleaved : public GemmImplementation<float, float> {
+public:
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+#ifdef __aarch64__
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(args));
+#elif defined(__arm__)
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(args));
#else
- return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+# error Unknown Architecture.
#endif
+ }
+
+ GemmImpl_sgemm_gemm_interleaved() : GemmImplementation<float, float>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static GemmImpl_gemv_batched<float, float> gemv_batched_impl{};
+#ifdef __aarch64__
+static GemmImpl_sgemm_gemv_pretransposed sgemm_gemv_pretransposed_impl{};
+static GemmImpl_sgemm_gemv_native_transposed sgemm_gemv_native_transposed_impl{};
+static GemmImpl_sgemm_gemm_native sgemm_gemm_native_impl{};
+#endif
+static GemmImpl_sgemm_gemm_interleaved sgemm_gemm_interleaved_impl{};
+
+/* List of implementations (order matters) */
+static std::vector<GemmImplementation<float, float> *> SGemmMethods = {
+ &gemv_batched_impl,
+#ifdef __aarch64__
+ &sgemm_gemv_pretransposed_impl,
+ &sgemm_gemv_native_transposed_impl,
+ &sgemm_gemm_native_impl,
+#endif
+ &sgemm_gemm_interleaved_impl
+};
+
+/* Templated function to return this list. */
+template<>
+std::vector<GemmImplementation<float, float> *> &gemm_implementation_list<float, float>() {
+ return SGemmMethods;
}
-// Instantiate static class variables.
-#ifdef __aarch64__
-const int sgemm_12x8::out_width;
-const int sgemm_12x8::out_height;
-
-const int sgemm_native_16x4::out_width;
-const int sgemm_native_16x4::out_height;
-#else
-const int sgemm_8x6::out_width;
-const int sgemm_8x6::out_height;
-#endif
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<float, float> gemm<float, float>(GemmArgs<float> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<float, float>(GemmArgs<float> &args);
+template bool method_is_compatible<float, float>(GemmMethod method, GemmArgs<float> &args);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
new file mode 100644
index 0000000..6734e3c
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "gemv_batched.hpp"
+
+namespace arm_gemm {
+
+template<typename Top, typename Tret>
+class GemmImplementation {
+public:
+ /* Is this implementation compatible with the args as provided? */
+ virtual bool is_supported(const GemmArgs<Tret> &args) { return true; }
+ /* Is this implementation "recommended" for these args (heuristic)? */
+ virtual bool is_recommended(const GemmArgs<Tret> &args) { return true; }
+ /* Instantiate this method please. */
+ virtual UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) = 0;
+
+ /* Indicate the "GemmMethod" for use as a selector */
+ const GemmMethod method;
+
+ virtual ~GemmImplementation() { }
+
+ GemmImplementation(GemmMethod method) : method(method) { }
+};
+
+/* "gemv_batched" implementation is type-agnostic, so template it here. */
+template<typename Top, typename Tret>
+class GemmImpl_gemv_batched : public GemmImplementation<Top, Tret> {
+public:
+ bool is_supported(const GemmArgs<Tret> &args) override {
+ return (args._Msize==1 && args._nbatches > 1);
+ }
+
+ UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) override {
+ return UniqueGemmCommon<Top, Tret> (new GemvBatched<Top, Tret>(args));
+ }
+
+ GemmImpl_gemv_batched() : GemmImplementation<Top, Tret>(GemmMethod::GEMV_BATCHED) { }
+};
+
+/* "Master" function implemented for each valid combination of types.
+ * Returns a list of GEMM implementation descriptors for processing by the
+ * other functions. */
+template<typename Top, typename Tret>
+std::vector<GemmImplementation<Top, Tret> *> &gemm_implementation_list();
+
+template<typename Top, typename Tret>
+GemmImplementation<Top, Tret> *find_implementation(GemmArgs<Tret> &args, GemmConfig *cfg) {
+ auto gemms = gemm_implementation_list<Top, Tret>();
+
+ for(auto &&i : gemms) {
+ /* Skip if this implementation doesn't support these args. */
+ if (!i->is_supported(args)) {
+ continue;
+ }
+
+ /* Skip if a specific method is requested and this is a different one. */
+ if (cfg && cfg->method != GemmMethod::DEFAULT && i->method != cfg->method) {
+ continue;
+ }
+
+ /* If no specific method is requested, check that this method recommends itself. */
+ if ((!cfg || cfg->method == GemmMethod::DEFAULT) && !i->is_recommended(args)) {
+ continue;
+ }
+
+ return i;
+ }
+
+ return nullptr;
+}
+
+template<typename Top, typename Tret>
+UniqueGemmCommon<Top, Tret> gemm(GemmArgs<Tret> &args, GemmConfig *cfg) {
+ auto impl = find_implementation<Top, Tret>(args, cfg);
+
+ if (impl) {
+ return impl->instantiate(args);
+ }
+
+ return UniqueGemmCommon<Top, Tret>(nullptr);
+}
+
+template<typename Top, typename Tret>
+GemmMethod get_gemm_method(GemmArgs<Tret> &args) {
+ auto impl = find_implementation<Top, Tret>(args, nullptr);
+
+ if (impl) {
+ return impl->method;
+ }
+
+ /* This shouldn't happen - there should always be at least one valid implementation. */
+ return GemmMethod::DEFAULT;
+}
+
+template<typename Top, typename Tret>
+bool method_is_compatible(GemmMethod method, GemmArgs<Tret> &args) {
+ /* Determine if the method is valid by attempting to obtain an implementation specifying this method. */
+ GemmConfig cfg(method);
+
+ auto impl = find_implementation<Top, Tret>(args, &cfg);
+
+ if (impl) {
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index 7669fe0..b7e8fa2 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -25,24 +25,37 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "kernels/a64_gemm_s16_12x8.hpp"
-namespace arm_gemm
-{
-template <>
-UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const int32_t alpha, const int32_t beta,
- const int maxthreads, const bool pretransposed_hint)
-{
- return UniqueGemmCommon<int16_t, int32_t>(new GemmInterleaved<gemm_s16_12x8, int16_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+namespace arm_gemm {
+
+class GemmImpl_gemm_s16_interleaved : public GemmImplementation<int16_t, int32_t> {
+public:
+ UniqueGemmCommon<int16_t, int32_t> instantiate(const GemmArgs<int32_t> &args) override {
+ return UniqueGemmCommon<int16_t, int32_t>(new GemmInterleaved<gemm_s16_12x8, int16_t, int32_t>(args));
+ }
+
+ GemmImpl_gemm_s16_interleaved() : GemmImplementation<int16_t, int32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static GemmImpl_gemm_s16_interleaved gemm_s16_interleaved_impl{};
+
+static std::vector<GemmImplementation<int16_t, int32_t> *> gemm_s16_methods = {
+ &gemm_s16_interleaved_impl
+};
+
+template<>
+std::vector<GemmImplementation<int16_t, int32_t> *> &gemm_implementation_list<int16_t, int32_t>() {
+ return gemm_s16_methods;
}
-// Instantiate static class members
-const int gemm_s16_12x8::out_width;
-const int gemm_s16_12x8::out_height;
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t>(GemmArgs<int32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<int16_t, int32_t>(GemmArgs<int32_t> &args);
+template bool method_is_compatible<int16_t, int32_t>(GemmMethod method, GemmArgs<int32_t> &args);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 6016af2..dffa056 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -25,35 +25,55 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "kernels/a64_gemm_s16_12x8.hpp"
#include "kernels/a64_gemm_s8_12x8.hpp"
#include "kernels/a64_gemm_s8_4x4.hpp"
-namespace arm_gemm
-{
-template <>
-UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const int32_t alpha, const int32_t beta,
- const int maxthreads, const bool pretransposed_hint)
-{
- if(ci.has_dotprod())
- {
- // Dot product supporting CPUs. This family has a special version for A55r1.
- return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_12x8, int8_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+namespace arm_gemm {
+
+class GemmImpl_gemm_s8_interleaved_dot : public GemmImplementation<int8_t, int32_t> {
+public:
+ bool is_supported(const GemmArgs<int32_t> &args) override {
+ return args._ci->has_dotprod();
}
- return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_4x4, int8_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<int8_t, int32_t> instantiate(const GemmArgs<int32_t> &args) override {
+ return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_12x8, int8_t, int32_t>(args));
+ }
+
+ GemmImpl_gemm_s8_interleaved_dot() : GemmImplementation<int8_t, int32_t>(GemmMethod::GEMM_INTERLEAVED_DOT) { }
+};
+
+class GemmImpl_gemm_s8_interleaved : public GemmImplementation<int8_t, int32_t> {
+public:
+ UniqueGemmCommon<int8_t, int32_t> instantiate(const GemmArgs<int32_t> &args) override {
+ return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_4x4, int8_t, int32_t>(args));
+ }
+
+ GemmImpl_gemm_s8_interleaved() : GemmImplementation<int8_t, int32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static GemmImpl_gemm_s8_interleaved_dot gemm_s8_interleaved_dot_impl{};
+static GemmImpl_gemm_s8_interleaved gemm_s8_interleaved_impl{};
+
+static std::vector<GemmImplementation<int8_t, int32_t> *> gemm_s8_methods = {
+ &gemm_s8_interleaved_dot_impl,
+ &gemm_s8_interleaved_impl
+};
+
+template<>
+std::vector<GemmImplementation<int8_t, int32_t> *> &gemm_implementation_list<int8_t, int32_t>() {
+ return gemm_s8_methods;
}
-// Instantiate static class members
-const int gemm_s8_12x8::out_width;
-const int gemm_s8_12x8::out_height;
-const int gemm_s8_4x4::out_width;
-const int gemm_s8_4x4::out_height;
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t>(GemmArgs<int32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<int8_t, int32_t>(GemmArgs<int32_t> &args);
+template bool method_is_compatible<int8_t, int32_t>(GemmMethod method, GemmArgs<int32_t> &args);
} // namespace arm_gemm
-#endif // aarch64
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index efc5171..bfa4908 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -23,8 +23,8 @@
*/
#pragma once
-#include <assert.h>
#include <stdio.h>
+#include <assert.h>
#include <algorithm>
@@ -41,23 +41,22 @@
// Some macros used to decide how much working space to allocate.
// Round allocations up to the next cache line.
-#define ALLOC_ROUND 64
-#define ROUND_UP(x) ((((x) + ALLOC_ROUND - 1) / ALLOC_ROUND) * ALLOC_ROUND)
+#define ALLOC_ROUND 64
+#define ROUND_UP(x) ((((x) + ALLOC_ROUND-1) / ALLOC_ROUND) * ALLOC_ROUND)
// Implementation of the GemmCommon abstract class.
//
// This implementation interleaves the source matrices in blocks - good for
// larger matrices.
-namespace arm_gemm
-{
-template <typename strategy, typename To, typename Tr>
-class GemmInterleaved : public GemmCommon<To, Tr>
-{
+namespace arm_gemm {
+
+template<typename strategy, typename To, typename Tr>
+class GemmInterleaved : public GemmCommon<To, Tr> {
typedef typename strategy::operand_type Toi;
- typedef typename strategy::result_type Tri;
+ typedef typename strategy::result_type Tri;
/* const properties set by constructor */
- const CPUInfo *const _ci;
+ const CPUInfo * const _ci;
const unsigned int _Msize;
const unsigned int _Nsize;
@@ -72,173 +71,138 @@
const Tr _alpha;
const Tr _beta;
- const unsigned int _maxthreads;
- const bool _pretransposed;
+ const int _maxthreads;
+ int _nthreads;
+ const bool _pretransposed;
/* Blocking info */
- unsigned int _k_block = 0;
- unsigned int _x_block = 0;
- unsigned int _Mround = 0;
+ unsigned int _k_block=0;
+ unsigned int _x_block=0;
+ unsigned int _Mround=0;
/* Working space, pretransposed buffer, buffer manager */
- const Toi *_B_transposed = nullptr;
- BufferManager *_bm = nullptr;
- void *_working_space = nullptr;
+ const Toi *_B_transposed=nullptr;
+ BufferManager *_bm=nullptr;
+ void *_working_space=nullptr;
/* We will need to walk through the blocks of B in a few contexts, so
* factor that out. */
- class blockwalker
- {
+ class blockwalker {
private:
/* Size loops, etc. based on our parent's configuration */
const GemmInterleaved<strategy, To, Tr> &_parent;
- /* K and X and multi parameters for current iteration. */
- unsigned int _k0 = 0, _x0 = 0, _multi = 0;
+ /* K, X and multi parameters for current iteration. */
+ unsigned int _k0=0, _x0=0, _multi=0;
- unsigned int _index = 0;
- bool _done = false;
- bool _newkblock = true;
- bool _newmulti = true;
+ unsigned int _index=0;
+ bool _done=false;
+ bool _newkblock=true;
+ bool _newmulti=true;
public:
- blockwalker(const GemmInterleaved<strategy, To, Tr> &parent)
- : _parent(parent)
- {
- }
+ blockwalker(const GemmInterleaved<strategy, To, Tr> &parent) : _parent(parent) { }
- unsigned int xmax()
- {
+ unsigned int xmax() {
return std::min(_x0 + _parent._x_block, _parent._Nsize);
}
- unsigned int kmax()
- {
+ unsigned int kmax() {
return std::min(_k0 + _parent._k_block, _parent._Ksize);
}
/* Advance to the next block, return false at the end. */
- bool advance(void)
- {
- if(_done)
- {
+ bool advance(void) {
+ if (_done) {
return false;
}
- _newkblock = false;
+ _newkblock=false;
_x0 += _parent._x_block;
- if(_x0 >= _parent._Nsize)
- {
- _x0 = 0;
+ if (_x0 >= _parent._Nsize) {
+ _x0=0;
_k0 += _parent._k_block;
- if(_k0 >= _parent._Ksize)
- {
- _k0 = 0;
+ if (_k0 >= _parent._Ksize) {
+ _k0=0;
_multi++;
- if(_multi >= _parent._nmulti)
- {
- _done = true;
+ if (_multi >= _parent._nmulti) {
+ _done=true;
return false;
}
- _newmulti = true;
+ _newmulti=true;
}
- _newkblock = true;
+ _newkblock=true;
}
_index++;
return true;
}
- unsigned int k0(void)
- {
- return _k0;
- }
- unsigned int x0(void)
- {
- return _x0;
- }
- unsigned int multi(void)
- {
- return _multi;
- }
- unsigned int index(void)
- {
- return _index;
- }
- bool done(void)
- {
- return _done;
- }
- bool newkblock(void)
- {
- return _newkblock;
- }
+ unsigned int k0(void) { return _k0; }
+ unsigned int x0(void) { return _x0; }
+ unsigned int multi(void) { return _multi; }
+ unsigned int index(void) { return _index; }
+ bool done(void) { return _done; }
+ bool newkblock(void) { return _newkblock; }
};
// A working size: One of these needed, regardless of thread count. Divided according to window.
- size_t get_a_working_size() const
- {
+ size_t get_a_working_size() const {
return ROUND_UP(sizeof(Toi) * _k_block * _Mround * _nbatches);
}
// B working size: 0, 1 or 3 of these needed depending on pretransposed and threading settings.
- size_t get_b_working_size() const
- {
+ size_t get_b_working_size() const {
return ROUND_UP(sizeof(Toi) * _x_block * _k_block);
}
// C working size: One needed per thread.
- size_t get_c_working_size() const
- {
- return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height);
+ size_t get_c_working_size() const {
+ return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height());
}
// Internal execute function.
// This supports both the "pretransposed" and "standard" interfaces via the template parameter.
- template <bool pretransposed>
- void execute_internal(unsigned int start, unsigned int end, int threadid)
- {
+ template<bool pretransposed>
+ void execute_internal(unsigned int start, unsigned int end, int threadid) {
#ifdef CYCLE_PROFILING
profiler prof;
#endif
-
strategy strat(_ci);
blockwalker current(*this);
- blockwalker next = current;
+ blockwalker next=current;
/* Translate 'start' and 'end' into a position within the batches and rows. */
- const unsigned int window_per_batch = _Mround / strategy::out_height;
- unsigned int batch_0 = start / window_per_batch;
- unsigned int batch_end = end / window_per_batch;
+ const unsigned int window_per_batch = _Mround / strategy::out_height();
+ unsigned int batch_0 = start / window_per_batch;
+ unsigned int batch_end = end / window_per_batch;
/* Compute the M values to operate on */
- unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height;
- unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height;
+ unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height();
+ unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height();
/* Make sure we've been set up correctly. */
- if(pretransposed)
- {
+ if (pretransposed) {
assert(_B_transposed);
- }
- else
- {
+ } else {
assert(_bm);
}
assert(_working_space);
int8_t *working_space_bytes = reinterpret_cast<int8_t *>(_working_space);
- // Private buffers. Treat working_space as an array of C buffers (one per thread) first, followed by the (window-divided) A buffer.
+ // Private buffers. Treat working_space as an array of C buffers
+ // (one per thread) first, followed by the (window-divided) A
+ // buffer.
// Set a_panel to the base of the A buffers - compute offsets into it based on M/batches later.
- Toi *const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()));
- Tri *const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size()));
+ Toi * const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()));
+ Tri * const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size()));
// Shared buffers - these come either from BufferManager or _B_transposed.
const Toi *b_panel;
- if(pretransposed)
- {
+ if (pretransposed) {
b_panel = _B_transposed;
}
@@ -247,45 +211,31 @@
// newkblock() is always true on the first iteration, so this will be set properly on the first loop.
int kern_k = 0;
- for(; !current.done(); current.advance())
- {
- if(current.newkblock())
- {
+ for (;!current.done();current.advance()) {
+ if (current.newkblock()) {
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height * (current.kmax() - current.k0()) * sizeof(Toi));
+ auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height() * (current.kmax()-current.k0()) * sizeof(Toi));
#endif
- for(unsigned int batch = batch_0; batch <= batch_end; batch++)
- {
- unsigned int first_m = (batch == batch_0) ? m_0 : 0;
+ for (unsigned int batch = batch_0; batch <= batch_end; batch++) {
+ unsigned int first_m = (batch == batch_0) ? m_0 : 0;
unsigned int last_m = (batch == batch_end) ? m_max : _Msize;
- if(first_m >= last_m)
+ if (first_m >= last_m)
continue;
- if(_trA ^ strategy::A_transpose)
- {
- Transform<strategy::A_interleave, strategy::A_block, true>(
- a_panel + ((batch * _Mround + first_m) * _k_block),
- this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
- this->_lda, first_m, last_m, current.k0(), current.kmax());
- }
- else
- {
- Transform<strategy::A_interleave, strategy::A_block, false>(
- a_panel + ((batch * _Mround + first_m) * _k_block),
- this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
- this->_lda, first_m, last_m, current.k0(), current.kmax());
- }
+
+ strat.transforms.PrepareA(a_panel + ((batch * _Mround + first_m) * _k_block),
+ this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
+ this->_lda, first_m, last_m, current.k0(), current.kmax(), _trA);
}
// Figure out how many "K" the kernel will actually process.
- kern_k = iceildiv(current.kmax() - current.k0(), strategy::k_unroll);
- kern_k *= strat.k_unroll;
+ kern_k = iceildiv(current.kmax() - current.k0(), strategy::k_unroll());
+ kern_k *= strat.k_unroll();
}
- int bblocks = iceildiv(current.xmax() - current.x0(), strategy::out_width);
+ int bblocks = iceildiv(current.xmax() - current.x0(), strategy::out_width());
- if(!pretransposed)
- {
+ if (!pretransposed) {
/* Look ahead to the next block and populate it if necessary.
* This avoids the populate operation becoming a bottleneck, and
* helps keep the threads synchronized (the first thread to get
@@ -294,96 +244,69 @@
* If we are running single threaded, bm->try_populate() will do
* nothing.
*/
- if(next.advance())
- {
- _bm->try_populate(next.index(), [&](void *buffer)
- {
+ if (next.advance()) {
+ _bm->try_populate(next.index(), [&](void *buffer) {
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_PREPB, (next.xmax() - next.x0()) * (next.kmax() - next.k0()) * sizeof(Toi));
+ auto p=prof.ScopedProfiler(PROFILE_PREPB, (next.xmax()-next.x0()) * (next.kmax()-next.k0()) * sizeof(Toi));
#endif
Toi *b_panel = reinterpret_cast<Toi *>(buffer);
- if(_trB ^ strategy::B_transpose)
- {
- Transform<strategy::B_interleave, strategy::B_block, true>(
- b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
- next.x0(), next.xmax(), next.k0(), next.kmax());
- }
- else
- {
- Transform<strategy::B_interleave, strategy::B_block, false>(
- b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
- next.x0(), next.xmax(), next.k0(), next.kmax());
- }
+
+ strat.transforms.PrepareB(b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
+ next.x0(), next.xmax(), next.k0(), next.kmax(), _trB);
});
}
+
/* Get the buffer for this iteration from the BufferManager. */
- b_panel = reinterpret_cast<Toi *>(_bm->get(current.index(), [&](void *bpv)
- {
+ b_panel = reinterpret_cast<Toi *>(_bm->get(current.index(), [&](void *bpv) {
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_PREPB, (current.xmax() - current.x0()) * (current.kmax() - current.k0()) * sizeof(Toi));
+ auto p=prof.ScopedProfiler(PROFILE_PREPB, (current.xmax()-current.x0()) * (current.kmax()-current.k0()) * sizeof(Toi));
#endif
Toi *b_panel = reinterpret_cast<Toi *>(bpv);
- if(_trB ^ strategy::B_transpose)
- {
- Transform<strategy::B_interleave, strategy::B_block, true>(
- b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax());
- }
- else
- {
- Transform<strategy::B_interleave, strategy::B_block, false>(
- b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax());
- }
+ strat.transforms.PrepareB(b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax(), _trB);
}));
}
/* Do the actual work. */
- for(unsigned int batch = batch_0; batch <= batch_end; batch++)
- {
- unsigned int first_m = (batch == batch_0) ? m_0 : 0;
+ for (unsigned int batch = batch_0; batch <= batch_end; batch++) {
+ unsigned int first_m = (batch == batch_0) ? m_0 : 0;
unsigned int last_m = (batch == batch_end) ? m_max : _Msize;
const Toi *a_ptr = a_panel + (batch * _Mround + first_m) * _k_block;
- if(first_m >= last_m)
+ if (first_m >= last_m)
continue;
- for(unsigned int y = first_m; y < last_m; y += strategy::out_height)
- {
- unsigned int ymax = std::min(_Msize, y + strategy::out_height);
+ for (unsigned int y=first_m; y<last_m; y+=strategy::out_height()) {
+ unsigned int ymax = std::min(_Msize, y + strategy::out_height());
{
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height * bblocks * strategy::out_width * kern_k));
+ auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k));
#endif
strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k);
- a_ptr += (strategy::out_height * kern_k);
+ a_ptr += (strategy::out_height() * kern_k);
}
{
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height * bblocks * strategy::out_width * sizeof(Tr)));
+ auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr)));
#endif
- MergeResults<strategy::out_width, strategy::out_height>(
- this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride),
- c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(),
- _alpha, (current.k0() == 0 ? _beta : static_cast<Tr>(1)));
+ strat.transforms.Merge(this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride),
+ c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(),
+ _alpha, (current.k0()==0 ? _beta : static_cast<Tr>(1)));
}
}
}
- if(pretransposed)
- {
- b_panel += (bblocks * strat.out_width * kern_k);
- }
- else
- {
+ if (pretransposed) {
+ b_panel += (bblocks * strat.out_width() * kern_k);
+ } else {
_bm->release(current.index());
}
}
@@ -391,57 +314,58 @@
public:
GemmInterleaved(GemmInterleaved &) = delete;
- GemmInterleaved &operator=(GemmInterleaved &) = delete;
+ GemmInterleaved & operator= (GemmInterleaved &) = delete;
/* Constructor */
- GemmInterleaved(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
- const Tr alpha, const Tr beta, const int maxthreads, const bool pretransposed)
- : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), _pretransposed(pretransposed)
- {
- const unsigned int L1_size = ci->get_L1_cache_size();
- const unsigned int L2_size = ci->get_L2_cache_size();
+ GemmInterleaved(const GemmArgs<Tr> &args)
+ : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
+ _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB),
+ _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _pretransposed(args._pretransposed_hint) {
+ const unsigned int L1_size = _ci->get_L1_cache_size();
+ const unsigned int L2_size = _ci->get_L2_cache_size();
- assert(maxthreads > 0);
+ assert(_maxthreads > 0);
// Work out blocking parameters
// k_block: Find out how much of the larger array can be loaded into half the cache.
// This should account for associative caches.
- _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width, strategy::out_height)));
+ _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height())));
// Needs to be (at least a single) multiple of the K unroll level.
- _k_block /= strategy::k_unroll;
- _k_block = std::max(_k_block, 1U) * strategy::k_unroll;
+ _k_block /= strategy::k_unroll();
+ _k_block = std::max(_k_block, 1U) * strategy::k_unroll();
// Now tune to presented problem size; this is how many blocks we need.
- int num_k_blocks = iceildiv(K, _k_block);
+ int num_k_blocks = iceildiv(_Ksize, _k_block);
// So divide the space equally into that many blocks.
- _k_block = iceildiv(K, num_k_blocks);
+ _k_block = iceildiv(_Ksize, num_k_blocks);
// And round UP to the K unroll level required.
- _k_block = iceildiv(_k_block, strategy::k_unroll);
- _k_block *= strategy::k_unroll;
+ _k_block = iceildiv(_k_block, strategy::k_unroll());
+ _k_block *= strategy::k_unroll();
// x_block: Work out how many rows (of length k_block) will fit in the L2
// Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents.
- _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width + strategy::out_height))) / (sizeof(Toi) * _k_block);
+ _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) /
+ (sizeof(Toi) * _k_block);
// Needs to be (at least a single) multiple of the kernel output width.
- _x_block /= strategy::out_width;
- _x_block = std::max(_x_block, 1U) * strategy::out_width;
+ _x_block /= strategy::out_width();
+ _x_block = std::max(_x_block, 1U) * strategy::out_width();
// And tune to the presented problem size.
- int num_x_blocks = iceildiv(N, _x_block);
- _x_block = iceildiv(N, num_x_blocks);
+ int num_x_blocks = iceildiv(_Nsize, _x_block);
+ _x_block = iceildiv(_Nsize, num_x_blocks);
- _x_block = iceildiv(_x_block, strategy::out_width);
- _x_block *= strategy::out_width;
+ _x_block = iceildiv(_x_block, strategy::out_width());
+ _x_block *= strategy::out_width();
// Work out the rounded size of M - needed for some buffers.
- _Mround = iceildiv(M, strategy::out_height);
- _Mround *= strategy::out_height;
+ _Mround = iceildiv(_Msize, strategy::out_height());
+ _Mround *= strategy::out_height();
}
// Interface implementation - Compulsory functions
@@ -450,45 +374,36 @@
// out work in units of out_height. Factor batches into the window, but
// not multi for now (as this would cause problems with the buffer
// manager).
-
- unsigned int get_window_size() const override
- {
+ unsigned int get_window_size() const override {
// _Mround is a multiple of out_height by definition.
- return (_Mround / strategy::out_height) * _nbatches;
+ return (_Mround / strategy::out_height()) * _nbatches;
}
// set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads.
- void set_nthreads(int nthreads) override
- {
- if(_bm)
- {
- _bm->set_nthreads(nthreads);
+ void set_nthreads(int nthreads) override {
+ _nthreads = std::min(nthreads, _maxthreads);
+ if (_bm) {
+ _bm->set_nthreads(_nthreads);
}
}
// Execute
- void execute(unsigned int start, unsigned int end, int threadid) override
- {
- if(_pretransposed)
- {
+ void execute(unsigned int start, unsigned int end, int threadid) override {
+ if (_pretransposed) {
execute_internal<true>(start, end, threadid);
- }
- else
- {
+ } else {
execute_internal<false>(start, end, threadid);
}
}
// Interface implementation - working space
- size_t get_working_size() const override
- {
+ size_t get_working_size() const override {
// In all cases, we need one A buffer plus a C buffer per thread.
size_t size = get_a_working_size() + (get_c_working_size() * _maxthreads);
// For pretransposed case, there is no working space needed for B.
// Otherwise, we need a BufferManager.
- if(!_pretransposed)
- {
+ if (!_pretransposed) {
size += BufferManager::get_storage_requirement(_maxthreads, get_b_working_size());
}
@@ -497,33 +412,28 @@
return size;
}
- void set_working_space(void *working_space) override
- {
+ void set_working_space(void *working_space) override {
// Make sure everything ends up cache line aligned
int8_t *working_space_bytes = reinterpret_cast<int8_t *>(working_space);
- intptr_t working_space_int = reinterpret_cast<intptr_t>(working_space);
+ intptr_t working_space_int = reinterpret_cast<intptr_t>(working_space);
- size_t diff = 0;
+ size_t diff=0;
- if(working_space_int & 0x3F)
- {
+ if (working_space_int & 0x3F) {
diff = 0x40 - (working_space_int & 0x3F);
}
working_space_bytes += diff;
- if(_pretransposed)
- {
+ if (_pretransposed) {
// Pretransposed case: just set internal pointer to parameter value.
_working_space = reinterpret_cast<void *>(working_space_bytes);
- }
- else
- {
+ } else {
// Otherwise, use the first part of the working space for the buffer manager.
// It's legal to call this again so don't leak a buffer manager if it already existed.
delete _bm;
- _bm = new BufferManager(_maxthreads, get_b_working_size(), reinterpret_cast<void *>(working_space_bytes));
+ _bm = new BufferManager(_nthreads, get_b_working_size(), reinterpret_cast<void *>(working_space_bytes));
working_space_bytes += BufferManager::get_storage_requirement(_maxthreads, get_b_working_size());
@@ -532,85 +442,66 @@
}
// Interface implementation - pretransposed
- bool B_is_pretransposed() const override
- {
+ bool B_is_pretransposed() const override {
return _pretransposed;
}
- bool B_pretranspose_required() const override
- {
- return _pretransposed && (_B_transposed == nullptr);
+ bool B_pretranspose_required() const override {
+ return _pretransposed && (_B_transposed==nullptr);
}
- size_t get_B_pretransposed_array_size() const override
- {
- size_t total = 0;
+ size_t get_B_pretransposed_array_size() const override {
+ size_t total=0;
blockwalker current(*this);
- do
- {
+ do {
/* Figure out the size of each block. */
size_t x_size = (current.xmax() - current.x0());
size_t k_size = (current.kmax() - current.k0());
/* Round sizes up as needed. */
- x_size = iceildiv(x_size, strategy::out_width);
- x_size *= strategy::out_width;
+ x_size = iceildiv(x_size, strategy::out_width());
+ x_size *= strategy::out_width();
- k_size = iceildiv(k_size, strategy::k_unroll);
- k_size *= strategy::k_unroll;
+ k_size = iceildiv(k_size, strategy::k_unroll());
+ k_size *= strategy::k_unroll();
total += x_size * k_size * sizeof(Toi);
- }
- while(current.advance());
+ } while (current.advance());
return total;
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override
- {
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
blockwalker current(*this);
- Toi *buffer = reinterpret_cast<Toi *>(in_buffer);
- _B_transposed = buffer;
+ Toi *buffer = reinterpret_cast<Toi *>(in_buffer);
+ _B_transposed = buffer;
+ strategy strat(_ci);
- do
- {
+ do {
/* Figure out the size of each block. */
size_t x_size = (current.xmax() - current.x0());
size_t k_size = (current.kmax() - current.k0());
/* Round sizes up as needed. */
- x_size = iceildiv(x_size, strategy::out_width);
- x_size *= strategy::out_width;
+ x_size = iceildiv(x_size, strategy::out_width());
+ x_size *= strategy::out_width();
- k_size = iceildiv(k_size, strategy::k_unroll);
- k_size *= strategy::k_unroll;
+ k_size = iceildiv(k_size, strategy::k_unroll());
+ k_size *= strategy::k_unroll();
- if(_trB ^ strategy::B_transpose)
- {
- Transform<strategy::B_interleave, strategy::B_block, true>(
- buffer, B + (current.multi() * B_multi_stride), ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax());
- }
- else
- {
- Transform<strategy::B_interleave, strategy::B_block, false>(
- buffer, B + (current.multi() * B_multi_stride), ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax());
- }
+ strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax(), _trB);
buffer += (x_size * k_size);
- }
- while(current.advance());
+ } while (current.advance());
}
- void set_pretransposed_B_data(void *in_buffer) override
- {
+ void set_pretransposed_B_data(void *in_buffer) override {
_B_transposed = reinterpret_cast<Toi *>(in_buffer);
}
- ~GemmInterleaved() override
- {
+ ~GemmInterleaved() override {
delete _bm;
}
};
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
index 075ab82..6bc7df0 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
@@ -34,8 +34,8 @@
#include "profiler.hpp"
#endif
-namespace arm_gemm
-{
+namespace arm_gemm {
+
// Implementation of the GemmCommon abstract class.
//
// This is implementation is for native GEMM with no transposition.
@@ -43,11 +43,10 @@
// By default the source data is used in-place, but if type conversion is
// needed we need to allocate working space (CURRENTLY NOT IMPLEMENTED).
-template <typename strategy, typename To, typename Tr>
-class GemmNative : public GemmCommon<To, Tr>
-{
+template<typename strategy, typename To, typename Tr>
+class GemmNative : public GemmCommon<To, Tr> {
typedef typename strategy::operand_type Toi;
- typedef typename strategy::result_type Tri;
+ typedef typename strategy::result_type Tri;
const unsigned int _Msize;
const unsigned int _Nsize;
@@ -58,72 +57,75 @@
Tr _beta;
- const CPUInfo *const _ci;
+ const CPUInfo * const _ci;
- unsigned int k_block = 0;
- unsigned int n_block = 0;
+ unsigned int k_block=0;
+ unsigned int n_block=0;
+
+ unsigned int window_per_batch() const {
+ return iceildiv(_Msize, strategy::out_height());
+ }
+
+ unsigned int window_per_multi() const {
+ return window_per_batch() * _nbatches;
+ }
public:
GemmNative(GemmNative &) = delete;
- GemmNative &operator=(GemmNative &) = delete;
+ GemmNative & operator= (GemmNative &) = delete;
- GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta)
- : _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci)
- {
+ GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta) :
+ _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci) {
/* For now don't do any blocking.*/
k_block = K;
n_block = N;
}
- // Window is number of out_height blocks
- unsigned int get_window_size() const override
- {
- return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis;
+ // Window is amount per multi multiplied by total number of multis.
+ unsigned int get_window_size() const override {
+ return window_per_multi() * _nmultis;
}
// Actually execute the GEMM.
- void execute(unsigned int start, unsigned int end, int) override
- {
+ void execute(unsigned int start, unsigned int end, int) override {
#ifdef CYCLE_PROFILING
profiler prof;
#endif
- strategy strat(_ci);
- const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height);
- const unsigned int window_per_multi = window_per_batch * _nbatches;
-
- const unsigned int first_multi = start / window_per_multi;
- const unsigned int last_multi = end / window_per_multi;
-
- const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch;
- const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch;
-
- const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
- const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
+ strategy strat(_ci);
static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same.");
static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
- for(unsigned int multi = first_multi; multi <= last_multi; multi++)
- {
- const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0;
- const unsigned int batch_max = (multi == last_multi) ? last_batch : _nbatches - 1;
+ /* Compute starting point based on 'start' */
+ unsigned int multi = start / window_per_multi();
+ unsigned int multi_pos = start % window_per_multi();
- for(unsigned int batch = batch_0; batch <= batch_max; batch++)
- {
- const unsigned int m_start = ((multi == first_multi) && (batch == first_batch)) ? first_row : 0;
- const unsigned int m_end = ((multi == last_multi) && (batch == last_batch)) ? last_row : _Msize;
+ unsigned int batch = multi_pos / window_per_batch();
+ unsigned int batch_pos = multi_pos % window_per_batch();
- for(unsigned int y0 = m_start; y0 < m_end; y0 += strategy::out_height)
- {
- const unsigned int ymax = std::min(y0 + strategy::out_height, m_end);
+ unsigned int y0 = batch_pos * strategy::out_height();
+
+ for (unsigned int pos=start; pos<end; pos++) {
+ const unsigned int ymax = std::min(y0 + strategy::out_height(), _Msize);
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax - y0) * _Nsize * _Ksize);
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
#endif
- strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
- this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
- this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
- _beta, (ymax - y0), _Nsize, _Ksize);
+ strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
+ this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
+ this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
+ _beta, (ymax-y0), _Nsize, _Ksize);
+
+ /* Advance to next item */
+ y0 += strategy::out_height();
+
+ /* Check for batch/multi overflow */
+ if (y0 >= _Msize) {
+ y0=0;
+ batch++;
+ if (batch == _nbatches) {
+ batch=0;
+ multi++;
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index 8f1f377..feea482 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -25,24 +25,37 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "kernels/a64_gemm_u16_12x8.hpp"
-namespace arm_gemm
-{
-template <>
-UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, uint32_t alpha, uint32_t beta,
- const int maxthreads, const bool pretransposed_hint)
-{
- return UniqueGemmCommon<uint16_t, uint32_t>(new GemmInterleaved<gemm_u16_12x8, uint16_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+namespace arm_gemm {
+
+class GemmImpl_gemm_u16_interleaved : public GemmImplementation<uint16_t, uint32_t> {
+public:
+ UniqueGemmCommon<uint16_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint16_t, uint32_t>(new GemmInterleaved<gemm_u16_12x8, uint16_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u16_interleaved() : GemmImplementation<uint16_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static GemmImpl_gemm_u16_interleaved gemm_u16_interleaved_impl{};
+
+static std::vector<GemmImplementation<uint16_t, uint32_t> *> gemm_u16_methods = {
+ &gemm_u16_interleaved_impl
+};
+
+template<>
+std::vector<GemmImplementation<uint16_t, uint32_t> *> &gemm_implementation_list<uint16_t, uint32_t>() {
+ return gemm_u16_methods;
}
-// Instantiate static class members
-const int gemm_u16_12x8::out_width;
-const int gemm_u16_12x8::out_height;
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t>(GemmArgs<uint32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<uint16_t, uint32_t>(GemmArgs<uint32_t> &args);
+template bool method_is_compatible<uint16_t, uint32_t>(GemmMethod method, GemmArgs<uint32_t> &args);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index 12e5aa6..60b7954 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -25,36 +25,55 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
+#include "kernels/a64_gemm_u16_12x8.hpp"
#include "kernels/a64_gemm_u8_12x8.hpp"
#include "kernels/a64_gemm_u8_4x4.hpp"
-namespace arm_gemm
-{
-template <>
-UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const uint32_t alpha, const uint32_t beta,
- const int maxthreads, const bool pretransposed_hint)
-{
- if(ci.has_dotprod())
- {
- // Dot product supporting CPUs. This family has a special version for A55r1.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+namespace arm_gemm {
+
+class GemmImpl_gemm_u8_interleaved_dot : public GemmImplementation<uint8_t, uint32_t> {
+public:
+ bool is_supported(const GemmArgs<uint32_t> &args) override {
+ return args._ci->has_dotprod();
}
- // Non dot-product code.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<uint8_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u8_interleaved_dot() : GemmImplementation<uint8_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED_DOT) { }
+};
+
+class GemmImpl_gemm_u8_interleaved : public GemmImplementation<uint8_t, uint32_t> {
+public:
+ UniqueGemmCommon<uint8_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u8_interleaved() : GemmImplementation<uint8_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static GemmImpl_gemm_u8_interleaved_dot gemm_u8_interleaved_dot_impl{};
+static GemmImpl_gemm_u8_interleaved gemm_u8_interleaved_impl{};
+
+static std::vector<GemmImplementation<uint8_t, uint32_t> *> gemm_u8_methods = {
+ &gemm_u8_interleaved_dot_impl,
+ &gemm_u8_interleaved_impl
+};
+
+template<>
+std::vector<GemmImplementation<uint8_t, uint32_t> *> &gemm_implementation_list<uint8_t, uint32_t>() {
+ return gemm_u8_methods;
}
-// Instantiate static class members
-const int gemm_u8_12x8::out_width;
-const int gemm_u8_12x8::out_height;
-
-const int gemm_u8_4x4::out_width;
-const int gemm_u8_4x4::out_height;
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(GemmArgs<uint32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<uint8_t, uint32_t>(GemmArgs<uint32_t> &args);
+template bool method_is_compatible<uint8_t, uint32_t>(GemmMethod method, GemmArgs<uint32_t> &args);
} // namespace arm_gemm
-#endif // aarch64
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
similarity index 68%
rename from src/core/NEON/kernels/arm_gemm/gemm_batched.hpp
rename to src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index 385358f..d65971e 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -25,80 +25,71 @@
#include "arm_gemm.hpp"
-namespace arm_gemm
-{
-template <typename To, typename Tr>
-class GemmBatched : public GemmCommon<To, Tr>
-{
+namespace arm_gemm {
+
+/* "Batched GEMV" (where M=1 and nbatches>1) can be executed much more
+ * efficiently as a GEMM (with M'=nbatches and nbatches'=1). This wrapper
+ * implements this. */
+template<typename To, typename Tr>
+class GemvBatched : public GemmCommon<To, Tr> {
private:
UniqueGemmCommon<To, Tr> _subgemm = nullptr;
public:
- GemmBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
- const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint)
- {
+ GemvBatched(const GemmArgs<Tr> &args) {
/* Just create a subgemm with batches->M */
- _subgemm = gemm<To, Tr>(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint);
+ GemmArgs<Tr> newargs = args;
+ newargs._Msize = args._nbatches;
+ newargs._nbatches = 1;
+ _subgemm = gemm<To,Tr>(newargs, nullptr);
}
void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
const To *B, const int ldb, const int B_multi_stride,
- Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override
- {
+ Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override {
/* A and C's batch stride becomes their new row stride. New batch stride is 0 as nbatches for subgemm is always 1. */
_subgemm->set_arrays(A, A_batch_stride, 0, A_multi_stride,
B, ldb, B_multi_stride,
C, C_batch_stride, 0, C_multi_stride);
}
- unsigned int get_window_size() const override
- {
+ unsigned int get_window_size() const override {
return _subgemm->get_window_size();
}
- void set_nthreads(int nthreads) override
- {
+ void set_nthreads(int nthreads) override {
_subgemm->set_nthreads(nthreads);
}
- void execute(unsigned int start, unsigned int end, int threadid) override
- {
+ void execute(unsigned int start, unsigned int end, int threadid) override {
_subgemm->execute(start, end, threadid);
}
- size_t get_working_size() const override
- {
+ size_t get_working_size() const override {
return _subgemm->get_working_size();
}
- void set_working_space(void *space) override
- {
+ void set_working_space(void *space) override {
_subgemm->set_working_space(space);
}
- bool B_is_pretransposed() const override
- {
+ bool B_is_pretransposed() const override {
return _subgemm->B_is_pretransposed();
}
- bool B_pretranspose_required() const override
- {
+ bool B_pretranspose_required() const override {
return _subgemm->B_pretranspose_required();
}
- size_t get_B_pretransposed_array_size() const override
- {
+ size_t get_B_pretransposed_array_size() const override {
return _subgemm->get_B_pretransposed_array_size();
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override
- {
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
_subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride);
}
- void set_pretransposed_B_data(void *buffer) override
- {
+ void set_pretransposed_B_data(void *buffer) override {
_subgemm->set_pretransposed_B_data(buffer);
}
};
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
index 63bb58a..e37d4c5 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
@@ -34,8 +34,8 @@
#include "profiler.hpp"
#endif
-namespace arm_gemm
-{
+namespace arm_gemm {
+
// Implementation of the GemmCommon abstract class.
//
// This is implementation is for a "native" (no-transform) GEMV with a
@@ -43,53 +43,48 @@
//
// As a native operation the source data is used in-place, so the internal
// and external operand/result types must match.
-template <typename strategy, typename To, typename Tr>
-class GemvNativeTransposed : public GemmCommon<To, Tr>
-{
+template<typename strategy, typename To, typename Tr>
+class GemvNativeTransposed : public GemmCommon<To, Tr> {
typedef typename strategy::operand_type Toi;
- typedef typename strategy::result_type Tri;
+ typedef typename strategy::result_type Tri;
const unsigned int _Nsize;
const unsigned int _Ksize;
+
const unsigned int _nmultis;
const Tr _beta;
- const CPUInfo *const _ci;
+ const CPUInfo * const _ci;
- unsigned int m_block = 0;
- unsigned int n_block = 0;
+ unsigned int m_block=0;
+ unsigned int n_block=0;
public:
GemvNativeTransposed(GemvNativeTransposed &) = delete;
- GemvNativeTransposed &operator=(GemvNativeTransposed &) = delete;
+ GemvNativeTransposed & operator= (GemvNativeTransposed &) = delete;
- GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const Tr beta)
- : _Nsize(N), _Ksize(K), _nmultis(nmultis), _beta(beta), _ci(ci)
- {
+ GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const Tr beta) : _Nsize(N), _Ksize(K), _nmultis(nmultis), _beta(beta), _ci(ci) {
/* For now don't do any blocking.*/
m_block = K;
n_block = N;
}
// Window is number of out_width blocks times number of multis.
- unsigned int get_window_size() const override
- {
+ unsigned int get_window_size() const override {
return iceildiv(_Nsize, strategy::out_width) * _nmultis;
}
// Actually execute the GEMV.
- void execute(unsigned int start, unsigned int end, int) override
- {
+ void execute(unsigned int start, unsigned int end, int) override {
#ifdef CYCLE_PROFILING
profiler prof;
#endif
-
strategy strat(_ci);
const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width);
- const unsigned int multi_0 = start / window_per_multi;
- const unsigned int multi_end = end / window_per_multi;
+ const unsigned int multi_0 = start / window_per_multi;
+ const unsigned int multi_end = end / window_per_multi;
const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width;
const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width;
@@ -97,27 +92,25 @@
static_assert(std::is_same<To, Toi>::value, "gemv_transposed: Operand types must be the same.");
static_assert(std::is_same<Tr, Tri>::value, "gemv_transposed: Result types must be the same.");
- for(unsigned int multi = multi_0; multi <= multi_end; multi++)
- {
- const unsigned int n_start = (multi == multi_0) ? n_0 : 0;
- const unsigned int n_end = (multi == multi_end) ? n_max : _Nsize;
+ for (unsigned int multi=multi_0; multi<=multi_end; multi++) {
+ const unsigned int n_start = (multi==multi_0) ? n_0 : 0;
+ const unsigned int n_end = (multi==multi_end) ? n_max : _Nsize;
- if(n_end <= n_start)
+ if (n_end <= n_start)
continue;
- for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
- {
+ for (unsigned int m0=0; m0<_Ksize; m0+=m_block) {
unsigned int mmax = std::min(m0 + m_block, _Ksize);
- for(unsigned int n0 = n_start; n0 < n_end; n0 += n_block)
- {
+
+ for (unsigned int n0=n_start; n0<n_end; n0+=n_block) {
unsigned int nmax = std::min(n0 + n_block, n_end);
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax - m0) * (nmax - n0));
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax-m0) * (nmax-n0));
#endif
strat.kernel(this->_Bptr + (multi * this->_B_multi_stride) + (m0 * this->_ldb) + n0,
this->_Aptr + (multi * this->_A_multi_stride) + m0,
this->_Cptr + (multi * this->_C_multi_stride) + n0,
- _beta, this->_ldb, (mmax - m0), (nmax - n0));
+ _beta, this->_ldb, (mmax-m0), (nmax-n0));
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 79f1359..d745883 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -34,66 +34,64 @@
#include "profiler.hpp"
#endif
-namespace arm_gemm
-{
+namespace arm_gemm {
+
// Implementation of the GemmCommon abstract class.
//
// This is implementation is for GEMV with pretransposition.
+//
// batches are not supported as a batched GEMV makes no sense (can be converted to a GEMM).
-
-template <typename strategy, typename To, typename Tr>
-class GemvPretransposed : public GemmCommon<To, Tr>
-{
+template<typename strategy, typename To, typename Tr>
+class GemvPretransposed : public GemmCommon<To, Tr> {
typedef typename strategy::operand_type Toi;
- typedef typename strategy::result_type Tri;
+ typedef typename strategy::result_type Tri;
const unsigned int _Nsize;
const unsigned int _Ksize;
+
const unsigned int _nmultis;
const bool _trB;
const Tr _beta;
- const CPUInfo *const _ci;
- const unsigned int _buffer_per_multi;
+ const CPUInfo * const _ci;
- unsigned int m_block = 0;
- unsigned int n_block = 0;
+ const unsigned int _buffer_per_multi;
+
+ unsigned int m_block=0;
+ unsigned int n_block=0;
const Toi *_A_pretransposed = nullptr;
public:
GemvPretransposed(GemvPretransposed &) = delete;
- GemvPretransposed &operator=(GemvPretransposed &) = delete;
+ GemvPretransposed & operator= (GemvPretransposed &) = delete;
- GemvPretransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const bool trB, const Tr beta)
- : _Nsize(N), _Ksize(K), _nmultis(nmultis), _trB(trB), _beta(beta), _ci(ci), _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave)
- {
+ GemvPretransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const bool trB, const Tr beta) :
+ _Nsize(N), _Ksize(K), _nmultis(nmultis), _trB(trB), _beta(beta), _ci(ci),
+ _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave) {
/* For now don't do any blocking.*/
m_block = K;
n_block = N;
}
// Window is number of out_width blocks, times number of multis.
- unsigned int get_window_size() const override
- {
+ unsigned int get_window_size() const override {
return iceildiv(_Nsize, strategy::out_width) * _nmultis;
}
// Actually execute the GEMV.
- void execute(unsigned int start, unsigned int end, int) override
- {
+ void execute(unsigned int start, unsigned int end, int) override {
#ifdef CYCLE_PROFILING
profiler prof;
#endif
-
strategy strat(_ci);
/* Break the window values down into multis of interest... */
const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width);
- const unsigned int multi_0 = start / window_per_multi;
- const unsigned int multi_end = end / window_per_multi;
+ const unsigned int multi_0 = start / window_per_multi;
+ const unsigned int multi_end = end / window_per_multi;
/* ... and figure out where we start and end in the first and last multi. */
const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width;
@@ -101,66 +99,56 @@
static_assert(std::is_same<Tr, Tri>::value, "GemvPretransposed: Result types must be the same.");
- for(unsigned int multi = multi_0; multi <= multi_end; multi++)
- {
- const unsigned int n_start = (multi == multi_0) ? n_0 : 0;
- const unsigned int n_end = (multi == multi_end) ? n_max : _Nsize;
+ for (unsigned int multi=multi_0; multi<=multi_end; multi++) {
+ const unsigned int n_start = (multi==multi_0) ? n_0 : 0;
+ const unsigned int n_end = (multi==multi_end) ? n_max : _Nsize;
- if(n_end <= n_start)
+ if (n_end <= n_start)
continue;
- for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
- {
+ for (unsigned int m0=0; m0<_Ksize; m0+=m_block) {
unsigned int mmax = std::min(m0 + m_block, _Ksize);
- for(unsigned int n = n_start; n < n_end; n += n_block)
- {
+
+ for (unsigned int n=n_start; n<n_end; n+=n_block) {
unsigned int nmax = std::min(n + n_block, n_end);
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax - m0) * (nmax - n));
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax-m0) * (nmax-n));
#endif
/* This assumes that the underlying call was a GEMM with M=1; for the N=1 case we would have to pick up this->_Bptr below instead */
strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave),
(_Ksize * strategy::A_interleave),
this->_Aptr + (multi * this->_A_multi_stride) + m0,
this->_Cptr + (multi * this->_C_multi_stride) + n,
- _beta, (mmax - m0), (nmax - n));
+ _beta, (mmax-m0), (nmax-n));
}
}
}
}
/* Pretransposed interface implementation */
- bool B_is_pretransposed() const override
- {
+ bool B_is_pretransposed() const override {
return true;
}
- bool B_pretranspose_required() const override
- {
+ bool B_pretranspose_required() const override {
/* Transpose is required if _A_pretransposed is still nullptr */
return (_A_pretransposed == nullptr);
}
- size_t get_B_pretransposed_array_size() const override
- {
+ size_t get_B_pretransposed_array_size() const override {
return _buffer_per_multi * _nmultis * sizeof(To);
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override
- {
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
Toi *A_buffer = reinterpret_cast<Toi *>(buffer);
- for(unsigned int multi = 0; multi < _nmultis; multi++)
- {
+ for (unsigned int multi=0; multi<_nmultis; multi++) {
/* Reverse sense here as we are dealing with B rather than A. So if
* strategy::A_transpose is false and _trB is false, we still
* transpose. */
- if(_trB ^ strategy::A_transpose)
- {
+ if (_trB ^ strategy::A_transpose) {
Transform<strategy::A_interleave, strategy::A_block, false>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
- }
- else
- {
+ } else {
Transform<strategy::A_interleave, strategy::A_block, true>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
}
}
@@ -168,8 +156,7 @@
_A_pretransposed = A_buffer;
}
- void set_pretransposed_B_data(void *buffer) override
- {
+ void set_pretransposed_B_data(void *buffer) override {
_A_pretransposed = reinterpret_cast<Toi *>(buffer);
}
};
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp
index de11dc5..06e6245 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp
@@ -25,8 +25,10 @@
#ifdef __arm__
-namespace arm_gemm
-{
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm {
+
// Actual kernel implementations
void a32_sgemm_8x6(const float *, const float *, float *, int, int, int);
void a32_sgemm_8x6_a53(const float *, const float *, float *, int, int, int);
@@ -40,35 +42,33 @@
// All kernels in the family must share these characteristics. The actual
// kernel to be used can be chosen at runtime, based on the CPU_type
// structure.
-class sgemm_8x6
-{
+class sgemm_8x6 {
public:
typedef float operand_type;
typedef float result_type;
typedef void (*kern_type)(const float *, const float *, float *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 6;
- static const int A_block = 1;
- static const int A_transpose = 0;
-
- /* Same for B input */
- static const int B_interleave = 8;
- static const int B_block = 1;
- static const int B_transpose = 1;
-
/* Kernel blocking parameters */
- static const int out_width = 8;
- static const int out_height = 6;
- static const int k_unroll = 1;
+ static int out_width() {
+ return 8;
+ }
+
+ static int out_height() {
+ return 6;
+ }
+
+ static int k_unroll() {
+ return 1;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 6, 8> transforms = {};
kern_type kernel = a32_sgemm_8x6;
- sgemm_8x6(const CPUInfo *ci)
- {
- switch(ci->get_cpu_model())
- {
+ sgemm_8x6(const CPUInfo *ci) {
+ switch(ci->get_cpu_model()) {
case CPUModel::A53:
kernel = a32_sgemm_8x6_a53;
break;
@@ -78,7 +78,7 @@
break;
default:
- kernel = a32_sgemm_8x6;
+ /* Generic kernel is selected by default. */
break;
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp
index 428498f..faabf66 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp
@@ -37,360 +37,371 @@
// Note that the intent of this is that either ablocks or bblocks will be 1
// - this construction allows the output loop to proceed in either order.
-namespace arm_gemm
-{
-void a32_sgemm_8x6_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a32_sgemm_8x6_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
const float *a_ptr = Apanel;
- float *c_ptr = Cpanel;
+ float *c_ptr = Cpanel;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const float *a_ptr0 = a_ptr;
- const float *b_ptr = Bpanel;
+ const float *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
- a_ptr = a_ptr0;
+ for (int xb=0; xb<bblocks; xb++) {
+ a_ptr = a_ptr0;
int tails = (K & 3);
- if(tails == 0)
- {
+ if (tails == 0) {
tails = 4;
}
- int k = ((K + 3) / 4) - 1;
+ int k = ((K+3)/4) - 1;
- __asm __volatile(
- "vmov.i32 q4, #0\n"
- "vld1.32 {d0-d1}, [%[a_ptr] :64]\n"
- "vmov.i32 q5, #0\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]\n"
- "vmov.i32 q6, #0\n"
- "ldr r0, [%[a_ptr], #0x10]\n"
- "vmov.i32 q7, #0\n"
- "ldr r1, [%[a_ptr], #0x14]\n"
- "vmov.i32 q8, #0\n" ASM_PREFETCH("[%[a_ptr], #0x40]") "vmov.i32 q9, #0\n" ASM_PREFETCH("[%[b_ptr], #0x40]") "vmov.i32 q10, #0\n" ASM_PREFETCH("[%[a_ptr], #0x80]") "vmov.i32 q11, #0\n"
+ __asm __volatile (
+ "vmov.i32 q4, #0\n"
+ "vld1.32 {d0-d1}, [%[a_ptr] :64]\n"
+ "vmov.i32 q5, #0\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]\n"
+ "vmov.i32 q6, #0\n"
+ "ldr r0, [%[a_ptr], #0x10]\n"
+ "vmov.i32 q7, #0\n"
+ "ldr r1, [%[a_ptr], #0x14]\n"
+ "vmov.i32 q8, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #0x40]")
+ "vmov.i32 q9, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #0x40]")
+ "vmov.i32 q10, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #0x80]")
+ "vmov.i32 q11, #0\n"
ASM_PREFETCH("[%[b_ptr], #0x80]")
- "vmov.i32 q12, #0\n"
- "vmov.i32 q13, #0\n" ASM_PREFETCH("[%[a_ptr], #0xC0]") "vmov.i32 q14, #0\n" ASM_PREFETCH("[%[b_ptr], #0XC0]")
- "vmov.i32 q15, #0\n"
- "cmp %[k], #0\n"
- "beq 6f\n"
+ "vmov.i32 q12, #0\n"
+ "vmov.i32 q13, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #0xC0]")
+ "vmov.i32 q14, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #0XC0]")
+ "vmov.i32 q15, #0\n"
+ "cmp %[k], #0\n"
+ "beq 6f\n"
"1:\n"
// Unroll 0
- "vldr d6, [%[b_ptr], #0x10]\n"
- "vmov d2, r0, r1\n"
- "vmla.f32 q4, q2, d0[0]\n"
- "ldr r0, [%[b_ptr], #0x18]\n"
- "vmla.f32 q5, q2, d0[1]\n"
- "ldr r1, [%[b_ptr], #0x1C]\n"
- "vmla.f32 q6, q2, d1[0]\n"
+ "vldr d6, [%[b_ptr], #0x10]\n"
+ "vmov d2, r0, r1\n"
+ "vmla.f32 q4, q2, d0[0]\n"
+ "ldr r0, [%[b_ptr], #0x18]\n"
+ "vmla.f32 q5, q2, d0[1]\n"
+ "ldr r1, [%[b_ptr], #0x1C]\n"
+ "vmla.f32 q6, q2, d1[0]\n"
- "vldr d3, [%[a_ptr], #0x18]\n"
- "vmov d7, r0, r1\n"
- "vmla.f32 q7, q2, d1[1]\n" ASM_PREFETCH("[%[a_ptr], #0x100]")
- "vmla.f32 q8, q2, d2[0]\n"
- "vmla.f32 q9, q2, d2[1]\n"
+ "vldr d3, [%[a_ptr], #0x18]\n"
+ "vmov d7, r0, r1\n"
+ "vmla.f32 q7, q2, d1[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #0x100]")
+ "vmla.f32 q8, q2, d2[0]\n"
+ "vmla.f32 q9, q2, d2[1]\n"
- "vldr d4, [%[b_ptr], #0x20]\n"
- "vmla.f32 q10, q3, d0[0]\n"
- "ldr r0, [%[b_ptr], #0x28]\n"
- "vmla.f32 q11, q3, d0[1]\n"
- "ldr r1, [%[b_ptr], #0x2C]\n"
- "vmla.f32 q12, q3, d1[0]\n"
+ "vldr d4, [%[b_ptr], #0x20]\n"
+ "vmla.f32 q10, q3, d0[0]\n"
+ "ldr r0, [%[b_ptr], #0x28]\n"
+ "vmla.f32 q11, q3, d0[1]\n"
+ "ldr r1, [%[b_ptr], #0x2C]\n"
+ "vmla.f32 q12, q3, d1[0]\n"
- "vldr d0, [%[a_ptr], #0x20]\n"
- "vmov d5, r0, r1\n"
- "vmla.f32 q13, q3, d1[1]\n"
- "ldr r0, [%[a_ptr], #0x28]\n"
- "vmla.f32 q14, q3, d2[0]\n"
- "ldr r1, [%[a_ptr], #0x2C]\n"
- "vmla.f32 q15, q3, d2[1]\n"
+ "vldr d0, [%[a_ptr], #0x20]\n"
+ "vmov d5, r0, r1\n"
+ "vmla.f32 q13, q3, d1[1]\n"
+ "ldr r0, [%[a_ptr], #0x28]\n"
+ "vmla.f32 q14, q3, d2[0]\n"
+ "ldr r1, [%[a_ptr], #0x2C]\n"
+ "vmla.f32 q15, q3, d2[1]\n"
// Unroll 1
- "vldr d6, [%[b_ptr], #0x30]\n"
- "vmov d1, r0, r1\n"
- "vmla.f32 q4, q2, d3[0]\n"
- "ldr r0, [%[b_ptr], #0x38]\n"
- "vmla.f32 q5, q2, d3[1]\n"
- "ldr r1, [%[b_ptr], #0x3C]\n"
- "vmla.f32 q6, q2, d0[0]\n"
+ "vldr d6, [%[b_ptr], #0x30]\n"
+ "vmov d1, r0, r1\n"
+ "vmla.f32 q4, q2, d3[0]\n"
+ "ldr r0, [%[b_ptr], #0x38]\n"
+ "vmla.f32 q5, q2, d3[1]\n"
+ "ldr r1, [%[b_ptr], #0x3C]\n"
+ "vmla.f32 q6, q2, d0[0]\n"
- "vldr d2, [%[a_ptr], #0x30]\n"
- "vmov d7, r0, r1\n"
- "vmla.f32 q7, q2, d0[1]\n" ASM_PREFETCH("[%[b_ptr], #0x100]")
- "vmla.f32 q8, q2, d1[0]\n"
- "vmla.f32 q9, q2, d1[1]\n"
+ "vldr d2, [%[a_ptr], #0x30]\n"
+ "vmov d7, r0, r1\n"
+ "vmla.f32 q7, q2, d0[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #0x100]")
+ "vmla.f32 q8, q2, d1[0]\n"
+ "vmla.f32 q9, q2, d1[1]\n"
- "vldr d4, [%[b_ptr], #0x40]\n"
- "vmla.f32 q10, q3, d3[0]\n"
- "ldr r0, [%[b_ptr], #0x48]\n"
- "vmla.f32 q11, q3, d3[1]\n"
- "ldr r1, [%[b_ptr], #0x4C]\n"
- "vmla.f32 q12, q3, d0[0]\n"
+ "vldr d4, [%[b_ptr], #0x40]\n"
+ "vmla.f32 q10, q3, d3[0]\n"
+ "ldr r0, [%[b_ptr], #0x48]\n"
+ "vmla.f32 q11, q3, d3[1]\n"
+ "ldr r1, [%[b_ptr], #0x4C]\n"
+ "vmla.f32 q12, q3, d0[0]\n"
- "vldr d3, [%[a_ptr], #0x38]\n"
- "vmov d5, r0, r1\n"
- "vmla.f32 q13, q3, d0[1]\n"
- "ldr r0, [%[a_ptr], #0x40]\n"
- "vmla.f32 q14, q3, d1[0]\n"
- "ldr r1, [%[a_ptr], #0x44]\n"
- "vmla.f32 q15, q3, d1[1]\n"
+ "vldr d3, [%[a_ptr], #0x38]\n"
+ "vmov d5, r0, r1\n"
+ "vmla.f32 q13, q3, d0[1]\n"
+ "ldr r0, [%[a_ptr], #0x40]\n"
+ "vmla.f32 q14, q3, d1[0]\n"
+ "ldr r1, [%[a_ptr], #0x44]\n"
+ "vmla.f32 q15, q3, d1[1]\n"
// Unroll 2
- "vldr d6, [%[b_ptr], #0x50]\n"
- "vmov d0, r0, r1\n"
- "vmla.f32 q4, q2, d2[0]\n"
- "ldr r0, [%[b_ptr], #0x58]\n"
- "vmla.f32 q5, q2, d2[1]\n"
- "ldr r1, [%[b_ptr], #0x5C]\n"
- "vmla.f32 q6, q2, d3[0]\n"
+ "vldr d6, [%[b_ptr], #0x50]\n"
+ "vmov d0, r0, r1\n"
+ "vmla.f32 q4, q2, d2[0]\n"
+ "ldr r0, [%[b_ptr], #0x58]\n"
+ "vmla.f32 q5, q2, d2[1]\n"
+ "ldr r1, [%[b_ptr], #0x5C]\n"
+ "vmla.f32 q6, q2, d3[0]\n"
- "vldr d1, [%[a_ptr], #0x48]\n"
- "vmov d7, r0, r1\n"
- "vmla.f32 q7, q2, d3[1]\n" ASM_PREFETCH("[%[a_ptr], #0x140]")
- "vmla.f32 q8, q2, d0[0]\n"
- "vmla.f32 q9, q2, d0[1]\n"
+ "vldr d1, [%[a_ptr], #0x48]\n"
+ "vmov d7, r0, r1\n"
+ "vmla.f32 q7, q2, d3[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #0x140]")
+ "vmla.f32 q8, q2, d0[0]\n"
+ "vmla.f32 q9, q2, d0[1]\n"
- "vldr d4, [%[b_ptr], #0x60]\n"
- "vmla.f32 q10, q3, d2[0]\n"
- "ldr r0, [%[b_ptr], #0x68]\n"
- "vmla.f32 q11, q3, d2[1]\n"
- "ldr r1, [%[b_ptr], #0x6C]\n"
- "vmla.f32 q12, q3, d3[0]\n"
+ "vldr d4, [%[b_ptr], #0x60]\n"
+ "vmla.f32 q10, q3, d2[0]\n"
+ "ldr r0, [%[b_ptr], #0x68]\n"
+ "vmla.f32 q11, q3, d2[1]\n"
+ "ldr r1, [%[b_ptr], #0x6C]\n"
+ "vmla.f32 q12, q3, d3[0]\n"
- "vldr d2, [%[a_ptr], #0x50]\n"
- "vmov d5, r0, r1\n"
- "vmla.f32 q13, q3, d3[1]\n"
- "ldr r0, [%[a_ptr], #0x58]\n"
- "vmla.f32 q14, q3, d0[0]\n"
- "ldr r1, [%[a_ptr], #0x5C]\n"
- "vmla.f32 q15, q3, d0[1]\n"
- "add %[a_ptr], %[a_ptr], #0x60\n"
+ "vldr d2, [%[a_ptr], #0x50]\n"
+ "vmov d5, r0, r1\n"
+ "vmla.f32 q13, q3, d3[1]\n"
+ "ldr r0, [%[a_ptr], #0x58]\n"
+ "vmla.f32 q14, q3, d0[0]\n"
+ "ldr r1, [%[a_ptr], #0x5C]\n"
+ "vmla.f32 q15, q3, d0[1]\n"
+ "add %[a_ptr], %[a_ptr], #0x60\n"
// Unroll 3
- "vldr d6, [%[b_ptr], #0x70]\n"
- "vmov d3, r0, r1\n"
- "vmla.f32 q4, q2, d1[0]\n"
- "ldr r0, [%[b_ptr], #0x78]\n"
- "vmla.f32 q5, q2, d1[1]\n"
- "ldr r1, [%[b_ptr], #0x7C]\n"
- "vmla.f32 q6, q2, d2[0]\n"
- "add %[b_ptr], %[b_ptr], #0x80\n"
+ "vldr d6, [%[b_ptr], #0x70]\n"
+ "vmov d3, r0, r1\n"
+ "vmla.f32 q4, q2, d1[0]\n"
+ "ldr r0, [%[b_ptr], #0x78]\n"
+ "vmla.f32 q5, q2, d1[1]\n"
+ "ldr r1, [%[b_ptr], #0x7C]\n"
+ "vmla.f32 q6, q2, d2[0]\n"
+ "add %[b_ptr], %[b_ptr], #0x80\n"
- "vldr d0, [%[a_ptr], #0x00]\n"
- "vmov d7, r0, r1\n"
- "vmla.f32 q7, q2, d2[1]\n" ASM_PREFETCH("[%[b_ptr], #0xC0]")
- "vmla.f32 q8, q2, d3[0]\n"
- "vmla.f32 q9, q2, d3[1]\n"
+ "vldr d0, [%[a_ptr], #0x00]\n"
+ "vmov d7, r0, r1\n"
+ "vmla.f32 q7, q2, d2[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #0xC0]")
+ "vmla.f32 q8, q2, d3[0]\n"
+ "vmla.f32 q9, q2, d3[1]\n"
- "vldr d4, [%[b_ptr], #0x00]\n"
- "vmla.f32 q10, q3, d1[0]\n"
- "ldr r0, [%[b_ptr], #0x08]\n"
- "vmla.f32 q11, q3, d1[1]\n"
- "ldr r1, [%[b_ptr], #0x0C]\n"
- "vmla.f32 q12, q3, d2[0]\n"
- "subs %[k], %[k], #1\n"
+ "vldr d4, [%[b_ptr], #0x00]\n"
+ "vmla.f32 q10, q3, d1[0]\n"
+ "ldr r0, [%[b_ptr], #0x08]\n"
+ "vmla.f32 q11, q3, d1[1]\n"
+ "ldr r1, [%[b_ptr], #0x0C]\n"
+ "vmla.f32 q12, q3, d2[0]\n"
+ "subs %[k], %[k], #1\n"
- "vldr d1, [%[a_ptr], #0x08]\n"
- "vmov d5, r0, r1\n"
- "vmla.f32 q13, q3, d2[1]\n"
- "ldr r0, [%[a_ptr], #0x10]\n"
- "vmla.f32 q14, q3, d3[0]\n"
- "ldr r1, [%[a_ptr], #0x14]\n"
- "vmla.f32 q15, q3, d3[1]\n"
- "bne 1b\n"
+ "vldr d1, [%[a_ptr], #0x08]\n"
+ "vmov d5, r0, r1\n"
+ "vmla.f32 q13, q3, d2[1]\n"
+ "ldr r0, [%[a_ptr], #0x10]\n"
+ "vmla.f32 q14, q3, d3[0]\n"
+ "ldr r1, [%[a_ptr], #0x14]\n"
+ "vmla.f32 q15, q3, d3[1]\n"
+ "bne 1b\n"
// "Tails" shows how many multiply blocks are needed at the
// end, must be 1-4 inclusive. Bail out to alternative tail
// immediately if it's 1.
"6:\n"
- "subs %[tails], %[tails], #1\n"
- "beq 3f\n"
+ "subs %[tails], %[tails], #1\n"
+ "beq 3f\n"
// Detached final iteration - for now adapt the generic
// tails rather than reimplementing for A53.
// Unroll 0
- "vmov d2, r0, r1\n"
- "add %[a_ptr], %[a_ptr], #0x18\n"
- "vmla.f32 q4, q2, d0[0]\n"
- "vld1.32 {d3}, [%[a_ptr] :64]!\n"
- "vmla.f32 q5, q2, d0[1]\n"
- "add %[b_ptr], %[b_ptr], #0x10\n"
- "vmla.f32 q6, q2, d1[0]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "vmla.f32 q7, q2, d1[1]\n"
- "vmla.f32 q8, q2, d2[0]\n"
- "subs %[tails], %[tails], #1\n"
- "vmla.f32 q9, q2, d2[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vmov d2, r0, r1\n"
+ "add %[a_ptr], %[a_ptr], #0x18\n"
+ "vmla.f32 q4, q2, d0[0]\n"
+ "vld1.32 {d3}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q5, q2, d0[1]\n"
+ "add %[b_ptr], %[b_ptr], #0x10\n"
+ "vmla.f32 q6, q2, d1[0]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d1[1]\n"
+ "vmla.f32 q8, q2, d2[0]\n"
+ "subs %[tails], %[tails], #1\n"
+ "vmla.f32 q9, q2, d2[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d0[0]\n"
- "vmla.f32 q11, q3, d0[1]\n"
- "vmla.f32 q12, q3, d1[0]\n"
- "vmla.f32 q13, q3, d1[1]\n"
- "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
- "vmla.f32 q14, q3, d2[0]\n"
- "vmla.f32 q15, q3, d2[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "beq 4f\n"
+ "vmla.f32 q10, q3, d0[0]\n"
+ "vmla.f32 q11, q3, d0[1]\n"
+ "vmla.f32 q12, q3, d1[0]\n"
+ "vmla.f32 q13, q3, d1[1]\n"
+ "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q14, q3, d2[0]\n"
+ "vmla.f32 q15, q3, d2[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "beq 4f\n"
// Unroll 1
- "vmla.f32 q4, q2, d3[0]\n"
- "vmla.f32 q5, q2, d3[1]\n"
- "subs %[tails], %[tails], #1\n"
- "vmla.f32 q6, q2, d0[0]\n"
- "vmla.f32 q7, q2, d0[1]\n"
- "vmla.f32 q8, q2, d1[0]\n"
- "vmla.f32 q9, q2, d1[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q4, q2, d3[0]\n"
+ "vmla.f32 q5, q2, d3[1]\n"
+ "subs %[tails], %[tails], #1\n"
+ "vmla.f32 q6, q2, d0[0]\n"
+ "vmla.f32 q7, q2, d0[1]\n"
+ "vmla.f32 q8, q2, d1[0]\n"
+ "vmla.f32 q9, q2, d1[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d3[0]\n"
- "vmla.f32 q11, q3, d3[1]\n"
- "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
- "vmla.f32 q12, q3, d0[0]\n"
- "vmla.f32 q13, q3, d0[1]\n"
- "vmla.f32 q14, q3, d1[0]\n"
- "vmla.f32 q15, q3, d1[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "beq 5f\n"
+ "vmla.f32 q10, q3, d3[0]\n"
+ "vmla.f32 q11, q3, d3[1]\n"
+ "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q12, q3, d0[0]\n"
+ "vmla.f32 q13, q3, d0[1]\n"
+ "vmla.f32 q14, q3, d1[0]\n"
+ "vmla.f32 q15, q3, d1[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "beq 5f\n"
// Unroll 2
- "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
- "vmla.f32 q4, q2, d2[0]\n"
- "vmla.f32 q5, q2, d2[1]\n"
- "vmla.f32 q6, q2, d3[0]\n"
- "vmla.f32 q7, q2, d3[1]\n"
- "vmla.f32 q8, q2, d0[0]\n"
- "vmla.f32 q9, q2, d0[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q4, q2, d2[0]\n"
+ "vmla.f32 q5, q2, d2[1]\n"
+ "vmla.f32 q6, q2, d3[0]\n"
+ "vmla.f32 q7, q2, d3[1]\n"
+ "vmla.f32 q8, q2, d0[0]\n"
+ "vmla.f32 q9, q2, d0[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d2[0]\n"
- "vmla.f32 q11, q3, d2[1]\n"
- "vmla.f32 q12, q3, d3[0]\n"
- "vmla.f32 q13, q3, d3[1]\n"
- "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
- "vmla.f32 q14, q3, d0[0]\n"
- "vmla.f32 q15, q3, d0[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d2[0]\n"
+ "vmla.f32 q11, q3, d2[1]\n"
+ "vmla.f32 q12, q3, d3[0]\n"
+ "vmla.f32 q13, q3, d3[1]\n"
+ "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q14, q3, d0[0]\n"
+ "vmla.f32 q15, q3, d0[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
// Unroll 3
- "vmla.f32 q4, q2, d1[0]\n"
- "vmla.f32 q10, q3, d1[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q5, q2, d1[1]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d1[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q6, q2, d2[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d2[0]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d2[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d2[1]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d3[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d3[0]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d3[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d3[1]\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "b 2f\n"
+ "vmla.f32 q4, q2, d1[0]\n"
+ "vmla.f32 q10, q3, d1[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q5, q2, d1[1]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d1[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q6, q2, d2[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d2[0]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d2[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d2[1]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d3[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d3[0]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d3[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d3[1]\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "b 2f\n"
// tails==1 final tail
"3:\n"
- "vmov d2, r0, r1\n"
- "add %[b_ptr], %[b_ptr], #0x10\n"
- "vmla.f32 q4, q2, d0[0]\n"
- "add %[a_ptr], %[a_ptr], #0x18\n"
- "vmla.f32 q5, q2, d0[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "vmla.f32 q6, q2, d1[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q10, q3, d0[0]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d0[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d1[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d1[1]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d1[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d2[0]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d2[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d2[1]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d2[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "b 2f\n"
+ "vmov d2, r0, r1\n"
+ "add %[b_ptr], %[b_ptr], #0x10\n"
+ "vmla.f32 q4, q2, d0[0]\n"
+ "add %[a_ptr], %[a_ptr], #0x18\n"
+ "vmla.f32 q5, q2, d0[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q6, q2, d1[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d0[0]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d0[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d1[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d1[1]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d1[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d2[0]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d2[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d2[1]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d2[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "b 2f\n"
// tails==2 final tail
"4:\n"
- "vmla.f32 q4, q2, d3[0]\n"
- "vmla.f32 q10, q3, d3[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q5, q2, d3[1]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d3[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q6, q2, d0[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d0[0]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d0[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d0[1]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d1[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d1[0]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d1[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d1[1]\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "b 2f\n"
+ "vmla.f32 q4, q2, d3[0]\n"
+ "vmla.f32 q10, q3, d3[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q5, q2, d3[1]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d3[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q6, q2, d0[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d0[0]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d0[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d0[1]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d1[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d1[0]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d1[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d1[1]\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "b 2f\n"
// tails==3 final tail
"5:\n"
- "vmla.f32 q4, q2, d2[0]\n"
- "vld1.32 {d0}, [%[a_ptr] :64]!\n"
- "vmla.f32 q5, q2, d2[1]\n"
- "vmla.f32 q6, q2, d3[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q10, q3, d2[0]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d2[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d3[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d3[1]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d3[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d0[0]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d0[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d0[1]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d0[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q4, q2, d2[0]\n"
+ "vld1.32 {d0}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q5, q2, d2[1]\n"
+ "vmla.f32 q6, q2, d3[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d2[0]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d2[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d3[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d3[1]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d3[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d0[0]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d0[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d0[1]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d0[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
"2:\n"
- "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n"
- : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k), [tails] "+r"(tails)
- :
- : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0", "r1");
+ "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n"
+ : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k), [tails] "+r" (tails)
+ :
+ : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15",
+ "r0", "r1", "cc", "memory"
+ );
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp
index 4cfb72a..76f51cc 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp
@@ -37,358 +37,376 @@
// Note that the intent of this is that either ablocks or bblocks will be 1
// - this construction allows the output loop to proceed in either order.
-namespace arm_gemm
-{
-void a32_sgemm_8x6_a55r1(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a32_sgemm_8x6_a55r1(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
const float *a_ptr = Apanel;
- float *c_ptr = Cpanel;
+ float *c_ptr = Cpanel;
/* Work out starting values for "k" and "tails" in the inner loop. */
int tails_initial = (K & 3);
- if(tails_initial == 0)
- {
+ if (tails_initial == 0) {
tails_initial = 4;
}
- int k_initial = ((K + 3) / 4) - 1;
+ int k_initial = ((K+3)/4) - 1;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const float *a_ptr0 = a_ptr;
- const float *b_ptr = Bpanel;
+ const float *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
int tails = tails_initial;
- int k = k_initial;
+ int k = k_initial;
a_ptr = a_ptr0;
- __asm __volatile(
- "vldr d0, [%[a_ptr]]\n"
- "vmov.i32 q4, #0\n"
- "vldr d1, [%[a_ptr], #0x08]\n"
- "vmov.i32 q5, #0\n"
- "vldr d4, [%[b_ptr]]\n"
- "vmov.i32 q6, #0\n"
- "vldr d5, [%[b_ptr], #0x08]\n"
- "vmov.i32 q7, #0\n"
- "vldr d2, [%[a_ptr], #0x10]\n"
- "vmov.i32 q8, #0\n" ASM_PREFETCH("[%[b_ptr], #0x40]") "vmov.i32 q9, #0\n" ASM_PREFETCH("[%[a_ptr], #0x40]") "vmov.i32 q10, #0\n" ASM_PREFETCH("[%[b_ptr], #0x80]") "vmov.i32 q11, #0\n"
- ASM_PREFETCH("[%[a_ptr], #0x80]") "vmov.i32 q12, #0\n" ASM_PREFETCH("[%[b_ptr], #0XC0]") "vmov.i32 q13, #0\n" ASM_PREFETCH("[%[a_ptr], #0xC0]") "vmov.i32 q14, #0\n"
- ASM_PREFETCH("[%[b_ptr], #0x100]") "vmov.i32 q15, #0\n" ASM_PREFETCH("[%[a_ptr], #0x100]") "cmp %[k], #0\n" ASM_PREFETCH("[%[b_ptr], #0x140]") "beq 6f\n"
+ __asm __volatile (
+ "vldr d0, [%[a_ptr]]\n"
+ "vmov.i32 q4, #0\n"
+ "vldr d1, [%[a_ptr], #0x08]\n"
+ "vmov.i32 q5, #0\n"
+ "vldr d4, [%[b_ptr]]\n"
+ "vmov.i32 q6, #0\n"
+ "vldr d5, [%[b_ptr], #0x08]\n"
+ "vmov.i32 q7, #0\n"
+ "vldr d2, [%[a_ptr], #0x10]\n"
+ "vmov.i32 q8, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #0x40]")
+ "vmov.i32 q9, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #0x40]")
+ "vmov.i32 q10, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #0x80]")
+ "vmov.i32 q11, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #0x80]")
+ "vmov.i32 q12, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #0XC0]")
+ "vmov.i32 q13, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #0xC0]")
+ "vmov.i32 q14, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #0x100]")
+ "vmov.i32 q15, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #0x100]")
+ "cmp %[k], #0\n"
+ ASM_PREFETCH("[%[b_ptr], #0x140]")
+ "beq 6f\n"
ASM_PREFETCH("[%[b_ptr], #0x180]")
"1:\n"
// Unroll 0
- "vmla.f32 q4, q2, d0[0]\n"
- "vldr d6, [%[b_ptr], #0x10]\n"
- "vmla.f32 q5, q2, d0[1]\n"
- "vldr d7, [%[b_ptr], #0x18]\n"
- "vmla.f32 q6, q2, d1[0]\n"
- "vldr d3, [%[a_ptr], #0x18]\n"
- "vmla.f32 q7, q2, d1[1]\n" ASM_PREFETCH("[%[a_ptr], #0x140]")
- "vmla.f32 q8, q2, d2[0]\n"
- "subs %[k], %[k], #1\n"
- "vmla.f32 q9, q2, d2[1]\n"
- "vldr d4, [%[b_ptr], #0x20]\n"
- "vmla.f32 q10, q3, d0[0]\n"
- "vldr d5, [%[b_ptr], #0x28]\n"
- "vmla.f32 q11, q3, d0[1]\n"
- "vldr d0, [%[a_ptr], #0x20]\n"
- "vmla.f32 q12, q3, d1[0]\n"
+ "vmla.f32 q4, q2, d0[0]\n"
+ "vldr d6, [%[b_ptr], #0x10]\n"
+ "vmla.f32 q5, q2, d0[1]\n"
+ "vldr d7, [%[b_ptr], #0x18]\n"
+ "vmla.f32 q6, q2, d1[0]\n"
+ "vldr d3, [%[a_ptr], #0x18]\n"
+ "vmla.f32 q7, q2, d1[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #0x140]")
+ "vmla.f32 q8, q2, d2[0]\n"
+ "subs %[k], %[k], #1\n"
+ "vmla.f32 q9, q2, d2[1]\n"
+ "vldr d4, [%[b_ptr], #0x20]\n"
+ "vmla.f32 q10, q3, d0[0]\n"
+ "vldr d5, [%[b_ptr], #0x28]\n"
+ "vmla.f32 q11, q3, d0[1]\n"
+ "vldr d0, [%[a_ptr], #0x20]\n"
+ "vmla.f32 q12, q3, d1[0]\n"
- "vmla.f32 q13, q3, d1[1]\n"
- "vldr d1, [%[a_ptr], #0x28]\n"
- "vmla.f32 q14, q3, d2[0]\n"
+ "vmla.f32 q13, q3, d1[1]\n"
+ "vldr d1, [%[a_ptr], #0x28]\n"
+ "vmla.f32 q14, q3, d2[0]\n"
- "vmla.f32 q15, q3, d2[1]\n"
- "vldr d6, [%[b_ptr], #0x30]\n"
+ "vmla.f32 q15, q3, d2[1]\n"
+ "vldr d6, [%[b_ptr], #0x30]\n"
// Unroll 1
- "vmla.f32 q4, q2, d3[0]\n"
- "vldr d7, [%[b_ptr], #0x38]\n"
- "vmla.f32 q5, q2, d3[1]\n"
- "vldr d2, [%[a_ptr], #0x30]\n"
- "vmla.f32 q6, q2, d0[0]\n"
+ "vmla.f32 q4, q2, d3[0]\n"
+ "vldr d7, [%[b_ptr], #0x38]\n"
+ "vmla.f32 q5, q2, d3[1]\n"
+ "vldr d2, [%[a_ptr], #0x30]\n"
+ "vmla.f32 q6, q2, d0[0]\n"
- "vmla.f32 q7, q2, d0[1]\n" ASM_PREFETCH("[%[b_ptr], #0x1C0]")
- "vmla.f32 q8, q2, d1[0]\n"
+ "vmla.f32 q7, q2, d0[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #0x1C0]")
+ "vmla.f32 q8, q2, d1[0]\n"
- "vmla.f32 q9, q2, d1[1]\n"
- "vldr d4, [%[b_ptr], #0x40]\n"
- "vmla.f32 q10, q3, d3[0]\n"
- "vldr d5, [%[b_ptr], #0x48]\n"
- "vmla.f32 q11, q3, d3[1]\n"
- "vldr d3, [%[a_ptr], #0x38]\n"
- "vmla.f32 q12, q3, d0[0]\n"
+ "vmla.f32 q9, q2, d1[1]\n"
+ "vldr d4, [%[b_ptr], #0x40]\n"
+ "vmla.f32 q10, q3, d3[0]\n"
+ "vldr d5, [%[b_ptr], #0x48]\n"
+ "vmla.f32 q11, q3, d3[1]\n"
+ "vldr d3, [%[a_ptr], #0x38]\n"
+ "vmla.f32 q12, q3, d0[0]\n"
- "vmla.f32 q13, q3, d0[1]\n"
- "vldr d0, [%[a_ptr], #0x40]\n"
- "vmla.f32 q14, q3, d1[0]\n"
+ "vmla.f32 q13, q3, d0[1]\n"
+ "vldr d0, [%[a_ptr], #0x40]\n"
+ "vmla.f32 q14, q3, d1[0]\n"
- "vmla.f32 q15, q3, d1[1]\n"
- "vldr d6, [%[b_ptr], #0x50]\n"
+ "vmla.f32 q15, q3, d1[1]\n"
+ "vldr d6, [%[b_ptr], #0x50]\n"
// Unroll 2
- "vmla.f32 q4, q2, d2[0]\n"
- "vldr d7, [%[b_ptr], #0x58]\n"
- "vmla.f32 q5, q2, d2[1]\n"
- "vldr d1, [%[a_ptr], #0x48]\n"
- "vmla.f32 q6, q2, d3[0]\n"
+ "vmla.f32 q4, q2, d2[0]\n"
+ "vldr d7, [%[b_ptr], #0x58]\n"
+ "vmla.f32 q5, q2, d2[1]\n"
+ "vldr d1, [%[a_ptr], #0x48]\n"
+ "vmla.f32 q6, q2, d3[0]\n"
- "vmla.f32 q7, q2, d3[1]\n" ASM_PREFETCH("[%[a_ptr], #0x180]")
- "vmla.f32 q8, q2, d0[0]\n"
+ "vmla.f32 q7, q2, d3[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #0x180]")
+ "vmla.f32 q8, q2, d0[0]\n"
- "vmla.f32 q9, q2, d0[1]\n"
- "vldr d4, [%[b_ptr], #0x60]\n"
- "vmla.f32 q10, q3, d2[0]\n"
- "vldr d5, [%[b_ptr], #0x68]\n"
- "vmla.f32 q11, q3, d2[1]\n"
- "vldr d2, [%[a_ptr], #0x50]\n"
- "vmla.f32 q12, q3, d3[0]\n"
+ "vmla.f32 q9, q2, d0[1]\n"
+ "vldr d4, [%[b_ptr], #0x60]\n"
+ "vmla.f32 q10, q3, d2[0]\n"
+ "vldr d5, [%[b_ptr], #0x68]\n"
+ "vmla.f32 q11, q3, d2[1]\n"
+ "vldr d2, [%[a_ptr], #0x50]\n"
+ "vmla.f32 q12, q3, d3[0]\n"
- "vmla.f32 q13, q3, d3[1]\n"
- "vldr d3, [%[a_ptr], #0x58]\n"
- "vmla.f32 q14, q3, d0[0]\n"
- "add %[a_ptr], %[a_ptr], #0x60\n"
- "vmla.f32 q15, q3, d0[1]\n"
- "vldr d6, [%[b_ptr], #0x70]\n"
+ "vmla.f32 q13, q3, d3[1]\n"
+ "vldr d3, [%[a_ptr], #0x58]\n"
+ "vmla.f32 q14, q3, d0[0]\n"
+ "add %[a_ptr], %[a_ptr], #0x60\n"
+ "vmla.f32 q15, q3, d0[1]\n"
+ "vldr d6, [%[b_ptr], #0x70]\n"
// Unroll 3
- "vmla.f32 q4, q2, d1[0]\n"
- "vldr d7, [%[b_ptr], #0x78]\n"
- "vmla.f32 q5, q2, d1[1]\n"
- "add %[b_ptr], %[b_ptr], #0x80\n"
- "vmla.f32 q6, q2, d2[0]\n"
- "vldr d0, [%[a_ptr], #0x00]\n"
- "vmla.f32 q7, q2, d2[1]\n" ASM_PREFETCH("[%[b_ptr], #0x180]")
- "vmla.f32 q8, q2, d3[0]\n"
+ "vmla.f32 q4, q2, d1[0]\n"
+ "vldr d7, [%[b_ptr], #0x78]\n"
+ "vmla.f32 q5, q2, d1[1]\n"
+ "add %[b_ptr], %[b_ptr], #0x80\n"
+ "vmla.f32 q6, q2, d2[0]\n"
+ "vldr d0, [%[a_ptr], #0x00]\n"
+ "vmla.f32 q7, q2, d2[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #0x180]")
+ "vmla.f32 q8, q2, d3[0]\n"
- "vmla.f32 q9, q2, d3[1]\n"
- "vldr d4, [%[b_ptr], #0x00]\n"
- "vmla.f32 q10, q3, d1[0]\n"
- "vldr d5, [%[b_ptr], #0x08]\n"
- "vmla.f32 q11, q3, d1[1]\n"
- "vldr d1, [%[a_ptr], #0x08]\n"
- "vmla.f32 q12, q3, d2[0]\n"
+ "vmla.f32 q9, q2, d3[1]\n"
+ "vldr d4, [%[b_ptr], #0x00]\n"
+ "vmla.f32 q10, q3, d1[0]\n"
+ "vldr d5, [%[b_ptr], #0x08]\n"
+ "vmla.f32 q11, q3, d1[1]\n"
+ "vldr d1, [%[a_ptr], #0x08]\n"
+ "vmla.f32 q12, q3, d2[0]\n"
- "vmla.f32 q13, q3, d2[1]\n"
- "vldr d2, [%[a_ptr], #0x10]\n"
- "vmla.f32 q14, q3, d3[0]\n"
+ "vmla.f32 q13, q3, d2[1]\n"
+ "vldr d2, [%[a_ptr], #0x10]\n"
+ "vmla.f32 q14, q3, d3[0]\n"
- "vmla.f32 q15, q3, d3[1]\n"
- "bne 1b\n"
+ "vmla.f32 q15, q3, d3[1]\n"
+ "bne 1b\n"
// "Tails" shows how many multiply blocks are needed at the
// end, must be 1-4 inclusive. Bail out to alternative tail
// immediately if it's 1.
"6:\n"
- "subs %[tails], %[tails], #1\n"
- "beq 3f\n"
+ "subs %[tails], %[tails], #1\n"
+ "beq 3f\n"
// Detached final iteration
// Unroll 0
- "vmla.f32 q4, q2, d0[0]\n"
- "vldr d6, [%[b_ptr], #0x10]\n"
- "vmla.f32 q5, q2, d0[1]\n"
- "vldr d7, [%[b_ptr], #0x18]\n"
- "vmla.f32 q6, q2, d1[0]\n"
- "vldr d3, [%[a_ptr], #0x18]\n"
- "vmla.f32 q7, q2, d1[1]\n"
- "subs %[tails], %[tails], #1\n"
- "vmla.f32 q8, q2, d2[0]\n"
- "vmla.f32 q9, q2, d2[1]\n"
- "vldr d4, [%[b_ptr], #0x20]\n"
+ "vmla.f32 q4, q2, d0[0]\n"
+ "vldr d6, [%[b_ptr], #0x10]\n"
+ "vmla.f32 q5, q2, d0[1]\n"
+ "vldr d7, [%[b_ptr], #0x18]\n"
+ "vmla.f32 q6, q2, d1[0]\n"
+ "vldr d3, [%[a_ptr], #0x18]\n"
+ "vmla.f32 q7, q2, d1[1]\n"
+ "subs %[tails], %[tails], #1\n"
+ "vmla.f32 q8, q2, d2[0]\n"
+ "vmla.f32 q9, q2, d2[1]\n"
+ "vldr d4, [%[b_ptr], #0x20]\n"
- "vmla.f32 q10, q3, d0[0]\n"
- "vldr d5, [%[b_ptr], #0x28]\n"
- "vmla.f32 q11, q3, d0[1]\n"
- "vldr d0, [%[a_ptr], #0x20]\n"
- "vmla.f32 q12, q3, d1[0]\n"
- "add %[b_ptr], %[b_ptr], #0x30\n"
- "vmla.f32 q13, q3, d1[1]\n"
- "vldr d1, [%[a_ptr], #0x28]\n"
- "vmla.f32 q14, q3, d2[0]\n"
- "vmla.f32 q15, q3, d2[1]\n"
- "beq 4f\n"
+ "vmla.f32 q10, q3, d0[0]\n"
+ "vldr d5, [%[b_ptr], #0x28]\n"
+ "vmla.f32 q11, q3, d0[1]\n"
+ "vldr d0, [%[a_ptr], #0x20]\n"
+ "vmla.f32 q12, q3, d1[0]\n"
+ "add %[b_ptr], %[b_ptr], #0x30\n"
+ "vmla.f32 q13, q3, d1[1]\n"
+ "vldr d1, [%[a_ptr], #0x28]\n"
+ "vmla.f32 q14, q3, d2[0]\n"
+ "vmla.f32 q15, q3, d2[1]\n"
+ "beq 4f\n"
// Unroll 1
- "vmla.f32 q4, q2, d3[0]\n"
- "vldr d6, [%[b_ptr], #0x30]\n"
- "vmla.f32 q5, q2, d3[1]\n"
- "vldr d7, [%[b_ptr], #0x38]\n"
- "vmla.f32 q6, q2, d0[0]\n"
- "vldr d2, [%[a_ptr], #0x30]\n"
- "vmla.f32 q7, q2, d0[1]\n"
- "subs %[tails], %[tails], #1\n"
- "vmla.f32 q8, q2, d1[0]\n"
+ "vmla.f32 q4, q2, d3[0]\n"
+ "vldr d6, [%[b_ptr], #0x30]\n"
+ "vmla.f32 q5, q2, d3[1]\n"
+ "vldr d7, [%[b_ptr], #0x38]\n"
+ "vmla.f32 q6, q2, d0[0]\n"
+ "vldr d2, [%[a_ptr], #0x30]\n"
+ "vmla.f32 q7, q2, d0[1]\n"
+ "subs %[tails], %[tails], #1\n"
+ "vmla.f32 q8, q2, d1[0]\n"
- "vmla.f32 q9, q2, d1[1]\n"
+ "vmla.f32 q9, q2, d1[1]\n"
- "vmla.f32 q10, q3, d3[0]\n"
- "vldr d4, [%[b_ptr], #0x40]\n"
- "vmla.f32 q11, q3, d3[1]\n"
- "vldr d5, [%[b_ptr], #0x48]\n"
- "vmla.f32 q12, q3, d0[0]\n"
- "vldr d3, [%[a_ptr], #0x38]\n"
- "vmla.f32 q13, q3, d0[1]\n"
- "vldr d0, [%[a_ptr], #0x40]\n"
- "vmla.f32 q14, q3, d1[0]\n"
- "vmla.f32 q15, q3, d1[1]\n"
- "beq 5f\n"
+ "vmla.f32 q10, q3, d3[0]\n"
+ "vldr d4, [%[b_ptr], #0x40]\n"
+ "vmla.f32 q11, q3, d3[1]\n"
+ "vldr d5, [%[b_ptr], #0x48]\n"
+ "vmla.f32 q12, q3, d0[0]\n"
+ "vldr d3, [%[a_ptr], #0x38]\n"
+ "vmla.f32 q13, q3, d0[1]\n"
+ "vldr d0, [%[a_ptr], #0x40]\n"
+ "vmla.f32 q14, q3, d1[0]\n"
+ "vmla.f32 q15, q3, d1[1]\n"
+ "beq 5f\n"
// Unroll 2
- "vmla.f32 q4, q2, d2[0]\n"
- "vldr d6, [%[b_ptr], #0x50]\n"
- "vmla.f32 q5, q2, d2[1]\n"
- "vldr d7, [%[b_ptr], #0x58]\n"
- "vmla.f32 q6, q2, d3[0]\n"
- "vldr d1, [%[a_ptr], #0x48]\n"
- "vmla.f32 q7, q2, d3[1]\n"
- "vmla.f32 q8, q2, d0[0]\n"
- "vmla.f32 q9, q2, d0[1]\n"
+ "vmla.f32 q4, q2, d2[0]\n"
+ "vldr d6, [%[b_ptr], #0x50]\n"
+ "vmla.f32 q5, q2, d2[1]\n"
+ "vldr d7, [%[b_ptr], #0x58]\n"
+ "vmla.f32 q6, q2, d3[0]\n"
+ "vldr d1, [%[a_ptr], #0x48]\n"
+ "vmla.f32 q7, q2, d3[1]\n"
+ "vmla.f32 q8, q2, d0[0]\n"
+ "vmla.f32 q9, q2, d0[1]\n"
- "vmla.f32 q10, q3, d2[0]\n"
- "vldr d4, [%[b_ptr], #0x60]\n"
- "vmla.f32 q11, q3, d2[1]\n"
- "vldr d5, [%[b_ptr], #0x68]\n"
- "vmla.f32 q12, q3, d3[0]\n"
- "vldr d2, [%[a_ptr], #0x50]\n"
- "vmla.f32 q13, q3, d3[1]\n"
- "vldr d3, [%[a_ptr], #0x58]\n"
- "vmla.f32 q14, q3, d0[0]\n"
- "vmla.f32 q15, q3, d0[1]\n"
+ "vmla.f32 q10, q3, d2[0]\n"
+ "vldr d4, [%[b_ptr], #0x60]\n"
+ "vmla.f32 q11, q3, d2[1]\n"
+ "vldr d5, [%[b_ptr], #0x68]\n"
+ "vmla.f32 q12, q3, d3[0]\n"
+ "vldr d2, [%[a_ptr], #0x50]\n"
+ "vmla.f32 q13, q3, d3[1]\n"
+ "vldr d3, [%[a_ptr], #0x58]\n"
+ "vmla.f32 q14, q3, d0[0]\n"
+ "vmla.f32 q15, q3, d0[1]\n"
// Unroll 3
- "vmla.f32 q4, q2, d1[0]\n"
- "vldr d6, [%[b_ptr], #0x70]\n"
- "vmla.f32 q5, q2, d1[1]\n"
- "vldr d7, [%[b_ptr], #0x78]\n"
- "vmla.f32 q10, q3, d1[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d1[1]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q6, q2, d2[0]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d2[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d2[1]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d2[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d3[0]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d3[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d3[1]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d3[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "add %[a_ptr], %[a_ptr], #0x60\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "add %[b_ptr], %[b_ptr], #0x80\n"
- "b 2f\n"
+ "vmla.f32 q4, q2, d1[0]\n"
+ "vldr d6, [%[b_ptr], #0x70]\n"
+ "vmla.f32 q5, q2, d1[1]\n"
+ "vldr d7, [%[b_ptr], #0x78]\n"
+ "vmla.f32 q10, q3, d1[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d1[1]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q6, q2, d2[0]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d2[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d2[1]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d2[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d3[0]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d3[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d3[1]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d3[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "add %[a_ptr], %[a_ptr], #0x60\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "add %[b_ptr], %[b_ptr], #0x80\n"
+ "b 2f\n"
// tails==1 final tail
"3:\n"
- "vmla.f32 q4, q2, d0[0]\n"
- "vldr d6, [%[b_ptr], #0x10]\n"
- "vmla.f32 q5, q2, d0[1]\n"
- "vldr d7, [%[b_ptr], #0x18]\n"
- "vmla.f32 q6, q2, d1[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q10, q3, d0[0]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d0[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d1[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d1[1]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d1[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d2[0]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d2[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d2[1]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d2[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "add %[a_ptr], %[a_ptr], #0x18\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "add %[b_ptr], %[b_ptr], #0x20\n"
- "b 2f\n"
+ "vmla.f32 q4, q2, d0[0]\n"
+ "vldr d6, [%[b_ptr], #0x10]\n"
+ "vmla.f32 q5, q2, d0[1]\n"
+ "vldr d7, [%[b_ptr], #0x18]\n"
+ "vmla.f32 q6, q2, d1[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d0[0]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d0[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d1[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d1[1]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d1[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d2[0]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d2[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d2[1]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d2[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "add %[a_ptr], %[a_ptr], #0x18\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "add %[b_ptr], %[b_ptr], #0x20\n"
+ "b 2f\n"
// tails==2 final tail
"4:\n"
- "vmla.f32 q4, q2, d3[0]\n"
- "vldr d6, [%[b_ptr], #0x30]\n"
- "vmla.f32 q5, q2, d3[1]\n"
- "vldr d7, [%[b_ptr], #0x38]\n"
- "vmla.f32 q10, q3, d3[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d3[1]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q6, q2, d0[0]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d0[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d0[1]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d0[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d1[0]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d1[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d1[1]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d1[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "add %[b_ptr], %[b_ptr], #0x40\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "add %[a_ptr], %[a_ptr], #0x30\n"
- "b 2f\n"
+ "vmla.f32 q4, q2, d3[0]\n"
+ "vldr d6, [%[b_ptr], #0x30]\n"
+ "vmla.f32 q5, q2, d3[1]\n"
+ "vldr d7, [%[b_ptr], #0x38]\n"
+ "vmla.f32 q10, q3, d3[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d3[1]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q6, q2, d0[0]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d0[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d0[1]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d0[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d1[0]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d1[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d1[1]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d1[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "add %[b_ptr], %[b_ptr], #0x40\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "add %[a_ptr], %[a_ptr], #0x30\n"
+ "b 2f\n"
// tails==3 final tail
"5:\n"
- "vmla.f32 q4, q2, d2[0]\n"
- "vldr d6, [%[b_ptr], #0x50]\n"
- "vmla.f32 q5, q2, d2[1]\n"
- "vldr d7, [%[b_ptr], #0x58]\n"
- "vmla.f32 q6, q2, d3[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q10, q3, d2[0]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d2[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d3[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d3[1]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d3[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d0[0]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d0[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d0[1]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d0[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "add %[a_ptr], %[a_ptr], #0x48\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "add %[b_ptr], %[b_ptr], #0x60\n"
+ "vmla.f32 q4, q2, d2[0]\n"
+ "vldr d6, [%[b_ptr], #0x50]\n"
+ "vmla.f32 q5, q2, d2[1]\n"
+ "vldr d7, [%[b_ptr], #0x58]\n"
+ "vmla.f32 q6, q2, d3[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d2[0]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d2[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d3[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d3[1]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d3[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d0[0]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d0[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d0[1]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d0[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "add %[a_ptr], %[a_ptr], #0x48\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "add %[b_ptr], %[b_ptr], #0x60\n"
"2:\n"
- "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n"
- : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k), [tails] "+r"(tails)
- :
- : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0", "r1");
+ "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n"
+ : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k), [tails] "+r" (tails)
+ :
+ : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15",
+ "r0", "r1", "cc", "memory"
+ );
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/generic.cpp
index d7d0484..3c840af 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/generic.cpp
@@ -37,120 +37,129 @@
// Note that the intent of this is that either ablocks or bblocks will be 1
// - this construction allows the output loop to proceed in either order.
-namespace arm_gemm
-{
-void a32_sgemm_8x6(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a32_sgemm_8x6(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
const float *a_ptr = Apanel;
- float *c_ptr = Cpanel;
+ float *c_ptr = Cpanel;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const float *a_ptr0 = a_ptr;
- const float *b_ptr = Bpanel;
+ const float *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
- a_ptr = a_ptr0;
+ for (int xb=0; xb<bblocks; xb++) {
+ a_ptr = a_ptr0;
int tails = (K & 3);
- if(tails == 0)
- {
+ if (tails == 0) {
tails = 4;
}
- int k = ((K + 3) / 4) - 1;
+ int k = ((K+3)/4) - 1;
- __asm __volatile(
- "vmov.i32 q4, #0\n"
- "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
- "vmov.i32 q5, #0\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmov.i32 q6, #0\n" ASM_PREFETCH("[%[a_ptr], #48]") "vmov.i32 q7, #0\n" ASM_PREFETCH("[%[b_ptr], #48]") "vmov.i32 q8, #0\n" ASM_PREFETCH("[%[a_ptr], #112]") "vmov.i32 q9, #0\n"
+ __asm __volatile (
+ "vmov.i32 q4, #0\n"
+ "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
+ "vmov.i32 q5, #0\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vmov.i32 q6, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #48]")
+ "vmov.i32 q7, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #48]")
+ "vmov.i32 q8, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #112]")
+ "vmov.i32 q9, #0\n"
ASM_PREFETCH("[%[b_ptr], #112]")
- "vmov.i32 q10, #0\n"
- "vmov.i32 q11, #0\n"
- "vmov.i32 q12, #0\n"
- "vmov.i32 q13, #0\n" ASM_PREFETCH("[%[a_ptr], #176]") "vmov.i32 q14, #0\n" ASM_PREFETCH("[%[b_ptr], #176]")
- "vmov.i32 q15, #0\n"
+ "vmov.i32 q10, #0\n"
+ "vmov.i32 q11, #0\n"
+ "vmov.i32 q12, #0\n"
+ "vmov.i32 q13, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #176]")
+ "vmov.i32 q14, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #176]")
+ "vmov.i32 q15, #0\n"
- "cmp %[k], #0\n"
- "beq 6f\n"
+ "cmp %[k], #0\n"
+ "beq 6f\n"
"1:\n"
// Unroll 0
- "vmla.f32 q4, q2, d0[0]\n"
- "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
- "vmla.f32 q5, q2, d0[1]\n"
- "vmla.f32 q6, q2, d1[0]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "vmla.f32 q7, q2, d1[1]\n"
- "vmla.f32 q8, q2, d2[0]\n"
- "vmla.f32 q9, q2, d2[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q4, q2, d0[0]\n"
+ "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q5, q2, d0[1]\n"
+ "vmla.f32 q6, q2, d1[0]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d1[1]\n"
+ "vmla.f32 q8, q2, d2[0]\n"
+ "vmla.f32 q9, q2, d2[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d0[0]\n"
- "vmla.f32 q11, q3, d0[1]\n"
- "vmla.f32 q12, q3, d1[0]\n"
- "vmla.f32 q13, q3, d1[1]\n"
- "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
- "vmla.f32 q14, q3, d2[0]\n"
- "vmla.f32 q15, q3, d2[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d0[0]\n"
+ "vmla.f32 q11, q3, d0[1]\n"
+ "vmla.f32 q12, q3, d1[0]\n"
+ "vmla.f32 q13, q3, d1[1]\n"
+ "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q14, q3, d2[0]\n"
+ "vmla.f32 q15, q3, d2[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
// Unroll 1
- "vmla.f32 q4, q2, d3[0]\n"
- "subs %[k], %[k], #1\n"
- "vmla.f32 q5, q2, d3[1]\n" ASM_PREFETCH("[%[a_ptr], #208]")
- "vmla.f32 q6, q2, d0[0]\n"
- "vmla.f32 q7, q2, d0[1]\n" ASM_PREFETCH("[%[b_ptr], #192]")
- "vmla.f32 q8, q2, d1[0]\n"
- "vmla.f32 q9, q2, d1[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q4, q2, d3[0]\n"
+ "subs %[k], %[k], #1\n"
+ "vmla.f32 q5, q2, d3[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #208]")
+ "vmla.f32 q6, q2, d0[0]\n"
+ "vmla.f32 q7, q2, d0[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "vmla.f32 q8, q2, d1[0]\n"
+ "vmla.f32 q9, q2, d1[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d3[0]\n"
- "vmla.f32 q11, q3, d3[1]\n"
- "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
- "vmla.f32 q12, q3, d0[0]\n"
- "vmla.f32 q13, q3, d0[1]\n"
- "vmla.f32 q14, q3, d1[0]\n"
- "vmla.f32 q15, q3, d1[1]\n"
- "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q10, q3, d3[0]\n"
+ "vmla.f32 q11, q3, d3[1]\n"
+ "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q12, q3, d0[0]\n"
+ "vmla.f32 q13, q3, d0[1]\n"
+ "vmla.f32 q14, q3, d1[0]\n"
+ "vmla.f32 q15, q3, d1[1]\n"
+ "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
// Unroll 2
- "vmla.f32 q4, q2, d2[0]\n"
- "vmla.f32 q5, q2, d2[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "vmla.f32 q6, q2, d3[0]\n"
- "vmla.f32 q7, q2, d3[1]\n" ASM_PREFETCH("[%[a_ptr], #240]")
- "vmla.f32 q8, q2, d0[0]\n"
- "vmla.f32 q9, q2, d0[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q4, q2, d2[0]\n"
+ "vmla.f32 q5, q2, d2[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q6, q2, d3[0]\n"
+ "vmla.f32 q7, q2, d3[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #240]")
+ "vmla.f32 q8, q2, d0[0]\n"
+ "vmla.f32 q9, q2, d0[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d2[0]\n"
- "vmla.f32 q11, q3, d2[1]\n" ASM_PREFETCH("[%[b_ptr], #208]")
- "vmla.f32 q12, q3, d3[0]\n"
- "vmla.f32 q13, q3, d3[1]\n"
- "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
- "vmla.f32 q14, q3, d0[0]\n"
- "vmla.f32 q15, q3, d0[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d2[0]\n"
+ "vmla.f32 q11, q3, d2[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #208]")
+ "vmla.f32 q12, q3, d3[0]\n"
+ "vmla.f32 q13, q3, d3[1]\n"
+ "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q14, q3, d0[0]\n"
+ "vmla.f32 q15, q3, d0[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
// Unroll 3
- "vmla.f32 q4, q2, d1[0]\n"
- "vmla.f32 q5, q2, d1[1]\n"
- "vmla.f32 q6, q2, d2[0]\n"
- "vmla.f32 q7, q2, d2[1]\n"
- "vmla.f32 q8, q2, d3[0]\n"
- "vmla.f32 q9, q2, d3[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q4, q2, d1[0]\n"
+ "vmla.f32 q5, q2, d1[1]\n"
+ "vmla.f32 q6, q2, d2[0]\n"
+ "vmla.f32 q7, q2, d2[1]\n"
+ "vmla.f32 q8, q2, d3[0]\n"
+ "vmla.f32 q9, q2, d3[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d1[0]\n"
- "vmla.f32 q11, q3, d1[1]\n"
- "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
- "vmla.f32 q12, q3, d2[0]\n"
- "vmla.f32 q13, q3, d2[1]\n"
- "vmla.f32 q14, q3, d3[0]\n"
- "vmla.f32 q15, q3, d3[1]\n"
- "bne 1b\n"
+ "vmla.f32 q10, q3, d1[0]\n"
+ "vmla.f32 q11, q3, d1[1]\n"
+ "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q12, q3, d2[0]\n"
+ "vmla.f32 q13, q3, d2[1]\n"
+ "vmla.f32 q14, q3, d3[0]\n"
+ "vmla.f32 q15, q3, d3[1]\n"
+ "bne 1b\n"
// Branch here if we never execute main loop.
"6:\n"
@@ -158,185 +167,187 @@
// "Tails" shows how many multiply blocks are needed at the
// end, must be 1-4 inclusive. Bail out to alternative tail
// immediately if it's 1.
- "subs %[tails], %[tails], #1\n"
- "beq 3f\n"
+ "subs %[tails], %[tails], #1\n"
+ "beq 3f\n"
// Detached final iteration
// Unroll 0
- "vmla.f32 q4, q2, d0[0]\n"
- "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
- "vmla.f32 q5, q2, d0[1]\n"
- "vmla.f32 q6, q2, d1[0]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "vmla.f32 q7, q2, d1[1]\n"
- "vmla.f32 q8, q2, d2[0]\n"
- "subs %[tails], %[tails], #1\n"
- "vmla.f32 q9, q2, d2[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q4, q2, d0[0]\n"
+ "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q5, q2, d0[1]\n"
+ "vmla.f32 q6, q2, d1[0]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d1[1]\n"
+ "vmla.f32 q8, q2, d2[0]\n"
+ "subs %[tails], %[tails], #1\n"
+ "vmla.f32 q9, q2, d2[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d0[0]\n"
- "vmla.f32 q11, q3, d0[1]\n"
- "vmla.f32 q12, q3, d1[0]\n"
- "vmla.f32 q13, q3, d1[1]\n"
- "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
- "vmla.f32 q14, q3, d2[0]\n"
- "vmla.f32 q15, q3, d2[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "beq 4f\n"
+ "vmla.f32 q10, q3, d0[0]\n"
+ "vmla.f32 q11, q3, d0[1]\n"
+ "vmla.f32 q12, q3, d1[0]\n"
+ "vmla.f32 q13, q3, d1[1]\n"
+ "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q14, q3, d2[0]\n"
+ "vmla.f32 q15, q3, d2[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "beq 4f\n"
// Unroll 1
- "vmla.f32 q4, q2, d3[0]\n"
- "vmla.f32 q5, q2, d3[1]\n"
- "subs %[tails], %[tails], #1\n"
- "vmla.f32 q6, q2, d0[0]\n"
- "vmla.f32 q7, q2, d0[1]\n"
- "vmla.f32 q8, q2, d1[0]\n"
- "vmla.f32 q9, q2, d1[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q4, q2, d3[0]\n"
+ "vmla.f32 q5, q2, d3[1]\n"
+ "subs %[tails], %[tails], #1\n"
+ "vmla.f32 q6, q2, d0[0]\n"
+ "vmla.f32 q7, q2, d0[1]\n"
+ "vmla.f32 q8, q2, d1[0]\n"
+ "vmla.f32 q9, q2, d1[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d3[0]\n"
- "vmla.f32 q11, q3, d3[1]\n"
- "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
- "vmla.f32 q12, q3, d0[0]\n"
- "vmla.f32 q13, q3, d0[1]\n"
- "vmla.f32 q14, q3, d1[0]\n"
- "vmla.f32 q15, q3, d1[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "beq 5f\n"
+ "vmla.f32 q10, q3, d3[0]\n"
+ "vmla.f32 q11, q3, d3[1]\n"
+ "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q12, q3, d0[0]\n"
+ "vmla.f32 q13, q3, d0[1]\n"
+ "vmla.f32 q14, q3, d1[0]\n"
+ "vmla.f32 q15, q3, d1[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "beq 5f\n"
// Unroll 2
- "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
- "vmla.f32 q4, q2, d2[0]\n"
- "vmla.f32 q5, q2, d2[1]\n"
- "vmla.f32 q6, q2, d3[0]\n"
- "vmla.f32 q7, q2, d3[1]\n"
- "vmla.f32 q8, q2, d0[0]\n"
- "vmla.f32 q9, q2, d0[1]\n"
- "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
+ "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q4, q2, d2[0]\n"
+ "vmla.f32 q5, q2, d2[1]\n"
+ "vmla.f32 q6, q2, d3[0]\n"
+ "vmla.f32 q7, q2, d3[1]\n"
+ "vmla.f32 q8, q2, d0[0]\n"
+ "vmla.f32 q9, q2, d0[1]\n"
+ "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
- "vmla.f32 q10, q3, d2[0]\n"
- "vmla.f32 q11, q3, d2[1]\n"
- "vmla.f32 q12, q3, d3[0]\n"
- "vmla.f32 q13, q3, d3[1]\n"
- "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
- "vmla.f32 q14, q3, d0[0]\n"
- "vmla.f32 q15, q3, d0[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d2[0]\n"
+ "vmla.f32 q11, q3, d2[1]\n"
+ "vmla.f32 q12, q3, d3[0]\n"
+ "vmla.f32 q13, q3, d3[1]\n"
+ "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q14, q3, d0[0]\n"
+ "vmla.f32 q15, q3, d0[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
// Unroll 3
- "vmla.f32 q4, q2, d1[0]\n"
- "vmla.f32 q10, q3, d1[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q5, q2, d1[1]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d1[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q6, q2, d2[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d2[0]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d2[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d2[1]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d3[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d3[0]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d3[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d3[1]\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "b 2f\n"
+ "vmla.f32 q4, q2, d1[0]\n"
+ "vmla.f32 q10, q3, d1[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q5, q2, d1[1]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d1[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q6, q2, d2[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d2[0]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d2[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d2[1]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d3[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d3[0]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d3[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d3[1]\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "b 2f\n"
// tails==1 final tail
"3:\n"
- "vmla.f32 q4, q2, d0[0]\n"
- "vld1.32 {d2}, [%[a_ptr] :64]!\n"
- "vmla.f32 q5, q2, d0[1]\n"
- "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
- "vmla.f32 q6, q2, d1[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q10, q3, d0[0]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d0[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d1[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d1[1]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d1[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d2[0]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d2[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d2[1]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d2[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "b 2f\n"
+ "vmla.f32 q4, q2, d0[0]\n"
+ "vld1.32 {d2}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q5, q2, d0[1]\n"
+ "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
+ "vmla.f32 q6, q2, d1[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d0[0]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d0[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d1[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d1[1]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d1[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d2[0]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d2[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d2[1]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d2[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "b 2f\n"
// tails==2 final tail
"4:\n"
- "vmla.f32 q4, q2, d3[0]\n"
- "vmla.f32 q10, q3, d3[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q5, q2, d3[1]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d3[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q6, q2, d0[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d0[0]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d0[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d0[1]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d1[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d1[0]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d1[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d1[1]\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
- "b 2f\n"
+ "vmla.f32 q4, q2, d3[0]\n"
+ "vmla.f32 q10, q3, d3[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q5, q2, d3[1]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d3[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q6, q2, d0[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d0[0]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d0[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d0[1]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d1[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d1[0]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d1[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d1[1]\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "b 2f\n"
// tails==3 final tail
"5:\n"
- "vmla.f32 q4, q2, d2[0]\n"
- "vld1.32 {d0}, [%[a_ptr] :64]!\n"
- "vmla.f32 q5, q2, d2[1]\n"
- "vmla.f32 q6, q2, d3[0]\n"
- "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
- "vmla.f32 q10, q3, d2[0]\n"
- "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
- "vmla.f32 q11, q3, d2[1]\n"
- "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
- "vmla.f32 q12, q3, d3[0]\n"
- "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
- "vmla.f32 q7, q2, d3[1]\n"
- "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
- "vmla.f32 q13, q3, d3[1]\n"
- "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
- "vmla.f32 q8, q2, d0[0]\n"
- "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
- "vmla.f32 q14, q3, d0[0]\n"
- "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
- "vmla.f32 q9, q2, d0[1]\n"
- "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
- "vmla.f32 q15, q3, d0[1]\n"
- "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
- "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q4, q2, d2[0]\n"
+ "vld1.32 {d0}, [%[a_ptr] :64]!\n"
+ "vmla.f32 q5, q2, d2[1]\n"
+ "vmla.f32 q6, q2, d3[0]\n"
+ "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q10, q3, d2[0]\n"
+ "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q11, q3, d2[1]\n"
+ "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q12, q3, d3[0]\n"
+ "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q7, q2, d3[1]\n"
+ "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q13, q3, d3[1]\n"
+ "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q8, q2, d0[0]\n"
+ "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q14, q3, d0[0]\n"
+ "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q9, q2, d0[1]\n"
+ "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
+ "vmla.f32 q15, q3, d0[1]\n"
+ "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
+ "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
"2:\n"
- "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n"
- : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k), [tails] "+r"(tails)
- :
- : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "cc");
+ "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n"
+ : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k), [tails] "+r" (tails)
+ :
+ : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15",
+ "cc", "memory"
+ );
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp
index 387f899..95a2bc2 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp
@@ -25,8 +25,10 @@
#ifdef __aarch64__
-namespace arm_gemm
-{
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm {
+
// Actual kernel implementations
void a64_gemm_s16_asimd_12x8(const int16_t *, const int16_t *, int32_t *, int, int, int);
@@ -38,34 +40,32 @@
// All kernels in the family must share these characteristics. The actual
// kernel to be used can be chosen at runtime, based on the CPU_type
// structure.
-class gemm_s16_12x8
-{
+class gemm_s16_12x8 {
public:
typedef int16_t operand_type;
typedef int32_t result_type;
typedef void (*kern_type)(const int16_t *, const int16_t *, int32_t *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 8;
- static const int A_block = 1;
- static const int A_transpose = 0;
-
- /* Same for B input */
- static const int B_interleave = 12;
- static const int B_block = 1;
- static const int B_transpose = 1;
-
/* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 1;
+ static int out_width() {
+ return 12;
+ }
+
+ static int out_height() {
+ return 8;
+ }
+
+ static int k_unroll() {
+ return 1;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12> transforms = {};
kern_type kernel = a64_gemm_s16_asimd_12x8;
- gemm_s16_12x8(const CPUInfo *ci)
- {
- }
+ gemm_s16_12x8(const CPUInfo *ci) { }
};
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8/generic.cpp
index b217dcf..823079a 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,281 +27,295 @@
#include "../../asmlib.hpp"
-namespace arm_gemm
-{
+namespace arm_gemm {
+
void a64_gemm_s16_asimd_12x8(const int16_t *Apanel, const int16_t *Bpanel, int32_t *Cpanel, int ablocks, int bblocks, int K)
{
- const int16_t *a_ptr = Apanel;
- int32_t *c_ptr = Cpanel;
+ const int16_t *a_ptr = Apanel;
+ int32_t *c_ptr = Cpanel;
- for(int yb = 0; yb < ablocks; yb++)
+ for (int yb = 0; yb < ablocks; yb++)
+ {
+ const int16_t *a_ptr0 = a_ptr;
+ const int16_t *b_ptr = Bpanel;
+
+ for (int xb = 0; xb < bblocks; xb++)
{
- const int16_t *a_ptr0 = a_ptr;
- const int16_t *b_ptr = Bpanel;
+ a_ptr = a_ptr0;
+ const bool odd_k = K & 0x1;
+ int k = (K+1)/2 - 1;
- for(int xb = 0; xb < bblocks; xb++)
- {
- a_ptr = a_ptr0;
- const bool odd_k = K & 0x1;
- int k = (K + 1) / 2 - 1;
+ register int16x8_t aa asm("v0");
+ register int16x8_t ab asm("v1");
+ register int16x8_t b0 asm("v2");
+ register int16x8_t b1 asm("v3");
+ register int16x8_t b2 asm("v4");
- register int16x8_t aa asm("v0");
- register int16x8_t ab asm("v1");
- register int16x8_t b0 asm("v2");
- register int16x8_t b1 asm("v3");
- register int16x8_t b2 asm("v4");
+ __asm __volatile (
+ "ldr %d[aa], [%x[a_ptr]]\n" // Load A[A].lower
+ "movi v5.4s, #0\n"
+ "ldr x20, [%x[a_ptr], #0x08]\n" // Load A[A].upper
+ "movi v6.4s, #0\n"
+ "ldr %d[b0], [%x[b_ptr]]\n" // Load B[0].lower
+ "ins %[aa].d[1], x20\n" // Merge A[A].lower and upper
+ "movi v7.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v8.4s, #0\n"
+ "ldr x20, [%x[b_ptr], #0x08]\n" // Load B[0].upper
+ "movi v9.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v10.4s, #0\n"
+ "ldr %d[b1], [%x[b_ptr], #0x10]\n" // Load B[1].lower
+ "ins %[b0].d[1], x20\n" // Merge B[0].lower and upper
+ "movi v11.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #96]")
+ "movi v12.4s, #0\n"
+ "movi v13.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #96]")
+ "movi v14.4s, #0\n"
+ "movi v15.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0\n"
+ "movi v17.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v18.4s, #0\n"
+ "movi v19.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #160]")
+ "movi v20.4s, #0\n"
+ "movi v21.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #160]")
+ "movi v22.4s, #0\n"
+ "movi v23.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v24.4s, #0\n"
+ "add %x[a_ptr], %x[a_ptr], #0x10\n"
+ "movi v25.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v26.4s, #0\n"
+ "add %x[b_ptr], %x[b_ptr], #0x18\n"
+ "movi v27.4s, #0\n"
+ "movi v28.4s, #0\n"
- __asm __volatile(
- "ldr %d[aa], [%x[a_ptr]]\n" // Load A[A].lower
- "movi v5.4s, #0\n"
- "ldr x20, [%x[a_ptr], #0x08]\n" // Load A[A].upper
- "movi v6.4s, #0\n"
- "ldr %d[b0], [%x[b_ptr]]\n" // Load B[0].lower
- "ins %[aa].d[1], x20\n" // Merge A[A].lower and upper
- "movi v7.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #64]")
- "movi v8.4s, #0\n"
- "ldr x20, [%x[b_ptr], #0x08]\n" // Load B[0].upper
- "movi v9.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #64]")
- "movi v10.4s, #0\n"
- "ldr %d[b1], [%x[b_ptr], #0x10]\n" // Load B[1].lower
- "ins %[b0].d[1], x20\n" // Merge B[0].lower and upper
- "movi v11.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #96]")
- "movi v12.4s, #0\n"
- "movi v13.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #96]")
- "movi v14.4s, #0\n"
- "movi v15.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #128]")
- "movi v16.4s, #0\n"
- "movi v17.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #128]")
- "movi v18.4s, #0\n"
- "movi v19.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #160]")
- "movi v20.4s, #0\n"
- "movi v21.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #160]")
- "movi v22.4s, #0\n"
- "movi v23.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #192]")
- "movi v24.4s, #0\n"
- "add %x[a_ptr], %x[a_ptr], #0x10\n"
- "movi v25.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #192]")
- "movi v26.4s, #0\n"
- "add %x[b_ptr], %x[b_ptr], #0x18\n"
- "movi v27.4s, #0\n"
- "movi v28.4s, #0\n"
+ "cbz %x[k], 2f\n" // Skip the loop if doing zero iterations.
- "cbz %x[k], 2f\n" // Skip the loop if doing zero iterations.
+ "1:\n" // Main loop
+ // First unroll
+ "smlal v5.4s, %[b0].4h, %[aa].h[0]\n"
+ "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper
+ "smlal v6.4s, %[b0].4h, %[aa].h[1]\n"
+ "smlal v7.4s, %[b0].4h, %[aa].h[2]\n"
+ "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower
+ "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper
+ "smlal v8.4s, %[b0].4h, %[aa].h[3]\n"
+ "smlal v9.4s, %[b0].4h, %[aa].h[4]\n"
+ "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper
+ "smlal v10.4s, %[b0].4h, %[aa].h[5]\n"
+ "smlal v11.4s, %[b0].4h, %[aa].h[6]\n"
+ "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower
+ "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper
+ "smlal v12.4s, %[b0].4h, %[aa].h[7]\n"
+ "smlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
+ "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper
+ "smlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
+ "smlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
+ "smlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
+ "smlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
+ "smlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
+ "smlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
+ "smlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
+ "ldr %d[b0], [%x[b_ptr], #0x18]\n" // Load B[0].lower
+ "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper
+ "smlal v21.4s, %[b1].4h, %[aa].h[0]\n"
+ "smlal v22.4s, %[b1].4h, %[aa].h[1]\n"
+ "ldr x20, [%x[b_ptr], #0x20]\n" // Load B[0].upper
+ "smlal v23.4s, %[b1].4h, %[aa].h[2]\n"
+ "smlal v24.4s, %[b1].4h, %[aa].h[3]\n"
+ "smlal v25.4s, %[b1].4h, %[aa].h[4]\n"
+ "smlal v26.4s, %[b1].4h, %[aa].h[5]\n"
+ "smlal v27.4s, %[b1].4h, %[aa].h[6]\n"
+ "smlal v28.4s, %[b1].4h, %[aa].h[7]\n"
- "1:\n" // Main loop
- // First unroll
- "smlal v5.4s, %[b0].4h, %[aa].h[0]\n"
- "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper
- "smlal v6.4s, %[b0].4h, %[aa].h[1]\n"
- "smlal v7.4s, %[b0].4h, %[aa].h[2]\n"
- "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower
- "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper
- "smlal v8.4s, %[b0].4h, %[aa].h[3]\n"
- "smlal v9.4s, %[b0].4h, %[aa].h[4]\n"
- "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper
- "smlal v10.4s, %[b0].4h, %[aa].h[5]\n"
- "smlal v11.4s, %[b0].4h, %[aa].h[6]\n"
- "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower
- "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper
- "smlal v12.4s, %[b0].4h, %[aa].h[7]\n"
- "smlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
- "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper
- "smlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
- "smlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
- "smlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
- "smlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
- "smlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
- "smlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
- "smlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
- "ldr %d[b0], [%x[b_ptr], #0x18]\n" // Load B[0].lower
- "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper
- "smlal v21.4s, %[b1].4h, %[aa].h[0]\n"
- "smlal v22.4s, %[b1].4h, %[aa].h[1]\n"
- "ldr x20, [%x[b_ptr], #0x20]\n" // Load B[0].upper
- "smlal v23.4s, %[b1].4h, %[aa].h[2]\n"
- "smlal v24.4s, %[b1].4h, %[aa].h[3]\n"
- "smlal v25.4s, %[b1].4h, %[aa].h[4]\n"
- "smlal v26.4s, %[b1].4h, %[aa].h[5]\n"
- "smlal v27.4s, %[b1].4h, %[aa].h[6]\n"
- "smlal v28.4s, %[b1].4h, %[aa].h[7]\n"
+ // Second unroll
+ "smlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
+ "ldr %d[aa], [%x[a_ptr], #0x10]\n" // Load A[A].lower
+ "ins %[b0].d[1], x20\n" // Merge B[0].lower and .upper
+ "smlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
+ "smlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
+ "ldr x20, [%x[a_ptr], #0x18]\n" // Load A[A].upper
+ "smlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
+ "smlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
+ "smlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
+ "smlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
+ "add %x[a_ptr], %x[a_ptr], #0x20\n"
+ "smlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
+ "smlal v13.4s, %[b2].4h, %[ab].h[0]\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
+ "smlal v14.4s, %[b2].4h, %[ab].h[1]\n"
+ "smlal v15.4s, %[b2].4h, %[ab].h[2]\n"
+ ASM_PREFETCH("[%[a_ptr], #320]")
+ "smlal v16.4s, %[b2].4h, %[ab].h[3]\n"
+ "smlal v17.4s, %[b2].4h, %[ab].h[4]\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
+ "smlal v18.4s, %[b2].4h, %[ab].h[5]\n"
+ "smlal v19.4s, %[b2].4h, %[ab].h[6]\n"
+ "smlal v20.4s, %[b2].4h, %[ab].h[7]\n"
+ "smlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
+ "smlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
+ "subs %x[k], %x[k], #0x1\n"
+ "smlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
+ "smlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
+ "ldr %d[b1], [%x[b_ptr], #0x28]\n" // Load B[1].lower
+ "ins %[aa].d[1], x20\n" // Merge A[A].lower and .upper
+ "smlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
+ "smlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
+ "add %x[b_ptr], %x[b_ptr], #0x30\n"
+ "smlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
+ "smlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
+ "bne 1b\n"
- // Second unroll
- "smlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
- "ldr %d[aa], [%x[a_ptr], #0x10]\n" // Load A[A].lower
- "ins %[b0].d[1], x20\n" // Merge B[0].lower and .upper
- "smlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
- "smlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
- "ldr x20, [%x[a_ptr], #0x18]\n" // Load A[A].upper
- "smlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
- "smlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
- "smlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
- "smlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
- "add %x[a_ptr], %x[a_ptr], #0x20\n"
- "smlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
- "smlal v13.4s, %[b2].4h, %[ab].h[0]\n" ASM_PREFETCH("[%[b_ptr], #320]")
- "smlal v14.4s, %[b2].4h, %[ab].h[1]\n"
- "smlal v15.4s, %[b2].4h, %[ab].h[2]\n" ASM_PREFETCH("[%[a_ptr], #320]")
- "smlal v16.4s, %[b2].4h, %[ab].h[3]\n"
- "smlal v17.4s, %[b2].4h, %[ab].h[4]\n" ASM_PREFETCH("[%[b_ptr], #448]")
- "smlal v18.4s, %[b2].4h, %[ab].h[5]\n"
- "smlal v19.4s, %[b2].4h, %[ab].h[6]\n"
- "smlal v20.4s, %[b2].4h, %[ab].h[7]\n"
- "smlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
- "smlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
- "subs %x[k], %x[k], #0x1\n"
- "smlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
- "smlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
- "ldr %d[b1], [%x[b_ptr], #0x28]\n" // Load B[1].lower
- "ins %[aa].d[1], x20\n" // Merge A[A].lower and .upper
- "smlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
- "smlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
- "add %x[b_ptr], %x[b_ptr], #0x30\n"
- "smlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
- "smlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
- "bne 1b\n"
+ "2:\n" // Even tail
+ "cbnz %x[odd_k], 3f\n"
- "2:\n" // Even tail
- "cbnz %x[odd_k], 3f\n"
+ "smlal v5.4s, %[b0].4h, %[aa].h[0]\n"
+ "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper
+ "smlal v6.4s, %[b0].4h, %[aa].h[1]\n"
+ "smlal v7.4s, %[b0].4h, %[aa].h[2]\n"
+ "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower
+ "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper
+ "smlal v8.4s, %[b0].4h, %[aa].h[3]\n"
+ "smlal v9.4s, %[b0].4h, %[aa].h[4]\n"
+ "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper
+ "smlal v10.4s, %[b0].4h, %[aa].h[5]\n"
+ "smlal v11.4s, %[b0].4h, %[aa].h[6]\n"
+ "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower
+ "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper
+ "smlal v12.4s, %[b0].4h, %[aa].h[7]\n"
+ "smlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
+ "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper
+ "smlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
+ "smlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
+ "smlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
+ "add %[a_ptr], %[a_ptr], #0x10\n"
+ "smlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
+ "add %[b_ptr], %[b_ptr], #0x18\n"
+ "smlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
+ "smlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
+ "smlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
+ "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper
+ "smlal v21.4s, %[b1].4h, %[aa].h[0]\n"
+ "smlal v22.4s, %[b1].4h, %[aa].h[1]\n"
+ "smlal v23.4s, %[b1].4h, %[aa].h[2]\n"
+ "smlal v24.4s, %[b1].4h, %[aa].h[3]\n"
+ "smlal v25.4s, %[b1].4h, %[aa].h[4]\n"
+ "smlal v26.4s, %[b1].4h, %[aa].h[5]\n"
+ "smlal v27.4s, %[b1].4h, %[aa].h[6]\n"
+ "smlal v28.4s, %[b1].4h, %[aa].h[7]\n"
- "smlal v5.4s, %[b0].4h, %[aa].h[0]\n"
- "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper
- "smlal v6.4s, %[b0].4h, %[aa].h[1]\n"
- "smlal v7.4s, %[b0].4h, %[aa].h[2]\n"
- "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower
- "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper
- "smlal v8.4s, %[b0].4h, %[aa].h[3]\n"
- "smlal v9.4s, %[b0].4h, %[aa].h[4]\n"
- "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper
- "smlal v10.4s, %[b0].4h, %[aa].h[5]\n"
- "smlal v11.4s, %[b0].4h, %[aa].h[6]\n"
- "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower
- "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper
- "smlal v12.4s, %[b0].4h, %[aa].h[7]\n"
- "smlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
- "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper
- "smlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
- "smlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
- "smlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
- "add %[a_ptr], %[a_ptr], #0x10\n"
- "smlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
- "add %[b_ptr], %[b_ptr], #0x18\n"
- "smlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
- "smlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
- "smlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
- "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper
- "smlal v21.4s, %[b1].4h, %[aa].h[0]\n"
- "smlal v22.4s, %[b1].4h, %[aa].h[1]\n"
- "smlal v23.4s, %[b1].4h, %[aa].h[2]\n"
- "smlal v24.4s, %[b1].4h, %[aa].h[3]\n"
- "smlal v25.4s, %[b1].4h, %[aa].h[4]\n"
- "smlal v26.4s, %[b1].4h, %[aa].h[5]\n"
- "smlal v27.4s, %[b1].4h, %[aa].h[6]\n"
- "smlal v28.4s, %[b1].4h, %[aa].h[7]\n"
+ "smlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
+ "smlal v13.4s, %[b2].4h, %[ab].h[0]\n"
+ "smlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
+ "smlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
+ "smlal v14.4s, %[b2].4h, %[ab].h[1]\n"
+ "str q5, [%x[c_ptr]]\n"
+ "smlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
+ "str q13, [%x[c_ptr], #0x10]\n"
+ "smlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
+ "str q21, [%x[c_ptr], #0x20]\n"
+ "smlal v15.4s, %[b2].4h, %[ab].h[2]\n"
+ "str q6, [%x[c_ptr], #0x30]\n"
+ "smlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
+ "str q14, [%x[c_ptr], #0x40]\n"
+ "smlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
+ "str q22, [%x[c_ptr], #0x50]\n"
+ "smlal v16.4s, %[b2].4h, %[ab].h[3]\n"
+ "str q7, [%x[c_ptr], #0x60]\n"
+ "smlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
+ "str q15, [%x[c_ptr], #0x70]\n"
+ "smlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
+ "str q23, [%x[c_ptr], #0x80]\n"
+ "smlal v17.4s, %[b2].4h, %[ab].h[4]\n"
+ "str q8, [%x[c_ptr], #0x90]\n"
+ "smlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
+ "str q16, [%x[c_ptr], #0xa0]\n"
+ "smlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
+ "str q24, [%x[c_ptr], #0xb0]\n"
+ "smlal v18.4s, %[b2].4h, %[ab].h[5]\n"
+ "str q9, [%x[c_ptr], #0xc0]\n"
+ "smlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
+ "str q17, [%x[c_ptr], #0xd0]\n"
+ "smlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
+ "str q25, [%x[c_ptr], #0xe0]\n"
+ "smlal v19.4s, %[b2].4h, %[ab].h[6]\n"
+ "str q10, [%x[c_ptr], #0xf0]\n"
+ "smlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
+ "str q18, [%x[c_ptr], #0x100]\n"
+ "smlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
+ "str q26, [%x[c_ptr], #0x110]\n"
+ "smlal v20.4s, %[b2].4h, %[ab].h[7]\n"
+ "str q11, [%x[c_ptr], #0x120]\n"
+ "smlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
+ "str q19, [%x[c_ptr], #0x130]\n"
+ "b 4f\n" // Complete write out
- "smlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
- "smlal v13.4s, %[b2].4h, %[ab].h[0]\n"
- "smlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
- "smlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
- "smlal v14.4s, %[b2].4h, %[ab].h[1]\n"
- "str q5, [%x[c_ptr]]\n"
- "smlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
- "str q13, [%x[c_ptr], #0x10]\n"
- "smlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
- "str q21, [%x[c_ptr], #0x20]\n"
- "smlal v15.4s, %[b2].4h, %[ab].h[2]\n"
- "str q6, [%x[c_ptr], #0x30]\n"
- "smlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
- "str q14, [%x[c_ptr], #0x40]\n"
- "smlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
- "str q22, [%x[c_ptr], #0x50]\n"
- "smlal v16.4s, %[b2].4h, %[ab].h[3]\n"
- "str q7, [%x[c_ptr], #0x60]\n"
- "smlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
- "str q15, [%x[c_ptr], #0x70]\n"
- "smlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
- "str q23, [%x[c_ptr], #0x80]\n"
- "smlal v17.4s, %[b2].4h, %[ab].h[4]\n"
- "str q8, [%x[c_ptr], #0x90]\n"
- "smlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
- "str q16, [%x[c_ptr], #0xa0]\n"
- "smlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
- "str q24, [%x[c_ptr], #0xb0]\n"
- "smlal v18.4s, %[b2].4h, %[ab].h[5]\n"
- "str q9, [%x[c_ptr], #0xc0]\n"
- "smlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
- "str q17, [%x[c_ptr], #0xd0]\n"
- "smlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
- "str q25, [%x[c_ptr], #0xe0]\n"
- "smlal v19.4s, %[b2].4h, %[ab].h[6]\n"
- "str q10, [%x[c_ptr], #0xf0]\n"
- "smlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
- "str q18, [%x[c_ptr], #0x100]\n"
- "smlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
- "str q26, [%x[c_ptr], #0x110]\n"
- "smlal v20.4s, %[b2].4h, %[ab].h[7]\n"
- "str q11, [%x[c_ptr], #0x120]\n"
- "smlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
- "str q19, [%x[c_ptr], #0x130]\n"
- "b 4f\n" // Complete write out
+ "3:\n" // Odd tail
+ "smlal v5.4s, %[b0].4h, %[aa].h[0]\n"
+ "smlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
+ "smlal v21.4s, %[b1].4h, %[aa].h[0]\n"
+ "smlal v6.4s, %[b0].4h, %[aa].h[1]\n"
+ "smlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
+ "smlal v22.4s, %[b1].4h, %[aa].h[1]\n"
+ "str q5, [%x[c_ptr]]\n"
+ "smlal v7.4s, %[b0].4h, %[aa].h[2]\n"
+ "str q13, [%x[c_ptr], #0x10]\n"
+ "smlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
+ "str q21, [%x[c_ptr], #0x20]\n"
+ "smlal v23.4s, %[b1].4h, %[aa].h[2]\n"
+ "str q6, [%x[c_ptr], #0x30]\n"
+ "smlal v8.4s, %[b0].4h, %[aa].h[3]\n"
+ "str q14, [%x[c_ptr], #0x40]\n"
+ "smlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
+ "str q22, [%x[c_ptr], #0x50]\n"
+ "smlal v24.4s, %[b1].4h, %[aa].h[3]\n"
+ "str q7, [%x[c_ptr], #0x60]\n"
+ "smlal v9.4s, %[b0].4h, %[aa].h[4]\n"
+ "str q15, [%x[c_ptr], #0x70]\n"
+ "smlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
+ "str q23, [%x[c_ptr], #0x80]\n"
+ "smlal v25.4s, %[b1].4h, %[aa].h[4]\n"
+ "str q8, [%x[c_ptr], #0x90]\n"
+ "smlal v10.4s, %[b0].4h, %[aa].h[5]\n"
+ "str q16, [%x[c_ptr], #0xa0]\n"
+ "smlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
+ "str q24, [%x[c_ptr], #0xb0]\n"
+ "smlal v26.4s, %[b1].4h, %[aa].h[5]\n"
+ "str q9, [%x[c_ptr], #0xc0]\n"
+ "smlal v11.4s, %[b0].4h, %[aa].h[6]\n"
+ "str q17, [%x[c_ptr], #0xd0]\n"
+ "smlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
+ "str q25, [%x[c_ptr], #0xe0]\n"
+ "smlal v27.4s, %[b1].4h, %[aa].h[6]\n"
+ "str q10, [%x[c_ptr], #0xf0]\n"
+ "smlal v12.4s, %[b0].4h, %[aa].h[7]\n"
+ "str q18, [%x[c_ptr], #0x100]\n"
+ "smlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
+ "str q26, [%x[c_ptr], #0x110]\n"
+ "smlal v28.4s, %[b1].4h, %[aa].h[7]\n"
+ "str q11, [%x[c_ptr], #0x120]\n"
- "3:\n" // Odd tail
- "smlal v5.4s, %[b0].4h, %[aa].h[0]\n"
- "smlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
- "smlal v21.4s, %[b1].4h, %[aa].h[0]\n"
- "smlal v6.4s, %[b0].4h, %[aa].h[1]\n"
- "smlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
- "smlal v22.4s, %[b1].4h, %[aa].h[1]\n"
- "str q5, [%x[c_ptr]]\n"
- "smlal v7.4s, %[b0].4h, %[aa].h[2]\n"
- "str q13, [%x[c_ptr], #0x10]\n"
- "smlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
- "str q21, [%x[c_ptr], #0x20]\n"
- "smlal v23.4s, %[b1].4h, %[aa].h[2]\n"
- "str q6, [%x[c_ptr], #0x30]\n"
- "smlal v8.4s, %[b0].4h, %[aa].h[3]\n"
- "str q14, [%x[c_ptr], #0x40]\n"
- "smlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
- "str q22, [%x[c_ptr], #0x50]\n"
- "smlal v24.4s, %[b1].4h, %[aa].h[3]\n"
- "str q7, [%x[c_ptr], #0x60]\n"
- "smlal v9.4s, %[b0].4h, %[aa].h[4]\n"
- "str q15, [%x[c_ptr], #0x70]\n"
- "smlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
- "str q23, [%x[c_ptr], #0x80]\n"
- "smlal v25.4s, %[b1].4h, %[aa].h[4]\n"
- "str q8, [%x[c_ptr], #0x90]\n"
- "smlal v10.4s, %[b0].4h, %[aa].h[5]\n"
- "str q16, [%x[c_ptr], #0xa0]\n"
- "smlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
- "str q24, [%x[c_ptr], #0xb0]\n"
- "smlal v26.4s, %[b1].4h, %[aa].h[5]\n"
- "str q9, [%x[c_ptr], #0xc0]\n"
- "smlal v11.4s, %[b0].4h, %[aa].h[6]\n"
- "str q17, [%x[c_ptr], #0xd0]\n"
- "smlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
- "str q25, [%x[c_ptr], #0xe0]\n"
- "smlal v27.4s, %[b1].4h, %[aa].h[6]\n"
- "str q10, [%x[c_ptr], #0xf0]\n"
- "smlal v12.4s, %[b0].4h, %[aa].h[7]\n"
- "str q18, [%x[c_ptr], #0x100]\n"
- "smlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
- "str q26, [%x[c_ptr], #0x110]\n"
- "smlal v28.4s, %[b1].4h, %[aa].h[7]\n"
- "str q11, [%x[c_ptr], #0x120]\n"
-
- "4:\n" // End of function
- "str q19, [%x[c_ptr], #0x130]\n"
- "str q27, [%x[c_ptr], #0x140]\n"
- "str q12, [%x[c_ptr], #0x150]\n"
- "str q20, [%x[c_ptr], #0x160]\n"
- "str q28, [%x[c_ptr], #0x170]\n"
- "add %x[c_ptr], %x[c_ptr], #0x180\n"
- : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k),
- [aa] "+w"(aa), [ab] "+w"(ab), [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2)
- : [odd_k] "r"(odd_k)
- : "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "x20", "cc");
- }
+ "4:\n" // End of function
+ "str q19, [%x[c_ptr], #0x130]\n"
+ "str q27, [%x[c_ptr], #0x140]\n"
+ "str q12, [%x[c_ptr], #0x150]\n"
+ "str q20, [%x[c_ptr], #0x160]\n"
+ "str q28, [%x[c_ptr], #0x170]\n"
+ "add %x[c_ptr], %x[c_ptr], #0x180\n"
+ : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k),
+ [aa] "+w" (aa), [ab] "+w" (ab), [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2)
+ : [odd_k] "r" (odd_k)
+ : "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "x20", "cc"
+ );
}
+ }
}
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp
index 08f90e1..fdc0200 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp
@@ -27,41 +27,41 @@
#include "arm_gemm.hpp"
-namespace arm_gemm
-{
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm {
+
// Load the actual kernel
void a64_gemm_s8_12x8(const int8_t *, const int8_t *, int32_t *, int, int, int);
void a64_gemm_s8_12x8_a55r1(const int8_t *, const int8_t *, int32_t *, int, int, int);
-class gemm_s8_12x8
-{
+class gemm_s8_12x8 {
public:
- typedef int8_t operand_type;
+ typedef int8_t operand_type;
typedef int32_t result_type;
typedef void (*kern_type)(const int8_t *, const int8_t *, int32_t *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 8;
- static const int A_block = 4;
- static const bool A_transpose = false;
-
- /* Same for B input */
- static const int B_interleave = 12;
- static const int B_block = 4;
- static const bool B_transpose = true;
-
/* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 4;
+ static int out_width() {
+ return 12;
+ }
+
+ static int out_height() {
+ return 8;
+ }
+
+ static int k_unroll() {
+ return 4;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12, 4> transforms = {};
kern_type kernel = a64_gemm_s8_12x8;
- gemm_s8_12x8(const CPUInfo *ci)
- {
- if(ci->get_cpu_model() == CPUModel::A55r1)
- {
+ gemm_s8_12x8(const CPUInfo *ci) {
+ if (ci->get_cpu_model() == CPUModel::A55r1) {
kernel = a64_gemm_s8_12x8_a55r1;
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp
index ef2f291..eaa7979 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp
@@ -31,40 +31,37 @@
#include "dot_toolchain_support.h"
#endif
-namespace arm_gemm
-{
-void a64_gemm_s8_12x8_a55r1(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, const int ablocks, const int bblocks, const int K)
-{
+namespace arm_gemm {
+
+void a64_gemm_s8_12x8_a55r1(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, const int ablocks, const int bblocks, const int K) {
const int8_t *a_ptr = Apanel;
- int32_t *c_ptr = Cpanel;
+ int32_t *c_ptr = Cpanel;
// We divide K by 4 because the sdot instruction processes 4 elements at a time.
- const int W = K / 4;
+ const int W = K/4;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
- const int oddk = (W & 1);
- const int k_iters = ((W + 1) / 2) - 1;
+ const int oddk = (W & 1);
+ const int k_iters = ((W+1)/2) - 1;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const int8_t *a_ptr0 = a_ptr;
- const int8_t *b_ptr = Bpanel;
+ const int8_t *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
a_ptr = a_ptr0;
int k = k_iters;
- register int32x4_t a0 asm("v0");
- register int32x4_t a1 asm("v1");
- register int32x4_t b0 asm("v2");
- register int32x4_t b1 asm("v3");
- register int32x4_t b2 asm("v4");
+ register int32x4_t a0 asm("v0");
+ register int32x4_t a1 asm("v1");
+ register int32x4_t b0 asm("v2");
+ register int32x4_t b1 asm("v3");
+ register int32x4_t b2 asm("v4");
register int32x4_t a0a asm("v5");
register int32x4_t a1a asm("v6");
- __asm __volatile(
+ __asm __volatile (
#ifdef NO_DOT_IN_TOOLCHAIN
_DECLARE_SDOT
#else
@@ -79,22 +76,39 @@
"ldr %q[a1], [%[a_ptr], #16]\n"
"movi v11.4s, #0x0\n"
"ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v12.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v13.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v14.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v15.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v17.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
"movi v18.4s, #0x0\n"
- "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v19.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
"movi v20.4s, #0x0\n"
- "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v21.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
"movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]")
+ "movi v23.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
"movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #384]")
+ "movi v25.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #384]")
"movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #448]")
+ "movi v27.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
"movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #384]")
+ "movi v29.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #384]")
"movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #512]")
+ "movi v31.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #512]")
// The loop is offset by these two instructions which must
// always be executed.
@@ -105,102 +119,105 @@
"cbz %w[k], 4f\n"
"1:\n"
- "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
- "subs %w[k], %w[k], #1\n"
- "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
+ "subs %w[k], %w[k], #1\n"
+ "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
- "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
"ins %[b2].d[1], x20\n"
- "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
+ "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
"ldr x20, [%[a_ptr], #40]\n"
- "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
- "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
- "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
"ins %[a0a].d[1], x20\n"
- "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
"ldr x20, [%[a_ptr], #56]\n"
- "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
+ "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
- "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
"ins %[a1a].d[1], x20\n"
- "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
"ldr x20, [%[b_ptr], #56]\n"
- "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
- "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
"ins %[b0].d[1], x20\n"
- "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
"ldr x20, [%[b_ptr], #72]\n"
- "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" ASM_PREFETCH("[%[a_ptr], #448]")
+ "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ ASM_PREFETCH("[%[a_ptr], #448]")
- "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
- "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #576]")
- "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
- "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #576]")
+ "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- // Unroll 1
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ // Unroll 1
+ "ldr %d[b2], [%[b_ptr], #80]\n"
- "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
+ "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
"ins %[b1].d[1], x20\n"
- "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
+ "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
"ldr x20, [%[b_ptr], #88]\n"
- "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
- "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
- "ldr %d[a0], [%[a_ptr], #64]\n"
+ "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
+ "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
+ "ldr %d[a0], [%[a_ptr], #64]\n"
- "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
+ "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
"ins %[b2].d[1], x20\n"
"sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
"ldr x20, [%[a_ptr], #72]\n"
- "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
- "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
- "ldr %d[a1], [%[a_ptr], #80]\n"
+ "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
+ "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
+ "ldr %d[a1], [%[a_ptr], #80]\n"
- "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
+ "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
"ins %[a0].d[1], x20\n"
- "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
+ "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
"ldr x20, [%[a_ptr], #88]\n"
- "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
- "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
- "ldr %d[b0], [%[b_ptr], #96]\n"
+ "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
+ "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
+ "ldr %d[b0], [%[b_ptr], #96]\n"
- "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
+ "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
"ins %[a1].d[1], x20\n"
- "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
+ "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
"ldr x20, [%[b_ptr], #104]\n"
- "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
- "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
- "ldr %d[b1], [%[b_ptr], #112]\n"
+ "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
+ "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
+ "ldr %d[b1], [%[b_ptr], #112]\n"
- "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
+ "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
"ins %[b0].d[1], x20\n"
- "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
+ "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
"ldr x20, [%[b_ptr], #120]\n"
- "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
- "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
+ "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
+ "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
- "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n" ASM_PREFETCH("[%[b_ptr], #640]")
- "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
+ "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
+ ASM_PREFETCH("[%[b_ptr], #640]")
+ "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
"ins %[b1].d[1], x20\n"
- "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
+ "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
"ldr %d[b2], [%[b_ptr], #32]\n"
"sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
- "b.ne 1b\n"
+ "b.ne 1b\n"
// Branch here if K=1 or 2. Do the right thing for odd/even at the end.
"4:\n"
@@ -212,71 +229,83 @@
"cbnz %w[oddk], 2f\n"
// Even K continuation
- "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
- "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
"ins %[b2].d[1], x20\n"
"sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
"ldr x20, [%[a_ptr], #40]\n"
- "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr]]")
- "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ ASM_PREFETCHW("[%[c_ptr]]")
+ "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
- "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
"ins %[a0a].d[1], x20\n"
- "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
"ldr x20, [%[a_ptr], #56]\n"
- "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
+ "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
- "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
"ins %[a1a].d[1], x20\n"
- "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
"ldr x20, [%[b_ptr], #56]\n"
- "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #64]")
- "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ ASM_PREFETCHW("[%[c_ptr], #64]")
+ "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #128]")
- "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #128]")
+ "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
- "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
"ins %[b0].d[1], x20\n"
- "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
"ldr x20, [%[b_ptr], #72]\n"
- "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #192]")
- "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ ASM_PREFETCHW("[%[c_ptr], #192]")
+ "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
- "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
+ "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
"ins %[b1].d[1], x20\n"
"sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
"ldr x20, [%[b_ptr], #88]\n"
- "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
+ "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
"ins %[b2].d[1], x20\n"
- "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]")
+ "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #256]")
"sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
"sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
- "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #320]")
+ "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
+ ASM_PREFETCHW("[%[c_ptr], #320]")
"sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
- "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]")
+ "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #384]")
"sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
- "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]")
+ "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #448]")
"sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
"sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
- "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]")
+ "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #512]")
"sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
- "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]")
+ "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #576]")
"sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
"sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
- "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #640]")
+ "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #640]")
"sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
- "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]")
+ "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #704]")
"sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
"add %[a_ptr], %[a_ptr], #64\n"
"sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
@@ -286,27 +315,41 @@
// Odd K continuation
"2:\n"
- "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" ASM_PREFETCHW("[%[c_ptr]]")
+ "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ ASM_PREFETCHW("[%[c_ptr]]")
"sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
"ins %[b2].d[1], x20\n"
- "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]")
+ "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #64]")
"sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
"add %[a_ptr], %[a_ptr], #32\n"
- "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]")
+ "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #128]")
"sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
"add %[b_ptr], %[b_ptr], #48\n"
- "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]")
+ "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #192]")
"sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]")
+ "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #256]")
"sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
- "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]")
+ "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #320]")
"sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]")
+ "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #384]")
"sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]")
+ "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #448]")
"sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
- ASM_PREFETCHWL2("[%[c_ptr], #640]") "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]")
+ "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #512]")
+ "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #576]")
+ "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #640]")
+ "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #704]")
"sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
// Common tail
@@ -340,13 +383,15 @@
#ifdef NO_DOT_IN_TOOLCHAIN
".purgem sdot\n"
#endif
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "+w" (a0), [a1] "+w" (a1), [a0a] "+w" (a0a), [a1a] "+w" (a1a),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"
+ );
+
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/dot_toolchain_support.h b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/dot_toolchain_support.h
index c76f99d..0bc688d 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/dot_toolchain_support.h
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/dot_toolchain_support.h
@@ -22,45 +22,47 @@
* SOFTWARE.
*/
+
+
// Define a macro to assemble the UDOT instruction (in the absence of toolchain support)
-#define _DECLARE_SDOT \
- ".altmacro\n" \
- ".macro sdot opd:req, opn:req, opm:req\n" \
- "local vd, vn, vm, h, l\n" \
- ".irp reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31\n" \
- ".ifeqs \"\\opd\",\"v\\reg\\.4s\"\n" \
- ".set vd,\\reg\n" \
- ".endif\n" \
- ".ifeqs \"\\opn\",\"v\\reg\\.16b\"\n" \
- ".set vn,\\reg\n" \
- ".endif\n" \
- ".irp idx,0,1,2,3\n" \
- ".ifeqs \"\\opm\",\"v\\reg\\.4b[\\idx\\]\"\n" \
- ".set vm,\\reg\n" \
- ".set h,\\idx / 2\n" \
- ".set l,\\idx %% 2\n" \
- ".endif\n" \
- ".endr\n" \
- ".endr\n" \
- ".ifndef vd\n" \
- ".error \"Bad operand \\opd\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".ifndef vn\n" \
- ".error \"Bad operand \\opn\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".ifndef vm\n" \
- ".error \"Bad operand \\opm\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".ifndef h\n" \
- ".error \"Bad operand \\opm\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".ifndef l\n" \
- ".error \"Bad operand \\opm\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".int 0x4f80e000 | vd | (vn << 5) | (vm << 16) | (l << 21) | (h << 11)\n" \
- ".endm\n"
+#define _DECLARE_SDOT ".altmacro\n"\
+ ".macro sdot opd:req, opn:req, opm:req\n"\
+ "local vd, vn, vm, h, l\n"\
+ ".irp reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31\n"\
+ ".ifeqs \"\\opd\",\"v\\reg\\.4s\"\n"\
+ ".set vd,\\reg\n"\
+ ".endif\n"\
+ ".ifeqs \"\\opn\",\"v\\reg\\.16b\"\n"\
+ ".set vn,\\reg\n"\
+ ".endif\n"\
+ ".irp idx,0,1,2,3\n"\
+ ".ifeqs \"\\opm\",\"v\\reg\\.4b[\\idx\\]\"\n"\
+ ".set vm,\\reg\n"\
+ ".set h,\\idx / 2\n"\
+ ".set l,\\idx %% 2\n"\
+ ".endif\n"\
+ ".endr\n"\
+ ".endr\n"\
+ ".ifndef vd\n"\
+ ".error \"Bad operand \\opd\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".ifndef vn\n"\
+ ".error \"Bad operand \\opn\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".ifndef vm\n"\
+ ".error \"Bad operand \\opm\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".ifndef h\n"\
+ ".error \"Bad operand \\opm\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".ifndef l\n"\
+ ".error \"Bad operand \\opm\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".int 0x4f80e000 | vd | (vn << 5) | (vm << 16) | (l << 21) | (h << 11)\n"\
+ ".endm\n"\
+
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/generic.cpp
index 258ef5e..19225dd 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/generic.cpp
@@ -31,309 +31,328 @@
#include "dot_toolchain_support.h"
#endif
-namespace arm_gemm
-{
-void a64_gemm_s8_12x8(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a64_gemm_s8_12x8(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, int ablocks, int bblocks, int K) {
const int8_t *a_ptr = Apanel;
- int32_t *c_ptr = Cpanel;
+ int32_t *c_ptr = Cpanel;
// We divide K by 4 because the sdot instruction processes 4 elements at a time.
- const int W = K / 4;
+ const int W = K/4;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
- const int oddk = (W & 1);
- const int init_value_k = ((W + 1) / 2) - 1;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ const int oddk = (W & 1);
+ const int init_value_k = ((W+1)/2) - 1;
+ for (int yb=0; yb<ablocks; yb++) {
const int8_t *a_ptr0 = a_ptr;
- const int8_t *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
- a_ptr = a_ptr0;
- int k = init_value_k;
- register int32x4_t a0 asm("v0");
- register int32x4_t a1 asm("v1");
- register int32x4_t b0 asm("v2");
- register int32x4_t b1 asm("v3");
- register int32x4_t b2 asm("v4");
+ const int8_t *b_ptr = Bpanel;
+ for (int xb=0; xb<bblocks; xb++) {
+ a_ptr = a_ptr0;
+ int k = init_value_k;
+ register int32x4_t a0 asm("v0");
+ register int32x4_t a1 asm("v1");
+ register int32x4_t b0 asm("v2");
+ register int32x4_t b1 asm("v3");
+ register int32x4_t b2 asm("v4");
register int32x4_t a0a asm("v5");
register int32x4_t a1a asm("v6");
- __asm __volatile(
+ __asm __volatile (
#ifdef NO_DOT_IN_TOOLCHAIN
_DECLARE_SDOT
#else
".arch armv8.2-a+dotprod\n"
#endif
// Initialize result registers, load initial operands, prime prefetches.
- "movi v8.4s, #0x0\n"
- "ldr %q[a0], [%[a_ptr]]\n"
- "movi v9.4s, #0x0\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "movi v10.4s, #0x0\n"
- "ldr %q[a1], [%[a_ptr], #16]\n"
- "movi v11.4s, #0x0\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n"
+ "movi v8.4s, #0x0\n"
+ "ldr %q[a0], [%[a_ptr]]\n"
+ "movi v9.4s, #0x0\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "movi v10.4s, #0x0\n"
+ "ldr %q[a1], [%[a_ptr], #16]\n"
+ "movi v11.4s, #0x0\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "movi v12.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v13.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v14.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v15.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v17.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v18.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v19.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v20.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
+ "movi v21.4s, #0x0\n"
ASM_PREFETCH("[%[b_ptr], #384]")
- "movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n"
- "movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n"
- "movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n"
- "movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- "movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n"
+ "movi v22.4s, #0x0\n"
+ "movi v23.4s, #0x0\n"
+ "movi v24.4s, #0x0\n"
+ "movi v25.4s, #0x0\n"
+ "movi v26.4s, #0x0\n"
+ "movi v27.4s, #0x0\n"
+ "movi v28.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v31.4s, #0x0\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 4f\n"
+ "cbz %w[k], 4f\n"
// Loop proper
"1:\n"
- "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
- "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n"
+ "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
+ "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
- "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "ldr %q[a0a], [%[a_ptr], #32]\n"
- "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
- "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
- "ldr %q[a1a], [%[a_ptr], #48]\n"
- "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
- "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "ldr %q[b0], [%[b_ptr], #48]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
+ "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "ldr %q[a0a], [%[a_ptr], #32]\n"
+ "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
+ "ldr %q[a1a], [%[a_ptr], #48]\n"
+ "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "ldr %q[b0], [%[b_ptr], #48]\n"
- "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
- "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" ASM_PREFETCH("[%[a_ptr], #320]")
- "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
- "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
- "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "ldr %q[b1], [%[b_ptr], #64]\n"
+ "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #320]")
+ "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "ldr %q[b1], [%[b_ptr], #64]\n"
- "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #448]")
- "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
- "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
- "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
- "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
- "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- "ldr %q[b2], [%[b_ptr], #80]\n"
+ "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
+ "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "ldr %q[b2], [%[b_ptr], #80]\n"
- "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
- "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
- "ldr %q[a0], [%[a_ptr], #64]\n"
- "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
- "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
- "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
- "ldr %q[a1], [%[a_ptr], #80]\n"
+ "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
+ "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
+ "ldr %q[a0], [%[a_ptr], #64]\n"
+ "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
+ "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
+ "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
+ "ldr %q[a1], [%[a_ptr], #80]\n"
"sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
- "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
- "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
- "ldr %q[b0], [%[b_ptr], #96]\n"
+ "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
+ "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
+ "ldr %q[b0], [%[b_ptr], #96]\n"
- "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
- "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #512]")
- "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
- "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
- "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
- "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
- "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
- "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
- "ldr %q[b1], [%[b_ptr], #112]\n"
+ "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
+ "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #512]")
+ "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
+ "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
+ "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
+ "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
+ "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
+ "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
+ "ldr %q[b1], [%[b_ptr], #112]\n"
- "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
- "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
- "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
- "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
- "subs %w[k], %w[k], #1\n"
- "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
- "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
- "bne 1b\n"
+ "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
+ "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
+ "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
+ "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
+ "subs %w[k], %w[k], #1\n"
+ "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
+ "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
+ "bne 1b\n"
// Target to use when K is 1 or 2 (i.e. zero iterations of main loop)
"4:\n"
// Branch to alternative tail for odd K
- "cbnz %w[oddk], 2f\n"
+ "cbnz %w[oddk], 2f\n"
// Detached final iteration (even K)
- "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
+ "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
"sdot v9.4s , %[b0].16b, %[a0].4b[1]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
- "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "ldr %q[a0a], [%[a_ptr], #32]\n"
- "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
+ "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "ldr %q[a0a], [%[a_ptr], #32]\n"
+ "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
"sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
- "ldr %q[a1a], [%[a_ptr], #48]\n"
- "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
- "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "ldr %q[b0], [%[b_ptr], #48]\n"
+ "ldr %q[a1a], [%[a_ptr], #48]\n"
+ "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "ldr %q[b0], [%[b_ptr], #48]\n"
- "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
- "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
- "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
- "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
- "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "ldr %q[b1], [%[b_ptr], #64]\n"
+ "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "ldr %q[b1], [%[b_ptr], #64]\n"
- "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
- "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
- "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
- "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
- "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- "ldr %q[b2], [%[b_ptr], #80]\n"
+ "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "ldr %q[b2], [%[b_ptr], #80]\n"
- "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
+ "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
- "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
+ "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
"sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
- "str q8, [%[c_ptr], #0]\n"
- "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
- "str q16, [%[c_ptr], #16]\n"
- "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
- "str q24, [%[c_ptr], #32]\n"
+ "str q8, [%[c_ptr], #0]\n"
+ "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
+ "str q24, [%[c_ptr], #32]\n"
- "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
- "str q9, [%[c_ptr], #48]\n"
- "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
- "str q17, [%[c_ptr], #64]\n"
- "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
- "str q25, [%[c_ptr], #80]\n"
- "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
- "str q10, [%[c_ptr], #96]\n"
+ "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
+ "str q9, [%[c_ptr], #48]\n"
+ "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
+ "str q10, [%[c_ptr], #96]\n"
- "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
- "str q18, [%[c_ptr], #112]\n"
- "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
- "str q26, [%[c_ptr], #128]\n"
- "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
- "str q11, [%[c_ptr], #144]\n"
+ "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
+ "str q11, [%[c_ptr], #144]\n"
- "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
- "str q19, [%[c_ptr], #160]\n"
- "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
- "str q27, [%[c_ptr], #176]\n"
- "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
- "str q12, [%[c_ptr], #192]\n"
+ "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
+ "str q12, [%[c_ptr], #192]\n"
"sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
- "str q20, [%[c_ptr], #208]\n"
- "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
- "str q28, [%[c_ptr], #224]\n"
- "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
- "str q13, [%[c_ptr], #240]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
+ "str q13, [%[c_ptr], #240]\n"
- "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
- "str q21, [%[c_ptr], #256]\n"
- "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
- "str q29, [%[c_ptr], #272]\n"
- "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
- "str q14, [%[c_ptr], #288]\n"
+ "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
+ "str q14, [%[c_ptr], #288]\n"
- "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
- "str q22, [%[c_ptr], #304]\n"
- "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
- "str q30, [%[c_ptr], #320]\n"
- "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
- "str q15, [%[c_ptr], #336]\n"
+ "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
+ "str q15, [%[c_ptr], #336]\n"
- "b 3f\n"
+ "b 3f\n"
// Detached final iteration (odd K)
"2:\n"
- "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
"sdot v9.4s , %[b0].16b, %[a0].4b[1]\n"
- "str q8, [%[c_ptr], #0]\n"
- "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
- "str q16, [%[c_ptr], #16]\n"
- "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "add %[b_ptr], %[b_ptr], #48\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "str q24, [%[c_ptr], #32]\n"
- "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
- "str q9, [%[c_ptr], #48]\n"
+ "str q8, [%[c_ptr], #0]\n"
+ "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "add %[b_ptr], %[b_ptr], #48\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "str q24, [%[c_ptr], #32]\n"
+ "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ "str q9, [%[c_ptr], #48]\n"
- "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
- "str q17, [%[c_ptr], #64]\n"
- "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "str q25, [%[c_ptr], #80]\n"
- "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "str q10, [%[c_ptr], #96]\n"
+ "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "str q10, [%[c_ptr], #96]\n"
- "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "str q18, [%[c_ptr], #112]\n"
- "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "str q26, [%[c_ptr], #128]\n"
- "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
- "str q11, [%[c_ptr], #144]\n"
+ "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ "str q11, [%[c_ptr], #144]\n"
- "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
- "str q19, [%[c_ptr], #160]\n"
- "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
- "str q27, [%[c_ptr], #176]\n"
- "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
- "str q12, [%[c_ptr], #192]\n"
+ "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "str q12, [%[c_ptr], #192]\n"
"sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
- "str q20, [%[c_ptr], #208]\n"
- "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
- "str q28, [%[c_ptr], #224]\n"
- "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
- "str q13, [%[c_ptr], #240]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ "str q13, [%[c_ptr], #240]\n"
- "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
- "str q21, [%[c_ptr], #256]\n"
- "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "str q29, [%[c_ptr], #272]\n"
- "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
- "str q14, [%[c_ptr], #288]\n"
+ "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ "str q14, [%[c_ptr], #288]\n"
- "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "str q22, [%[c_ptr], #304]\n"
- "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "str q30, [%[c_ptr], #320]\n"
- "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- "str q15, [%[c_ptr], #336]\n"
+ "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "str q15, [%[c_ptr], #336]\n"
+
// Common tail
"3:\n"
- "str q23, [%[c_ptr], #352]\n"
- "str q31, [%[c_ptr], #368]\n"
- "add %[c_ptr], %[c_ptr], #384\n"
+ "str q23, [%[c_ptr], #352]\n"
+ "str q31, [%[c_ptr], #368]\n"
+ "add %[c_ptr], %[c_ptr], #384\n"
#ifdef NO_DOT_IN_TOOLCHAIN
".purgem sdot\n"
#endif
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc");
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "+w" (a0), [a1] "+w" (a1), [a0a] "+w" (a0a), [a1a] "+w" (a1a),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"
+ );
+
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp
index 2ec28f4..be7ead9 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp
@@ -25,43 +25,44 @@
#ifdef __aarch64__
-namespace arm_gemm
-{
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm {
+
// Load the actual kernel
void a64_gemm_s8_4x4(const int8_t *, const int8_t *, int32_t *, int, int, int);
#include "arm_gemm.hpp"
-class gemm_s8_4x4
-{
+class gemm_s8_4x4 {
public:
- typedef int8_t operand_type;
+ typedef int8_t operand_type;
typedef int32_t result_type;
typedef void (*kern_type)(const int8_t *, const int8_t *, int32_t *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 4;
- static const int A_block = 16;
- static const bool A_transpose = false;
-
- /* Same for B input */
- static const int B_interleave = 4;
- static const int B_block = 16;
- static const bool B_transpose = true;
-
/* Kernel blocking parameters */
- static const int out_width = 4;
- static const int out_height = 4;
- static const int k_unroll = 16;
-
- kern_type kernel = a64_gemm_s8_4x4;
-
- gemm_s8_4x4(const CPUInfo *ci)
- {
+ static int out_width() {
+ return 4;
}
+
+ static int out_height() {
+ return 4;
+ }
+
+ static int k_unroll() {
+ return 16;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 4, 4, 16> transforms = {};
+
+ kern_type kernel=a64_gemm_s8_4x4;
+
+ gemm_s8_4x4(const CPUInfo *ci) { }
};
} // namespace arm_gemm
#endif // __aarch64__
+
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp
index 243b94e..2fc54f8 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,56 +27,66 @@
#include "../../asmlib.hpp"
-namespace arm_gemm
-{
-void a64_gemm_s8_4x4(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a64_gemm_s8_4x4(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, int ablocks, int bblocks, int K) {
const int8_t *a_ptr = Apanel;
- int32_t *c_ptr = Cpanel;
+ int32_t *c_ptr = Cpanel;
K /= 16;
int oddk = (K & 1);
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const int8_t *a_ptr0 = a_ptr;
- const int8_t *b_ptr = Bpanel;
+ const int8_t *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
a_ptr = a_ptr0;
- int k = ((K + 1) / 2) - 1;
+ int k = ((K+1)/2)-1;
- register int8x16_t b0 asm("v4");
- register int8x16_t b1 asm("v5");
- register int8x16_t b2 asm("v6");
- register int8x16_t b3 asm("v7");
+ register int8x16_t b0 asm("v4");
+ register int8x16_t b1 asm("v5");
+ register int8x16_t b2 asm("v6");
+ register int8x16_t b3 asm("v7");
register int8x16_t b0a asm("v8");
register int8x16_t b1a asm("v9");
register int8x16_t b2a asm("v10");
register int8x16_t b3a asm("v11");
- __asm __volatile(
- "movi v16.4s, #0x0\n"
- "ldr q0, [%[a_ptr]]\n"
- "movi v17.4s, #0x0\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "movi v18.4s, #0x0\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v19.4s, #0x0\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "movi v20.4s, #0x0\n"
- "ldr %q[b3], [%[b_ptr], #48]\n"
- "movi v21.4s, #0x0\n"
- "ldr q1, [%[a_ptr], #16]\n"
- "movi v22.4s, #0x0\n"
- "ldr q2, [%[a_ptr], #32]\n"
- "movi v23.4s, #0x0\n"
- "ldr q3, [%[a_ptr], #48]\n"
- "movi v24.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v26.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v27.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v28.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]") "movi v30.4s, #0x0\n"
- ASM_PREFETCH("[%[b_ptr], #256]") "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]")
+ __asm __volatile (
+ "movi v16.4s, #0x0\n"
+ "ldr q0, [%[a_ptr]]\n"
+ "movi v17.4s, #0x0\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "movi v18.4s, #0x0\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "movi v19.4s, #0x0\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "movi v20.4s, #0x0\n"
+ "ldr %q[b3], [%[b_ptr], #48]\n"
+ "movi v21.4s, #0x0\n"
+ "ldr q1, [%[a_ptr], #16]\n"
+ "movi v22.4s, #0x0\n"
+ "ldr q2, [%[a_ptr], #32]\n"
+ "movi v23.4s, #0x0\n"
+ "ldr q3, [%[a_ptr], #48]\n"
+ "movi v24.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v25.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v26.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v27.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v28.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v29.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v30.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v31.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
// Loop structure optimized for A57 (after r0).
@@ -97,356 +107,360 @@
// of multiplies that need to be pulled out.
// Start of unroll 0 (first iteration)
- "smull v12.8h, v0.8b, %[b0].8b\n"
- "smull v13.8h, v0.8b, %[b1].8b\n"
+ "smull v12.8h, v0.8b, %[b0].8b\n"
+ "smull v13.8h, v0.8b, %[b1].8b\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 4f\n"
+ "cbz %w[k], 4f\n"
// Unroll 0 continuation (branch target)
"1:\n"
- "smull v14.8h, v0.8b, %[b2].8b\n"
- "subs %w[k], %w[k], #1\n"
- "smull v15.8h, v0.8b, %[b3].8b\n"
- "ldr %q[b0a], [%[b_ptr], #64]\n"
- "smlal2 v12.8h, v0.16b, %[b0].16b\n"
- "smlal2 v13.8h, v0.16b, %[b1].16b\n"
- "ldr %q[b1a], [%[b_ptr], #80]\n"
- "smlal2 v14.8h, v0.16b, %[b2].16b\n"
- "smlal2 v15.8h, v0.16b, %[b3].16b\n"
- "ldr q0, [%[a_ptr], #64]\n"
+ "smull v14.8h, v0.8b, %[b2].8b\n"
+ "subs %w[k], %w[k], #1\n"
+ "smull v15.8h, v0.8b, %[b3].8b\n"
+ "ldr %q[b0a], [%[b_ptr], #64]\n"
+ "smlal2 v12.8h, v0.16b, %[b0].16b\n"
+ "smlal2 v13.8h, v0.16b, %[b1].16b\n"
+ "ldr %q[b1a], [%[b_ptr], #80]\n"
+ "smlal2 v14.8h, v0.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v0.16b, %[b3].16b\n"
+ "ldr q0, [%[a_ptr], #64]\n"
- "sadalp v16.4s, v12.8h\n"
- "smull v12.8h, v1.8b, %[b0].8b\n"
- "sadalp v17.4s, v13.8h\n"
- "sadalp v18.4s, v14.8h\n"
- "smull v13.8h, v1.8b, %[b1].8b\n"
- "sadalp v19.4s, v15.8h\n"
- "smull v14.8h, v1.8b, %[b2].8b\n"
- "ldr %q[b2a], [%[b_ptr], #96]\n"
- "smull v15.8h, v1.8b, %[b3].8b\n"
- "smlal2 v12.8h, v1.16b, %[b0].16b\n"
- "ldr %q[b3a], [%[b_ptr], #112]\n"
- "smlal2 v13.8h, v1.16b, %[b1].16b\n"
- "add %[b_ptr], %[b_ptr], #128\n"
- "smlal2 v14.8h, v1.16b, %[b2].16b\n"
- "smlal2 v15.8h, v1.16b, %[b3].16b\n"
- "ldr q1, [%[a_ptr], #80]\n"
+ "sadalp v16.4s, v12.8h\n"
+ "smull v12.8h, v1.8b, %[b0].8b\n"
+ "sadalp v17.4s, v13.8h\n"
+ "sadalp v18.4s, v14.8h\n"
+ "smull v13.8h, v1.8b, %[b1].8b\n"
+ "sadalp v19.4s, v15.8h\n"
+ "smull v14.8h, v1.8b, %[b2].8b\n"
+ "ldr %q[b2a], [%[b_ptr], #96]\n"
+ "smull v15.8h, v1.8b, %[b3].8b\n"
+ "smlal2 v12.8h, v1.16b, %[b0].16b\n"
+ "ldr %q[b3a], [%[b_ptr], #112]\n"
+ "smlal2 v13.8h, v1.16b, %[b1].16b\n"
+ "add %[b_ptr], %[b_ptr], #128\n"
+ "smlal2 v14.8h, v1.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v1.16b, %[b3].16b\n"
+ "ldr q1, [%[a_ptr], #80]\n"
- "sadalp v20.4s, v12.8h\n"
- "smull v12.8h, v2.8b, %[b0].8b\n"
- "sadalp v21.4s, v13.8h\n"
- "sadalp v22.4s, v14.8h\n"
- "smull v13.8h, v2.8b, %[b1].8b\n"
- "sadalp v23.4s, v15.8h\n"
- "smull v14.8h, v2.8b, %[b2].8b\n"
- "smull v15.8h, v2.8b, %[b3].8b\n"
- "smlal2 v12.8h, v2.16b, %[b0].16b\n" ASM_PREFETCH("[%[b_ptr], #192]")
- "smlal2 v13.8h, v2.16b, %[b1].16b\n"
- "smlal2 v14.8h, v2.16b, %[b2].16b\n" ASM_PREFETCH("[%[a_ptr], #320]")
- "smlal2 v15.8h, v2.16b, %[b3].16b\n"
- "ldr q2, [%[a_ptr], #96]\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v2.8b, %[b0].8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v13.8h, v2.8b, %[b1].8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v14.8h, v2.8b, %[b2].8b\n"
+ "smull v15.8h, v2.8b, %[b3].8b\n"
+ "smlal2 v12.8h, v2.16b, %[b0].16b\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "smlal2 v13.8h, v2.16b, %[b1].16b\n"
+ "smlal2 v14.8h, v2.16b, %[b2].16b\n"
+ ASM_PREFETCH("[%[a_ptr], #320]")
+ "smlal2 v15.8h, v2.16b, %[b3].16b\n"
+ "ldr q2, [%[a_ptr], #96]\n"
- "sadalp v24.4s, v12.8h\n"
- "smull v12.8h, v3.8b, %[b0].8b\n"
- "sadalp v25.4s, v13.8h\n"
- "sadalp v26.4s, v14.8h\n"
- "smull v13.8h, v3.8b, %[b1].8b\n"
- "sadalp v27.4s, v15.8h\n"
- "smull v14.8h, v3.8b, %[b2].8b\n"
- "smull v15.8h, v3.8b, %[b3].8b\n"
- "smlal2 v12.8h, v3.16b, %[b0].16b\n"
- "ldr %q[b0], [%[b_ptr], #0]\n"
- "smlal2 v13.8h, v3.16b, %[b1].16b\n"
- "smlal2 v14.8h, v3.16b, %[b2].16b\n"
- "smlal2 v15.8h, v3.16b, %[b3].16b\n"
- "ldr q3, [%[a_ptr], #112]\n"
+ "sadalp v24.4s, v12.8h\n"
+ "smull v12.8h, v3.8b, %[b0].8b\n"
+ "sadalp v25.4s, v13.8h\n"
+ "sadalp v26.4s, v14.8h\n"
+ "smull v13.8h, v3.8b, %[b1].8b\n"
+ "sadalp v27.4s, v15.8h\n"
+ "smull v14.8h, v3.8b, %[b2].8b\n"
+ "smull v15.8h, v3.8b, %[b3].8b\n"
+ "smlal2 v12.8h, v3.16b, %[b0].16b\n"
+ "ldr %q[b0], [%[b_ptr], #0]\n"
+ "smlal2 v13.8h, v3.16b, %[b1].16b\n"
+ "smlal2 v14.8h, v3.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v3.16b, %[b3].16b\n"
+ "ldr q3, [%[a_ptr], #112]\n"
// Unroll 1
- "sadalp v28.4s, v12.8h\n"
- "smull v12.8h, v0.8b, %[b0a].8b\n"
- "sadalp v29.4s, v13.8h\n"
- "sadalp v30.4s, v14.8h\n"
- "smull v13.8h, v0.8b, %[b1a].8b\n"
- "sadalp v31.4s, v15.8h\n"
- "smull v14.8h, v0.8b, %[b2a].8b\n"
- "smull v15.8h, v0.8b, %[b3a].8b\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "smlal2 v12.8h, v0.16b, %[b0a].16b\n"
- "smlal2 v13.8h, v0.16b, %[b1a].16b\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "smlal2 v14.8h, v0.16b, %[b2a].16b\n"
- "smlal2 v15.8h, v0.16b, %[b3a].16b\n"
- "ldr q0, [%[a_ptr], #128]\n"
+ "sadalp v28.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, %[b0a].8b\n"
+ "sadalp v29.4s, v13.8h\n"
+ "sadalp v30.4s, v14.8h\n"
+ "smull v13.8h, v0.8b, %[b1a].8b\n"
+ "sadalp v31.4s, v15.8h\n"
+ "smull v14.8h, v0.8b, %[b2a].8b\n"
+ "smull v15.8h, v0.8b, %[b3a].8b\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "smlal2 v12.8h, v0.16b, %[b0a].16b\n"
+ "smlal2 v13.8h, v0.16b, %[b1a].16b\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "smlal2 v14.8h, v0.16b, %[b2a].16b\n"
+ "smlal2 v15.8h, v0.16b, %[b3a].16b\n"
+ "ldr q0, [%[a_ptr], #128]\n"
- "sadalp v16.4s, v12.8h\n"
- "smull v12.8h, v1.8b, %[b0a].8b\n"
- "sadalp v17.4s, v13.8h\n"
- "sadalp v18.4s, v14.8h\n"
- "smull v13.8h, v1.8b, %[b1a].8b\n"
- "sadalp v19.4s, v15.8h\n"
- "add %[a_ptr], %[a_ptr], #128\n"
- "smull v14.8h, v1.8b, %[b2a].8b\n"
- "smull v15.8h, v1.8b, %[b3a].8b\n"
- "ldr %q[b3], [%[b_ptr], #48]\n"
- "smlal2 v12.8h, v1.16b, %[b0a].16b\n"
- "smlal2 v13.8h, v1.16b, %[b1a].16b\n"
- "smlal2 v14.8h, v1.16b, %[b2a].16b\n"
- "smlal2 v15.8h, v1.16b, %[b3a].16b\n"
- "ldr q1, [%[a_ptr], #16]\n"
+ "sadalp v16.4s, v12.8h\n"
+ "smull v12.8h, v1.8b, %[b0a].8b\n"
+ "sadalp v17.4s, v13.8h\n"
+ "sadalp v18.4s, v14.8h\n"
+ "smull v13.8h, v1.8b, %[b1a].8b\n"
+ "sadalp v19.4s, v15.8h\n"
+ "add %[a_ptr], %[a_ptr], #128\n"
+ "smull v14.8h, v1.8b, %[b2a].8b\n"
+ "smull v15.8h, v1.8b, %[b3a].8b\n"
+ "ldr %q[b3], [%[b_ptr], #48]\n"
+ "smlal2 v12.8h, v1.16b, %[b0a].16b\n"
+ "smlal2 v13.8h, v1.16b, %[b1a].16b\n"
+ "smlal2 v14.8h, v1.16b, %[b2a].16b\n"
+ "smlal2 v15.8h, v1.16b, %[b3a].16b\n"
+ "ldr q1, [%[a_ptr], #16]\n"
- "sadalp v20.4s, v12.8h\n"
- "smull v12.8h, v2.8b, %[b0a].8b\n"
- "sadalp v21.4s, v13.8h\n"
- "sadalp v22.4s, v14.8h\n"
- "smull v13.8h, v2.8b, %[b1a].8b\n"
- "sadalp v23.4s, v15.8h\n"
- "smull v14.8h, v2.8b, %[b2a].8b\n"
- "smull v15.8h, v2.8b, %[b3a].8b\n"
- "smlal2 v12.8h, v2.16b, %[b0a].16b\n" ASM_PREFETCH("[%[b_ptr], #256]")
- "smlal2 v13.8h, v2.16b, %[b1a].16b\n"
- "smlal2 v14.8h, v2.16b, %[b2a].16b\n" ASM_PREFETCH("[%[a_ptr], #256]")
- "smlal2 v15.8h, v2.16b, %[b3a].16b\n"
- "ldr q2, [%[a_ptr], #32]\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v2.8b, %[b0a].8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v13.8h, v2.8b, %[b1a].8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v14.8h, v2.8b, %[b2a].8b\n"
+ "smull v15.8h, v2.8b, %[b3a].8b\n"
+ "smlal2 v12.8h, v2.16b, %[b0a].16b\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "smlal2 v13.8h, v2.16b, %[b1a].16b\n"
+ "smlal2 v14.8h, v2.16b, %[b2a].16b\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
+ "smlal2 v15.8h, v2.16b, %[b3a].16b\n"
+ "ldr q2, [%[a_ptr], #32]\n"
- "sadalp v24.4s, v12.8h\n"
- "smull v12.8h, v3.8b, %[b0a].8b\n"
- "sadalp v25.4s, v13.8h\n"
- "sadalp v26.4s, v14.8h\n"
- "smull v13.8h, v3.8b, %[b1a].8b\n"
- "sadalp v27.4s, v15.8h\n"
- "smull v14.8h, v3.8b, %[b2a].8b\n"
- "smull v15.8h, v3.8b, %[b3a].8b\n"
- "smlal2 v12.8h, v3.16b, %[b0a].16b\n"
- "smlal2 v13.8h, v3.16b, %[b1a].16b\n"
- "smlal2 v14.8h, v3.16b, %[b2a].16b\n"
- "smlal2 v15.8h, v3.16b, %[b3a].16b\n"
- "ldr q3, [%[a_ptr], #48]\n"
+ "sadalp v24.4s, v12.8h\n"
+ "smull v12.8h, v3.8b, %[b0a].8b\n"
+ "sadalp v25.4s, v13.8h\n"
+ "sadalp v26.4s, v14.8h\n"
+ "smull v13.8h, v3.8b, %[b1a].8b\n"
+ "sadalp v27.4s, v15.8h\n"
+ "smull v14.8h, v3.8b, %[b2a].8b\n"
+ "smull v15.8h, v3.8b, %[b3a].8b\n"
+ "smlal2 v12.8h, v3.16b, %[b0a].16b\n"
+ "smlal2 v13.8h, v3.16b, %[b1a].16b\n"
+ "smlal2 v14.8h, v3.16b, %[b2a].16b\n"
+ "smlal2 v15.8h, v3.16b, %[b3a].16b\n"
+ "ldr q3, [%[a_ptr], #48]\n"
// Start of unroll 0 for next iteration.
- "sadalp v28.4s, v12.8h\n"
- "smull v12.8h, v0.8b, %[b0].8b\n"
- "sadalp v29.4s, v13.8h\n"
- "sadalp v30.4s, v14.8h\n"
- "smull v13.8h, v0.8b, %[b1].8b\n"
- "sadalp v31.4s, v15.8h\n"
- "bne 1b\n"
+ "sadalp v28.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, %[b0].8b\n"
+ "sadalp v29.4s, v13.8h\n"
+ "sadalp v30.4s, v14.8h\n"
+ "smull v13.8h, v0.8b, %[b1].8b\n"
+ "sadalp v31.4s, v15.8h\n"
+ "bne 1b\n"
// Target to use when K=1 or 2 (i.e. zero iterations of main loop)
"4:\n"
// Branch to alternative tail for odd K
- "cbnz %w[oddk], 2f\n"
+ "cbnz %w[oddk], 2f\n"
// Detached final iteration (even K)
- "smull v14.8h, v0.8b, %[b2].8b\n"
- "smull v15.8h, v0.8b, %[b3].8b\n"
- "ldr %q[b0a], [%[b_ptr], #64]\n"
- "smlal2 v12.8h, v0.16b, %[b0].16b\n"
- "smlal2 v13.8h, v0.16b, %[b1].16b\n"
- "ldr %q[b1a], [%[b_ptr], #80]\n"
- "smlal2 v14.8h, v0.16b, %[b2].16b\n"
- "smlal2 v15.8h, v0.16b, %[b3].16b\n"
- "ldr q0, [%[a_ptr], #64]\n"
+ "smull v14.8h, v0.8b, %[b2].8b\n"
+ "smull v15.8h, v0.8b, %[b3].8b\n"
+ "ldr %q[b0a], [%[b_ptr], #64]\n"
+ "smlal2 v12.8h, v0.16b, %[b0].16b\n"
+ "smlal2 v13.8h, v0.16b, %[b1].16b\n"
+ "ldr %q[b1a], [%[b_ptr], #80]\n"
+ "smlal2 v14.8h, v0.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v0.16b, %[b3].16b\n"
+ "ldr q0, [%[a_ptr], #64]\n"
- "sadalp v16.4s, v12.8h\n"
- "smull v12.8h, v1.8b, %[b0].8b\n"
- "sadalp v17.4s, v13.8h\n"
- "sadalp v18.4s, v14.8h\n"
- "smull v13.8h, v1.8b, %[b1].8b\n"
- "sadalp v19.4s, v15.8h\n"
- "smull v14.8h, v1.8b, %[b2].8b\n"
- "ldr %q[b2a], [%[b_ptr], #96]\n"
- "smull v15.8h, v1.8b, %[b3].8b\n"
- "smlal2 v12.8h, v1.16b, %[b0].16b\n"
- "ldr %q[b3a], [%[b_ptr], #112]\n"
- "smlal2 v13.8h, v1.16b, %[b1].16b\n"
- "add %[b_ptr], %[b_ptr], #128\n"
- "smlal2 v14.8h, v1.16b, %[b2].16b\n"
- "smlal2 v15.8h, v1.16b, %[b3].16b\n"
- "ldr q1, [%[a_ptr], #80]\n"
+ "sadalp v16.4s, v12.8h\n"
+ "smull v12.8h, v1.8b, %[b0].8b\n"
+ "sadalp v17.4s, v13.8h\n"
+ "sadalp v18.4s, v14.8h\n"
+ "smull v13.8h, v1.8b, %[b1].8b\n"
+ "sadalp v19.4s, v15.8h\n"
+ "smull v14.8h, v1.8b, %[b2].8b\n"
+ "ldr %q[b2a], [%[b_ptr], #96]\n"
+ "smull v15.8h, v1.8b, %[b3].8b\n"
+ "smlal2 v12.8h, v1.16b, %[b0].16b\n"
+ "ldr %q[b3a], [%[b_ptr], #112]\n"
+ "smlal2 v13.8h, v1.16b, %[b1].16b\n"
+ "add %[b_ptr], %[b_ptr], #128\n"
+ "smlal2 v14.8h, v1.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v1.16b, %[b3].16b\n"
+ "ldr q1, [%[a_ptr], #80]\n"
- "sadalp v20.4s, v12.8h\n"
- "smull v12.8h, v2.8b, %[b0].8b\n"
- "sadalp v21.4s, v13.8h\n"
- "sadalp v22.4s, v14.8h\n"
- "smull v13.8h, v2.8b, %[b1].8b\n"
- "sadalp v23.4s, v15.8h\n"
- "smull v14.8h, v2.8b, %[b2].8b\n"
- "smull v15.8h, v2.8b, %[b3].8b\n"
- "smlal2 v12.8h, v2.16b, %[b0].16b\n"
- "smlal2 v13.8h, v2.16b, %[b1].16b\n"
- "smlal2 v14.8h, v2.16b, %[b2].16b\n"
- "smlal2 v15.8h, v2.16b, %[b3].16b\n"
- "ldr q2, [%[a_ptr], #96]\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v2.8b, %[b0].8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v13.8h, v2.8b, %[b1].8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v14.8h, v2.8b, %[b2].8b\n"
+ "smull v15.8h, v2.8b, %[b3].8b\n"
+ "smlal2 v12.8h, v2.16b, %[b0].16b\n"
+ "smlal2 v13.8h, v2.16b, %[b1].16b\n"
+ "smlal2 v14.8h, v2.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v2.16b, %[b3].16b\n"
+ "ldr q2, [%[a_ptr], #96]\n"
- "sadalp v24.4s, v12.8h\n"
- "smull v12.8h, v3.8b, %[b0].8b\n"
- "sadalp v25.4s, v13.8h\n"
- "sadalp v26.4s, v14.8h\n"
- "smull v13.8h, v3.8b, %[b1].8b\n"
- "sadalp v27.4s, v15.8h\n"
- "smull v14.8h, v3.8b, %[b2].8b\n"
- "smull v15.8h, v3.8b, %[b3].8b\n"
- "smlal2 v12.8h, v3.16b, %[b0].16b\n"
- "smlal2 v13.8h, v3.16b, %[b1].16b\n"
- "smlal2 v14.8h, v3.16b, %[b2].16b\n"
- "smlal2 v15.8h, v3.16b, %[b3].16b\n"
- "ldr q3, [%[a_ptr], #112]\n"
+ "sadalp v24.4s, v12.8h\n"
+ "smull v12.8h, v3.8b, %[b0].8b\n"
+ "sadalp v25.4s, v13.8h\n"
+ "sadalp v26.4s, v14.8h\n"
+ "smull v13.8h, v3.8b, %[b1].8b\n"
+ "sadalp v27.4s, v15.8h\n"
+ "smull v14.8h, v3.8b, %[b2].8b\n"
+ "smull v15.8h, v3.8b, %[b3].8b\n"
+ "smlal2 v12.8h, v3.16b, %[b0].16b\n"
+ "smlal2 v13.8h, v3.16b, %[b1].16b\n"
+ "smlal2 v14.8h, v3.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v3.16b, %[b3].16b\n"
+ "ldr q3, [%[a_ptr], #112]\n"
// Unroll 1
- "sadalp v28.4s, v12.8h\n"
- "smull v12.8h, v0.8b, %[b0a].8b\n"
- "sadalp v29.4s, v13.8h\n"
- "sadalp v30.4s, v14.8h\n"
- "smull v13.8h, v0.8b, %[b1a].8b\n"
- "sadalp v31.4s, v15.8h\n"
- "smull v14.8h, v0.8b, %[b2a].8b\n"
- "add %[a_ptr], %[a_ptr], #128\n"
- "smull v15.8h, v0.8b, %[b3a].8b\n"
- "smlal2 v12.8h, v0.16b, %[b0a].16b\n"
- "smlal2 v13.8h, v0.16b, %[b1a].16b\n"
- "smlal2 v14.8h, v0.16b, %[b2a].16b\n"
- "smlal2 v15.8h, v0.16b, %[b3a].16b\n"
+ "sadalp v28.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, %[b0a].8b\n"
+ "sadalp v29.4s, v13.8h\n"
+ "sadalp v30.4s, v14.8h\n"
+ "smull v13.8h, v0.8b, %[b1a].8b\n"
+ "sadalp v31.4s, v15.8h\n"
+ "smull v14.8h, v0.8b, %[b2a].8b\n"
+ "add %[a_ptr], %[a_ptr], #128\n"
+ "smull v15.8h, v0.8b, %[b3a].8b\n"
+ "smlal2 v12.8h, v0.16b, %[b0a].16b\n"
+ "smlal2 v13.8h, v0.16b, %[b1a].16b\n"
+ "smlal2 v14.8h, v0.16b, %[b2a].16b\n"
+ "smlal2 v15.8h, v0.16b, %[b3a].16b\n"
- "sadalp v16.4s, v12.8h\n"
- "smull v12.8h, v1.8b, %[b0a].8b\n"
- "sadalp v17.4s, v13.8h\n"
- "sadalp v18.4s, v14.8h\n"
- "smull v13.8h, v1.8b, %[b1a].8b\n"
- "sadalp v19.4s, v15.8h\n"
- "smull v14.8h, v1.8b, %[b2a].8b\n"
- "smull v15.8h, v1.8b, %[b3a].8b\n"
- "smlal2 v12.8h, v1.16b, %[b0a].16b\n"
- "addp v16.4s, v16.4s, v17.4s\n"
- "smlal2 v13.8h, v1.16b, %[b1a].16b\n"
- "addp v17.4s, v18.4s, v19.4s\n"
- "smlal2 v14.8h, v1.16b, %[b2a].16b\n"
- "smlal2 v15.8h, v1.16b, %[b3a].16b\n"
+ "sadalp v16.4s, v12.8h\n"
+ "smull v12.8h, v1.8b, %[b0a].8b\n"
+ "sadalp v17.4s, v13.8h\n"
+ "sadalp v18.4s, v14.8h\n"
+ "smull v13.8h, v1.8b, %[b1a].8b\n"
+ "sadalp v19.4s, v15.8h\n"
+ "smull v14.8h, v1.8b, %[b2a].8b\n"
+ "smull v15.8h, v1.8b, %[b3a].8b\n"
+ "smlal2 v12.8h, v1.16b, %[b0a].16b\n"
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "smlal2 v13.8h, v1.16b, %[b1a].16b\n"
+ "addp v17.4s, v18.4s, v19.4s\n"
+ "smlal2 v14.8h, v1.16b, %[b2a].16b\n"
+ "smlal2 v15.8h, v1.16b, %[b3a].16b\n"
- "sadalp v20.4s, v12.8h\n"
- "smull v12.8h, v2.8b, %[b0a].8b\n"
- "sadalp v21.4s, v13.8h\n"
- "sadalp v22.4s, v14.8h\n"
- "smull v13.8h, v2.8b, %[b1a].8b\n"
- "sadalp v23.4s, v15.8h\n"
- "addp v16.4s, v16.4s, v17.4s\n"
- "smull v14.8h, v2.8b, %[b2a].8b\n"
- "addp v18.4s, v20.4s, v21.4s\n"
- "addp v19.4s, v22.4s, v23.4s\n"
- "smull v15.8h, v2.8b, %[b3a].8b\n"
- "smlal2 v12.8h, v2.16b, %[b0a].16b\n"
- "str q16, [%[c_ptr]]\n"
- "smlal2 v13.8h, v2.16b, %[b1a].16b\n"
- "smlal2 v14.8h, v2.16b, %[b2a].16b\n"
- "smlal2 v15.8h, v2.16b, %[b3a].16b\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v2.8b, %[b0a].8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v13.8h, v2.8b, %[b1a].8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "smull v14.8h, v2.8b, %[b2a].8b\n"
+ "addp v18.4s, v20.4s, v21.4s\n"
+ "addp v19.4s, v22.4s, v23.4s\n"
+ "smull v15.8h, v2.8b, %[b3a].8b\n"
+ "smlal2 v12.8h, v2.16b, %[b0a].16b\n"
+ "str q16, [%[c_ptr]]\n"
+ "smlal2 v13.8h, v2.16b, %[b1a].16b\n"
+ "smlal2 v14.8h, v2.16b, %[b2a].16b\n"
+ "smlal2 v15.8h, v2.16b, %[b3a].16b\n"
- "sadalp v24.4s, v12.8h\n"
- "smull v12.8h, v3.8b, %[b0a].8b\n"
- "sadalp v25.4s, v13.8h\n"
- "sadalp v26.4s, v14.8h\n"
- "smull v13.8h, v3.8b, %[b1a].8b\n"
- "sadalp v27.4s, v15.8h\n"
- "addp v17.4s, v18.4s, v19.4s\n"
- "smull v14.8h, v3.8b, %[b2a].8b\n"
- "addp v20.4s, v24.4s, v25.4s\n"
- "addp v21.4s, v26.4s, v27.4s\n"
- "smull v15.8h, v3.8b, %[b3a].8b\n"
- "smlal2 v12.8h, v3.16b, %[b0a].16b\n"
- "str q17, [%[c_ptr], #16]\n"
- "smlal2 v13.8h, v3.16b, %[b1a].16b\n"
- "smlal2 v14.8h, v3.16b, %[b2a].16b\n"
- "addp v18.4s, v20.4s, v21.4s\n"
- "smlal2 v15.8h, v3.16b, %[b3a].16b\n"
- "b 3f\n"
+ "sadalp v24.4s, v12.8h\n"
+ "smull v12.8h, v3.8b, %[b0a].8b\n"
+ "sadalp v25.4s, v13.8h\n"
+ "sadalp v26.4s, v14.8h\n"
+ "smull v13.8h, v3.8b, %[b1a].8b\n"
+ "sadalp v27.4s, v15.8h\n"
+ "addp v17.4s, v18.4s, v19.4s\n"
+ "smull v14.8h, v3.8b, %[b2a].8b\n"
+ "addp v20.4s, v24.4s, v25.4s\n"
+ "addp v21.4s, v26.4s, v27.4s\n"
+ "smull v15.8h, v3.8b, %[b3a].8b\n"
+ "smlal2 v12.8h, v3.16b, %[b0a].16b\n"
+ "str q17, [%[c_ptr], #16]\n"
+ "smlal2 v13.8h, v3.16b, %[b1a].16b\n"
+ "smlal2 v14.8h, v3.16b, %[b2a].16b\n"
+ "addp v18.4s, v20.4s, v21.4s\n"
+ "smlal2 v15.8h, v3.16b, %[b3a].16b\n"
+ "b 3f\n"
// Detached final iteration (odd K)
"2:\n"
- "smull v14.8h, v0.8b, %[b2].8b\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "smull v15.8h, v0.8b, %[b3].8b\n"
- "add %[b_ptr], %[b_ptr], #64\n"
- "smlal2 v12.8h, v0.16b, %[b0].16b\n"
- "smlal2 v13.8h, v0.16b, %[b1].16b\n"
- "smlal2 v14.8h, v0.16b, %[b2].16b\n"
- "smlal2 v15.8h, v0.16b, %[b3].16b\n"
+ "smull v14.8h, v0.8b, %[b2].8b\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "smull v15.8h, v0.8b, %[b3].8b\n"
+ "add %[b_ptr], %[b_ptr], #64\n"
+ "smlal2 v12.8h, v0.16b, %[b0].16b\n"
+ "smlal2 v13.8h, v0.16b, %[b1].16b\n"
+ "smlal2 v14.8h, v0.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v0.16b, %[b3].16b\n"
- "sadalp v16.4s, v12.8h\n"
- "smull v12.8h, v1.8b, %[b0].8b\n"
- "sadalp v17.4s, v13.8h\n"
- "sadalp v18.4s, v14.8h\n"
- "smull v13.8h, v1.8b, %[b1].8b\n"
- "sadalp v19.4s, v15.8h\n"
- "smull v14.8h, v1.8b, %[b2].8b\n"
- "smull v15.8h, v1.8b, %[b3].8b\n"
- "smlal2 v12.8h, v1.16b, %[b0].16b\n"
- "addp v16.4s, v16.4s, v17.4s\n"
- "smlal2 v13.8h, v1.16b, %[b1].16b\n"
- "addp v17.4s, v18.4s, v19.4s\n"
- "smlal2 v14.8h, v1.16b, %[b2].16b\n"
- "smlal2 v15.8h, v1.16b, %[b3].16b\n"
+ "sadalp v16.4s, v12.8h\n"
+ "smull v12.8h, v1.8b, %[b0].8b\n"
+ "sadalp v17.4s, v13.8h\n"
+ "sadalp v18.4s, v14.8h\n"
+ "smull v13.8h, v1.8b, %[b1].8b\n"
+ "sadalp v19.4s, v15.8h\n"
+ "smull v14.8h, v1.8b, %[b2].8b\n"
+ "smull v15.8h, v1.8b, %[b3].8b\n"
+ "smlal2 v12.8h, v1.16b, %[b0].16b\n"
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "smlal2 v13.8h, v1.16b, %[b1].16b\n"
+ "addp v17.4s, v18.4s, v19.4s\n"
+ "smlal2 v14.8h, v1.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v1.16b, %[b3].16b\n"
- "sadalp v20.4s, v12.8h\n"
- "smull v12.8h, v2.8b, %[b0].8b\n"
- "sadalp v21.4s, v13.8h\n"
- "sadalp v22.4s, v14.8h\n"
- "smull v13.8h, v2.8b, %[b1].8b\n"
- "sadalp v23.4s, v15.8h\n"
- "addp v16.4s, v16.4s, v17.4s\n"
- "smull v14.8h, v2.8b, %[b2].8b\n"
- "addp v18.4s, v20.4s, v21.4s\n"
- "addp v19.4s, v22.4s, v23.4s\n"
- "smull v15.8h, v2.8b, %[b3].8b\n"
- "smlal2 v12.8h, v2.16b, %[b0].16b\n"
- "str q16, [%[c_ptr]]\n"
- "smlal2 v13.8h, v2.16b, %[b1].16b\n"
- "smlal2 v14.8h, v2.16b, %[b2].16b\n"
- "smlal2 v15.8h, v2.16b, %[b3].16b\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v2.8b, %[b0].8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v13.8h, v2.8b, %[b1].8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "smull v14.8h, v2.8b, %[b2].8b\n"
+ "addp v18.4s, v20.4s, v21.4s\n"
+ "addp v19.4s, v22.4s, v23.4s\n"
+ "smull v15.8h, v2.8b, %[b3].8b\n"
+ "smlal2 v12.8h, v2.16b, %[b0].16b\n"
+ "str q16, [%[c_ptr]]\n"
+ "smlal2 v13.8h, v2.16b, %[b1].16b\n"
+ "smlal2 v14.8h, v2.16b, %[b2].16b\n"
+ "smlal2 v15.8h, v2.16b, %[b3].16b\n"
- "sadalp v24.4s, v12.8h\n"
- "smull v12.8h, v3.8b, %[b0].8b\n"
- "sadalp v25.4s, v13.8h\n"
- "sadalp v26.4s, v14.8h\n"
- "smull v13.8h, v3.8b, %[b1].8b\n"
- "sadalp v27.4s, v15.8h\n"
- "addp v17.4s, v18.4s, v19.4s\n"
- "smull v14.8h, v3.8b, %[b2].8b\n"
- "addp v20.4s, v24.4s, v25.4s\n"
- "addp v21.4s, v26.4s, v27.4s\n"
- "smull v15.8h, v3.8b, %[b3].8b\n"
- "smlal2 v12.8h, v3.16b, %[b0].16b\n"
- "str q17, [%[c_ptr], #16]\n"
- "smlal2 v13.8h, v3.16b, %[b1].16b\n"
- "smlal2 v14.8h, v3.16b, %[b2].16b\n"
- "addp v18.4s, v20.4s, v21.4s\n"
- "smlal2 v15.8h, v3.16b, %[b3].16b\n"
+ "sadalp v24.4s, v12.8h\n"
+ "smull v12.8h, v3.8b, %[b0].8b\n"
+ "sadalp v25.4s, v13.8h\n"
+ "sadalp v26.4s, v14.8h\n"
+ "smull v13.8h, v3.8b, %[b1].8b\n"
+ "sadalp v27.4s, v15.8h\n"
+ "addp v17.4s, v18.4s, v19.4s\n"
+ "smull v14.8h, v3.8b, %[b2].8b\n"
+ "addp v20.4s, v24.4s, v25.4s\n"
+ "addp v21.4s, v26.4s, v27.4s\n"
+ "smull v15.8h, v3.8b, %[b3].8b\n"
+ "smlal2 v12.8h, v3.16b, %[b0].16b\n"
+ "str q17, [%[c_ptr], #16]\n"
+ "smlal2 v13.8h, v3.16b, %[b1].16b\n"
+ "smlal2 v14.8h, v3.16b, %[b2].16b\n"
+ "addp v18.4s, v20.4s, v21.4s\n"
+ "smlal2 v15.8h, v3.16b, %[b3].16b\n"
"3:\n"
// Final additions
- "sadalp v28.4s, v12.8h\n"
- "str q18, [%[c_ptr], #32]\n"
- "sadalp v29.4s, v13.8h\n"
- "sadalp v30.4s, v14.8h\n"
- "sadalp v31.4s, v15.8h\n"
+ "sadalp v28.4s, v12.8h\n"
+ "str q18, [%[c_ptr], #32]\n"
+ "sadalp v29.4s, v13.8h\n"
+ "sadalp v30.4s, v14.8h\n"
+ "sadalp v31.4s, v15.8h\n"
// Horizontal reduction, phase 1
- "addp v22.4s, v28.4s, v29.4s\n"
- "addp v23.4s, v30.4s, v31.4s\n"
+ "addp v22.4s, v28.4s, v29.4s\n"
+ "addp v23.4s, v30.4s, v31.4s\n"
// Horizontal reduction, phase 2
- "addp v19.4s, v22.4s, v23.4s\n"
- "str q19, [%[c_ptr], #48]\n"
- "add %[c_ptr], %[c_ptr], #64\n"
+ "addp v19.4s, v22.4s, v23.4s\n"
+ "str q19, [%[c_ptr], #48]\n"
+ "add %[c_ptr], %[c_ptr], #64\n"
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [b3] "+w"(b3),
- [b0a] "+w"(b0a), [b1a] "+w"(b1a), [b2a] "+w"(b2a), [b3a] "+w"(b3a),
- [k] "+r"(k)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v0", "v1", "v2", "v3", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc");
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [b3] "+w" (b3),
+ [b0a] "+w" (b0a), [b1a] "+w" (b1a), [b2a] "+w" (b2a), [b3a] "+w" (b3a),
+ [k] "+r" (k)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v0","v1","v2","v3","v12","v13","v14","v15","v16","v17","v18","v19",
+ "v20","v21","v22","v23","v24","v25","v26","v27","v28","v29","v30","v31", "cc");
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp
index 3975732..d2692ba 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp
@@ -25,8 +25,10 @@
#ifdef __aarch64__
-namespace arm_gemm
-{
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm {
+
// Actual kernel implementations
void a64_gemm_u16_asimd_12x8(const uint16_t *, const uint16_t *, uint32_t *, int, int, int);
@@ -38,34 +40,32 @@
// All kernels in the family must share these characteristics. The actual
// kernel to be used can be chosen at runtime, based on the CPU_type
// structure.
-class gemm_u16_12x8
-{
+class gemm_u16_12x8 {
public:
typedef uint16_t operand_type;
typedef uint32_t result_type;
typedef void (*kern_type)(const uint16_t *, const uint16_t *, uint32_t *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 8;
- static const int A_block = 1;
- static const int A_transpose = 0;
-
- /* Same for B input */
- static const int B_interleave = 12;
- static const int B_block = 1;
- static const int B_transpose = 1;
-
/* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 1;
+ static int out_width() {
+ return 12;
+ }
+
+ static int out_height() {
+ return 8;
+ }
+
+ static int k_unroll() {
+ return 1;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12> transforms = {};
kern_type kernel = a64_gemm_u16_asimd_12x8;
- gemm_u16_12x8(const CPUInfo *ci)
- {
- }
+ gemm_u16_12x8(const CPUInfo *ci) { }
};
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8/generic.cpp
index 7903878..4c21620 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8/generic.cpp
@@ -27,281 +27,295 @@
#include "../../asmlib.hpp"
-namespace arm_gemm
-{
+namespace arm_gemm {
+
void a64_gemm_u16_asimd_12x8(const uint16_t *Apanel, const uint16_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K)
{
- const uint16_t *a_ptr = Apanel;
- uint32_t *c_ptr = Cpanel;
+ const uint16_t *a_ptr = Apanel;
+ uint32_t *c_ptr = Cpanel;
- for(int yb = 0; yb < ablocks; yb++)
+ for (int yb = 0; yb < ablocks; yb++)
+ {
+ const uint16_t *a_ptr0 = a_ptr;
+ const uint16_t *b_ptr = Bpanel;
+
+ for (int xb = 0; xb < bblocks; xb++)
{
- const uint16_t *a_ptr0 = a_ptr;
- const uint16_t *b_ptr = Bpanel;
+ a_ptr = a_ptr0;
+ const bool odd_k = K & 0x1;
+ int k = (K+1)/2 - 1;
- for(int xb = 0; xb < bblocks; xb++)
- {
- a_ptr = a_ptr0;
- const bool odd_k = K & 0x1;
- int k = (K + 1) / 2 - 1;
+ register uint16x8_t aa asm("v0");
+ register uint16x8_t ab asm("v1");
+ register uint16x8_t b0 asm("v2");
+ register uint16x8_t b1 asm("v3");
+ register uint16x8_t b2 asm("v4");
- register uint16x8_t aa asm("v0");
- register uint16x8_t ab asm("v1");
- register uint16x8_t b0 asm("v2");
- register uint16x8_t b1 asm("v3");
- register uint16x8_t b2 asm("v4");
+ __asm __volatile (
+ "ldr %d[aa], [%x[a_ptr]]\n" // Load A[A].lower
+ "movi v5.4s, #0\n"
+ "ldr x20, [%x[a_ptr], #0x08]\n" // Load A[A].upper
+ "movi v6.4s, #0\n"
+ "ldr %d[b0], [%x[b_ptr]]\n" // Load B[0].lower
+ "ins %[aa].d[1], x20\n" // Merge A[A].lower and upper
+ "movi v7.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v8.4s, #0\n"
+ "ldr x20, [%x[b_ptr], #0x08]\n" // Load B[0].upper
+ "movi v9.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v10.4s, #0\n"
+ "ldr %d[b1], [%x[b_ptr], #0x10]\n" // Load B[1].lower
+ "ins %[b0].d[1], x20\n" // Merge B[0].lower and upper
+ "movi v11.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #96]")
+ "movi v12.4s, #0\n"
+ "movi v13.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #96]")
+ "movi v14.4s, #0\n"
+ "movi v15.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0\n"
+ "movi v17.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v18.4s, #0\n"
+ "movi v19.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #160]")
+ "movi v20.4s, #0\n"
+ "movi v21.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #160]")
+ "movi v22.4s, #0\n"
+ "movi v23.4s, #0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v24.4s, #0\n"
+ "add %x[a_ptr], %x[a_ptr], #0x10\n"
+ "movi v25.4s, #0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v26.4s, #0\n"
+ "add %x[b_ptr], %x[b_ptr], #0x18\n"
+ "movi v27.4s, #0\n"
+ "movi v28.4s, #0\n"
- __asm __volatile(
- "ldr %d[aa], [%x[a_ptr]]\n" // Load A[A].lower
- "movi v5.4s, #0\n"
- "ldr x20, [%x[a_ptr], #0x08]\n" // Load A[A].upper
- "movi v6.4s, #0\n"
- "ldr %d[b0], [%x[b_ptr]]\n" // Load B[0].lower
- "ins %[aa].d[1], x20\n" // Merge A[A].lower and upper
- "movi v7.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #64]")
- "movi v8.4s, #0\n"
- "ldr x20, [%x[b_ptr], #0x08]\n" // Load B[0].upper
- "movi v9.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #64]")
- "movi v10.4s, #0\n"
- "ldr %d[b1], [%x[b_ptr], #0x10]\n" // Load B[1].lower
- "ins %[b0].d[1], x20\n" // Merge B[0].lower and upper
- "movi v11.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #96]")
- "movi v12.4s, #0\n"
- "movi v13.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #96]")
- "movi v14.4s, #0\n"
- "movi v15.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #128]")
- "movi v16.4s, #0\n"
- "movi v17.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #128]")
- "movi v18.4s, #0\n"
- "movi v19.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #160]")
- "movi v20.4s, #0\n"
- "movi v21.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #160]")
- "movi v22.4s, #0\n"
- "movi v23.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #192]")
- "movi v24.4s, #0\n"
- "add %x[a_ptr], %x[a_ptr], #0x10\n"
- "movi v25.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #192]")
- "movi v26.4s, #0\n"
- "add %x[b_ptr], %x[b_ptr], #0x18\n"
- "movi v27.4s, #0\n"
- "movi v28.4s, #0\n"
+ "cbz %x[k], 2f\n" // Skip the loop if doing zero iterations.
- "cbz %x[k], 2f\n" // Skip the loop if doing zero iterations.
+ "1:\n" // Main loop
+ // First unroll
+ "smlal v5.4s, %[b0].4h, %[aa].h[0]\n"
+ "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper
+ "umlal v6.4s, %[b0].4h, %[aa].h[1]\n"
+ "umlal v7.4s, %[b0].4h, %[aa].h[2]\n"
+ "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower
+ "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper
+ "umlal v8.4s, %[b0].4h, %[aa].h[3]\n"
+ "umlal v9.4s, %[b0].4h, %[aa].h[4]\n"
+ "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper
+ "umlal v10.4s, %[b0].4h, %[aa].h[5]\n"
+ "umlal v11.4s, %[b0].4h, %[aa].h[6]\n"
+ "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower
+ "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper
+ "umlal v12.4s, %[b0].4h, %[aa].h[7]\n"
+ "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
+ "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper
+ "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
+ "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
+ "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
+ "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
+ "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
+ "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
+ "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
+ "ldr %d[b0], [%x[b_ptr], #0x18]\n" // Load B[0].lower
+ "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper
+ "umlal v21.4s, %[b1].4h, %[aa].h[0]\n"
+ "umlal v22.4s, %[b1].4h, %[aa].h[1]\n"
+ "ldr x20, [%x[b_ptr], #0x20]\n" // Load B[0].upper
+ "umlal v23.4s, %[b1].4h, %[aa].h[2]\n"
+ "umlal v24.4s, %[b1].4h, %[aa].h[3]\n"
+ "umlal v25.4s, %[b1].4h, %[aa].h[4]\n"
+ "umlal v26.4s, %[b1].4h, %[aa].h[5]\n"
+ "umlal v27.4s, %[b1].4h, %[aa].h[6]\n"
+ "umlal v28.4s, %[b1].4h, %[aa].h[7]\n"
- "1:\n" // Main loop
- // First unroll
- "smlal v5.4s, %[b0].4h, %[aa].h[0]\n"
- "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper
- "umlal v6.4s, %[b0].4h, %[aa].h[1]\n"
- "umlal v7.4s, %[b0].4h, %[aa].h[2]\n"
- "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower
- "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper
- "umlal v8.4s, %[b0].4h, %[aa].h[3]\n"
- "umlal v9.4s, %[b0].4h, %[aa].h[4]\n"
- "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper
- "umlal v10.4s, %[b0].4h, %[aa].h[5]\n"
- "umlal v11.4s, %[b0].4h, %[aa].h[6]\n"
- "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower
- "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper
- "umlal v12.4s, %[b0].4h, %[aa].h[7]\n"
- "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
- "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper
- "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
- "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
- "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
- "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
- "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
- "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
- "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
- "ldr %d[b0], [%x[b_ptr], #0x18]\n" // Load B[0].lower
- "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper
- "umlal v21.4s, %[b1].4h, %[aa].h[0]\n"
- "umlal v22.4s, %[b1].4h, %[aa].h[1]\n"
- "ldr x20, [%x[b_ptr], #0x20]\n" // Load B[0].upper
- "umlal v23.4s, %[b1].4h, %[aa].h[2]\n"
- "umlal v24.4s, %[b1].4h, %[aa].h[3]\n"
- "umlal v25.4s, %[b1].4h, %[aa].h[4]\n"
- "umlal v26.4s, %[b1].4h, %[aa].h[5]\n"
- "umlal v27.4s, %[b1].4h, %[aa].h[6]\n"
- "umlal v28.4s, %[b1].4h, %[aa].h[7]\n"
+ // Second unroll
+ "umlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
+ "ldr %d[aa], [%x[a_ptr], #0x10]\n" // Load A[A].lower
+ "ins %[b0].d[1], x20\n" // Merge B[0].lower and .upper
+ "umlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
+ "umlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
+ "ldr x20, [%x[a_ptr], #0x18]\n" // Load A[A].upper
+ "umlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
+ "umlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
+ "umlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
+ "umlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
+ "add %x[a_ptr], %x[a_ptr], #0x20\n"
+ "umlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
+ "umlal v13.4s, %[b2].4h, %[ab].h[0]\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
+ "umlal v14.4s, %[b2].4h, %[ab].h[1]\n"
+ "umlal v15.4s, %[b2].4h, %[ab].h[2]\n"
+ ASM_PREFETCH("[%[a_ptr], #320]")
+ "umlal v16.4s, %[b2].4h, %[ab].h[3]\n"
+ "umlal v17.4s, %[b2].4h, %[ab].h[4]\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
+ "umlal v18.4s, %[b2].4h, %[ab].h[5]\n"
+ "umlal v19.4s, %[b2].4h, %[ab].h[6]\n"
+ "umlal v20.4s, %[b2].4h, %[ab].h[7]\n"
+ "umlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
+ "umlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
+ "subs %x[k], %x[k], #0x1\n"
+ "umlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
+ "umlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
+ "ldr %d[b1], [%x[b_ptr], #0x28]\n" // Load B[1].lower
+ "ins %[aa].d[1], x20\n" // Merge A[A].lower and .upper
+ "umlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
+ "umlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
+ "add %x[b_ptr], %x[b_ptr], #0x30\n"
+ "umlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
+ "umlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
+ "bne 1b\n"
- // Second unroll
- "umlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
- "ldr %d[aa], [%x[a_ptr], #0x10]\n" // Load A[A].lower
- "ins %[b0].d[1], x20\n" // Merge B[0].lower and .upper
- "umlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
- "umlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
- "ldr x20, [%x[a_ptr], #0x18]\n" // Load A[A].upper
- "umlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
- "umlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
- "umlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
- "umlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
- "add %x[a_ptr], %x[a_ptr], #0x20\n"
- "umlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
- "umlal v13.4s, %[b2].4h, %[ab].h[0]\n" ASM_PREFETCH("[%[b_ptr], #320]")
- "umlal v14.4s, %[b2].4h, %[ab].h[1]\n"
- "umlal v15.4s, %[b2].4h, %[ab].h[2]\n" ASM_PREFETCH("[%[a_ptr], #320]")
- "umlal v16.4s, %[b2].4h, %[ab].h[3]\n"
- "umlal v17.4s, %[b2].4h, %[ab].h[4]\n" ASM_PREFETCH("[%[b_ptr], #448]")
- "umlal v18.4s, %[b2].4h, %[ab].h[5]\n"
- "umlal v19.4s, %[b2].4h, %[ab].h[6]\n"
- "umlal v20.4s, %[b2].4h, %[ab].h[7]\n"
- "umlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
- "umlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
- "subs %x[k], %x[k], #0x1\n"
- "umlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
- "umlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
- "ldr %d[b1], [%x[b_ptr], #0x28]\n" // Load B[1].lower
- "ins %[aa].d[1], x20\n" // Merge A[A].lower and .upper
- "umlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
- "umlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
- "add %x[b_ptr], %x[b_ptr], #0x30\n"
- "umlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
- "umlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
- "bne 1b\n"
+ "2:\n" // Even tail
+ "cbnz %x[odd_k], 3f\n"
- "2:\n" // Even tail
- "cbnz %x[odd_k], 3f\n"
+ "umlal v5.4s, %[b0].4h, %[aa].h[0]\n"
+ "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper
+ "umlal v6.4s, %[b0].4h, %[aa].h[1]\n"
+ "umlal v7.4s, %[b0].4h, %[aa].h[2]\n"
+ "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower
+ "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper
+ "umlal v8.4s, %[b0].4h, %[aa].h[3]\n"
+ "umlal v9.4s, %[b0].4h, %[aa].h[4]\n"
+ "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper
+ "umlal v10.4s, %[b0].4h, %[aa].h[5]\n"
+ "umlal v11.4s, %[b0].4h, %[aa].h[6]\n"
+ "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower
+ "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper
+ "umlal v12.4s, %[b0].4h, %[aa].h[7]\n"
+ "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
+ "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper
+ "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
+ "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
+ "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
+ "add %[a_ptr], %[a_ptr], #0x10\n"
+ "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
+ "add %[b_ptr], %[b_ptr], #0x18\n"
+ "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
+ "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
+ "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
+ "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper
+ "umlal v21.4s, %[b1].4h, %[aa].h[0]\n"
+ "umlal v22.4s, %[b1].4h, %[aa].h[1]\n"
+ "umlal v23.4s, %[b1].4h, %[aa].h[2]\n"
+ "umlal v24.4s, %[b1].4h, %[aa].h[3]\n"
+ "umlal v25.4s, %[b1].4h, %[aa].h[4]\n"
+ "umlal v26.4s, %[b1].4h, %[aa].h[5]\n"
+ "umlal v27.4s, %[b1].4h, %[aa].h[6]\n"
+ "umlal v28.4s, %[b1].4h, %[aa].h[7]\n"
- "umlal v5.4s, %[b0].4h, %[aa].h[0]\n"
- "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper
- "umlal v6.4s, %[b0].4h, %[aa].h[1]\n"
- "umlal v7.4s, %[b0].4h, %[aa].h[2]\n"
- "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower
- "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper
- "umlal v8.4s, %[b0].4h, %[aa].h[3]\n"
- "umlal v9.4s, %[b0].4h, %[aa].h[4]\n"
- "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper
- "umlal v10.4s, %[b0].4h, %[aa].h[5]\n"
- "umlal v11.4s, %[b0].4h, %[aa].h[6]\n"
- "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower
- "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper
- "umlal v12.4s, %[b0].4h, %[aa].h[7]\n"
- "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
- "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper
- "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
- "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
- "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
- "add %[a_ptr], %[a_ptr], #0x10\n"
- "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
- "add %[b_ptr], %[b_ptr], #0x18\n"
- "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
- "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
- "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
- "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper
- "umlal v21.4s, %[b1].4h, %[aa].h[0]\n"
- "umlal v22.4s, %[b1].4h, %[aa].h[1]\n"
- "umlal v23.4s, %[b1].4h, %[aa].h[2]\n"
- "umlal v24.4s, %[b1].4h, %[aa].h[3]\n"
- "umlal v25.4s, %[b1].4h, %[aa].h[4]\n"
- "umlal v26.4s, %[b1].4h, %[aa].h[5]\n"
- "umlal v27.4s, %[b1].4h, %[aa].h[6]\n"
- "umlal v28.4s, %[b1].4h, %[aa].h[7]\n"
+ "umlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
+ "umlal v13.4s, %[b2].4h, %[ab].h[0]\n"
+ "umlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
+ "umlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
+ "umlal v14.4s, %[b2].4h, %[ab].h[1]\n"
+ "str q5, [%x[c_ptr]]\n"
+ "umlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
+ "str q13, [%x[c_ptr], #0x10]\n"
+ "umlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
+ "str q21, [%x[c_ptr], #0x20]\n"
+ "umlal v15.4s, %[b2].4h, %[ab].h[2]\n"
+ "str q6, [%x[c_ptr], #0x30]\n"
+ "umlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
+ "str q14, [%x[c_ptr], #0x40]\n"
+ "umlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
+ "str q22, [%x[c_ptr], #0x50]\n"
+ "umlal v16.4s, %[b2].4h, %[ab].h[3]\n"
+ "str q7, [%x[c_ptr], #0x60]\n"
+ "umlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
+ "str q15, [%x[c_ptr], #0x70]\n"
+ "umlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
+ "str q23, [%x[c_ptr], #0x80]\n"
+ "umlal v17.4s, %[b2].4h, %[ab].h[4]\n"
+ "str q8, [%x[c_ptr], #0x90]\n"
+ "umlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
+ "str q16, [%x[c_ptr], #0xa0]\n"
+ "umlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
+ "str q24, [%x[c_ptr], #0xb0]\n"
+ "umlal v18.4s, %[b2].4h, %[ab].h[5]\n"
+ "str q9, [%x[c_ptr], #0xc0]\n"
+ "umlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
+ "str q17, [%x[c_ptr], #0xd0]\n"
+ "umlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
+ "str q25, [%x[c_ptr], #0xe0]\n"
+ "umlal v19.4s, %[b2].4h, %[ab].h[6]\n"
+ "str q10, [%x[c_ptr], #0xf0]\n"
+ "umlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
+ "str q18, [%x[c_ptr], #0x100]\n"
+ "umlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
+ "str q26, [%x[c_ptr], #0x110]\n"
+ "umlal v20.4s, %[b2].4h, %[ab].h[7]\n"
+ "str q11, [%x[c_ptr], #0x120]\n"
+ "umlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
+ "str q19, [%x[c_ptr], #0x130]\n"
+ "b 4f\n" // Complete write out
- "umlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
- "umlal v13.4s, %[b2].4h, %[ab].h[0]\n"
- "umlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
- "umlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
- "umlal v14.4s, %[b2].4h, %[ab].h[1]\n"
- "str q5, [%x[c_ptr]]\n"
- "umlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
- "str q13, [%x[c_ptr], #0x10]\n"
- "umlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
- "str q21, [%x[c_ptr], #0x20]\n"
- "umlal v15.4s, %[b2].4h, %[ab].h[2]\n"
- "str q6, [%x[c_ptr], #0x30]\n"
- "umlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
- "str q14, [%x[c_ptr], #0x40]\n"
- "umlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
- "str q22, [%x[c_ptr], #0x50]\n"
- "umlal v16.4s, %[b2].4h, %[ab].h[3]\n"
- "str q7, [%x[c_ptr], #0x60]\n"
- "umlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
- "str q15, [%x[c_ptr], #0x70]\n"
- "umlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
- "str q23, [%x[c_ptr], #0x80]\n"
- "umlal v17.4s, %[b2].4h, %[ab].h[4]\n"
- "str q8, [%x[c_ptr], #0x90]\n"
- "umlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
- "str q16, [%x[c_ptr], #0xa0]\n"
- "umlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
- "str q24, [%x[c_ptr], #0xb0]\n"
- "umlal v18.4s, %[b2].4h, %[ab].h[5]\n"
- "str q9, [%x[c_ptr], #0xc0]\n"
- "umlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
- "str q17, [%x[c_ptr], #0xd0]\n"
- "umlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
- "str q25, [%x[c_ptr], #0xe0]\n"
- "umlal v19.4s, %[b2].4h, %[ab].h[6]\n"
- "str q10, [%x[c_ptr], #0xf0]\n"
- "umlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
- "str q18, [%x[c_ptr], #0x100]\n"
- "umlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
- "str q26, [%x[c_ptr], #0x110]\n"
- "umlal v20.4s, %[b2].4h, %[ab].h[7]\n"
- "str q11, [%x[c_ptr], #0x120]\n"
- "umlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
- "str q19, [%x[c_ptr], #0x130]\n"
- "b 4f\n" // Complete write out
+ "3:\n" // Odd tail
+ "umlal v5.4s, %[b0].4h, %[aa].h[0]\n"
+ "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
+ "umlal v21.4s, %[b1].4h, %[aa].h[0]\n"
+ "umlal v6.4s, %[b0].4h, %[aa].h[1]\n"
+ "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
+ "umlal v22.4s, %[b1].4h, %[aa].h[1]\n"
+ "str q5, [%x[c_ptr]]\n"
+ "umlal v7.4s, %[b0].4h, %[aa].h[2]\n"
+ "str q13, [%x[c_ptr], #0x10]\n"
+ "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
+ "str q21, [%x[c_ptr], #0x20]\n"
+ "umlal v23.4s, %[b1].4h, %[aa].h[2]\n"
+ "str q6, [%x[c_ptr], #0x30]\n"
+ "umlal v8.4s, %[b0].4h, %[aa].h[3]\n"
+ "str q14, [%x[c_ptr], #0x40]\n"
+ "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
+ "str q22, [%x[c_ptr], #0x50]\n"
+ "umlal v24.4s, %[b1].4h, %[aa].h[3]\n"
+ "str q7, [%x[c_ptr], #0x60]\n"
+ "umlal v9.4s, %[b0].4h, %[aa].h[4]\n"
+ "str q15, [%x[c_ptr], #0x70]\n"
+ "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
+ "str q23, [%x[c_ptr], #0x80]\n"
+ "umlal v25.4s, %[b1].4h, %[aa].h[4]\n"
+ "str q8, [%x[c_ptr], #0x90]\n"
+ "umlal v10.4s, %[b0].4h, %[aa].h[5]\n"
+ "str q16, [%x[c_ptr], #0xa0]\n"
+ "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
+ "str q24, [%x[c_ptr], #0xb0]\n"
+ "umlal v26.4s, %[b1].4h, %[aa].h[5]\n"
+ "str q9, [%x[c_ptr], #0xc0]\n"
+ "umlal v11.4s, %[b0].4h, %[aa].h[6]\n"
+ "str q17, [%x[c_ptr], #0xd0]\n"
+ "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
+ "str q25, [%x[c_ptr], #0xe0]\n"
+ "umlal v27.4s, %[b1].4h, %[aa].h[6]\n"
+ "str q10, [%x[c_ptr], #0xf0]\n"
+ "umlal v12.4s, %[b0].4h, %[aa].h[7]\n"
+ "str q18, [%x[c_ptr], #0x100]\n"
+ "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
+ "str q26, [%x[c_ptr], #0x110]\n"
+ "umlal v28.4s, %[b1].4h, %[aa].h[7]\n"
+ "str q11, [%x[c_ptr], #0x120]\n"
- "3:\n" // Odd tail
- "umlal v5.4s, %[b0].4h, %[aa].h[0]\n"
- "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
- "umlal v21.4s, %[b1].4h, %[aa].h[0]\n"
- "umlal v6.4s, %[b0].4h, %[aa].h[1]\n"
- "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
- "umlal v22.4s, %[b1].4h, %[aa].h[1]\n"
- "str q5, [%x[c_ptr]]\n"
- "umlal v7.4s, %[b0].4h, %[aa].h[2]\n"
- "str q13, [%x[c_ptr], #0x10]\n"
- "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
- "str q21, [%x[c_ptr], #0x20]\n"
- "umlal v23.4s, %[b1].4h, %[aa].h[2]\n"
- "str q6, [%x[c_ptr], #0x30]\n"
- "umlal v8.4s, %[b0].4h, %[aa].h[3]\n"
- "str q14, [%x[c_ptr], #0x40]\n"
- "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
- "str q22, [%x[c_ptr], #0x50]\n"
- "umlal v24.4s, %[b1].4h, %[aa].h[3]\n"
- "str q7, [%x[c_ptr], #0x60]\n"
- "umlal v9.4s, %[b0].4h, %[aa].h[4]\n"
- "str q15, [%x[c_ptr], #0x70]\n"
- "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
- "str q23, [%x[c_ptr], #0x80]\n"
- "umlal v25.4s, %[b1].4h, %[aa].h[4]\n"
- "str q8, [%x[c_ptr], #0x90]\n"
- "umlal v10.4s, %[b0].4h, %[aa].h[5]\n"
- "str q16, [%x[c_ptr], #0xa0]\n"
- "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
- "str q24, [%x[c_ptr], #0xb0]\n"
- "umlal v26.4s, %[b1].4h, %[aa].h[5]\n"
- "str q9, [%x[c_ptr], #0xc0]\n"
- "umlal v11.4s, %[b0].4h, %[aa].h[6]\n"
- "str q17, [%x[c_ptr], #0xd0]\n"
- "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
- "str q25, [%x[c_ptr], #0xe0]\n"
- "umlal v27.4s, %[b1].4h, %[aa].h[6]\n"
- "str q10, [%x[c_ptr], #0xf0]\n"
- "umlal v12.4s, %[b0].4h, %[aa].h[7]\n"
- "str q18, [%x[c_ptr], #0x100]\n"
- "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
- "str q26, [%x[c_ptr], #0x110]\n"
- "umlal v28.4s, %[b1].4h, %[aa].h[7]\n"
- "str q11, [%x[c_ptr], #0x120]\n"
-
- "4:\n" // End of function
- "str q19, [%x[c_ptr], #0x130]\n"
- "str q27, [%x[c_ptr], #0x140]\n"
- "str q12, [%x[c_ptr], #0x150]\n"
- "str q20, [%x[c_ptr], #0x160]\n"
- "str q28, [%x[c_ptr], #0x170]\n"
- "add %x[c_ptr], %x[c_ptr], #0x180\n"
- : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k),
- [aa] "+w"(aa), [ab] "+w"(ab), [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2)
- : [odd_k] "r"(odd_k)
- : "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "x20", "cc");
- }
+ "4:\n" // End of function
+ "str q19, [%x[c_ptr], #0x130]\n"
+ "str q27, [%x[c_ptr], #0x140]\n"
+ "str q12, [%x[c_ptr], #0x150]\n"
+ "str q20, [%x[c_ptr], #0x160]\n"
+ "str q28, [%x[c_ptr], #0x170]\n"
+ "add %x[c_ptr], %x[c_ptr], #0x180\n"
+ : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k),
+ [aa] "+w" (aa), [ab] "+w" (ab), [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2)
+ : [odd_k] "r" (odd_k)
+ : "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "x20", "cc"
+ );
}
+ }
}
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp
index 26255b1..a252abf 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp
@@ -27,41 +27,51 @@
#include "arm_gemm.hpp"
-namespace arm_gemm
-{
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm {
+
// Load the actual kernel
void a64_gemm_u8_12x8(const uint8_t *, const uint8_t *, uint32_t *, int, int, int);
void a64_gemm_u8_12x8_a55r1(const uint8_t *, const uint8_t *, uint32_t *, int, int, int);
-class gemm_u8_12x8
-{
+class gemm_u8_12x8 {
public:
- typedef uint8_t operand_type;
+ typedef uint8_t operand_type;
typedef uint32_t result_type;
typedef void (*kern_type)(const uint8_t *, const uint8_t *, uint32_t *, int, int, int);
/* Describes the data layout for A input */
- static const int A_interleave = 8;
- static const int A_block = 4;
- static const bool A_transpose = false;
+ static const int A_interleave = 8;
+ static const int A_block = 4;
+ static const bool A_transpose = false;
/* Same for B input */
- static const int B_interleave = 12;
- static const int B_block = 4;
- static const bool B_transpose = true;
+ static const int B_interleave = 12;
+ static const int B_block = 4;
+ static const bool B_transpose = true;
/* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 4;
+ static int out_width() {
+ return 12;
+ }
+
+ static int out_height() {
+ return 8;
+ }
+
+ static int k_unroll() {
+ return 4;
+ }
+
+ // Use the standard fixed sized transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12, 4> transforms = {};
kern_type kernel = a64_gemm_u8_12x8;
- gemm_u8_12x8(const CPUInfo *ci)
- {
- if(ci->get_cpu_model() == CPUModel::A55r1)
- {
+ gemm_u8_12x8(const CPUInfo *ci) {
+ if (ci->get_cpu_model() == CPUModel::A55r1) {
kernel = a64_gemm_u8_12x8_a55r1;
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp
index f8fafbd..994aea6 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp
@@ -31,40 +31,37 @@
#include "dot_toolchain_support.h"
#endif
-namespace arm_gemm
-{
-void a64_gemm_u8_12x8_a55r1(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, const int ablocks, const int bblocks, const int K)
-{
+namespace arm_gemm {
+
+void a64_gemm_u8_12x8_a55r1(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, const int ablocks, const int bblocks, const int K) {
const uint8_t *a_ptr = Apanel;
- uint32_t *c_ptr = Cpanel;
+ uint32_t *c_ptr = Cpanel;
// We divide K by 4 because the udot instruction processes 4 elements at a time.
- const int W = K / 4;
+ const int W = K/4;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
- const int oddk = (W & 1);
- const int k_iters = ((W + 1) / 2) - 1;
+ const int oddk = (W & 1);
+ const int k_iters = ((W+1)/2) - 1;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const uint8_t *a_ptr0 = a_ptr;
- const uint8_t *b_ptr = Bpanel;
+ const uint8_t *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
a_ptr = a_ptr0;
int k = k_iters;
- register int32x4_t a0 asm("v0");
- register int32x4_t a1 asm("v1");
- register int32x4_t b0 asm("v2");
- register int32x4_t b1 asm("v3");
- register int32x4_t b2 asm("v4");
+ register int32x4_t a0 asm("v0");
+ register int32x4_t a1 asm("v1");
+ register int32x4_t b0 asm("v2");
+ register int32x4_t b1 asm("v3");
+ register int32x4_t b2 asm("v4");
register int32x4_t a0a asm("v5");
register int32x4_t a1a asm("v6");
- __asm __volatile(
+ __asm __volatile (
#ifdef NO_DOT_IN_TOOLCHAIN
_DECLARE_UDOT
#else
@@ -79,22 +76,39 @@
"ldr %q[a1], [%[a_ptr], #16]\n"
"movi v11.4s, #0x0\n"
"ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v12.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v13.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v14.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v15.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v17.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
"movi v18.4s, #0x0\n"
- "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v19.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
"movi v20.4s, #0x0\n"
- "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v21.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
"movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]")
+ "movi v23.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
"movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #384]")
+ "movi v25.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #384]")
"movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #448]")
+ "movi v27.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
"movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #384]")
+ "movi v29.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #384]")
"movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #512]")
+ "movi v31.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #512]")
// The loop is offset by these two instructions which must
// always be executed.
@@ -105,102 +119,105 @@
"cbz %w[k], 4f\n"
"1:\n"
- "udot v9.4s , %[b0].16b, %[a0].4b[1]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "udot v10.4s, %[b0].16b, %[a0].4b[2]\n"
- "subs %w[k], %w[k], #1\n"
- "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "udot v9.4s , %[b0].16b, %[a0].4b[1]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "udot v10.4s, %[b0].16b, %[a0].4b[2]\n"
+ "subs %w[k], %w[k], #1\n"
+ "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
- "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
"ins %[b2].d[1], x20\n"
- "udot v13.4s, %[b0].16b, %[a1].4b[1]\n"
+ "udot v13.4s, %[b0].16b, %[a1].4b[1]\n"
"ldr x20, [%[a_ptr], #40]\n"
- "udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
- "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
- "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
"ins %[a0a].d[1], x20\n"
- "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
"ldr x20, [%[a_ptr], #56]\n"
- "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
+ "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
- "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
"ins %[a1a].d[1], x20\n"
- "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
"ldr x20, [%[b_ptr], #56]\n"
- "udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
- "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
"ins %[b0].d[1], x20\n"
- "udot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ "udot v25.4s, %[b2].16b, %[a0].4b[1]\n"
"ldr x20, [%[b_ptr], #72]\n"
- "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "udot v27.4s, %[b2].16b, %[a0].4b[3]\n" ASM_PREFETCH("[%[a_ptr], #448]")
+ "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ ASM_PREFETCH("[%[a_ptr], #448]")
- "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
- "udot v29.4s, %[b2].16b, %[a1].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #576]")
- "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
- "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #576]")
+ "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- // Unroll 1
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ // Unroll 1
+ "ldr %d[b2], [%[b_ptr], #80]\n"
- "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
+ "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
"ins %[b1].d[1], x20\n"
- "udot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
+ "udot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
"ldr x20, [%[b_ptr], #88]\n"
- "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
- "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
- "ldr %d[a0], [%[a_ptr], #64]\n"
+ "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
+ "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
+ "ldr %d[a0], [%[a_ptr], #64]\n"
- "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
+ "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
"ins %[b2].d[1], x20\n"
"udot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
"ldr x20, [%[a_ptr], #72]\n"
- "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
- "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
- "ldr %d[a1], [%[a_ptr], #80]\n"
+ "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
+ "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
+ "ldr %d[a1], [%[a_ptr], #80]\n"
- "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
+ "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
"ins %[a0].d[1], x20\n"
- "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
+ "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
"ldr x20, [%[a_ptr], #88]\n"
- "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
- "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
- "ldr %d[b0], [%[b_ptr], #96]\n"
+ "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
+ "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
+ "ldr %d[b0], [%[b_ptr], #96]\n"
- "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
+ "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
"ins %[a1].d[1], x20\n"
- "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
+ "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
"ldr x20, [%[b_ptr], #104]\n"
- "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
- "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
- "ldr %d[b1], [%[b_ptr], #112]\n"
+ "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
+ "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
+ "ldr %d[b1], [%[b_ptr], #112]\n"
- "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
+ "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
"ins %[b0].d[1], x20\n"
- "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
+ "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
"ldr x20, [%[b_ptr], #120]\n"
- "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
- "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
+ "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
+ "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
- "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n" ASM_PREFETCH("[%[b_ptr], #640]")
- "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
+ "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
+ ASM_PREFETCH("[%[b_ptr], #640]")
+ "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
"ins %[b1].d[1], x20\n"
- "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
+ "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
"ldr %d[b2], [%[b_ptr], #32]\n"
"udot v8.4s , %[b0].16b, %[a0].4b[0]\n"
- "b.ne 1b\n"
+ "b.ne 1b\n"
// Branch here if K=1 or 2. Do the right thing for odd/even at the end.
"4:\n"
@@ -212,71 +229,83 @@
"cbnz %w[oddk], 2f\n"
// Even K continuation
- "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
- "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
"ins %[b2].d[1], x20\n"
"udot v13.4s, %[b0].16b, %[a1].4b[1]\n"
"ldr x20, [%[a_ptr], #40]\n"
- "udot v14.4s, %[b0].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr]]")
- "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ ASM_PREFETCHW("[%[c_ptr]]")
+ "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
- "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
"ins %[a0a].d[1], x20\n"
- "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
"ldr x20, [%[a_ptr], #56]\n"
- "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
+ "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
- "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
"ins %[a1a].d[1], x20\n"
- "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
"ldr x20, [%[b_ptr], #56]\n"
- "udot v22.4s, %[b1].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #64]")
- "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ ASM_PREFETCHW("[%[c_ptr], #64]")
+ "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "udot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #128]")
- "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "udot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #128]")
+ "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
- "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
"ins %[b0].d[1], x20\n"
- "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
"ldr x20, [%[b_ptr], #72]\n"
- "udot v30.4s, %[b2].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #192]")
- "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ ASM_PREFETCHW("[%[c_ptr], #192]")
+ "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
- "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
+ "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
"ins %[b1].d[1], x20\n"
"udot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
"ldr x20, [%[b_ptr], #88]\n"
- "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
+ "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
"ins %[b2].d[1], x20\n"
- "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]")
+ "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #256]")
"udot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
"udot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
- "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #320]")
+ "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
+ ASM_PREFETCHW("[%[c_ptr], #320]")
"udot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
- "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]")
+ "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #384]")
"udot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
- "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]")
+ "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #448]")
"udot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
"udot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
- "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]")
+ "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #512]")
"udot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
- "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]")
+ "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #576]")
"udot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
"udot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
- "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #640]")
+ "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #640]")
"udot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
- "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]")
+ "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #704]")
"udot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
"add %[a_ptr], %[a_ptr], #64\n"
"udot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
@@ -286,27 +315,41 @@
// Odd K continuation
"2:\n"
- "udot v11.4s, %[b0].16b, %[a0].4b[3]\n" ASM_PREFETCHW("[%[c_ptr]]")
+ "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ ASM_PREFETCHW("[%[c_ptr]]")
"udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
"ins %[b2].d[1], x20\n"
- "udot v13.4s, %[b0].16b, %[a1].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]")
+ "udot v13.4s, %[b0].16b, %[a1].4b[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #64]")
"udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
"add %[a_ptr], %[a_ptr], #32\n"
- "udot v15.4s, %[b0].16b, %[a1].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]")
+ "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #128]")
"udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
"add %[b_ptr], %[b_ptr], #48\n"
- "udot v17.4s, %[b1].16b, %[a0].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]")
+ "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #192]")
"udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "udot v19.4s, %[b1].16b, %[a0].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]")
+ "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #256]")
"udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
- "udot v21.4s, %[b1].16b, %[a1].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]")
+ "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #320]")
"udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "udot v23.4s, %[b1].16b, %[a1].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]")
+ "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #384]")
"udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "udot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]")
+ "udot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #448]")
"udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "udot v27.4s, %[b2].16b, %[a0].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") "udot v28.4s, %[b2].16b, %[a1].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
- ASM_PREFETCHWL2("[%[c_ptr], #640]") "udot v30.4s, %[b2].16b, %[a1].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]")
+ "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #512]")
+ "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #576]")
+ "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #640]")
+ "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #704]")
"udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
// Common tail
@@ -340,13 +383,15 @@
#ifdef NO_DOT_IN_TOOLCHAIN
".purgem udot\n"
#endif
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "+w" (a0), [a1] "+w" (a1), [a0a] "+w" (a0a), [a1a] "+w" (a1a),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"
+ );
+
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/dot_toolchain_support.h b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/dot_toolchain_support.h
index 5ee273b..b05e899 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/dot_toolchain_support.h
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/dot_toolchain_support.h
@@ -22,45 +22,46 @@
* SOFTWARE.
*/
+
// Define a macro to assemble the UDOT instruction (in the absence of toolchain support)
-#define _DECLARE_UDOT \
- ".altmacro\n" \
- ".macro udot opd:req, opn:req, opm:req\n" \
- "local vd, vn, vm, h, l\n" \
- ".irp reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31\n" \
- ".ifeqs \"\\opd\",\"v\\reg\\.4s\"\n" \
- ".set vd,\\reg\n" \
- ".endif\n" \
- ".ifeqs \"\\opn\",\"v\\reg\\.16b\"\n" \
- ".set vn,\\reg\n" \
- ".endif\n" \
- ".irp idx,0,1,2,3\n" \
- ".ifeqs \"\\opm\",\"v\\reg\\.4b[\\idx\\]\"\n" \
- ".set vm,\\reg\n" \
- ".set h,\\idx / 2\n" \
- ".set l,\\idx %% 2\n" \
- ".endif\n" \
- ".endr\n" \
- ".endr\n" \
- ".ifndef vd\n" \
- ".error \"Bad operand \\opd\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".ifndef vn\n" \
- ".error \"Bad operand \\opn\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".ifndef vm\n" \
- ".error \"Bad operand \\opm\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".ifndef h\n" \
- ".error \"Bad operand \\opm\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".ifndef l\n" \
- ".error \"Bad operand \\opm\"\n" \
- ".exitm\n" \
- ".endif\n" \
- ".int 0x6f80e000 | vd | (vn << 5) | (vm << 16) | (l << 21) | (h << 11)\n" \
- ".endm\n"
+#define _DECLARE_UDOT ".altmacro\n"\
+ ".macro udot opd:req, opn:req, opm:req\n"\
+ "local vd, vn, vm, h, l\n"\
+ ".irp reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31\n"\
+ ".ifeqs \"\\opd\",\"v\\reg\\.4s\"\n"\
+ ".set vd,\\reg\n"\
+ ".endif\n"\
+ ".ifeqs \"\\opn\",\"v\\reg\\.16b\"\n"\
+ ".set vn,\\reg\n"\
+ ".endif\n"\
+ ".irp idx,0,1,2,3\n"\
+ ".ifeqs \"\\opm\",\"v\\reg\\.4b[\\idx\\]\"\n"\
+ ".set vm,\\reg\n"\
+ ".set h,\\idx / 2\n"\
+ ".set l,\\idx %% 2\n"\
+ ".endif\n"\
+ ".endr\n"\
+ ".endr\n"\
+ ".ifndef vd\n"\
+ ".error \"Bad operand \\opd\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".ifndef vn\n"\
+ ".error \"Bad operand \\opn\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".ifndef vm\n"\
+ ".error \"Bad operand \\opm\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".ifndef h\n"\
+ ".error \"Bad operand \\opm\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".ifndef l\n"\
+ ".error \"Bad operand \\opm\"\n"\
+ ".exitm\n"\
+ ".endif\n"\
+ ".int 0x6f80e000 | vd | (vn << 5) | (vm << 16) | (l << 21) | (h << 11)\n"\
+ ".endm\n"\
+
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/generic.cpp
index d026dc5..80dd873 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/generic.cpp
@@ -31,309 +31,328 @@
#include "dot_toolchain_support.h"
#endif
-namespace arm_gemm
-{
-void a64_gemm_u8_12x8(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a64_gemm_u8_12x8(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K) {
const uint8_t *a_ptr = Apanel;
- uint32_t *c_ptr = Cpanel;
+ uint32_t *c_ptr = Cpanel;
// We divide K by 4 because the udot instruction processes 4 elements at a time.
- const int W = K / 4;
+ const int W = K/4;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
- const int oddk = (W & 1);
- const int init_value_k = ((W + 1) / 2) - 1;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ const int oddk = (W & 1);
+ const int init_value_k = ((W+1)/2) - 1;
+ for (int yb=0; yb<ablocks; yb++) {
const uint8_t *a_ptr0 = a_ptr;
- const uint8_t *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
- a_ptr = a_ptr0;
- int k = init_value_k;
- register uint8x16_t a0 asm("v0");
- register uint8x16_t a1 asm("v1");
- register uint8x16_t b0 asm("v2");
- register uint8x16_t b1 asm("v3");
- register uint8x16_t b2 asm("v4");
+ const uint8_t *b_ptr = Bpanel;
+ for (int xb=0; xb<bblocks; xb++) {
+ a_ptr = a_ptr0;
+ int k = init_value_k;
+ register uint8x16_t a0 asm("v0");
+ register uint8x16_t a1 asm("v1");
+ register uint8x16_t b0 asm("v2");
+ register uint8x16_t b1 asm("v3");
+ register uint8x16_t b2 asm("v4");
register uint8x16_t a0a asm("v5");
register uint8x16_t a1a asm("v6");
- __asm __volatile(
+ __asm __volatile (
#ifdef NO_DOT_IN_TOOLCHAIN
_DECLARE_UDOT
#else
".arch armv8.2-a+dotprod\n"
#endif
// Initialize result registers, load initial operands, prime prefetches.
- "movi v8.4s, #0x0\n"
- "ldr %q[a0], [%[a_ptr]]\n"
- "movi v9.4s, #0x0\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "movi v10.4s, #0x0\n"
- "ldr %q[a1], [%[a_ptr], #16]\n"
- "movi v11.4s, #0x0\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n"
+ "movi v8.4s, #0x0\n"
+ "ldr %q[a0], [%[a_ptr]]\n"
+ "movi v9.4s, #0x0\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "movi v10.4s, #0x0\n"
+ "ldr %q[a1], [%[a_ptr], #16]\n"
+ "movi v11.4s, #0x0\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "movi v12.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v13.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v14.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v15.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v17.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v18.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v19.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v20.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
+ "movi v21.4s, #0x0\n"
ASM_PREFETCH("[%[b_ptr], #384]")
- "movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n"
- "movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n"
- "movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n"
- "movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- "movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n"
+ "movi v22.4s, #0x0\n"
+ "movi v23.4s, #0x0\n"
+ "movi v24.4s, #0x0\n"
+ "movi v25.4s, #0x0\n"
+ "movi v26.4s, #0x0\n"
+ "movi v27.4s, #0x0\n"
+ "movi v28.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v31.4s, #0x0\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 4f\n"
+ "cbz %w[k], 4f\n"
// Loop proper
"1:\n"
- "udot v8.4s , %[b0].16b, %[a0].4b[0]\n"
- "udot v9.4s , %[b0].16b, %[a0].4b[1]\n"
+ "udot v8.4s , %[b0].16b, %[a0].4b[0]\n"
+ "udot v9.4s , %[b0].16b, %[a0].4b[1]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "udot v10.4s, %[b0].16b, %[a0].4b[2]\n"
- "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "ldr %q[a0a], [%[a_ptr], #32]\n"
- "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
- "udot v13.4s, %[b0].16b, %[a1].4b[1]\n"
- "ldr %q[a1a], [%[a_ptr], #48]\n"
- "udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
- "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "ldr %q[b0], [%[b_ptr], #48]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "udot v10.4s, %[b0].16b, %[a0].4b[2]\n"
+ "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "ldr %q[a0a], [%[a_ptr], #32]\n"
+ "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "udot v13.4s, %[b0].16b, %[a1].4b[1]\n"
+ "ldr %q[a1a], [%[a_ptr], #48]\n"
+ "udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "ldr %q[b0], [%[b_ptr], #48]\n"
- "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
- "udot v17.4s, %[b1].16b, %[a0].4b[1]\n" ASM_PREFETCH("[%[a_ptr], #320]")
- "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
- "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
- "udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "ldr %q[b1], [%[b_ptr], #64]\n"
+ "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #320]")
+ "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "ldr %q[b1], [%[b_ptr], #64]\n"
- "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "udot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #448]")
- "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
- "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
- "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
- "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
- "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- "ldr %q[b2], [%[b_ptr], #80]\n"
+ "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "udot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
+ "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "ldr %q[b2], [%[b_ptr], #80]\n"
- "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
- "udot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
- "ldr %q[a0], [%[a_ptr], #64]\n"
- "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
- "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
- "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
- "ldr %q[a1], [%[a_ptr], #80]\n"
+ "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
+ "udot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
+ "ldr %q[a0], [%[a_ptr], #64]\n"
+ "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
+ "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
+ "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
+ "ldr %q[a1], [%[a_ptr], #80]\n"
"udot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
- "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
- "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
- "ldr %q[b0], [%[b_ptr], #96]\n"
+ "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
+ "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
+ "ldr %q[b0], [%[b_ptr], #96]\n"
- "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
- "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #512]")
- "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
- "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
- "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
- "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
- "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
- "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
- "ldr %q[b1], [%[b_ptr], #112]\n"
+ "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
+ "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #512]")
+ "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
+ "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
+ "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
+ "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
+ "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
+ "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
+ "ldr %q[b1], [%[b_ptr], #112]\n"
- "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
- "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
- "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
- "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
- "subs %w[k], %w[k], #1\n"
- "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
- "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
- "bne 1b\n"
+ "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
+ "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
+ "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
+ "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
+ "subs %w[k], %w[k], #1\n"
+ "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
+ "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
+ "bne 1b\n"
// Target to use when K is 1 or 2 (i.e. zero iterations of main loop)
"4:\n"
// Branch to alternative tail for odd K
- "cbnz %w[oddk], 2f\n"
+ "cbnz %w[oddk], 2f\n"
// Detached final iteration (even K)
- "udot v8.4s , %[b0].16b, %[a0].4b[0]\n"
+ "udot v8.4s , %[b0].16b, %[a0].4b[0]\n"
"udot v9.4s , %[b0].16b, %[a0].4b[1]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "udot v10.4s, %[b0].16b, %[a0].4b[2]\n"
- "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "ldr %q[a0a], [%[a_ptr], #32]\n"
- "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "udot v10.4s, %[b0].16b, %[a0].4b[2]\n"
+ "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "ldr %q[a0a], [%[a_ptr], #32]\n"
+ "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
"udot v13.4s, %[b0].16b, %[a1].4b[1]\n"
- "ldr %q[a1a], [%[a_ptr], #48]\n"
- "udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
- "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "ldr %q[b0], [%[b_ptr], #48]\n"
+ "ldr %q[a1a], [%[a_ptr], #48]\n"
+ "udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "ldr %q[b0], [%[b_ptr], #48]\n"
- "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
- "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
- "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
- "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
- "udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "ldr %q[b1], [%[b_ptr], #64]\n"
+ "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "ldr %q[b1], [%[b_ptr], #64]\n"
- "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "udot v25.4s, %[b2].16b, %[a0].4b[1]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
- "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
- "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
- "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
- "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- "ldr %q[b2], [%[b_ptr], #80]\n"
+ "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "udot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "ldr %q[b2], [%[b_ptr], #80]\n"
- "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
+ "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
- "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
+ "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
"udot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
- "str q8, [%[c_ptr], #0]\n"
- "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
- "str q16, [%[c_ptr], #16]\n"
- "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
- "str q24, [%[c_ptr], #32]\n"
+ "str q8, [%[c_ptr], #0]\n"
+ "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
+ "str q24, [%[c_ptr], #32]\n"
- "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
- "str q9, [%[c_ptr], #48]\n"
- "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
- "str q17, [%[c_ptr], #64]\n"
- "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
- "str q25, [%[c_ptr], #80]\n"
- "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
- "str q10, [%[c_ptr], #96]\n"
+ "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
+ "str q9, [%[c_ptr], #48]\n"
+ "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
+ "str q10, [%[c_ptr], #96]\n"
- "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
- "str q18, [%[c_ptr], #112]\n"
- "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
- "str q26, [%[c_ptr], #128]\n"
- "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
- "str q11, [%[c_ptr], #144]\n"
+ "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
+ "str q11, [%[c_ptr], #144]\n"
- "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
- "str q19, [%[c_ptr], #160]\n"
- "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
- "str q27, [%[c_ptr], #176]\n"
- "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
- "str q12, [%[c_ptr], #192]\n"
+ "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
+ "str q12, [%[c_ptr], #192]\n"
"udot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
- "str q20, [%[c_ptr], #208]\n"
- "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
- "str q28, [%[c_ptr], #224]\n"
- "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
- "str q13, [%[c_ptr], #240]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
+ "str q13, [%[c_ptr], #240]\n"
- "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
- "str q21, [%[c_ptr], #256]\n"
- "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
- "str q29, [%[c_ptr], #272]\n"
- "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
- "str q14, [%[c_ptr], #288]\n"
+ "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
+ "str q14, [%[c_ptr], #288]\n"
- "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
- "str q22, [%[c_ptr], #304]\n"
- "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
- "str q30, [%[c_ptr], #320]\n"
- "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
- "str q15, [%[c_ptr], #336]\n"
+ "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
+ "str q15, [%[c_ptr], #336]\n"
- "b 3f\n"
+ "b 3f\n"
// Detached final iteration (odd K)
"2:\n"
- "udot v8.4s , %[b0].16b, %[a0].4b[0]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
+ "udot v8.4s , %[b0].16b, %[a0].4b[0]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "udot v16.4s, %[b1].16b, %[a0].4b[0]\n"
"udot v9.4s , %[b0].16b, %[a0].4b[1]\n"
- "str q8, [%[c_ptr], #0]\n"
- "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
- "str q16, [%[c_ptr], #16]\n"
- "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
- "add %[b_ptr], %[b_ptr], #48\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "str q24, [%[c_ptr], #32]\n"
- "udot v25.4s, %[b2].16b, %[a0].4b[1]\n"
- "str q9, [%[c_ptr], #48]\n"
+ "str q8, [%[c_ptr], #0]\n"
+ "udot v17.4s, %[b1].16b, %[a0].4b[1]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "udot v24.4s, %[b2].16b, %[a0].4b[0]\n"
+ "add %[b_ptr], %[b_ptr], #48\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "str q24, [%[c_ptr], #32]\n"
+ "udot v25.4s, %[b2].16b, %[a0].4b[1]\n"
+ "str q9, [%[c_ptr], #48]\n"
- "udot v10.4s, %[b0].16b, %[a0].4b[2]\n"
- "str q17, [%[c_ptr], #64]\n"
- "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
- "str q25, [%[c_ptr], #80]\n"
- "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
- "str q10, [%[c_ptr], #96]\n"
+ "udot v10.4s, %[b0].16b, %[a0].4b[2]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "udot v18.4s, %[b1].16b, %[a0].4b[2]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "udot v26.4s, %[b2].16b, %[a0].4b[2]\n"
+ "str q10, [%[c_ptr], #96]\n"
- "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
- "str q18, [%[c_ptr], #112]\n"
- "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
- "str q26, [%[c_ptr], #128]\n"
- "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
- "str q11, [%[c_ptr], #144]\n"
+ "udot v11.4s, %[b0].16b, %[a0].4b[3]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "udot v19.4s, %[b1].16b, %[a0].4b[3]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "udot v27.4s, %[b2].16b, %[a0].4b[3]\n"
+ "str q11, [%[c_ptr], #144]\n"
- "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
- "str q19, [%[c_ptr], #160]\n"
- "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
- "str q27, [%[c_ptr], #176]\n"
- "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
- "str q12, [%[c_ptr], #192]\n"
+ "udot v12.4s, %[b0].16b, %[a1].4b[0]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "udot v20.4s, %[b1].16b, %[a1].4b[0]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "udot v28.4s, %[b2].16b, %[a1].4b[0]\n"
+ "str q12, [%[c_ptr], #192]\n"
"udot v13.4s, %[b0].16b, %[a1].4b[1]\n"
- "str q20, [%[c_ptr], #208]\n"
- "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
- "str q28, [%[c_ptr], #224]\n"
- "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
- "str q13, [%[c_ptr], #240]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "udot v21.4s, %[b1].16b, %[a1].4b[1]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "udot v29.4s, %[b2].16b, %[a1].4b[1]\n"
+ "str q13, [%[c_ptr], #240]\n"
- "udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
- "str q21, [%[c_ptr], #256]\n"
- "udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
- "str q29, [%[c_ptr], #272]\n"
- "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
- "str q14, [%[c_ptr], #288]\n"
+ "udot v14.4s, %[b0].16b, %[a1].4b[2]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "udot v22.4s, %[b1].16b, %[a1].4b[2]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "udot v30.4s, %[b2].16b, %[a1].4b[2]\n"
+ "str q14, [%[c_ptr], #288]\n"
- "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
- "str q22, [%[c_ptr], #304]\n"
- "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
- "str q30, [%[c_ptr], #320]\n"
- "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
- "str q15, [%[c_ptr], #336]\n"
+ "udot v15.4s, %[b0].16b, %[a1].4b[3]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "udot v23.4s, %[b1].16b, %[a1].4b[3]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "udot v31.4s, %[b2].16b, %[a1].4b[3]\n"
+ "str q15, [%[c_ptr], #336]\n"
+
// Common tail
"3:\n"
- "str q23, [%[c_ptr], #352]\n"
- "str q31, [%[c_ptr], #368]\n"
- "add %[c_ptr], %[c_ptr], #384\n"
+ "str q23, [%[c_ptr], #352]\n"
+ "str q31, [%[c_ptr], #368]\n"
+ "add %[c_ptr], %[c_ptr], #384\n"
#ifdef NO_DOT_IN_TOOLCHAIN
".purgem udot\n"
#endif
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc");
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "+w" (a0), [a1] "+w" (a1), [a0a] "+w" (a0a), [a1a] "+w" (a1a),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"
+ );
+
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp
index 5aa5291..2da3ecd 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp
@@ -25,39 +25,49 @@
#ifdef __aarch64__
-namespace arm_gemm
-{
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm {
+
// Kernel definition
void a64_gemm_u8_4x4(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K);
-class gemm_u8_4x4
-{
+class gemm_u8_4x4 {
public:
- typedef uint8_t operand_type;
+ typedef uint8_t operand_type;
typedef uint32_t result_type;
typedef void (*kern_type)(const uint8_t *, const uint8_t *, uint32_t *, int, int, int);
/* Describes the data layout for A input */
- static const int A_interleave = 4;
- static const int A_block = 16;
- static const bool A_transpose = false;
+ static const int A_interleave = 4;
+ static const int A_block = 16;
+ static const bool A_transpose = false;
/* Same for B input */
- static const int B_interleave = 4;
- static const int B_block = 16;
- static const bool B_transpose = true;
+ static const int B_interleave = 4;
+ static const int B_block = 16;
+ static const bool B_transpose = true;
/* Kernel blocking parameters */
- static const int out_width = 4;
- static const int out_height = 4;
- static const int k_unroll = 16;
+ static int out_width() {
+ return 4;
+ }
- kern_type kernel = nullptr;
+ static int out_height() {
+ return 4;
+ }
- gemm_u8_4x4(const CPUInfo *ci)
- {
- kernel = a64_gemm_u8_4x4;
+ static int k_unroll() {
+ return 16;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 4, 4, 16> transforms = {};
+
+ kern_type kernel = a64_gemm_u8_4x4;
+
+ gemm_u8_4x4(const CPUInfo *ci) {
}
};
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4/generic.cpp
index 0a881ff..2e60833 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,243 +27,255 @@
#include "../../asmlib.hpp"
-namespace arm_gemm
-{
-void a64_gemm_u8_4x4(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a64_gemm_u8_4x4(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K) {
const uint8_t *a_ptr = Apanel;
- uint32_t *c_ptr = Cpanel;
+ uint32_t *c_ptr = Cpanel;
K /= 16;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const uint8_t *a_ptr0 = a_ptr;
- const uint8_t *b_ptr = Bpanel;
+ const uint8_t *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
a_ptr = a_ptr0;
- int k = K - 1;
+ int k = K-1;
- register uint8x16_t b0 asm("v4");
- register uint8x16_t b1 asm("v5");
- register uint8x16_t b2 asm("v6");
- register uint8x16_t b3 asm("v7");
+ register uint8x16_t b0 asm("v4");
+ register uint8x16_t b1 asm("v5");
+ register uint8x16_t b2 asm("v6");
+ register uint8x16_t b3 asm("v7");
- __asm __volatile(
- "movi v16.4s, #0x0\n"
- "ldr q0, [%[a_ptr]]\n"
- "movi v17.4s, #0x0\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "movi v18.4s, #0x0\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v19.4s, #0x0\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "movi v20.4s, #0x0\n"
- "ldr %q[b3], [%[b_ptr], #48]\n"
- "movi v21.4s, #0x0\n"
- "ldr q1, [%[a_ptr], #16]\n"
- "movi v22.4s, #0x0\n"
- "ldr q2, [%[a_ptr], #32]\n"
- "movi v23.4s, #0x0\n"
- "ldr q3, [%[a_ptr], #48]\n"
- "movi v24.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v26.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v27.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v28.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]") "movi v30.4s, #0x0\n"
- ASM_PREFETCH("[%[b_ptr], #256]") "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]")
+ __asm __volatile (
+ "movi v16.4s, #0x0\n"
+ "ldr q0, [%[a_ptr]]\n"
+ "movi v17.4s, #0x0\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "movi v18.4s, #0x0\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "movi v19.4s, #0x0\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "movi v20.4s, #0x0\n"
+ "ldr %q[b3], [%[b_ptr], #48]\n"
+ "movi v21.4s, #0x0\n"
+ "ldr q1, [%[a_ptr], #16]\n"
+ "movi v22.4s, #0x0\n"
+ "ldr q2, [%[a_ptr], #32]\n"
+ "movi v23.4s, #0x0\n"
+ "ldr q3, [%[a_ptr], #48]\n"
+ "movi v24.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v25.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v26.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v27.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v28.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v29.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v30.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v31.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
- "umull v12.8h, v0.8b, %[b0].8b\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "umull v13.8h, v0.8b, %[b1].8b\n"
- "umull v14.8h, v0.8b, %[b2].8b\n"
- "add %[b_ptr], %[b_ptr], #64\n"
- "umull v15.8h, v0.8b, %[b3].8b\n"
+ "umull v12.8h, v0.8b, %[b0].8b\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "umull v13.8h, v0.8b, %[b1].8b\n"
+ "umull v14.8h, v0.8b, %[b2].8b\n"
+ "add %[b_ptr], %[b_ptr], #64\n"
+ "umull v15.8h, v0.8b, %[b3].8b\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 2f\n"
+ "cbz %w[k], 2f\n"
"1:\n"
- "uadalp v16.4s, v12.8h\n"
- "umull2 v12.8h, v0.16b, %[b0].16b\n"
- "uadalp v17.4s, v13.8h\n"
- "umull2 v13.8h, v0.16b, %[b1].16b\n"
- "uadalp v18.4s, v14.8h\n"
- "umull2 v14.8h, v0.16b, %[b2].16b\n"
- "uadalp v19.4s, v15.8h\n"
- "umull2 v15.8h, v0.16b, %[b3].16b\n"
- "ldr q0, [%[a_ptr]]\n"
+ "uadalp v16.4s, v12.8h\n"
+ "umull2 v12.8h, v0.16b, %[b0].16b\n"
+ "uadalp v17.4s, v13.8h\n"
+ "umull2 v13.8h, v0.16b, %[b1].16b\n"
+ "uadalp v18.4s, v14.8h\n"
+ "umull2 v14.8h, v0.16b, %[b2].16b\n"
+ "uadalp v19.4s, v15.8h\n"
+ "umull2 v15.8h, v0.16b, %[b3].16b\n"
+ "ldr q0, [%[a_ptr]]\n"
- "uadalp v16.4s, v12.8h\n"
- "umull v12.8h, v1.8b, %[b0].8b\n"
- "uadalp v17.4s, v13.8h\n"
- "umull v13.8h, v1.8b, %[b1].8b\n"
- "subs %w[k], %w[k], #1\n"
- "uadalp v18.4s, v14.8h\n"
- "umull v14.8h, v1.8b, %[b2].8b\n"
- "uadalp v19.4s, v15.8h\n"
- "umull v15.8h, v1.8b, %[b3].8b\n"
+ "uadalp v16.4s, v12.8h\n"
+ "umull v12.8h, v1.8b, %[b0].8b\n"
+ "uadalp v17.4s, v13.8h\n"
+ "umull v13.8h, v1.8b, %[b1].8b\n"
+ "subs %w[k], %w[k], #1\n"
+ "uadalp v18.4s, v14.8h\n"
+ "umull v14.8h, v1.8b, %[b2].8b\n"
+ "uadalp v19.4s, v15.8h\n"
+ "umull v15.8h, v1.8b, %[b3].8b\n"
- "uadalp v20.4s, v12.8h\n"
- "umull2 v12.8h, v1.16b, %[b0].16b\n"
- "uadalp v21.4s, v13.8h\n"
- "umull2 v13.8h, v1.16b, %[b1].16b\n" ASM_PREFETCH("[%[a_ptr], #256]")
- "uadalp v22.4s, v14.8h\n"
- "umull2 v14.8h, v1.16b, %[b2].16b\n"
- "uadalp v23.4s, v15.8h\n"
- "umull2 v15.8h, v1.16b, %[b3].16b\n"
- "ldr q1, [%[a_ptr], #16]\n"
+ "uadalp v20.4s, v12.8h\n"
+ "umull2 v12.8h, v1.16b, %[b0].16b\n"
+ "uadalp v21.4s, v13.8h\n"
+ "umull2 v13.8h, v1.16b, %[b1].16b\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
+ "uadalp v22.4s, v14.8h\n"
+ "umull2 v14.8h, v1.16b, %[b2].16b\n"
+ "uadalp v23.4s, v15.8h\n"
+ "umull2 v15.8h, v1.16b, %[b3].16b\n"
+ "ldr q1, [%[a_ptr], #16]\n"
- "uadalp v20.4s, v12.8h\n"
- "umull v12.8h, v2.8b, %[b0].8b\n"
- "uadalp v21.4s, v13.8h\n"
- "umull v13.8h, v2.8b, %[b1].8b\n" ASM_PREFETCH("[%[b_ptr], #256]")
- "uadalp v22.4s, v14.8h\n"
- "umull v14.8h, v2.8b, %[b2].8b\n"
- "uadalp v23.4s, v15.8h\n"
- "umull v15.8h, v2.8b, %[b3].8b\n"
+ "uadalp v20.4s, v12.8h\n"
+ "umull v12.8h, v2.8b, %[b0].8b\n"
+ "uadalp v21.4s, v13.8h\n"
+ "umull v13.8h, v2.8b, %[b1].8b\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "uadalp v22.4s, v14.8h\n"
+ "umull v14.8h, v2.8b, %[b2].8b\n"
+ "uadalp v23.4s, v15.8h\n"
+ "umull v15.8h, v2.8b, %[b3].8b\n"
- "uadalp v24.4s, v12.8h\n"
- "umull2 v12.8h, v2.16b, %[b0].16b\n"
- "uadalp v25.4s, v13.8h\n"
- "umull2 v13.8h, v2.16b, %[b1].16b\n"
- "uadalp v26.4s, v14.8h\n"
- "umull2 v14.8h, v2.16b, %[b2].16b\n"
- "uadalp v27.4s, v15.8h\n"
- "umull2 v15.8h, v2.16b, %[b3].16b\n"
- "ldr q2, [%[a_ptr], #32]\n"
+ "uadalp v24.4s, v12.8h\n"
+ "umull2 v12.8h, v2.16b, %[b0].16b\n"
+ "uadalp v25.4s, v13.8h\n"
+ "umull2 v13.8h, v2.16b, %[b1].16b\n"
+ "uadalp v26.4s, v14.8h\n"
+ "umull2 v14.8h, v2.16b, %[b2].16b\n"
+ "uadalp v27.4s, v15.8h\n"
+ "umull2 v15.8h, v2.16b, %[b3].16b\n"
+ "ldr q2, [%[a_ptr], #32]\n"
- "uadalp v24.4s, v12.8h\n"
- "umull v12.8h, v3.8b, %[b0].8b\n"
- "uadalp v25.4s, v13.8h\n"
- "umull v13.8h, v3.8b, %[b1].8b\n"
- "uadalp v26.4s, v14.8h\n"
- "umull v14.8h, v3.8b, %[b2].8b\n"
- "uadalp v27.4s, v15.8h\n"
- "umull v15.8h, v3.8b, %[b3].8b\n"
+ "uadalp v24.4s, v12.8h\n"
+ "umull v12.8h, v3.8b, %[b0].8b\n"
+ "uadalp v25.4s, v13.8h\n"
+ "umull v13.8h, v3.8b, %[b1].8b\n"
+ "uadalp v26.4s, v14.8h\n"
+ "umull v14.8h, v3.8b, %[b2].8b\n"
+ "uadalp v27.4s, v15.8h\n"
+ "umull v15.8h, v3.8b, %[b3].8b\n"
- "uadalp v28.4s, v12.8h\n"
- "umull2 v12.8h, v3.16b, %[b0].16b\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "uadalp v29.4s, v13.8h\n"
- "umull2 v13.8h, v3.16b, %[b1].16b\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "uadalp v30.4s, v14.8h\n"
- "umull2 v14.8h, v3.16b, %[b2].16b\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "uadalp v31.4s, v15.8h\n"
- "umull2 v15.8h, v3.16b, %[b3].16b\n"
- "ldr %q[b3], [%[b_ptr], #48]\n"
+ "uadalp v28.4s, v12.8h\n"
+ "umull2 v12.8h, v3.16b, %[b0].16b\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "uadalp v29.4s, v13.8h\n"
+ "umull2 v13.8h, v3.16b, %[b1].16b\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "uadalp v30.4s, v14.8h\n"
+ "umull2 v14.8h, v3.16b, %[b2].16b\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "uadalp v31.4s, v15.8h\n"
+ "umull2 v15.8h, v3.16b, %[b3].16b\n"
+ "ldr %q[b3], [%[b_ptr], #48]\n"
- "uadalp v28.4s, v12.8h\n"
- "umull v12.8h, v0.8b, %[b0].8b\n"
- "add %[b_ptr], %[b_ptr], #64\n"
- "uadalp v29.4s, v13.8h\n"
- "umull v13.8h, v0.8b, %[b1].8b\n"
- "ldr q3, [%[a_ptr], #48]\n"
- "uadalp v30.4s, v14.8h\n"
- "umull v14.8h, v0.8b, %[b2].8b\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "uadalp v31.4s, v15.8h\n"
- "umull v15.8h, v0.8b, %[b3].8b\n"
- "bne 1b\n"
+ "uadalp v28.4s, v12.8h\n"
+ "umull v12.8h, v0.8b, %[b0].8b\n"
+ "add %[b_ptr], %[b_ptr], #64\n"
+ "uadalp v29.4s, v13.8h\n"
+ "umull v13.8h, v0.8b, %[b1].8b\n"
+ "ldr q3, [%[a_ptr], #48]\n"
+ "uadalp v30.4s, v14.8h\n"
+ "umull v14.8h, v0.8b, %[b2].8b\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "uadalp v31.4s, v15.8h\n"
+ "umull v15.8h, v0.8b, %[b3].8b\n"
+ "bne 1b\n"
// Branch target
"2:\n"
- "uadalp v16.4s, v12.8h\n"
- "umull2 v12.8h, v0.16b, %[b0].16b\n"
- "uadalp v17.4s, v13.8h\n"
- "umull2 v13.8h, v0.16b, %[b1].16b\n"
- "uadalp v18.4s, v14.8h\n"
- "umull2 v14.8h, v0.16b, %[b2].16b\n"
- "uadalp v19.4s, v15.8h\n"
- "umull2 v15.8h, v0.16b, %[b3].16b\n"
+ "uadalp v16.4s, v12.8h\n"
+ "umull2 v12.8h, v0.16b, %[b0].16b\n"
+ "uadalp v17.4s, v13.8h\n"
+ "umull2 v13.8h, v0.16b, %[b1].16b\n"
+ "uadalp v18.4s, v14.8h\n"
+ "umull2 v14.8h, v0.16b, %[b2].16b\n"
+ "uadalp v19.4s, v15.8h\n"
+ "umull2 v15.8h, v0.16b, %[b3].16b\n"
- "uadalp v16.4s, v12.8h\n"
- "umull v12.8h, v1.8b, %[b0].8b\n"
- "uadalp v17.4s, v13.8h\n"
- "umull v13.8h, v1.8b, %[b1].8b\n"
- "uadalp v18.4s, v14.8h\n"
- "umull v14.8h, v1.8b, %[b2].8b\n"
- "uadalp v19.4s, v15.8h\n"
- "umull v15.8h, v1.8b, %[b3].8b\n"
+ "uadalp v16.4s, v12.8h\n"
+ "umull v12.8h, v1.8b, %[b0].8b\n"
+ "uadalp v17.4s, v13.8h\n"
+ "umull v13.8h, v1.8b, %[b1].8b\n"
+ "uadalp v18.4s, v14.8h\n"
+ "umull v14.8h, v1.8b, %[b2].8b\n"
+ "uadalp v19.4s, v15.8h\n"
+ "umull v15.8h, v1.8b, %[b3].8b\n"
- "uadalp v20.4s, v12.8h\n"
- "umull2 v12.8h, v1.16b, %[b0].16b\n"
- "uadalp v21.4s, v13.8h\n"
- "umull2 v13.8h, v1.16b, %[b1].16b\n"
- "uadalp v22.4s, v14.8h\n"
- "umull2 v14.8h, v1.16b, %[b2].16b\n"
- "uadalp v23.4s, v15.8h\n"
- "umull2 v15.8h, v1.16b, %[b3].16b\n"
+ "uadalp v20.4s, v12.8h\n"
+ "umull2 v12.8h, v1.16b, %[b0].16b\n"
+ "uadalp v21.4s, v13.8h\n"
+ "umull2 v13.8h, v1.16b, %[b1].16b\n"
+ "uadalp v22.4s, v14.8h\n"
+ "umull2 v14.8h, v1.16b, %[b2].16b\n"
+ "uadalp v23.4s, v15.8h\n"
+ "umull2 v15.8h, v1.16b, %[b3].16b\n"
- "uadalp v20.4s, v12.8h\n"
- "umull v12.8h, v2.8b, %[b0].8b\n"
- "uadalp v21.4s, v13.8h\n"
- "umull v13.8h, v2.8b, %[b1].8b\n"
- "uadalp v22.4s, v14.8h\n"
- "umull v14.8h, v2.8b, %[b2].8b\n"
- "uadalp v23.4s, v15.8h\n"
- "umull v15.8h, v2.8b, %[b3].8b\n"
+ "uadalp v20.4s, v12.8h\n"
+ "umull v12.8h, v2.8b, %[b0].8b\n"
+ "uadalp v21.4s, v13.8h\n"
+ "umull v13.8h, v2.8b, %[b1].8b\n"
+ "uadalp v22.4s, v14.8h\n"
+ "umull v14.8h, v2.8b, %[b2].8b\n"
+ "uadalp v23.4s, v15.8h\n"
+ "umull v15.8h, v2.8b, %[b3].8b\n"
- "uadalp v24.4s, v12.8h\n"
- "umull2 v12.8h, v2.16b, %[b0].16b\n"
- "uadalp v25.4s, v13.8h\n"
- "umull2 v13.8h, v2.16b, %[b1].16b\n"
- "uadalp v26.4s, v14.8h\n"
- "umull2 v14.8h, v2.16b, %[b2].16b\n"
- "uadalp v27.4s, v15.8h\n"
- "umull2 v15.8h, v2.16b, %[b3].16b\n"
+ "uadalp v24.4s, v12.8h\n"
+ "umull2 v12.8h, v2.16b, %[b0].16b\n"
+ "uadalp v25.4s, v13.8h\n"
+ "umull2 v13.8h, v2.16b, %[b1].16b\n"
+ "uadalp v26.4s, v14.8h\n"
+ "umull2 v14.8h, v2.16b, %[b2].16b\n"
+ "uadalp v27.4s, v15.8h\n"
+ "umull2 v15.8h, v2.16b, %[b3].16b\n"
- "uadalp v24.4s, v12.8h\n"
- "umull v12.8h, v3.8b, %[b0].8b\n"
- "uadalp v25.4s, v13.8h\n"
- "umull v13.8h, v3.8b, %[b1].8b\n"
- "uadalp v26.4s, v14.8h\n"
- "umull v14.8h, v3.8b, %[b2].8b\n"
- "uadalp v27.4s, v15.8h\n"
- "umull v15.8h, v3.8b, %[b3].8b\n"
+ "uadalp v24.4s, v12.8h\n"
+ "umull v12.8h, v3.8b, %[b0].8b\n"
+ "uadalp v25.4s, v13.8h\n"
+ "umull v13.8h, v3.8b, %[b1].8b\n"
+ "uadalp v26.4s, v14.8h\n"
+ "umull v14.8h, v3.8b, %[b2].8b\n"
+ "uadalp v27.4s, v15.8h\n"
+ "umull v15.8h, v3.8b, %[b3].8b\n"
- "uadalp v28.4s, v12.8h\n"
- "umull2 v12.8h, v3.16b, %[b0].16b\n"
- "uadalp v29.4s, v13.8h\n"
- "umull2 v13.8h, v3.16b, %[b1].16b\n"
- "uadalp v30.4s, v14.8h\n"
- "umull2 v14.8h, v3.16b, %[b2].16b\n"
- "uadalp v31.4s, v15.8h\n"
- "umull2 v15.8h, v3.16b, %[b3].16b\n"
+ "uadalp v28.4s, v12.8h\n"
+ "umull2 v12.8h, v3.16b, %[b0].16b\n"
+ "uadalp v29.4s, v13.8h\n"
+ "umull2 v13.8h, v3.16b, %[b1].16b\n"
+ "uadalp v30.4s, v14.8h\n"
+ "umull2 v14.8h, v3.16b, %[b2].16b\n"
+ "uadalp v31.4s, v15.8h\n"
+ "umull2 v15.8h, v3.16b, %[b3].16b\n"
- "uadalp v28.4s, v12.8h\n"
- "uadalp v29.4s, v13.8h\n"
- "uadalp v30.4s, v14.8h\n"
- "uadalp v31.4s, v15.8h\n"
+ "uadalp v28.4s, v12.8h\n"
+ "uadalp v29.4s, v13.8h\n"
+ "uadalp v30.4s, v14.8h\n"
+ "uadalp v31.4s, v15.8h\n"
- "addp v16.4s, v16.4s, v17.4s\n"
- "addp v17.4s, v18.4s, v19.4s\n"
- "addp v18.4s, v20.4s, v21.4s\n"
- "addp v19.4s, v22.4s, v23.4s\n"
- "addp v20.4s, v24.4s, v25.4s\n"
- "addp v21.4s, v26.4s, v27.4s\n"
- "addp v22.4s, v28.4s, v29.4s\n"
- "addp v23.4s, v30.4s, v31.4s\n"
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "addp v17.4s, v18.4s, v19.4s\n"
+ "addp v18.4s, v20.4s, v21.4s\n"
+ "addp v19.4s, v22.4s, v23.4s\n"
+ "addp v20.4s, v24.4s, v25.4s\n"
+ "addp v21.4s, v26.4s, v27.4s\n"
+ "addp v22.4s, v28.4s, v29.4s\n"
+ "addp v23.4s, v30.4s, v31.4s\n"
- "addp v16.4s, v16.4s, v17.4s\n"
- "addp v17.4s, v18.4s, v19.4s\n"
- "addp v18.4s, v20.4s, v21.4s\n"
- "addp v19.4s, v22.4s, v23.4s\n"
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "addp v17.4s, v18.4s, v19.4s\n"
+ "addp v18.4s, v20.4s, v21.4s\n"
+ "addp v19.4s, v22.4s, v23.4s\n"
- "str q16, [%[c_ptr]]\n"
- "str q17, [%[c_ptr], #16]\n"
- "str q18, [%[c_ptr], #32]\n"
- "str q19, [%[c_ptr], #48]\n"
- "add %[c_ptr], %[c_ptr], #64\n"
+ "str q16, [%[c_ptr]]\n"
+ "str q17, [%[c_ptr], #16]\n"
+ "str q18, [%[c_ptr], #32]\n"
+ "str q19, [%[c_ptr], #48]\n"
+ "add %[c_ptr], %[c_ptr], #64\n"
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [b3] "+w"(b3),
- [k] "+r"(k)
- :
- : "x20", "x21", "v0", "v1", "v2", "v3", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc");
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [b3] "+w" (b3),
+ [k] "+r" (k)
+ :
+ : "x20", "x21", "v0","v1","v2","v3","v12","v13","v14","v15","v16","v17","v18","v19",
+ "v20","v21","v22","v23","v24","v25","v26","v27","v28","v29","v30","v31", "cc");
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp
index 5fc0a7b..911a4eb 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp
@@ -27,8 +27,10 @@
#include "arm_gemm.hpp"
-namespace arm_gemm
-{
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm {
+
// Actual kernel implementations
void a64_hgemm_asimd_24x8(const __fp16 *, const __fp16 *, __fp16 *, int, int, int);
void a64_hgemm_asimd_24x8_a55r1(const __fp16 *, const __fp16 *, __fp16 *, int, int, int);
@@ -37,33 +39,34 @@
//
// The generic "gemm_opt" function will instantiate one of these (allowing
// the constructor to pick a kernel implementation).
-class hgemm_24x8
-{
+class hgemm_24x8 {
public:
typedef __fp16 operand_type;
typedef __fp16 result_type;
typedef void (*kern_type)(const __fp16 *, const __fp16 *, __fp16 *, int, int, int);
- static const int A_block = 1;
- static const int A_interleave = 8;
- static const bool A_transpose = false;
+ /* Kernel blocking parameters */
+ static int out_width() {
+ return 24;
+ }
- static const int B_block = 1;
- static const int B_interleave = 24;
- static const bool B_transpose = true;
+ static int out_height() {
+ return 8;
+ }
- static const int out_width = 24;
- static const int out_height = 8;
- static const int k_unroll = 1;
+ static int k_unroll() {
+ return 1;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 24> transforms = {};
// Default to the generic kernel
kern_type kernel = a64_hgemm_asimd_24x8;
- hgemm_24x8(const CPUInfo *ci)
- {
- if(ci->get_cpu_model() == CPUModel::A55r1)
- {
+ hgemm_24x8(const CPUInfo *ci) {
+ if (ci->get_cpu_model() == CPUModel::A55r1) {
kernel = a64_hgemm_asimd_24x8_a55r1;
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp
index 2186117..a3839ce 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp
@@ -39,25 +39,22 @@
// Note that the intent of this is that either ablocks or bblocks will be 1
// - this construction allows the output loop to proceed in either order.
-namespace arm_gemm
-{
-void a64_hgemm_asimd_24x8_a55r1(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a64_hgemm_asimd_24x8_a55r1(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K) {
const __fp16 *a_ptr = Apanel;
- __fp16 *c_ptr = Cpanel;
+ __fp16 *c_ptr = Cpanel;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
- int oddk = (K & 1);
- int k_iters = ((K + 1) / 2) - 1;
+ int oddk = (K & 1);
+ int k_iters = ((K+1)/2) - 1;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const __fp16 *a_ptr0 = a_ptr;
- const __fp16 *b_ptr = Bpanel;
+ const __fp16 *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
int k = k_iters;
a_ptr = a_ptr0;
@@ -65,294 +62,333 @@
// "A" operands to save on "ins" instructions. Since A55 is
// in-order, two sets of "A" operands and one set of "B" is
// sufficient.
- register float16x8_t a0 asm("v0");
- register float16x8_t a1 asm("v1");
+ register float16x8_t a0 asm("v0");
+ register float16x8_t a1 asm("v1");
register float16x8_t a0a asm("v2");
register float16x8_t a1a asm("v3");
- register float16x8_t b0 asm("v4");
- register float16x8_t b1 asm("v5");
- register float16x8_t b2 asm("v6");
+ register float16x8_t b0 asm("v4");
+ register float16x8_t b1 asm("v5");
+ register float16x8_t b2 asm("v6");
- __asm __volatile(
- // Enable FP16 extensions
- ".arch armv8.2-a+fp16\n"
+ __asm __volatile (
+ // Enable FP16 instruction support (but only if it's not already on).
+#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ ".arch armv8.2-a+fp16\n"
+#endif
// Initialize result registers, load initial operands, prime prefetches.
- "movi v8.8h, #0x0\n"
- "ldr %d[a0], [%[a_ptr]]\n"
- "movi v9.8h, #0x0\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "movi v10.8h, #0x0\n"
- "ldr %d[a1], [%[a_ptr], #8]\n"
- "movi v11.8h, #0x0\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v12.8h, #0x0\n"
- "movi v13.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]")
- "movi v14.8h, #0x0\n"
- "movi v15.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]")
- "movi v16.8h, #0x0\n"
- "movi v17.8h, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]")
- "movi v18.8h, #0x0\n"
- "movi v19.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]")
- "movi v20.8h, #0x0\n"
- "movi v21.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]")
- "movi v22.8h, #0x0\n"
- "movi v23.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]")
- "movi v24.8h, #0x0\n"
- "movi v25.8h, #0x0\n"
- "movi v26.8h, #0x0\n"
- "movi v27.8h, #0x0\n"
- "movi v28.8h, #0x0\n"
- "movi v29.8h, #0x0\n"
- "movi v30.8h, #0x0\n"
- "movi v31.8h, #0x0\n"
+ "movi v8.8h, #0x0\n"
+ "ldr %d[a0], [%[a_ptr]]\n"
+ "movi v9.8h, #0x0\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "movi v10.8h, #0x0\n"
+ "ldr %d[a1], [%[a_ptr], #8]\n"
+ "movi v11.8h, #0x0\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "movi v12.8h, #0x0\n"
+ "movi v13.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v14.8h, #0x0\n"
+ "movi v15.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v16.8h, #0x0\n"
+ "movi v17.8h, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v18.8h, #0x0\n"
+ "movi v19.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v20.8h, #0x0\n"
+ "movi v21.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v22.8h, #0x0\n"
+ "movi v23.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v24.8h, #0x0\n"
+ "movi v25.8h, #0x0\n"
+ "movi v26.8h, #0x0\n"
+ "movi v27.8h, #0x0\n"
+ "movi v28.8h, #0x0\n"
+ "movi v29.8h, #0x0\n"
+ "movi v30.8h, #0x0\n"
+ "movi v31.8h, #0x0\n"
// The loop is offset by these two instructions which must
// always be executed.
- "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
- "ldr %d[b2], [%[b_ptr], #32]\n"
+ "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
+ "ldr %d[b2], [%[b_ptr], #32]\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 4f\n"
+ "cbz %w[k], 4f\n"
"1:\n"
- "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
- "ldr %d[a0a], [%[a_ptr], #16]\n"
+ "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
+ "ldr %d[a0a], [%[a_ptr], #16]\n"
- "fmla v12.8h, %[b0].8h, %[a1].h[0]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v13.8h, %[b0].8h, %[a1].h[1]\n"
- "fmla v14.8h, %[b0].8h, %[a1].h[2]\n"
- "fmla v15.8h, %[b0].8h, %[a1].h[3]\n"
- "ldr %d[a1a], [%[a_ptr], #24]\n"
+ "fmla v12.8h, %[b0].8h, %[a1].h[0]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v13.8h, %[b0].8h, %[a1].h[1]\n"
+ "fmla v14.8h, %[b0].8h, %[a1].h[2]\n"
+ "fmla v15.8h, %[b0].8h, %[a1].h[3]\n"
+ "ldr %d[a1a], [%[a_ptr], #24]\n"
- "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
- "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
- "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
- "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
+ "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
+ "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
+ "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
+ "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
- "fmla v20.8h, %[b1].8h, %[a1].h[0]\n"
- "fmla v21.8h, %[b1].8h, %[a1].h[1]\n"
- "ldr x20, [%[b_ptr], #56]\n"
- "fmla v22.8h, %[b1].8h, %[a1].h[2]\n"
- "fmla v23.8h, %[b1].8h, %[a1].h[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "fmla v20.8h, %[b1].8h, %[a1].h[0]\n"
+ "fmla v21.8h, %[b1].8h, %[a1].h[1]\n"
+ "ldr x20, [%[b_ptr], #56]\n"
+ "fmla v22.8h, %[b1].8h, %[a1].h[2]\n"
+ "fmla v23.8h, %[b1].8h, %[a1].h[3]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
- "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
- "ins %[b0].d[1], x20\n"
- "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
- "ldr x20, [%[b_ptr], #72]\n"
- "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
- "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" ASM_PREFETCH("[%[a_ptr], #128]")
+ "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
+ "ins %[b0].d[1], x20\n"
+ "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
+ "ldr x20, [%[b_ptr], #72]\n"
+ "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
+ "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
- "fmla v28.8h, %[b2].8h, %[a1].h[0]\n"
- "fmla v29.8h, %[b2].8h, %[a1].h[1]\n" ASM_PREFETCH("[%[b_ptr], #384]")
- "fmla v30.8h, %[b2].8h, %[a1].h[2]\n"
- "fmla v31.8h, %[b2].8h, %[a1].h[3]\n"
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ "fmla v28.8h, %[b2].8h, %[a1].h[0]\n"
+ "fmla v29.8h, %[b2].8h, %[a1].h[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #384]")
+ "fmla v30.8h, %[b2].8h, %[a1].h[2]\n"
+ "fmla v31.8h, %[b2].8h, %[a1].h[3]\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
// Unroll 1
- "fmla v8.8h , %[b0].8h, %[a0a].h[0]\n"
- "ins %[b1].d[1], x20\n"
- "fmla v9.8h , %[b0].8h, %[a0a].h[1]\n"
- "ldr x20, [%[b_ptr], #88]\n"
- "fmla v10.8h, %[b0].8h, %[a0a].h[2]\n"
- "fmla v11.8h, %[b0].8h, %[a0a].h[3]\n"
- "ldr %d[a0], [%[a_ptr], #32]\n"
+ "fmla v8.8h , %[b0].8h, %[a0a].h[0]\n"
+ "ins %[b1].d[1], x20\n"
+ "fmla v9.8h , %[b0].8h, %[a0a].h[1]\n"
+ "ldr x20, [%[b_ptr], #88]\n"
+ "fmla v10.8h, %[b0].8h, %[a0a].h[2]\n"
+ "fmla v11.8h, %[b0].8h, %[a0a].h[3]\n"
+ "ldr %d[a0], [%[a_ptr], #32]\n"
- "fmla v12.8h, %[b0].8h, %[a1a].h[0]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v13.8h, %[b0].8h, %[a1a].h[1]\n"
- "fmla v14.8h, %[b0].8h, %[a1a].h[2]\n"
- "fmla v15.8h, %[b0].8h, %[a1a].h[3]\n"
- "ldr %d[a1], [%[a_ptr], #40]\n"
+ "fmla v12.8h, %[b0].8h, %[a1a].h[0]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v13.8h, %[b0].8h, %[a1a].h[1]\n"
+ "fmla v14.8h, %[b0].8h, %[a1a].h[2]\n"
+ "fmla v15.8h, %[b0].8h, %[a1a].h[3]\n"
+ "ldr %d[a1], [%[a_ptr], #40]\n"
- "fmla v16.8h, %[b1].8h, %[a0a].h[0]\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "fmla v17.8h, %[b1].8h, %[a0a].h[1]\n"
- "fmla v18.8h, %[b1].8h, %[a0a].h[2]\n"
- "fmla v19.8h, %[b1].8h, %[a0a].h[3]\n"
- "ldr %d[b0], [%[b_ptr], #96]\n"
+ "fmla v16.8h, %[b1].8h, %[a0a].h[0]\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "fmla v17.8h, %[b1].8h, %[a0a].h[1]\n"
+ "fmla v18.8h, %[b1].8h, %[a0a].h[2]\n"
+ "fmla v19.8h, %[b1].8h, %[a0a].h[3]\n"
+ "ldr %d[b0], [%[b_ptr], #96]\n"
- "fmla v20.8h, %[b1].8h, %[a1a].h[0]\n"
- "fmla v21.8h, %[b1].8h, %[a1a].h[1]\n"
- "ldr x20, [%[b_ptr], #104]\n"
- "fmla v22.8h, %[b1].8h, %[a1a].h[2]\n"
- "fmla v23.8h, %[b1].8h, %[a1a].h[3]\n"
- "ldr %d[b1], [%[b_ptr], #112]\n"
+ "fmla v20.8h, %[b1].8h, %[a1a].h[0]\n"
+ "fmla v21.8h, %[b1].8h, %[a1a].h[1]\n"
+ "ldr x20, [%[b_ptr], #104]\n"
+ "fmla v22.8h, %[b1].8h, %[a1a].h[2]\n"
+ "fmla v23.8h, %[b1].8h, %[a1a].h[3]\n"
+ "ldr %d[b1], [%[b_ptr], #112]\n"
- "fmla v24.8h, %[b2].8h, %[a0a].h[0]\n"
- "ins %[b0].d[1], x20\n"
- "fmla v25.8h, %[b2].8h, %[a0a].h[1]\n"
- "ldr x20, [%[b_ptr], #120]\n"
- "fmla v26.8h, %[b2].8h, %[a0a].h[2]\n"
- "fmla v27.8h, %[b2].8h, %[a0a].h[3]\n"
+ "fmla v24.8h, %[b2].8h, %[a0a].h[0]\n"
+ "ins %[b0].d[1], x20\n"
+ "fmla v25.8h, %[b2].8h, %[a0a].h[1]\n"
+ "ldr x20, [%[b_ptr], #120]\n"
+ "fmla v26.8h, %[b2].8h, %[a0a].h[2]\n"
+ "fmla v27.8h, %[b2].8h, %[a0a].h[3]\n"
- "fmla v28.8h, %[b2].8h, %[a1a].h[0]\n" ASM_PREFETCH("[%[b_ptr], #448]")
- "fmla v29.8h, %[b2].8h, %[a1a].h[1]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v30.8h, %[b2].8h, %[a1a].h[2]\n"
- "ins %[b1].d[1], x20\n"
- "fmla v31.8h, %[b2].8h, %[a1a].h[3]\n"
- "ldr %d[b2], [%[b_ptr], #32]\n"
+ "fmla v28.8h, %[b2].8h, %[a1a].h[0]\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
+ "fmla v29.8h, %[b2].8h, %[a1a].h[1]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v30.8h, %[b2].8h, %[a1a].h[2]\n"
+ "ins %[b1].d[1], x20\n"
+ "fmla v31.8h, %[b2].8h, %[a1a].h[3]\n"
+ "ldr %d[b2], [%[b_ptr], #32]\n"
- "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
- "bne 1b\n"
+ "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
+ "bne 1b\n"
"4:\n"
// Start final iteration - branch off to "odd" code before we load a0a
- "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
- "cbnz %w[oddk], 2f\n"
+ "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
+ "cbnz %w[oddk], 2f\n"
// Even K continuation
- "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
- "ldr %d[a0a], [%[a_ptr], #16]\n"
+ "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
+ "ldr %d[a0a], [%[a_ptr], #16]\n"
- "fmla v12.8h, %[b0].8h, %[a1].h[0]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v13.8h, %[b0].8h, %[a1].h[1]\n" ASM_PREFETCHW("[%[c_ptr]]")
- "fmla v14.8h, %[b0].8h, %[a1].h[2]\n"
- "fmla v15.8h, %[b0].8h, %[a1].h[3]\n"
- "ldr %d[a1a], [%[a_ptr], #24]\n"
+ "fmla v12.8h, %[b0].8h, %[a1].h[0]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v13.8h, %[b0].8h, %[a1].h[1]\n"
+ ASM_PREFETCHW("[%[c_ptr]]")
+ "fmla v14.8h, %[b0].8h, %[a1].h[2]\n"
+ "fmla v15.8h, %[b0].8h, %[a1].h[3]\n"
+ "ldr %d[a1a], [%[a_ptr], #24]\n"
- "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
- "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]")
- "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
- "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
+ "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
+ "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #64]")
+ "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
+ "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
- "fmla v20.8h, %[b1].8h, %[a1].h[0]\n"
- "fmla v21.8h, %[b1].8h, %[a1].h[1]\n"
- "ldr x20, [%[b_ptr], #56]\n"
- "fmla v22.8h, %[b1].8h, %[a1].h[2]\n"
- "fmla v23.8h, %[b1].8h, %[a1].h[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "fmla v20.8h, %[b1].8h, %[a1].h[0]\n"
+ "fmla v21.8h, %[b1].8h, %[a1].h[1]\n"
+ "ldr x20, [%[b_ptr], #56]\n"
+ "fmla v22.8h, %[b1].8h, %[a1].h[2]\n"
+ "fmla v23.8h, %[b1].8h, %[a1].h[3]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
- "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
- "ins %[b0].d[1], x20\n"
- "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
- "ldr x20, [%[b_ptr], #72]\n"
- "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
- "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]")
+ "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
+ "ins %[b0].d[1], x20\n"
+ "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
+ "ldr x20, [%[b_ptr], #72]\n"
+ "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
+ "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #128]")
- "fmla v28.8h, %[b2].8h, %[a1].h[0]\n"
- "fmla v29.8h, %[b2].8h, %[a1].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]")
- "fmla v30.8h, %[b2].8h, %[a1].h[2]\n"
- "fmla v31.8h, %[b2].8h, %[a1].h[3]\n"
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ "fmla v28.8h, %[b2].8h, %[a1].h[0]\n"
+ "fmla v29.8h, %[b2].8h, %[a1].h[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #192]")
+ "fmla v30.8h, %[b2].8h, %[a1].h[2]\n"
+ "fmla v31.8h, %[b2].8h, %[a1].h[3]\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
- "fmla v8.8h , %[b0].8h, %[a0a].h[0]\n"
- "ins %[b1].d[1], x20\n"
- "fmla v9.8h , %[b0].8h, %[a0a].h[1]\n"
- "ldr x20, [%[b_ptr], #88]\n"
- "fmla v10.8h, %[b0].8h, %[a0a].h[2]\n"
- "fmla v11.8h, %[b0].8h, %[a0a].h[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]")
+ "fmla v8.8h , %[b0].8h, %[a0a].h[0]\n"
+ "ins %[b1].d[1], x20\n"
+ "fmla v9.8h , %[b0].8h, %[a0a].h[1]\n"
+ "ldr x20, [%[b_ptr], #88]\n"
+ "fmla v10.8h, %[b0].8h, %[a0a].h[2]\n"
+ "fmla v11.8h, %[b0].8h, %[a0a].h[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #256]")
- "fmla v12.8h, %[b0].8h, %[a1a].h[0]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v13.8h, %[b0].8h, %[a1a].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]")
- "fmla v14.8h, %[b0].8h, %[a1a].h[2]\n"
- "fmla v15.8h, %[b0].8h, %[a1a].h[3]\n"
- "ldr %d[a1], [%[a_ptr], #40]\n"
+ "fmla v12.8h, %[b0].8h, %[a1a].h[0]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v13.8h, %[b0].8h, %[a1a].h[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #320]")
+ "fmla v14.8h, %[b0].8h, %[a1a].h[2]\n"
+ "fmla v15.8h, %[b0].8h, %[a1a].h[3]\n"
+ "ldr %d[a1], [%[a_ptr], #40]\n"
- "fmla v16.8h, %[b1].8h, %[a0a].h[0]\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "fmla v17.8h, %[b1].8h, %[a0a].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]")
- "fmla v18.8h, %[b1].8h, %[a0a].h[2]\n"
- "fmla v19.8h, %[b1].8h, %[a0a].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]")
+ "fmla v16.8h, %[b1].8h, %[a0a].h[0]\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "fmla v17.8h, %[b1].8h, %[a0a].h[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #384]")
+ "fmla v18.8h, %[b1].8h, %[a0a].h[2]\n"
+ "fmla v19.8h, %[b1].8h, %[a0a].h[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #448]")
- "fmla v20.8h, %[b1].8h, %[a1a].h[0]\n"
- "fmla v21.8h, %[b1].8h, %[a1a].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]")
- "fmla v22.8h, %[b1].8h, %[a1a].h[2]\n"
- "fmla v23.8h, %[b1].8h, %[a1a].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]")
+ "fmla v20.8h, %[b1].8h, %[a1a].h[0]\n"
+ "fmla v21.8h, %[b1].8h, %[a1a].h[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #512]")
+ "fmla v22.8h, %[b1].8h, %[a1a].h[2]\n"
+ "fmla v23.8h, %[b1].8h, %[a1a].h[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #576]")
- "fmla v24.8h, %[b2].8h, %[a0a].h[0]\n"
- "fmla v25.8h, %[b2].8h, %[a0a].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #640]")
- "fmla v26.8h, %[b2].8h, %[a0a].h[2]\n"
- "fmla v27.8h, %[b2].8h, %[a0a].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]")
+ "fmla v24.8h, %[b2].8h, %[a0a].h[0]\n"
+ "fmla v25.8h, %[b2].8h, %[a0a].h[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #640]")
+ "fmla v26.8h, %[b2].8h, %[a0a].h[2]\n"
+ "fmla v27.8h, %[b2].8h, %[a0a].h[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #704]")
- "fmla v28.8h, %[b2].8h, %[a1a].h[0]\n"
- "fmla v29.8h, %[b2].8h, %[a1a].h[1]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v30.8h, %[b2].8h, %[a1a].h[2]\n"
- "fmla v31.8h, %[b2].8h, %[a1a].h[3]\n"
- "b 3f\n"
+ "fmla v28.8h, %[b2].8h, %[a1a].h[0]\n"
+ "fmla v29.8h, %[b2].8h, %[a1a].h[1]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v30.8h, %[b2].8h, %[a1a].h[2]\n"
+ "fmla v31.8h, %[b2].8h, %[a1a].h[3]\n"
+ "b 3f\n"
"2:\n"
// Odd tail
- "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" ASM_PREFETCHW("[%[c_ptr]]")
+ "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
+ ASM_PREFETCHW("[%[c_ptr]]")
- "fmla v12.8h, %[b0].8h, %[a1].h[0]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v13.8h, %[b0].8h, %[a1].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]")
- "fmla v14.8h, %[b0].8h, %[a1].h[2]\n"
- "add %[a_ptr], %[a_ptr], #16\n"
- "fmla v15.8h, %[b0].8h, %[a1].h[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]")
+ "fmla v12.8h, %[b0].8h, %[a1].h[0]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v13.8h, %[b0].8h, %[a1].h[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #64]")
+ "fmla v14.8h, %[b0].8h, %[a1].h[2]\n"
+ "add %[a_ptr], %[a_ptr], #16\n"
+ "fmla v15.8h, %[b0].8h, %[a1].h[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #128]")
- "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
- "add %[b_ptr], %[b_ptr], #48\n"
- "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]")
- "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
- "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]")
+ "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
+ "add %[b_ptr], %[b_ptr], #48\n"
+ "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #192]")
+ "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
+ "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #256]")
- "fmla v20.8h, %[b1].8h, %[a1].h[0]\n"
- "fmla v21.8h, %[b1].8h, %[a1].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]")
- "fmla v22.8h, %[b1].8h, %[a1].h[2]\n"
- "fmla v23.8h, %[b1].8h, %[a1].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]")
+ "fmla v20.8h, %[b1].8h, %[a1].h[0]\n"
+ "fmla v21.8h, %[b1].8h, %[a1].h[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #320]")
+ "fmla v22.8h, %[b1].8h, %[a1].h[2]\n"
+ "fmla v23.8h, %[b1].8h, %[a1].h[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #384]")
- "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
- "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]")
- "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
- "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]")
+ "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
+ "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #384]")
+ "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
+ "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #448]")
- "fmla v28.8h, %[b2].8h, %[a1].h[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") "fmla v29.8h, %[b2].8h, %[a1].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") "fmla v30.8h, %[b2].8h, %[a1].h[2]\n"
- ASM_PREFETCHWL2("[%[c_ptr], #640]") "fmla v31.8h, %[b2].8h, %[a1].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]")
+ "fmla v28.8h, %[b2].8h, %[a1].h[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #512]")
+ "fmla v29.8h, %[b2].8h, %[a1].h[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #576]")
+ "fmla v30.8h, %[b2].8h, %[a1].h[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #640]")
+ "fmla v31.8h, %[b2].8h, %[a1].h[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #704]")
// Common tail
// A55 won't dual issue these stores with anything else, so
// simplest to do them all in this common code.
"3:\n"
- "str q8, [%[c_ptr]]\n"
- "str q16, [%[c_ptr], #16]\n"
- "str q24, [%[c_ptr], #32]\n"
- "str q9, [%[c_ptr], #48]\n"
- "str q17, [%[c_ptr], #64]\n"
- "str q25, [%[c_ptr], #80]\n"
- "str q10, [%[c_ptr], #96]\n"
- "str q18, [%[c_ptr], #112]\n"
- "str q26, [%[c_ptr], #128]\n"
- "str q11, [%[c_ptr], #144]\n"
- "str q19, [%[c_ptr], #160]\n"
- "str q27, [%[c_ptr], #176]\n"
- "str q12, [%[c_ptr], #192]\n"
- "str q20, [%[c_ptr], #208]\n"
- "str q28, [%[c_ptr], #224]\n"
- "str q13, [%[c_ptr], #240]\n"
- "str q21, [%[c_ptr], #256]\n"
- "str q29, [%[c_ptr], #272]\n"
- "str q14, [%[c_ptr], #288]\n"
- "str q22, [%[c_ptr], #304]\n"
- "str q30, [%[c_ptr], #320]\n"
- "str q15, [%[c_ptr], #336]\n"
- "str q23, [%[c_ptr], #352]\n"
- "str q31, [%[c_ptr], #368]\n"
+ "str q8, [%[c_ptr]]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "str q24, [%[c_ptr], #32]\n"
+ "str q9, [%[c_ptr], #48]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "str q10, [%[c_ptr], #96]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "str q11, [%[c_ptr], #144]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "str q12, [%[c_ptr], #192]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "str q13, [%[c_ptr], #240]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "str q14, [%[c_ptr], #288]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "str q15, [%[c_ptr], #336]\n"
+ "str q23, [%[c_ptr], #352]\n"
+ "str q31, [%[c_ptr], #368]\n"
"5:\n"
- "add %[c_ptr], %[c_ptr], #384\n"
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "=w"(a0), [a0a] "=w"(a0a), [a1] "=w"(a1), [a1a] "=w"(a1a),
- [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [k] "+r"(k)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
+ "add %[c_ptr], %[c_ptr], #384\n"
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "=w" (a0), [a0a] "=w" (a0a), [a1] "=w" (a1), [a1a] "=w" (a1a),
+ [b0] "=w" (b0), [b1] "=w" (b1), [b2] "=w" (b2), [k] "+r" (k)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"
+ );
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp
index 65a5d43..418a375 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp
@@ -39,297 +39,311 @@
// Note that the intent of this is that either ablocks or bblocks will be 1
// - this construction allows the output loop to proceed in either order.
-namespace arm_gemm
-{
-void a64_hgemm_asimd_24x8(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a64_hgemm_asimd_24x8(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K) {
const __fp16 *a_ptr = Apanel;
- __fp16 *c_ptr = Cpanel;
+ __fp16 *c_ptr = Cpanel;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const __fp16 *a_ptr0 = a_ptr;
- const __fp16 *b_ptr = Bpanel;
+ const __fp16 *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
a_ptr = a_ptr0;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
- int k = ((K + 1) / 2) - 1;
+ int k = ((K+1)/2) - 1;
- register float16x8_t a0 asm("v0");
+ register float16x8_t a0 asm("v0");
register float16x8_t a0a asm("v1");
- register float16x8_t b0 asm("v2");
- register float16x8_t b1 asm("v3");
- register float16x8_t b2 asm("v4");
+ register float16x8_t b0 asm("v2");
+ register float16x8_t b1 asm("v3");
+ register float16x8_t b2 asm("v4");
register float16x8_t b0a asm("v5");
register float16x8_t b1a asm("v6");
register float16x8_t b2a asm("v7");
- __asm __volatile(
- ".arch armv8.2-a+fp16\n"
+ __asm __volatile (
+ // Enable FP16 instruction support (but only if it's not already on).
+#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ ".arch armv8.2-a+fp16\n"
+#endif
// Initialize result registers, load initial operands, prime prefetches.
- "movi v8.8h, #0x0\n"
- "ldr %q[a0], [%[a_ptr]]\n"
- "movi v9.8h, #0x0\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "movi v10.8h, #0x0\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v11.8h, #0x0\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "movi v12.8h, #0x0\n"
- "ldr %q[b0a], [%[b_ptr], #48]\n"
- "movi v13.8h, #0x0\n"
- "ldr %q[b1a], [%[b_ptr], #64]\n"
- "movi v14.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v15.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v16.8h, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v17.8h, #0x0\n"
- ASM_PREFETCH("[%[b_ptr], #192]") "movi v18.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v19.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]")
- "movi v20.8h, #0x0\n"
- "movi v21.8h, #0x0\n"
- "movi v22.8h, #0x0\n"
- "movi v23.8h, #0x0\n"
- "movi v24.8h, #0x0\n"
- "movi v25.8h, #0x0\n"
- "movi v26.8h, #0x0\n"
- "movi v27.8h, #0x0\n"
- "movi v28.8h, #0x0\n"
- "movi v29.8h, #0x0\n"
- "movi v30.8h, #0x0\n"
- "movi v31.8h, #0x0\n"
+ "movi v8.8h, #0x0\n"
+ "ldr %q[a0], [%[a_ptr]]\n"
+ "movi v9.8h, #0x0\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "movi v10.8h, #0x0\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "movi v11.8h, #0x0\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "movi v12.8h, #0x0\n"
+ "ldr %q[b0a], [%[b_ptr], #48]\n"
+ "movi v13.8h, #0x0\n"
+ "ldr %q[b1a], [%[b_ptr], #64]\n"
+ "movi v14.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v15.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v16.8h, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v17.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v18.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v19.8h, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v20.8h, #0x0\n"
+ "movi v21.8h, #0x0\n"
+ "movi v22.8h, #0x0\n"
+ "movi v23.8h, #0x0\n"
+ "movi v24.8h, #0x0\n"
+ "movi v25.8h, #0x0\n"
+ "movi v26.8h, #0x0\n"
+ "movi v27.8h, #0x0\n"
+ "movi v28.8h, #0x0\n"
+ "movi v29.8h, #0x0\n"
+ "movi v30.8h, #0x0\n"
+ "movi v31.8h, #0x0\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 4f\n"
+ "cbz %w[k], 4f\n"
"1:\n"
- "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
- "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
- "ldr %q[a0a], [%[a_ptr], #16]\n"
- "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
- "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
- "ldr %q[b2a], [%[b_ptr], #80]\n"
- "fmla v12.8h, %[b0].8h, %[a0].h[4]\n"
- "fmla v13.8h, %[b0].8h, %[a0].h[5]\n"
- "fmla v14.8h, %[b0].8h, %[a0].h[6]\n"
- "fmla v15.8h, %[b0].8h, %[a0].h[7]\n"
- "ldr %q[b0], [%[b_ptr], #96]\n"
+ "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
+ "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
+ "ldr %q[a0a], [%[a_ptr], #16]\n"
+ "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
+ "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
+ "ldr %q[b2a], [%[b_ptr], #80]\n"
+ "fmla v12.8h, %[b0].8h, %[a0].h[4]\n"
+ "fmla v13.8h, %[b0].8h, %[a0].h[5]\n"
+ "fmla v14.8h, %[b0].8h, %[a0].h[6]\n"
+ "fmla v15.8h, %[b0].8h, %[a0].h[7]\n"
+ "ldr %q[b0], [%[b_ptr], #96]\n"
- "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
- "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" ASM_PREFETCH("[%[a_ptr], #128]")
- "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
- "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v20.8h, %[b1].8h, %[a0].h[4]\n"
- "fmla v21.8h, %[b1].8h, %[a0].h[5]\n"
- "fmla v22.8h, %[b1].8h, %[a0].h[6]\n"
- "fmla v23.8h, %[b1].8h, %[a0].h[7]\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
+ "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
+ "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
+ "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v20.8h, %[b1].8h, %[a0].h[4]\n"
+ "fmla v21.8h, %[b1].8h, %[a0].h[5]\n"
+ "fmla v22.8h, %[b1].8h, %[a0].h[6]\n"
+ "fmla v23.8h, %[b1].8h, %[a0].h[7]\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
- "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
- "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" ASM_PREFETCH("[%[b_ptr], #288]")
- "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
- "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
- "fmla v28.8h, %[b2].8h, %[a0].h[4]\n"
- "fmla v29.8h, %[b2].8h, %[a0].h[5]\n"
- "fmla v30.8h, %[b2].8h, %[a0].h[6]\n"
- "fmla v31.8h, %[b2].8h, %[a0].h[7]\n"
- "ldr %q[a0], [%[a_ptr], #32]\n"
+ "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
+ "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #288]")
+ "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
+ "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
+ "fmla v28.8h, %[b2].8h, %[a0].h[4]\n"
+ "fmla v29.8h, %[b2].8h, %[a0].h[5]\n"
+ "fmla v30.8h, %[b2].8h, %[a0].h[6]\n"
+ "fmla v31.8h, %[b2].8h, %[a0].h[7]\n"
+ "ldr %q[a0], [%[a_ptr], #32]\n"
- "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n"
- "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n"
- "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n"
- "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n"
- "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n"
- "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n"
- "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n"
- "ldr %q[b0a], [%[b_ptr], #48]\n"
+ "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n"
+ "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n"
+ "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n"
+ "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n"
+ "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n"
+ "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n"
+ "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n"
+ "ldr %q[b0a], [%[b_ptr], #48]\n"
- "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n"
- "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" ASM_PREFETCH("[%[b_ptr], #352]")
- "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n"
- "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n"
- "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n"
- "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n"
- "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n"
- "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n"
- "ldr %q[b1a], [%[b_ptr], #64]\n"
+ "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n"
+ "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #352]")
+ "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n"
+ "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n"
+ "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n"
+ "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n"
+ "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n"
+ "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n"
+ "ldr %q[b1a], [%[b_ptr], #64]\n"
- "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n"
- "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n"
- "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n"
- "fmla v28.8h, %[b2a].8h, %[a0a].h[4]\n"
- "fmla v29.8h, %[b2a].8h, %[a0a].h[5]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla v30.8h, %[b2a].8h, %[a0a].h[6]\n"
- "fmla v31.8h, %[b2a].8h, %[a0a].h[7]\n"
+ "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n"
+ "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n"
+ "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n"
+ "fmla v28.8h, %[b2a].8h, %[a0a].h[4]\n"
+ "fmla v29.8h, %[b2a].8h, %[a0a].h[5]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla v30.8h, %[b2a].8h, %[a0a].h[6]\n"
+ "fmla v31.8h, %[b2a].8h, %[a0a].h[7]\n"
- "bne 1b\n"
+ "bne 1b\n"
"4:\n"
// Jump to odd tail if necessary.
- "cbnz %w[oddk], 2f\n"
+ "cbnz %w[oddk], 2f\n"
// Even tail.
- "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
+ "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
"fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
- "ldr %q[a0a], [%[a_ptr], #16]\n"
- "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
- "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
- "ldr %q[b2a], [%[b_ptr], #80]\n"
- "fmla v12.8h, %[b0].8h, %[a0].h[4]\n"
+ "ldr %q[a0a], [%[a_ptr], #16]\n"
+ "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
+ "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
+ "ldr %q[b2a], [%[b_ptr], #80]\n"
+ "fmla v12.8h, %[b0].8h, %[a0].h[4]\n"
"fmla v13.8h, %[b0].8h, %[a0].h[5]\n"
- "fmla v14.8h, %[b0].8h, %[a0].h[6]\n"
- "fmla v15.8h, %[b0].8h, %[a0].h[7]\n"
+ "fmla v14.8h, %[b0].8h, %[a0].h[6]\n"
+ "fmla v15.8h, %[b0].8h, %[a0].h[7]\n"
- "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
- "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
- "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
- "fmla v20.8h, %[b1].8h, %[a0].h[4]\n"
- "fmla v21.8h, %[b1].8h, %[a0].h[5]\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "fmla v22.8h, %[b1].8h, %[a0].h[6]\n"
- "fmla v23.8h, %[b1].8h, %[a0].h[7]\n"
+ "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
+ "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
+ "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
+ "fmla v20.8h, %[b1].8h, %[a0].h[4]\n"
+ "fmla v21.8h, %[b1].8h, %[a0].h[5]\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "fmla v22.8h, %[b1].8h, %[a0].h[6]\n"
+ "fmla v23.8h, %[b1].8h, %[a0].h[7]\n"
- "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
- "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
- "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
- "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
- "fmla v28.8h, %[b2].8h, %[a0].h[4]\n"
- "fmla v29.8h, %[b2].8h, %[a0].h[5]\n"
- "fmla v30.8h, %[b2].8h, %[a0].h[6]\n"
- "fmla v31.8h, %[b2].8h, %[a0].h[7]\n"
+ "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
+ "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
+ "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
+ "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
+ "fmla v28.8h, %[b2].8h, %[a0].h[4]\n"
+ "fmla v29.8h, %[b2].8h, %[a0].h[5]\n"
+ "fmla v30.8h, %[b2].8h, %[a0].h[6]\n"
+ "fmla v31.8h, %[b2].8h, %[a0].h[7]\n"
- "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n"
- "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n"
- "str q8, [%[c_ptr]]\n"
- "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n"
- "str q16, [%[c_ptr], #16]\n"
+ "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n"
+ "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n"
+ "str q8, [%[c_ptr]]\n"
+ "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n"
+ "str q16, [%[c_ptr], #16]\n"
- "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n"
- "str q24, [%[c_ptr], #32]\n"
- "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n"
- "str q9, [%[c_ptr], #48]\n"
- "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n"
- "str q17, [%[c_ptr], #64]\n"
+ "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n"
+ "str q24, [%[c_ptr], #32]\n"
+ "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n"
+ "str q9, [%[c_ptr], #48]\n"
+ "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n"
+ "str q17, [%[c_ptr], #64]\n"
- "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n"
- "str q25, [%[c_ptr], #80]\n"
- "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n"
- "str q10, [%[c_ptr], #96]\n"
- "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n"
- "str q18, [%[c_ptr], #112]\n"
+ "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n"
+ "str q10, [%[c_ptr], #96]\n"
+ "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n"
+ "str q18, [%[c_ptr], #112]\n"
- "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n"
- "str q26, [%[c_ptr], #128]\n"
- "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n"
- "str q11, [%[c_ptr], #144]\n"
- "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n"
- "str q19, [%[c_ptr], #160]\n"
+ "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n"
+ "str q11, [%[c_ptr], #144]\n"
+ "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n"
+ "str q19, [%[c_ptr], #160]\n"
- "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n"
- "str q27, [%[c_ptr], #176]\n"
- "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n"
- "str q12, [%[c_ptr], #192]\n"
- "fmla v28.8h, %[b2a].8h, %[a0a].h[4]\n"
- "str q20, [%[c_ptr], #208]\n"
+ "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n"
+ "str q12, [%[c_ptr], #192]\n"
+ "fmla v28.8h, %[b2a].8h, %[a0a].h[4]\n"
+ "str q20, [%[c_ptr], #208]\n"
- "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n"
- "str q28, [%[c_ptr], #224]\n"
- "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n"
- "str q13, [%[c_ptr], #240]\n"
- "fmla v29.8h, %[b2a].8h, %[a0a].h[5]\n"
- "str q21, [%[c_ptr], #256]\n"
+ "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n"
+ "str q13, [%[c_ptr], #240]\n"
+ "fmla v29.8h, %[b2a].8h, %[a0a].h[5]\n"
+ "str q21, [%[c_ptr], #256]\n"
- "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n"
- "str q29, [%[c_ptr], #272]\n"
- "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n"
- "str q14, [%[c_ptr], #288]\n"
- "fmla v30.8h, %[b2a].8h, %[a0a].h[6]\n"
- "str q22, [%[c_ptr], #304]\n"
+ "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n"
+ "str q14, [%[c_ptr], #288]\n"
+ "fmla v30.8h, %[b2a].8h, %[a0a].h[6]\n"
+ "str q22, [%[c_ptr], #304]\n"
- "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n"
- "str q30, [%[c_ptr], #320]\n"
- "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n"
- "str q15, [%[c_ptr], #336]\n"
- "fmla v31.8h, %[b2a].8h, %[a0a].h[7]\n"
- "b 3f\n"
+ "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n"
+ "str q15, [%[c_ptr], #336]\n"
+ "fmla v31.8h, %[b2a].8h, %[a0a].h[7]\n"
+ "b 3f\n"
// Odd tail
"2:\n"
- "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
- "add %[b_ptr], %[b_ptr], #48\n"
- "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
- "add %[a_ptr], %[a_ptr], #16\n"
- "str q8, [%[c_ptr]]\n"
- "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
- "str q16, [%[c_ptr], #16]\n"
+ "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
+ "add %[b_ptr], %[b_ptr], #48\n"
+ "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
+ "add %[a_ptr], %[a_ptr], #16\n"
+ "str q8, [%[c_ptr]]\n"
+ "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
+ "str q16, [%[c_ptr], #16]\n"
- "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
- "str q24, [%[c_ptr], #32]\n"
- "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
- "str q9, [%[c_ptr], #48]\n"
- "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
- "str q17, [%[c_ptr], #64]\n"
+ "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
+ "str q24, [%[c_ptr], #32]\n"
+ "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
+ "str q9, [%[c_ptr], #48]\n"
+ "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
+ "str q17, [%[c_ptr], #64]\n"
- "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
- "str q25, [%[c_ptr], #80]\n"
- "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
- "str q10, [%[c_ptr], #96]\n"
- "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
- "str q18, [%[c_ptr], #112]\n"
+ "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
+ "str q10, [%[c_ptr], #96]\n"
+ "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
+ "str q18, [%[c_ptr], #112]\n"
- "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
- "str q26, [%[c_ptr], #128]\n"
- "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
- "str q11, [%[c_ptr], #144]\n"
- "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
- "str q19, [%[c_ptr], #160]\n"
+ "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
+ "str q11, [%[c_ptr], #144]\n"
+ "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
+ "str q19, [%[c_ptr], #160]\n"
- "fmla v12.8h, %[b0].8h, %[a0].h[4]\n"
- "str q27, [%[c_ptr], #176]\n"
- "fmla v20.8h, %[b1].8h, %[a0].h[4]\n"
- "str q12, [%[c_ptr], #192]\n"
- "fmla v28.8h, %[b2].8h, %[a0].h[4]\n"
- "str q20, [%[c_ptr], #208]\n"
+ "fmla v12.8h, %[b0].8h, %[a0].h[4]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "fmla v20.8h, %[b1].8h, %[a0].h[4]\n"
+ "str q12, [%[c_ptr], #192]\n"
+ "fmla v28.8h, %[b2].8h, %[a0].h[4]\n"
+ "str q20, [%[c_ptr], #208]\n"
- "fmla v13.8h, %[b0].8h, %[a0].h[5]\n"
- "str q28, [%[c_ptr], #224]\n"
- "fmla v21.8h, %[b1].8h, %[a0].h[5]\n"
- "str q13, [%[c_ptr], #240]\n"
- "fmla v29.8h, %[b2].8h, %[a0].h[5]\n"
- "str q21, [%[c_ptr], #256]\n"
+ "fmla v13.8h, %[b0].8h, %[a0].h[5]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "fmla v21.8h, %[b1].8h, %[a0].h[5]\n"
+ "str q13, [%[c_ptr], #240]\n"
+ "fmla v29.8h, %[b2].8h, %[a0].h[5]\n"
+ "str q21, [%[c_ptr], #256]\n"
- "fmla v14.8h, %[b0].8h, %[a0].h[6]\n"
- "str q29, [%[c_ptr], #272]\n"
- "fmla v22.8h, %[b1].8h, %[a0].h[6]\n"
- "str q14, [%[c_ptr], #288]\n"
- "fmla v30.8h, %[b2].8h, %[a0].h[6]\n"
- "str q22, [%[c_ptr], #304]\n"
+ "fmla v14.8h, %[b0].8h, %[a0].h[6]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "fmla v22.8h, %[b1].8h, %[a0].h[6]\n"
+ "str q14, [%[c_ptr], #288]\n"
+ "fmla v30.8h, %[b2].8h, %[a0].h[6]\n"
+ "str q22, [%[c_ptr], #304]\n"
- "fmla v15.8h, %[b0].8h, %[a0].h[7]\n"
- "str q30, [%[c_ptr], #320]\n"
- "fmla v23.8h, %[b1].8h, %[a0].h[7]\n"
- "str q15, [%[c_ptr], #336]\n"
- "fmla v31.8h, %[b2].8h, %[a0].h[7]\n"
+ "fmla v15.8h, %[b0].8h, %[a0].h[7]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "fmla v23.8h, %[b1].8h, %[a0].h[7]\n"
+ "str q15, [%[c_ptr], #336]\n"
+ "fmla v31.8h, %[b2].8h, %[a0].h[7]\n"
"3:\n"
- "str q23, [%[c_ptr], #352]\n"
- "str q31, [%[c_ptr], #368]\n"
- "add %[c_ptr], %[c_ptr], #384\n"
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "+w"(a0), [a0a] "+w"(a0a),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k),
- [b0a] "+w"(b0a), [b1a] "+w"(b1a), [b2a] "+w"(b2a)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc");
+ "str q23, [%[c_ptr], #352]\n"
+ "str q31, [%[c_ptr], #368]\n"
+ "add %[c_ptr], %[c_ptr], #384\n"
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "+w" (a0), [a0a] "+w" (a0a),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k),
+ [b0a] "+w" (b0a), [b1a] "+w" (b1a), [b2a] "+w" (b2a)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"
+ );
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp
index 91a9e8d..10d1069 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp
@@ -25,8 +25,10 @@
#ifdef __aarch64__
-namespace arm_gemm
-{
+#include "../std_transforms_fixed.hpp"
+
+namespace arm_gemm {
+
// Actual kernel implementations
void a64_sgemm_asimd_12x8(const float *, const float *, float *, int, int, int);
void a64_sgemm_asimd_12x8_a53(const float *, const float *, float *, int, int, int);
@@ -41,36 +43,34 @@
// All kernels in the family must share these characteristics. The actual
// kernel to be used can be chosen at runtime, based on the CPU_type
// structure.
-class sgemm_12x8
-{
+class sgemm_12x8 {
public:
typedef float operand_type;
typedef float result_type;
typedef void (*kern_type)(const float *, const float *, float *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 8;
- static const int A_block = 1;
- static const int A_transpose = 0;
-
- /* Same for B input */
- static const int B_interleave = 12;
- static const int B_block = 1;
- static const int B_transpose = 1;
-
/* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 1;
+ static int out_width() {
+ return 12;
+ }
- kern_type kernel = a64_sgemm_asimd_12x8;
+ static int out_height() {
+ return 8;
+ }
- sgemm_12x8(const CPUInfo *ci)
- {
+ static int k_unroll() {
+ return 1;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12> transforms = {};
+
+ kern_type kernel=a64_sgemm_asimd_12x8;
+
+ sgemm_12x8(const CPUInfo *ci) {
// Select specific kernel if available
- switch(ci->get_cpu_model())
- {
+ switch(ci->get_cpu_model()) {
case CPUModel::A53:
kernel = a64_sgemm_asimd_12x8_a53;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp
index 618ebc7..2400191 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp
@@ -27,333 +27,347 @@
#include "../../asmlib.hpp"
-namespace arm_gemm
-{
-void a64_sgemm_asimd_12x8_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a64_sgemm_asimd_12x8_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
const float *a_ptr = Apanel;
- float *c_ptr = Cpanel;
+ float *c_ptr = Cpanel;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const float *a_ptr0 = a_ptr;
- const float *b_ptr = Bpanel;
+ const float *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
a_ptr = a_ptr0;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
- int k = ((K + 1) / 2) - 1;
+ int k = ((K+1)/2) - 1;
- register float32x4_t a0 asm("v0");
- register float32x4_t a1 asm("v1");
- register float32x4_t b0 asm("v2");
- register float32x4_t b1 asm("v3");
- register float32x4_t b2 asm("v4");
+ register float32x4_t a0 asm("v0");
+ register float32x4_t a1 asm("v1");
+ register float32x4_t b0 asm("v2");
+ register float32x4_t b1 asm("v3");
+ register float32x4_t b2 asm("v4");
register float32x4_t a0a asm("v5");
register float32x4_t a1a asm("v6");
- __asm __volatile(
+ __asm __volatile (
// Initialize result registers, load initial operands, prime prefetches.
- "movi v8.4s, #0x0\n"
- "ldr %q[a0], [%[a_ptr]]\n"
- "movi v9.4s, #0x0\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "movi v10.4s, #0x0\n"
- "ldr %q[a1], [%[a_ptr], #16]\n"
- "movi v11.4s, #0x0\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n"
+ "movi v8.4s, #0x0\n"
+ "ldr %q[a0], [%[a_ptr]]\n"
+ "movi v9.4s, #0x0\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "movi v10.4s, #0x0\n"
+ "ldr %q[a1], [%[a_ptr], #16]\n"
+ "movi v11.4s, #0x0\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "movi v12.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v13.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v14.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v15.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v17.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v18.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v19.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v20.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
+ "movi v21.4s, #0x0\n"
ASM_PREFETCH("[%[b_ptr], #384]")
- "movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n"
- "movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n"
- "movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n"
- "movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- "movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n"
+ "movi v22.4s, #0x0\n"
+ "movi v23.4s, #0x0\n"
+ "movi v24.4s, #0x0\n"
+ "movi v25.4s, #0x0\n"
+ "movi v26.4s, #0x0\n"
+ "movi v27.4s, #0x0\n"
+ "movi v28.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v31.4s, #0x0\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 4f\n"
+ "cbz %w[k], 4f\n"
"1:\n"
// Unroll 0
- "ldr %d[b2], [%[b_ptr], #32]\n"
+ "ldr %d[b2], [%[b_ptr], #32]\n"
"nop\n"
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "ldr x20, [%[a_ptr], #40]\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "ldr x20, [%[a_ptr], #40]\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
- "ins %[a0a].d[1], x20\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "ldr x20, [%[a_ptr], #56]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "ins %[a0a].d[1], x20\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "ldr x20, [%[a_ptr], #56]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
- "ins %[a1a].d[1], x20\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "ldr x20, [%[b_ptr], #56]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
+ "ins %[a1a].d[1], x20\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "ldr x20, [%[b_ptr], #56]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
ASM_PREFETCH("[%[a_ptr], #320]")
- "ins %[b0].d[1], x20\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "ins %[b0].d[1], x20\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
ASM_PREFETCH("[%[b_ptr], #448]")
"nop\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
"nop\n"
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "ldr x20, [%[b_ptr], #72]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "ldr x20, [%[b_ptr], #72]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
ASM_PREFETCH("[%[b_ptr], #512]")
- "ins %[b1].d[1], x20\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "ins %[b1].d[1], x20\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
// Unroll 1
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
"nop\n"
- "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
- "ldr x20, [%[b_ptr], #88]\n"
- "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
- "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
+ "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
+ "ldr x20, [%[b_ptr], #88]\n"
+ "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
+ "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
- "ldr %d[a0], [%[a_ptr], #64]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
- "ldr x20, [%[a_ptr], #72]\n"
- "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
+ "ldr %d[a0], [%[a_ptr], #64]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
+ "ldr x20, [%[a_ptr], #72]\n"
+ "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
- "ldr %d[a1], [%[a_ptr], #80]\n"
- "ins %[a0].d[1], x20\n"
- "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
- "ldr x20, [%[a_ptr], #88]\n"
- "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
- "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
+ "ldr %d[a1], [%[a_ptr], #80]\n"
+ "ins %[a0].d[1], x20\n"
+ "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
+ "ldr x20, [%[a_ptr], #88]\n"
+ "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
- "ldr %d[b0], [%[b_ptr], #96]\n"
- "ins %[a1].d[1], x20\n"
- "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
- "ldr x20, [%[b_ptr], #104]\n"
- "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
+ "ldr %d[b0], [%[b_ptr], #96]\n"
+ "ins %[a1].d[1], x20\n"
+ "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
+ "ldr x20, [%[b_ptr], #104]\n"
+ "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
"nop\n"
- "ins %[b0].d[1], x20\n"
- "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
+ "ins %[b0].d[1], x20\n"
+ "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
"nop\n"
"nop\n"
- "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
+ "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
- "ldr %d[b1], [%[b_ptr], #112]\n"
+ "ldr %d[b1], [%[b_ptr], #112]\n"
"nop\n"
- "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
- "ldr x20, [%[b_ptr], #120]\n"
- "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
+ "ldr x20, [%[b_ptr], #120]\n"
+ "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
"nop\n"
- "ins %[b1].d[1], x20\n"
- "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
- "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
- "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
+ "ins %[b1].d[1], x20\n"
+ "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
+ "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
+ "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
- "bne 1b\n"
+ "bne 1b\n"
// Branch here if K=1 or 2. Do the right thing for odd/even at the end.
"4:\n"
- "cbnz %w[oddk], 2f\n"
+ "cbnz %w[oddk], 2f\n"
// Detached final iteration. (even K)
- "ldr %d[b2], [%[b_ptr], #32]\n"
+ "ldr %d[b2], [%[b_ptr], #32]\n"
"nop\n"
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "ldr x20, [%[a_ptr], #40]\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "ldr x20, [%[a_ptr], #40]\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
- "ins %[a0a].d[1], x20\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "ldr x20, [%[a_ptr], #56]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "ins %[a0a].d[1], x20\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "ldr x20, [%[a_ptr], #56]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
- "ins %[a1a].d[1], x20\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "ldr x20, [%[b_ptr], #56]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
+ "ins %[a1a].d[1], x20\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "ldr x20, [%[b_ptr], #56]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "ins %[b0].d[1], x20\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "ins %[b0].d[1], x20\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
"nop\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
"nop\n"
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "ldr x20, [%[b_ptr], #72]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "ldr x20, [%[b_ptr], #72]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "ins %[b1].d[1], x20\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "ins %[b1].d[1], x20\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
"nop\n"
- "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
- "ldr x20, [%[b_ptr], #88]\n"
- "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
- "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
+ "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
+ "ldr x20, [%[b_ptr], #88]\n"
+ "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
+ "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
- "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
- "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
- "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
- "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
- "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
- "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
- "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
- "b 3f\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
+ "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
+ "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
+ "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
+ "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
+ "b 3f\n"
// Detached final iteration. (odd K)
"2:\n"
- "ldr %d[b2], [%[b_ptr], #32]\n"
+ "ldr %d[b2], [%[b_ptr], #32]\n"
"nop\n"
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "add %[b_ptr], %[b_ptr], #48\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "add %[b_ptr], %[b_ptr], #48\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
// Common tail
"3:\n"
- "str q8, [%[c_ptr]]\n"
- "str q16, [%[c_ptr], #16]\n"
- "str q24, [%[c_ptr], #32]\n"
- "str q9, [%[c_ptr], #48]\n"
- "str q17, [%[c_ptr], #64]\n"
- "str q25, [%[c_ptr], #80]\n"
- "str q10, [%[c_ptr], #96]\n"
- "str q18, [%[c_ptr], #112]\n"
- "str q26, [%[c_ptr], #128]\n"
- "str q11, [%[c_ptr], #144]\n"
- "str q19, [%[c_ptr], #160]\n"
- "str q27, [%[c_ptr], #176]\n"
- "str q12, [%[c_ptr], #192]\n"
- "str q20, [%[c_ptr], #208]\n"
- "str q28, [%[c_ptr], #224]\n"
- "str q13, [%[c_ptr], #240]\n"
- "str q21, [%[c_ptr], #256]\n"
- "str q29, [%[c_ptr], #272]\n"
- "str q14, [%[c_ptr], #288]\n"
- "str q22, [%[c_ptr], #304]\n"
- "str q30, [%[c_ptr], #320]\n"
- "str q15, [%[c_ptr], #336]\n"
- "str q23, [%[c_ptr], #352]\n"
- "str q31, [%[c_ptr], #368]\n"
- "add %[c_ptr], %[c_ptr], #384\n"
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc");
+ "str q8, [%[c_ptr]]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "str q24, [%[c_ptr], #32]\n"
+ "str q9, [%[c_ptr], #48]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "str q10, [%[c_ptr], #96]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "str q11, [%[c_ptr], #144]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "str q12, [%[c_ptr], #192]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "str q13, [%[c_ptr], #240]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "str q14, [%[c_ptr], #288]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "str q15, [%[c_ptr], #336]\n"
+ "str q23, [%[c_ptr], #352]\n"
+ "str q31, [%[c_ptr], #368]\n"
+ "add %[c_ptr], %[c_ptr], #384\n"
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "+w" (a0), [a1] "+w" (a1), [a0a] "+w" (a0a), [a1a] "+w" (a1a),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"
+ );
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55.cpp
index 4ca25eb..d9aaee1 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55.cpp
@@ -27,326 +27,348 @@
#include "../../asmlib.hpp"
-namespace arm_gemm
-{
-void a64_sgemm_asimd_12x8_a55(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K)
-{
+namespace arm_gemm {
+
+void a64_sgemm_asimd_12x8_a55(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
const float *a_ptr = Apanel;
- float *c_ptr = Cpanel;
+ float *c_ptr = Cpanel;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const float *a_ptr0 = a_ptr;
- const float *b_ptr = Bpanel;
+ const float *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
a_ptr = a_ptr0;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
- int k = ((K + 1) / 2) - 1;
+ int k = ((K+1)/2) - 1;
- register float32x4_t a0 asm("v0");
- register float32x4_t a1 asm("v1");
- register float32x4_t b0 asm("v2");
- register float32x4_t b1 asm("v3");
- register float32x4_t b2 asm("v4");
+ register float32x4_t a0 asm("v0");
+ register float32x4_t a1 asm("v1");
+ register float32x4_t b0 asm("v2");
+ register float32x4_t b1 asm("v3");
+ register float32x4_t b2 asm("v4");
register float32x4_t a0a asm("v5");
register float32x4_t a1a asm("v6");
- __asm __volatile(
+ __asm __volatile (
// Initialize result registers, load initial operands, prime prefetches.
- "movi v8.4s, #0x0\n"
- "ldr %q[a0], [%[a_ptr]]\n"
- "movi v9.4s, #0x0\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "movi v10.4s, #0x0\n"
- "ldr %q[a1], [%[a_ptr], #16]\n"
- "movi v11.4s, #0x0\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n"
+ "movi v8.4s, #0x0\n"
+ "ldr %q[a0], [%[a_ptr]]\n"
+ "movi v9.4s, #0x0\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "movi v10.4s, #0x0\n"
+ "ldr %q[a1], [%[a_ptr], #16]\n"
+ "movi v11.4s, #0x0\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "movi v12.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v13.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v14.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v15.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v17.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v18.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v19.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v20.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
+ "movi v21.4s, #0x0\n"
ASM_PREFETCH("[%[b_ptr], #384]")
- "movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n"
- "movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n"
- "movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n"
- "movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- "movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n"
+ "movi v22.4s, #0x0\n"
+ "movi v23.4s, #0x0\n"
+ "movi v24.4s, #0x0\n"
+ "movi v25.4s, #0x0\n"
+ "movi v26.4s, #0x0\n"
+ "movi v27.4s, #0x0\n"
+ "movi v28.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v31.4s, #0x0\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 4f\n"
+ "cbz %w[k], 4f\n"
"1:\n"
// Unroll 0
- "ldr %d[b2], [%[b_ptr], #32]\n"
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "ldr %d[b2], [%[b_ptr], #32]\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
- "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "subs %w[k], %w[k], #1\n"
+ "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "subs %w[k], %w[k], #1\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "ldr x20, [%[a_ptr], #40]\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "ins %[b2].d[1], x20\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
- "ins %[a0a].d[1], x20\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ "ldr x20, [%[a_ptr], #40]\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "ldr x20, [%[a_ptr], #56]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "ins %[a0a].d[1], x20\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
- "ins %[a1a].d[1], x20\n" ASM_PREFETCH("[%[a_ptr], #320]")
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "ldr x20, [%[b_ptr], #56]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "ldr x20, [%[a_ptr], #56]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
- "ins %[b0].d[1], x20\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
+ "ins %[a1a].d[1], x20\n"
+ ASM_PREFETCH("[%[a_ptr], #320]")
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "ldr x20, [%[b_ptr], #56]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
- "ldr x20, [%[b_ptr], #72]\n"
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" ASM_PREFETCH("[%[b_ptr], #448]")
+ "ldr %d[b1], [%[b_ptr], #64]\n"
+ "ins %[b0].d[1], x20\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" ASM_PREFETCH("[%[b_ptr], #512]")
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ "ldr x20, [%[b_ptr], #72]\n"
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
+
+
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #512]")
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
// Unroll 1
- "ldr %d[b2], [%[b_ptr], #80]\n"
- "ins %[b1].d[1], x20\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
+ "ins %[b1].d[1], x20\n"
- "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
- "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
- "ldr x20, [%[b_ptr], #88]\n"
- "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
- "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
+ "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
+ "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
+ "ldr x20, [%[b_ptr], #88]\n"
+ "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
+ "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
- "ldr %d[a0], [%[a_ptr], #64]\n"
- "ins %[b2].d[1], x20\n"
+ "ldr %d[a0], [%[a_ptr], #64]\n"
+ "ins %[b2].d[1], x20\n"
- "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
- "ldr x20, [%[a_ptr], #72]\n"
- "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
+ "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
+ "ldr x20, [%[a_ptr], #72]\n"
+ "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
- "ldr %d[a1], [%[a_ptr], #80]\n"
- "ins %[a0].d[1], x20\n"
+ "ldr %d[a1], [%[a_ptr], #80]\n"
+ "ins %[a0].d[1], x20\n"
- "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
- "ldr x20, [%[a_ptr], #88]\n"
- "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
+ "ldr x20, [%[a_ptr], #88]\n"
+ "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
- "ldr %d[b0], [%[b_ptr], #96]\n"
- "ins %[a1].d[1], x20\n"
- "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
- "ldr x20, [%[b_ptr], #104]\n"
- "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
+ "ldr %d[b0], [%[b_ptr], #96]\n"
+ "ins %[a1].d[1], x20\n"
- "ldr %d[b1], [%[b_ptr], #112]\n"
- "ins %[b0].d[1], x20\n"
+ "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
+ "ldr x20, [%[b_ptr], #104]\n"
+ "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
- "ldr x20, [%[b_ptr], #120]\n"
- "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
+ "ldr %d[b1], [%[b_ptr], #112]\n"
+ "ins %[b0].d[1], x20\n"
- "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
- "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
+ "ldr x20, [%[b_ptr], #120]\n"
+ "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
- "ldr %d[b2], [%[b_ptr], #32]\n"
- "ins %[b1].d[1], x20\n"
+ "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
+ "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
- "bne 1b\n"
+
+ "ldr %d[b2], [%[b_ptr], #32]\n"
+ "ins %[b1].d[1], x20\n"
+
+
+ "bne 1b\n"
// Branch here if K=1 or 2. Do the right thing for odd/even at the end.
"4:\n"
- "cbnz %w[oddk], 2f\n"
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "cbnz %w[oddk], 2f\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
// Detached final iteration. (even K)
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
- "ins %[b2].d[1], x20\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "ins %[b2].d[1], x20\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "ldr x20, [%[a_ptr], #40]\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ "ldr x20, [%[a_ptr], #40]\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
- "ins %[a0a].d[1], x20\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "ins %[a0a].d[1], x20\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "ldr x20, [%[a_ptr], #56]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "ldr x20, [%[a_ptr], #56]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
- "ins %[a1a].d[1], x20\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
+ "ins %[a1a].d[1], x20\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "ldr x20, [%[b_ptr], #56]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "ldr x20, [%[b_ptr], #56]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
- "ins %[b0].d[1], x20\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
+ "ins %[b0].d[1], x20\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- "ldr x20, [%[b_ptr], #72]\n"
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ "ldr x20, [%[b_ptr], #72]\n"
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
- "ldr %d[b2], [%[b_ptr], #80]\n"
- "ins %[b1].d[1], x20\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
+ "ins %[b1].d[1], x20\n"
- "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
- "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
- "ldr x20, [%[b_ptr], #88]\n"
- "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
+ "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
+ "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
+ "ldr x20, [%[b_ptr], #88]\n"
+ "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
- "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
- "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
- "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
- "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
- "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
- "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
- "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
- "b 3f\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
+ "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
+ "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
+ "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
+ "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
+ "b 3f\n"
// Detached final iteration. (odd K)
"2:\n"
- "ldr %d[b2], [%[b_ptr], #32]\n"
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "ldr %d[b2], [%[b_ptr], #32]\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
- "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "add %[b_ptr], %[b_ptr], #48\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "add %[b_ptr], %[b_ptr], #48\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
// Common tail
"3:\n"
- "str q8, [%[c_ptr]]\n"
- "str q16, [%[c_ptr], #16]\n"
- "str q24, [%[c_ptr], #32]\n"
- "str q9, [%[c_ptr], #48]\n"
- "str q17, [%[c_ptr], #64]\n"
- "str q25, [%[c_ptr], #80]\n"
- "str q10, [%[c_ptr], #96]\n"
- "str q18, [%[c_ptr], #112]\n"
- "str q26, [%[c_ptr], #128]\n"
- "str q11, [%[c_ptr], #144]\n"
- "str q19, [%[c_ptr], #160]\n"
- "str q27, [%[c_ptr], #176]\n"
- "str q12, [%[c_ptr], #192]\n"
- "str q20, [%[c_ptr], #208]\n"
- "str q28, [%[c_ptr], #224]\n"
- "str q13, [%[c_ptr], #240]\n"
- "str q21, [%[c_ptr], #256]\n"
- "str q29, [%[c_ptr], #272]\n"
- "str q14, [%[c_ptr], #288]\n"
- "str q22, [%[c_ptr], #304]\n"
- "str q30, [%[c_ptr], #320]\n"
- "str q15, [%[c_ptr], #336]\n"
- "str q23, [%[c_ptr], #352]\n"
- "str q31, [%[c_ptr], #368]\n"
- "add %[c_ptr], %[c_ptr], #384\n"
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc");
+ "str q8, [%[c_ptr]]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "str q24, [%[c_ptr], #32]\n"
+ "str q9, [%[c_ptr], #48]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "str q10, [%[c_ptr], #96]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "str q11, [%[c_ptr], #144]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "str q12, [%[c_ptr], #192]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "str q13, [%[c_ptr], #240]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "str q14, [%[c_ptr], #288]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "str q15, [%[c_ptr], #336]\n"
+ "str q23, [%[c_ptr], #352]\n"
+ "str q31, [%[c_ptr], #368]\n"
+ "add %[c_ptr], %[c_ptr], #384\n"
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "+w" (a0), [a1] "+w" (a1), [a0a] "+w" (a0a), [a1a] "+w" (a1a),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"
+ );
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55r1.cpp
index 89fe6ac..114c807 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55r1.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55r1.cpp
@@ -27,37 +27,34 @@
#include "../../asmlib.hpp"
-namespace arm_gemm
-{
-void a64_sgemm_asimd_12x8_a55r1(const float *Apanel, const float *Bpanel, float *Cpanel, const int ablocks, const int bblocks, const int K)
-{
+namespace arm_gemm {
+
+void a64_sgemm_asimd_12x8_a55r1(const float *Apanel, const float *Bpanel, float *Cpanel, const int ablocks, const int bblocks, const int K) {
const float *a_ptr = Apanel;
- float *c_ptr = Cpanel;
+ float *c_ptr = Cpanel;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
- int oddk = (K & 1);
- int k_iters = ((K + 1) / 2) - 1;
+ int oddk = (K & 1);
+ int k_iters = ((K+1)/2) - 1;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const float *a_ptr0 = a_ptr;
- const float *b_ptr = Bpanel;
+ const float *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
a_ptr = a_ptr0;
int k = k_iters;
- register float32x4_t a0 asm("v0");
- register float32x4_t a1 asm("v1");
- register float32x4_t b0 asm("v2");
- register float32x4_t b1 asm("v3");
- register float32x4_t b2 asm("v4");
+ register float32x4_t a0 asm("v0");
+ register float32x4_t a1 asm("v1");
+ register float32x4_t b0 asm("v2");
+ register float32x4_t b1 asm("v3");
+ register float32x4_t b2 asm("v4");
register float32x4_t a0a asm("v5");
register float32x4_t a1a asm("v6");
- __asm __volatile(
+ __asm __volatile (
// Initialize result registers, load initial operands, prime prefetches.
"movi v8.4s, #0x0\n"
"ldr %q[a0], [%[a_ptr]]\n"
@@ -67,272 +64,319 @@
"ldr %q[a1], [%[a_ptr], #16]\n"
"movi v11.4s, #0x0\n"
"ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v12.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v13.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v14.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v15.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v17.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
"movi v18.4s, #0x0\n"
- "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v19.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
"movi v20.4s, #0x0\n"
- "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v21.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
"movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]")
+ "movi v23.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
"movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #384]")
+ "movi v25.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #384]")
"movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #448]")
+ "movi v27.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
"movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #384]")
+ "movi v29.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #384]")
"movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #512]")
+ "movi v31.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #512]")
// The loop is offset by these two instructions which must
// always be executed.
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
- "ldr %d[b2], [%[b_ptr], #32]\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "ldr %d[b2], [%[b_ptr], #32]\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 4f\n"
+ "cbz %w[k], 4f\n"
"1:\n"
// Unroll 0
- "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "ldr x20, [%[a_ptr], #40]\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ "ldr x20, [%[a_ptr], #40]\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "ins %[a0a].d[1], x20\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "ldr x20, [%[a_ptr], #56]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "ins %[a0a].d[1], x20\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "ldr x20, [%[a_ptr], #56]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "ins %[a1a].d[1], x20\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "ldr x20, [%[b_ptr], #56]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "ins %[a1a].d[1], x20\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "ldr x20, [%[b_ptr], #56]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "ins %[b0].d[1], x20\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
- "ldr x20, [%[b_ptr], #72]\n"
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" ASM_PREFETCH("[%[a_ptr], #448]")
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "ins %[b0].d[1], x20\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ "ldr x20, [%[b_ptr], #72]\n"
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ ASM_PREFETCH("[%[a_ptr], #448]")
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" ASM_PREFETCH("[%[b_ptr], #576]")
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #576]")
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
// Unroll 1
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
- "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
- "ins %[b1].d[1], x20\n"
- "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
- "ldr x20, [%[b_ptr], #88]\n"
- "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
- "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
- "ldr %d[a0], [%[a_ptr], #64]\n"
+ "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
+ "ins %[b1].d[1], x20\n"
+ "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
+ "ldr x20, [%[b_ptr], #88]\n"
+ "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
+ "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
+ "ldr %d[a0], [%[a_ptr], #64]\n"
- "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
- "ldr x20, [%[a_ptr], #72]\n"
- "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
- "ldr %d[a1], [%[a_ptr], #80]\n"
+ "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
+ "ldr x20, [%[a_ptr], #72]\n"
+ "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
+ "ldr %d[a1], [%[a_ptr], #80]\n"
- "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
- "ins %[a0].d[1], x20\n"
- "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
- "ldr x20, [%[a_ptr], #88]\n"
- "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
- "ldr %d[b0], [%[b_ptr], #96]\n"
+ "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
+ "ins %[a0].d[1], x20\n"
+ "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
+ "ldr x20, [%[a_ptr], #88]\n"
+ "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
+ "ldr %d[b0], [%[b_ptr], #96]\n"
- "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
- "ins %[a1].d[1], x20\n"
- "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
- "ldr x20, [%[b_ptr], #104]\n"
- "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
- "ldr %d[b1], [%[b_ptr], #112]\n"
+ "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
+ "ins %[a1].d[1], x20\n"
+ "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
+ "ldr x20, [%[b_ptr], #104]\n"
+ "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
+ "ldr %d[b1], [%[b_ptr], #112]\n"
- "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
- "ins %[b0].d[1], x20\n"
- "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
- "ldr x20, [%[b_ptr], #120]\n"
- "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
+ "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
+ "ins %[b0].d[1], x20\n"
+ "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
+ "ldr x20, [%[b_ptr], #120]\n"
+ "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
+ "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
- "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" ASM_PREFETCH("[%[b_ptr], #640]")
- "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
- "ins %[b1].d[1], x20\n"
- "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
- "ldr %d[b2], [%[b_ptr], #32]\n"
+ "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
+ ASM_PREFETCH("[%[b_ptr], #640]")
+ "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
+ "ins %[b1].d[1], x20\n"
+ "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
+ "ldr %d[b2], [%[b_ptr], #32]\n"
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
- "b.ne 1b\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "b.ne 1b\n"
// Branch here if K=1 or 2. Do the right thing for odd/even at the end.
"4:\n"
- // Start final iteration - branch off to "odd" code before we load a0a.
- "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "ldr x20, [%[b_ptr], #40]\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "cbnz %w[oddk], 2f\n"
+ // Start final iteration - branch off to "odd" code before we load a0a.
+ "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
+ "ldr x20, [%[b_ptr], #40]\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "cbnz %w[oddk], 2f\n"
// Even K continuation
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "ldr %d[a0a], [%[a_ptr], #32]\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "ldr %d[a0a], [%[a_ptr], #32]\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "ldr x20, [%[a_ptr], #40]\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" ASM_PREFETCHW("[%[c_ptr]]")
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "ldr %d[a1a], [%[a_ptr], #48]\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ "ldr x20, [%[a_ptr], #40]\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ ASM_PREFETCHW("[%[c_ptr]]")
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "ldr %d[a1a], [%[a_ptr], #48]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "ins %[a0a].d[1], x20\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "ldr x20, [%[a_ptr], #56]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "ldr %d[b0], [%[b_ptr], #48]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "ins %[a0a].d[1], x20\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "ldr x20, [%[a_ptr], #56]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "ldr %d[b0], [%[b_ptr], #48]\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "ins %[a1a].d[1], x20\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "ldr x20, [%[b_ptr], #56]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" ASM_PREFETCHW("[%[c_ptr], #64]")
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "ins %[a1a].d[1], x20\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "ldr x20, [%[b_ptr], #56]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ ASM_PREFETCHW("[%[c_ptr], #64]")
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" ASM_PREFETCHW("[%[c_ptr], #128]")
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
- "ldr %d[b1], [%[b_ptr], #64]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #128]")
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ "ldr %d[b1], [%[b_ptr], #64]\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "ins %[b0].d[1], x20\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- "ldr x20, [%[b_ptr], #72]\n"
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" ASM_PREFETCHW("[%[c_ptr], #192]")
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
- "ldr %d[b2], [%[b_ptr], #80]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "ins %[b0].d[1], x20\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ "ldr x20, [%[b_ptr], #72]\n"
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ ASM_PREFETCHW("[%[c_ptr], #192]")
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "ldr %d[b2], [%[b_ptr], #80]\n"
- "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
- "ins %[b1].d[1], x20\n"
- "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
- "ldr x20, [%[b_ptr], #88]\n"
- "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
- "ins %[b2].d[1], x20\n"
+ "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
+ "ins %[b1].d[1], x20\n"
+ "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
+ "ldr x20, [%[b_ptr], #88]\n"
+ "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
+ "ins %[b2].d[1], x20\n"
- "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]")
- "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
- "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n" ASM_PREFETCHW("[%[c_ptr], #320]")
- "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
- "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]")
- "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
- "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]")
- "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
- "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]")
- "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]")
- "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
- "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #640]")
- "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
- "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]")
- "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
- "b 3f\n"
+ "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #256]")
+ "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
+ "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
+ ASM_PREFETCHW("[%[c_ptr], #320]")
+ "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
+ "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #384]")
+ "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
+ "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #448]")
+ "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #512]")
+ "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #576]")
+ "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
+ "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #640]")
+ "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #704]")
+ "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
+ "b 3f\n"
// Odd K continuation
"2:\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" ASM_PREFETCHW("[%[c_ptr]]")
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "ins %[b2].d[1], x20\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]")
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]")
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "add %[b_ptr], %[b_ptr], #48\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]")
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]")
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]")
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]")
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]")
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- ASM_PREFETCHWL2("[%[c_ptr], #640]") "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]")
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ ASM_PREFETCHW("[%[c_ptr]]")
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "ins %[b2].d[1], x20\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #64]")
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #128]")
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "add %[b_ptr], %[b_ptr], #48\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #192]")
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ ASM_PREFETCHW("[%[c_ptr], #256]")
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ ASM_PREFETCHW("[%[c_ptr], #320]")
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #384]")
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #448]")
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #512]")
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #576]")
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #640]")
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ ASM_PREFETCHWL2("[%[c_ptr], #704]")
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
// Common tail
"3:\n"
- "str q8, [%[c_ptr]]\n"
- "str q16, [%[c_ptr], #16]\n"
- "str q24, [%[c_ptr], #32]\n"
- "str q9, [%[c_ptr], #48]\n"
- "str q17, [%[c_ptr], #64]\n"
- "str q25, [%[c_ptr], #80]\n"
- "str q10, [%[c_ptr], #96]\n"
- "str q18, [%[c_ptr], #112]\n"
- "str q26, [%[c_ptr], #128]\n"
- "str q11, [%[c_ptr], #144]\n"
- "str q19, [%[c_ptr], #160]\n"
- "str q27, [%[c_ptr], #176]\n"
- "str q12, [%[c_ptr], #192]\n"
- "str q20, [%[c_ptr], #208]\n"
- "str q28, [%[c_ptr], #224]\n"
- "str q13, [%[c_ptr], #240]\n"
- "str q21, [%[c_ptr], #256]\n"
- "str q29, [%[c_ptr], #272]\n"
- "str q14, [%[c_ptr], #288]\n"
- "str q22, [%[c_ptr], #304]\n"
- "str q30, [%[c_ptr], #320]\n"
- "str q15, [%[c_ptr], #336]\n"
- "str q23, [%[c_ptr], #352]\n"
- "str q31, [%[c_ptr], #368]\n"
- "add %[c_ptr], %[c_ptr], #384\n"
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k)
- : [oddk] "r"(oddk)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc");
+ "str q8, [%[c_ptr]]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "str q24, [%[c_ptr], #32]\n"
+ "str q9, [%[c_ptr], #48]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "str q10, [%[c_ptr], #96]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "str q11, [%[c_ptr], #144]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "str q12, [%[c_ptr], #192]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "str q13, [%[c_ptr], #240]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "str q14, [%[c_ptr], #288]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "str q15, [%[c_ptr], #336]\n"
+ "str q23, [%[c_ptr], #352]\n"
+ "str q31, [%[c_ptr], #368]\n"
+ "add %[c_ptr], %[c_ptr], #384\n"
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "+w" (a0), [a1] "+w" (a1), [a0a] "+w" (a0a), [a1a] "+w" (a1a),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k)
+ : [oddk] "r" (oddk)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"
+ );
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/generic.cpp
index 42e870e..7169c8b 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/generic.cpp
@@ -37,311 +37,327 @@
// Note that the intent of this is that either ablocks or bblocks will be 1
// - this construction allows the output loop to proceed in either order.
-namespace arm_gemm
-{
-void a64_sgemm_asimd_12x8_jumps(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K, long int row_jump = 0, long int block_jump = 0)
-{
+namespace arm_gemm {
+
+void a64_sgemm_asimd_12x8_jumps(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K, long int row_jump=0, long int block_jump=0) {
const float *a_ptr = Apanel;
- float *c_ptr = Cpanel;
+ float *c_ptr = Cpanel;
- for(int yb = 0; yb < ablocks; yb++)
- {
+ for (int yb=0; yb<ablocks; yb++) {
const float *a_ptr0 = a_ptr;
- const float *b_ptr = Bpanel;
+ const float *b_ptr = Bpanel;
- for(int xb = 0; xb < bblocks; xb++)
- {
+ for (int xb=0; xb<bblocks; xb++) {
a_ptr = a_ptr0;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int oddk = (K & 1);
- int k = ((K + 1) / 2) - 1;
+ int k = ((K+1)/2) - 1;
- register float32x4_t a0 asm("v0");
- register float32x4_t a1 asm("v1");
- register float32x4_t b0 asm("v2");
- register float32x4_t b1 asm("v3");
- register float32x4_t b2 asm("v4");
+ register float32x4_t a0 asm("v0");
+ register float32x4_t a1 asm("v1");
+ register float32x4_t b0 asm("v2");
+ register float32x4_t b1 asm("v3");
+ register float32x4_t b2 asm("v4");
register float32x4_t a0a asm("v5");
register float32x4_t a1a asm("v6");
- __asm __volatile(
+ __asm __volatile (
// Initialize result registers, load initial operands, prime prefetches.
- "movi v8.4s, #0x0\n"
- "ldr %q[a0], [%[a_ptr]]\n"
- "movi v9.4s, #0x0\n"
- "ldr %q[b0], [%[b_ptr]]\n"
- "movi v10.4s, #0x0\n"
- "ldr %q[a1], [%[a_ptr], #16]\n"
- "movi v11.4s, #0x0\n"
- "ldr %q[b1], [%[b_ptr], #16]\n"
- "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n"
- ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n"
+ "movi v8.4s, #0x0\n"
+ "ldr %q[a0], [%[a_ptr]]\n"
+ "movi v9.4s, #0x0\n"
+ "ldr %q[b0], [%[b_ptr]]\n"
+ "movi v10.4s, #0x0\n"
+ "ldr %q[a1], [%[a_ptr], #16]\n"
+ "movi v11.4s, #0x0\n"
+ "ldr %q[b1], [%[b_ptr], #16]\n"
+ "movi v12.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #64]")
+ "movi v13.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #64]")
+ "movi v14.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #128]")
+ "movi v15.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #128]")
+ "movi v16.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #192]")
+ "movi v17.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #256]")
+ "movi v18.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #192]")
+ "movi v19.4s, #0x0\n"
+ ASM_PREFETCH("[%[b_ptr], #320]")
+ "movi v20.4s, #0x0\n"
+ ASM_PREFETCH("[%[a_ptr], #256]")
+ "movi v21.4s, #0x0\n"
ASM_PREFETCH("[%[b_ptr], #384]")
- "movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n"
- "movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n"
- "movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n"
- "movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- "movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n"
+ "movi v22.4s, #0x0\n"
+ "movi v23.4s, #0x0\n"
+ "movi v24.4s, #0x0\n"
+ "movi v25.4s, #0x0\n"
+ "movi v26.4s, #0x0\n"
+ "movi v27.4s, #0x0\n"
+ "movi v28.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v31.4s, #0x0\n"
// Skip loop if we are doing zero iterations of it.
- "cbz %w[k], 4f\n"
+ "cbz %w[k], 4f\n"
// Loop proper
"1:\n"
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
- "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "add %[b_ptr], %[b_ptr], %[row_jump]\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "ldr %q[a0a], [%[a_ptr], #32]\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "ldr %q[a1a], [%[a_ptr], #48]\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "ldr %q[b0], [%[b_ptr], #48]\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[row_jump]\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "ldr %q[a0a], [%[a_ptr], #32]\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
+ "ldr %q[a1a], [%[a_ptr], #48]\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "ldr %q[b0], [%[b_ptr], #48]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" ASM_PREFETCH("[%[a_ptr], #320]")
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "ldr %q[b1], [%[b_ptr], #64]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ ASM_PREFETCH("[%[a_ptr], #320]")
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "ldr %q[b1], [%[b_ptr], #64]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" ASM_PREFETCH("[%[b_ptr], #448]")
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
- "ldr %q[b2], [%[b_ptr], #80]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #448]")
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "ldr %q[b2], [%[b_ptr], #80]\n"
- "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
- "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
- "ldr %q[a0], [%[a_ptr], #64]\n"
- "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
- "add %[b_ptr], %[b_ptr], %[row_jump]\n"
- "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
- "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
- "ldr %q[a1], [%[a_ptr], #80]\n"
+ "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
+ "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
+ "ldr %q[a0], [%[a_ptr], #64]\n"
+ "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[row_jump]\n"
+ "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
+ "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
+ "ldr %q[a1], [%[a_ptr], #80]\n"
"fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
- "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
- "ldr %q[b0], [%[b_ptr], #96]\n"
+ "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
+ "ldr %q[b0], [%[b_ptr], #96]\n"
- "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n" ASM_PREFETCH("[%[b_ptr], #512]")
- "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
- "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
- "ldr %q[b1], [%[b_ptr], #112]\n"
+ "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
+ ASM_PREFETCH("[%[b_ptr], #512]")
+ "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
+ "ldr %q[b1], [%[b_ptr], #112]\n"
- "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
- "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
- "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
- "bne 1b\n"
+ "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
+ "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
+ "bne 1b\n"
// Target to use when K is 1 or 2 (i.e. zero iterations of main loop)
"4:\n"
// Branch to alternative tail for odd K
- "cbnz %w[oddk], 2f\n"
+ "cbnz %w[oddk], 2f\n"
// Detached final iteration (even K)
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
"fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "add %[b_ptr], %[b_ptr], %[row_jump]\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "ldr %q[a0a], [%[a_ptr], #32]\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[row_jump]\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "ldr %q[a0a], [%[a_ptr], #32]\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
"fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "ldr %q[a1a], [%[a_ptr], #48]\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "ldr %q[b0], [%[b_ptr], #48]\n"
+ "ldr %q[a1a], [%[a_ptr], #48]\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "ldr %q[b0], [%[b_ptr], #48]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "ldr %q[b1], [%[b_ptr], #64]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "ldr %q[b1], [%[b_ptr], #64]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
- "add %[a_ptr], %[a_ptr], #64\n"
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
- "ldr %q[b2], [%[b_ptr], #80]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ "add %[a_ptr], %[a_ptr], #64\n"
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "ldr %q[b2], [%[b_ptr], #80]\n"
- "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
- "add %[b_ptr], %[b_ptr], %[block_jump]\n"
- "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
- "add %[b_ptr], %[b_ptr], #96\n"
+ "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n"
+ "add %[b_ptr], %[b_ptr], %[block_jump]\n"
+ "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n"
+ "add %[b_ptr], %[b_ptr], #96\n"
"fmla v9.4s , %[b0].4s, %[a0a].s[1]\n"
- "add %[b_ptr], %[b_ptr], %[row_jump]\n"
- "str q8, [%[c_ptr], #0]\n"
- "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
- "str q16, [%[c_ptr], #16]\n"
- "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
- "str q24, [%[c_ptr], #32]\n"
+ "add %[b_ptr], %[b_ptr], %[row_jump]\n"
+ "str q8, [%[c_ptr], #0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n"
+ "str q24, [%[c_ptr], #32]\n"
- "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
- "str q9, [%[c_ptr], #48]\n"
- "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
- "str q17, [%[c_ptr], #64]\n"
- "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
- "str q25, [%[c_ptr], #80]\n"
- "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
- "str q10, [%[c_ptr], #96]\n"
+ "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n"
+ "str q9, [%[c_ptr], #48]\n"
+ "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n"
+ "str q10, [%[c_ptr], #96]\n"
- "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
- "str q18, [%[c_ptr], #112]\n"
- "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
- "str q26, [%[c_ptr], #128]\n"
- "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
- "str q11, [%[c_ptr], #144]\n"
+ "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n"
+ "str q11, [%[c_ptr], #144]\n"
- "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
- "str q19, [%[c_ptr], #160]\n"
- "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
- "str q27, [%[c_ptr], #176]\n"
- "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
- "str q12, [%[c_ptr], #192]\n"
+ "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n"
+ "str q12, [%[c_ptr], #192]\n"
"fmla v13.4s, %[b0].4s, %[a1a].s[1]\n"
- "str q20, [%[c_ptr], #208]\n"
- "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
- "str q28, [%[c_ptr], #224]\n"
- "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
- "str q13, [%[c_ptr], #240]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n"
+ "str q13, [%[c_ptr], #240]\n"
- "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
- "str q21, [%[c_ptr], #256]\n"
- "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
- "str q29, [%[c_ptr], #272]\n"
- "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
- "str q14, [%[c_ptr], #288]\n"
+ "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n"
+ "str q14, [%[c_ptr], #288]\n"
- "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
- "str q22, [%[c_ptr], #304]\n"
- "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
- "str q30, [%[c_ptr], #320]\n"
- "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
- "str q15, [%[c_ptr], #336]\n"
+ "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n"
+ "str q15, [%[c_ptr], #336]\n"
- "b 3f\n"
+ "b 3f\n"
// Detached final iteration (odd K)
"2:\n"
- "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
- "ldr %q[b2], [%[b_ptr], #32]\n"
- "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
- "add %[b_ptr], %[b_ptr], %[row_jump]\n"
+ "fmla v8.4s , %[b0].4s, %[a0].s[0]\n"
+ "ldr %q[b2], [%[b_ptr], #32]\n"
+ "fmla v16.4s, %[b1].4s, %[a0].s[0]\n"
+ "add %[b_ptr], %[b_ptr], %[row_jump]\n"
"fmla v9.4s , %[b0].4s, %[a0].s[1]\n"
- "str q8, [%[c_ptr], #0]\n"
- "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
- "str q16, [%[c_ptr], #16]\n"
- "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
- "add %[b_ptr], %[b_ptr], #48\n"
- "add %[a_ptr], %[a_ptr], #32\n"
- "str q24, [%[c_ptr], #32]\n"
- "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
- "str q9, [%[c_ptr], #48]\n"
+ "str q8, [%[c_ptr], #0]\n"
+ "fmla v17.4s, %[b1].4s, %[a0].s[1]\n"
+ "str q16, [%[c_ptr], #16]\n"
+ "fmla v24.4s, %[b2].4s, %[a0].s[0]\n"
+ "add %[b_ptr], %[b_ptr], #48\n"
+ "add %[a_ptr], %[a_ptr], #32\n"
+ "str q24, [%[c_ptr], #32]\n"
+ "fmla v25.4s, %[b2].4s, %[a0].s[1]\n"
+ "str q9, [%[c_ptr], #48]\n"
- "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
- "str q17, [%[c_ptr], #64]\n"
- "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
- "str q25, [%[c_ptr], #80]\n"
- "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
- "str q10, [%[c_ptr], #96]\n"
+ "fmla v10.4s, %[b0].4s, %[a0].s[2]\n"
+ "str q17, [%[c_ptr], #64]\n"
+ "fmla v18.4s, %[b1].4s, %[a0].s[2]\n"
+ "str q25, [%[c_ptr], #80]\n"
+ "fmla v26.4s, %[b2].4s, %[a0].s[2]\n"
+ "str q10, [%[c_ptr], #96]\n"
- "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
- "str q18, [%[c_ptr], #112]\n"
- "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
- "str q26, [%[c_ptr], #128]\n"
- "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
- "str q11, [%[c_ptr], #144]\n"
+ "fmla v11.4s, %[b0].4s, %[a0].s[3]\n"
+ "str q18, [%[c_ptr], #112]\n"
+ "fmla v19.4s, %[b1].4s, %[a0].s[3]\n"
+ "str q26, [%[c_ptr], #128]\n"
+ "fmla v27.4s, %[b2].4s, %[a0].s[3]\n"
+ "str q11, [%[c_ptr], #144]\n"
- "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
- "str q19, [%[c_ptr], #160]\n"
- "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
- "str q27, [%[c_ptr], #176]\n"
- "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
- "str q12, [%[c_ptr], #192]\n"
+ "fmla v12.4s, %[b0].4s, %[a1].s[0]\n"
+ "str q19, [%[c_ptr], #160]\n"
+ "fmla v20.4s, %[b1].4s, %[a1].s[0]\n"
+ "str q27, [%[c_ptr], #176]\n"
+ "fmla v28.4s, %[b2].4s, %[a1].s[0]\n"
+ "str q12, [%[c_ptr], #192]\n"
"fmla v13.4s, %[b0].4s, %[a1].s[1]\n"
- "str q20, [%[c_ptr], #208]\n"
- "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
- "str q28, [%[c_ptr], #224]\n"
- "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
- "str q13, [%[c_ptr], #240]\n"
+ "str q20, [%[c_ptr], #208]\n"
+ "fmla v21.4s, %[b1].4s, %[a1].s[1]\n"
+ "str q28, [%[c_ptr], #224]\n"
+ "fmla v29.4s, %[b2].4s, %[a1].s[1]\n"
+ "str q13, [%[c_ptr], #240]\n"
- "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
- "str q21, [%[c_ptr], #256]\n"
- "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
- "str q29, [%[c_ptr], #272]\n"
- "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
- "str q14, [%[c_ptr], #288]\n"
+ "fmla v14.4s, %[b0].4s, %[a1].s[2]\n"
+ "str q21, [%[c_ptr], #256]\n"
+ "fmla v22.4s, %[b1].4s, %[a1].s[2]\n"
+ "str q29, [%[c_ptr], #272]\n"
+ "fmla v30.4s, %[b2].4s, %[a1].s[2]\n"
+ "str q14, [%[c_ptr], #288]\n"
- "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
- "str q22, [%[c_ptr], #304]\n"
- "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
- "str q30, [%[c_ptr], #320]\n"
- "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
- "str q15, [%[c_ptr], #336]\n"
+ "fmla v15.4s, %[b0].4s, %[a1].s[3]\n"
+ "str q22, [%[c_ptr], #304]\n"
+ "fmla v23.4s, %[b1].4s, %[a1].s[3]\n"
+ "str q30, [%[c_ptr], #320]\n"
+ "fmla v31.4s, %[b2].4s, %[a1].s[3]\n"
+ "str q15, [%[c_ptr], #336]\n"
// Common tail
"3:\n"
- "str q23, [%[c_ptr], #352]\n"
- "str q31, [%[c_ptr], #368]\n"
- "add %[c_ptr], %[c_ptr], #384\n"
- :
- [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr),
- [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a),
- [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k)
- : [oddk] "r"(oddk), [row_jump] "r"(row_jump), [block_jump] "r"(block_jump)
- : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
- "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc");
+ "str q23, [%[c_ptr], #352]\n"
+ "str q31, [%[c_ptr], #368]\n"
+ "add %[c_ptr], %[c_ptr], #384\n"
+ :
+ [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
+ [a0] "+w" (a0), [a1] "+w" (a1), [a0a] "+w" (a0a), [a1a] "+w" (a1a),
+ [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k)
+ : [oddk] "r" (oddk), [row_jump] "r" (row_jump), [block_jump] "r" (block_jump)
+ : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
+ "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"
+ );
}
}
}
-void a64_sgemm_asimd_12x8(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K)
-{
+void a64_sgemm_asimd_12x8(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
a64_sgemm_asimd_12x8_jumps(Apanel, Bpanel, Cpanel, ablocks, bblocks, K, 0, 0);
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp
index eceacc9..1a35965 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp
@@ -25,8 +25,8 @@
#ifdef __aarch64__
-namespace arm_gemm
-{
+namespace arm_gemm {
+
// Actual kernel implementations
void a64_sgemm_native_16x4(const float *, int, const float *, int, float *, int, float, int, int, int);
@@ -38,8 +38,7 @@
// All kernels in the family must share these characteristics. The actual
// kernel to be used can be chosen at runtime, based on the CPU_type
// structure.
-class sgemm_native_16x4
-{
+class sgemm_native_16x4 {
public:
typedef float operand_type;
typedef float result_type;
@@ -47,15 +46,23 @@
typedef void (*kern_type)(const float *, int, const float *, int, float *, int, float, int, int, int);
/* Kernel blocking parameters */
- static const int out_width = 16;
- static const int out_height = 4;
- static const int k_unroll = 1;
+ static int out_width() {
+ return 16;
+ }
+
+ static int out_height() {
+ return 4;
+ }
+
+ static int k_unroll() {
+ return 1;
+ }
// Default to the generic kernel
- kern_type kernel = a64_sgemm_native_16x4;
+ kern_type kernel=a64_sgemm_native_16x4;
- sgemm_native_16x4(const CPUInfo *ci)
- {
+ sgemm_native_16x4(const CPUInfo *ci) {
+
}
};
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
index 8d4a38c..8325b3f 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
@@ -23,41 +23,55 @@
*/
#ifdef __aarch64__
+#include <algorithm>
#include <cstddef>
+#include <cstring>
#include <arm_neon.h>
-namespace arm_gemm
-{
-void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, float *C, int ldc, float beta, int M, int N, int K)
-{
- const int oddk = ((K % 8) >= 4) ? 1 : 0;
- const int beta0 = (beta == 0.0f) ? 1 : 0;
+namespace arm_gemm {
+
+void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, float *C, int ldc, float beta, int M, int N, int K) {
+ const int oddk = ((K % 8) >= 4) ? 1 : 0;
+ const int beta0 = (beta == 0.0f) ? 1 : 0;
const int oddones = (K % 4);
+ float dummy_buffer[16];
+
+ std::memset(dummy_buffer, 0, sizeof(dummy_buffer));
+
/* For now, very naive with no blocking */
- for(int y = 0; y < M; y += 4)
- {
- for(int x0 = 0; x0 < N; x0 += 16)
- {
- const float *a_ptr0 = A + (y * lda);
- const float *a_ptr1 = a_ptr0 + lda;
- const float *a_ptr2 = a_ptr1 + lda;
- const float *a_ptr3 = a_ptr2 + lda;
+ for (int y=0; y<M; y+=4) {
+ const int activerows = std::min(M-y, 4);
+
+ const float * const a_ptr0_base = A + (y * lda);
+ const float * const a_ptr1_base = (activerows > 1) ? (a_ptr0_base + lda) : dummy_buffer;
+ const float * const a_ptr2_base = (activerows > 2) ? (a_ptr1_base + lda) : dummy_buffer;
+ const float * const a_ptr3_base = (activerows > 3) ? (a_ptr2_base + lda) : dummy_buffer;
+
+ const unsigned long a_incr1 = (activerows > 1) ? 32 : 0;
+ const unsigned long a_incr2 = (activerows > 2) ? 32 : 0;
+ const unsigned long a_incr3 = (activerows > 3) ? 32 : 0;
+
+ float *c_ptr0 = C + (y * ldc);
+ float *c_ptr1 = (activerows > 1) ? c_ptr0 + ldc : dummy_buffer;
+ float *c_ptr2 = (activerows > 2) ? c_ptr1 + ldc : dummy_buffer;
+ float *c_ptr3 = (activerows > 3) ? c_ptr2 + ldc : dummy_buffer;
+
+ for (int x0=0; x0<N; x0+=16) {
+ const float *a_ptr0 = a_ptr0_base;
+ const float *a_ptr1 = a_ptr1_base;
+ const float *a_ptr2 = a_ptr2_base;
+ const float *a_ptr3 = a_ptr3_base;
const float *b_ptr = B + x0;
- float *c_ptr0 = C + (y * ldc) + x0;
- float *c_ptr1 = c_ptr0 + ldc;
- float *c_ptr2 = c_ptr1 + ldc;
- float *c_ptr3 = c_ptr2 + ldc;
-
- int loops = ((K + 4) / 8) - 1;
- int odds = oddones;
+ int loops = ((K+4)/8) - 1;
+ int odds = oddones;
size_t ldbb = ldb * sizeof(float);
- __asm __volatile(
+ __asm __volatile (
"a0 .req v0\n"
"a1 .req v1\n"
"a2 .req v2\n"
@@ -92,774 +106,780 @@
"b2aq .req q14\n"
"b3aq .req q15\n"
- "movi v16.4s, #0x0\n"
- "ldr a0q, [%[a_ptr0]]\n"
- "movi v17.4s, #0x0\n"
- "ldr b0q, [%[b_ptr]]\n"
- "movi v18.4s, #0x0\n"
- "ldr b1q, [%[b_ptr], #16]\n"
- "movi v19.4s, #0x0\n"
- "ldr b2q, [%[b_ptr], #32]\n"
- "movi v20.4s, #0x0\n"
- "ldr b3q, [%[b_ptr], #48]\n"
- "movi v21.4s, #0x0\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "ldr a1q, [%[a_ptr1]]\n"
- "movi v22.4s, #0x0\n"
- "ldr a2q, [%[a_ptr2]]\n"
- "movi v23.4s, #0x0\n"
- "ldr a3q, [%[a_ptr3]]\n"
- "movi v24.4s, #0x0\n"
- "ldr b0aq, [%[b_ptr]]\n"
- "movi v25.4s, #0x0\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
- "movi v26.4s, #0x0\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
- "cbz %w[beta0], 5f\n"
- "movi v27.4s, #0x0\n"
- "movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- "movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n"
+ "movi v16.4s, #0x0\n"
+ "ldr a0q, [%[a_ptr0]]\n"
+ "movi v17.4s, #0x0\n"
+ "ldr b0q, [%[b_ptr]]\n"
+ "movi v18.4s, #0x0\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
+ "movi v19.4s, #0x0\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
+ "movi v20.4s, #0x0\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
+ "movi v21.4s, #0x0\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "ldr a1q, [%[a_ptr1]]\n"
+ "movi v22.4s, #0x0\n"
+ "ldr a2q, [%[a_ptr2]]\n"
+ "movi v23.4s, #0x0\n"
+ "ldr a3q, [%[a_ptr3]]\n"
+ "movi v24.4s, #0x0\n"
+ "ldr b0aq, [%[b_ptr]]\n"
+ "movi v25.4s, #0x0\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
+ "movi v26.4s, #0x0\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
+ "cbz %w[beta0], 5f\n"
+ "movi v27.4s, #0x0\n"
+ "movi v28.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v31.4s, #0x0\n"
// Skip if no complete loops.
- "cbz %w[loops], 4f\n"
- "b 1f\n"
+ "cbz %w[loops], 4f\n"
+ "b 1f\n"
// If beta is non-zero, need to load and multiply by beta
"5:\n"
- "ld1r {v4.4s}, [%[betaptr]]\n"
- "ldr q16, [%[c_ptr0]]\n"
- "ldr q17, [%[c_ptr0], #16]\n"
- "ldr q18, [%[c_ptr0], #32]\n"
- "ldr q19, [%[c_ptr0], #48]\n"
+ "ld1r {v4.4s}, [%[betaptr]]\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #16]\n"
+ "ldr q18, [%[c_ptr0], #32]\n"
+ "ldr q19, [%[c_ptr0], #48]\n"
- "ldr q20, [%[c_ptr1]]\n"
- "fmul v16.4s, v16.4s, v4.4s\n"
- "ldr q21, [%[c_ptr1], #16]\n"
- "fmul v17.4s, v17.4s, v4.4s\n"
- "ldr q22, [%[c_ptr1], #32]\n"
- "fmul v18.4s, v18.4s, v4.4s\n"
- "ldr q23, [%[c_ptr1], #48]\n"
- "fmul v19.4s, v19.4s, v4.4s\n"
+ "ldr q20, [%[c_ptr1]]\n"
+ "fmul v16.4s, v16.4s, v4.4s\n"
+ "ldr q21, [%[c_ptr1], #16]\n"
+ "fmul v17.4s, v17.4s, v4.4s\n"
+ "ldr q22, [%[c_ptr1], #32]\n"
+ "fmul v18.4s, v18.4s, v4.4s\n"
+ "ldr q23, [%[c_ptr1], #48]\n"
+ "fmul v19.4s, v19.4s, v4.4s\n"
- "ldr q24, [%[c_ptr2]]\n"
- "fmul v20.4s, v20.4s, v4.4s\n"
- "ldr q25, [%[c_ptr2], #16]\n"
- "fmul v21.4s, v21.4s, v4.4s\n"
- "ldr q26, [%[c_ptr2], #32]\n"
- "fmul v22.4s, v22.4s, v4.4s\n"
- "ldr q27, [%[c_ptr2], #48]\n"
- "fmul v23.4s, v23.4s, v4.4s\n"
+ "ldr q24, [%[c_ptr2]]\n"
+ "fmul v20.4s, v20.4s, v4.4s\n"
+ "ldr q25, [%[c_ptr2], #16]\n"
+ "fmul v21.4s, v21.4s, v4.4s\n"
+ "ldr q26, [%[c_ptr2], #32]\n"
+ "fmul v22.4s, v22.4s, v4.4s\n"
+ "ldr q27, [%[c_ptr2], #48]\n"
+ "fmul v23.4s, v23.4s, v4.4s\n"
- "ldr q28, [%[c_ptr3]]\n"
- "fmul v24.4s, v24.4s, v4.4s\n"
- "ldr q29, [%[c_ptr3], #16]\n"
- "fmul v25.4s, v25.4s, v4.4s\n"
- "ldr q30, [%[c_ptr3], #32]\n"
- "fmul v26.4s, v26.4s, v4.4s\n"
- "ldr q31, [%[c_ptr3], #48]\n"
- "fmul v27.4s, v27.4s, v4.4s\n"
+ "ldr q28, [%[c_ptr3]]\n"
+ "fmul v24.4s, v24.4s, v4.4s\n"
+ "ldr q29, [%[c_ptr3], #16]\n"
+ "fmul v25.4s, v25.4s, v4.4s\n"
+ "ldr q30, [%[c_ptr3], #32]\n"
+ "fmul v26.4s, v26.4s, v4.4s\n"
+ "ldr q31, [%[c_ptr3], #48]\n"
+ "fmul v27.4s, v27.4s, v4.4s\n"
- "fmul v28.4s, v28.4s, v4.4s\n"
- "fmul v29.4s, v29.4s, v4.4s\n"
- "fmul v30.4s, v30.4s, v4.4s\n"
- "fmul v31.4s, v31.4s, v4.4s\n"
+ "fmul v28.4s, v28.4s, v4.4s\n"
+ "fmul v29.4s, v29.4s, v4.4s\n"
+ "fmul v30.4s, v30.4s, v4.4s\n"
+ "fmul v31.4s, v31.4s, v4.4s\n"
- "cbz %w[loops], 4f\n"
+ "cbz %w[loops], 4f\n"
"1:\n"
// Unroll 0
- "fmla v16.4s, bb0.4s, a0.s[0]\n"
- "fmla v20.4s, bb0.4s, a1.s[0]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
- "fmla v24.4s, bb0.4s, a2.s[0]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v28.4s, bb0.4s, a3.s[0]\n"
- "ldr b0q, [%[b_ptr]]\n"
+ "fmla v16.4s, bb0.4s, a0.s[0]\n"
+ "fmla v20.4s, bb0.4s, a1.s[0]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v24.4s, bb0.4s, a2.s[0]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v28.4s, bb0.4s, a3.s[0]\n"
+ "ldr b0q, [%[b_ptr]]\n"
- "fmla v17.4s, bb1.4s, a0.s[0]\n"
- "fmla v21.4s, bb1.4s, a1.s[0]\n"
- "ldr a0aq, [%[a_ptr0], #16]\n"
- "fmla v25.4s, bb1.4s, a2.s[0]\n"
- "fmla v29.4s, bb1.4s, a3.s[0]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v17.4s, bb1.4s, a0.s[0]\n"
+ "fmla v21.4s, bb1.4s, a1.s[0]\n"
+ "ldr a0aq, [%[a_ptr0], #16]\n"
+ "fmla v25.4s, bb1.4s, a2.s[0]\n"
+ "fmla v29.4s, bb1.4s, a3.s[0]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v18.4s, bb2.4s, a0.s[0]\n"
- "fmla v22.4s, bb2.4s, a1.s[0]\n"
- "ldr a1aq, [%[a_ptr1], #16]\n"
- "fmla v26.4s, bb2.4s, a2.s[0]\n"
- "fmla v30.4s, bb2.4s, a3.s[0]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v18.4s, bb2.4s, a0.s[0]\n"
+ "fmla v22.4s, bb2.4s, a1.s[0]\n"
+ "ldr a1aq, [%[a_ptr1], #16]\n"
+ "fmla v26.4s, bb2.4s, a2.s[0]\n"
+ "fmla v30.4s, bb2.4s, a3.s[0]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v19.4s, bb3.4s, a0.s[0]\n"
- "fmla v23.4s, bb3.4s, a1.s[0]\n"
- "ldr a2aq, [%[a_ptr2], #16]\n"
- "fmla v27.4s, bb3.4s, a2.s[0]\n"
- "fmla v31.4s, bb3.4s, a3.s[0]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v19.4s, bb3.4s, a0.s[0]\n"
+ "fmla v23.4s, bb3.4s, a1.s[0]\n"
+ "ldr a2aq, [%[a_ptr2], #16]\n"
+ "fmla v27.4s, bb3.4s, a2.s[0]\n"
+ "fmla v31.4s, bb3.4s, a3.s[0]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 1
- "fmla v16.4s, b0a.4s, a0.s[1]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v20.4s, b0a.4s, a1.s[1]\n"
- "ldr a3aq, [%[a_ptr3], #16]\n"
- "fmla v24.4s, b0a.4s, a2.s[1]\n"
- "fmla v28.4s, b0a.4s, a3.s[1]\n"
- "ldr b0aq, [%[b_ptr]]\n"
+ "fmla v16.4s, b0a.4s, a0.s[1]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v20.4s, b0a.4s, a1.s[1]\n"
+ "ldr a3aq, [%[a_ptr3], #16]\n"
+ "fmla v24.4s, b0a.4s, a2.s[1]\n"
+ "fmla v28.4s, b0a.4s, a3.s[1]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
- "fmla v17.4s, b1a.4s, a0.s[1]\n"
- "fmla v21.4s, b1a.4s, a1.s[1]\n"
- "subs %w[loops], %w[loops], #1\n"
- "fmla v25.4s, b1a.4s, a2.s[1]\n"
- "fmla v29.4s, b1a.4s, a3.s[1]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
+ "fmla v17.4s, b1a.4s, a0.s[1]\n"
+ "fmla v21.4s, b1a.4s, a1.s[1]\n"
+ "subs %w[loops], %w[loops], #1\n"
+ "fmla v25.4s, b1a.4s, a2.s[1]\n"
+ "fmla v29.4s, b1a.4s, a3.s[1]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
- "fmla v18.4s, b2a.4s, a0.s[1]\n"
- "fmla v22.4s, b2a.4s, a1.s[1]\n"
- "fmla v26.4s, b2a.4s, a2.s[1]\n"
- "fmla v30.4s, b2a.4s, a3.s[1]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
+ "fmla v18.4s, b2a.4s, a0.s[1]\n"
+ "fmla v22.4s, b2a.4s, a1.s[1]\n"
+ "fmla v26.4s, b2a.4s, a2.s[1]\n"
+ "fmla v30.4s, b2a.4s, a3.s[1]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
- "fmla v19.4s, b3a.4s, a0.s[1]\n"
- "fmla v23.4s, b3a.4s, a1.s[1]\n"
- "fmla v27.4s, b3a.4s, a2.s[1]\n"
- "fmla v31.4s, b3a.4s, a3.s[1]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v19.4s, b3a.4s, a0.s[1]\n"
+ "fmla v23.4s, b3a.4s, a1.s[1]\n"
+ "fmla v27.4s, b3a.4s, a2.s[1]\n"
+ "fmla v31.4s, b3a.4s, a3.s[1]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 2
- "fmla v16.4s, bb0.4s, a0.s[2]\n"
- "fmla v20.4s, bb0.4s, a1.s[2]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, bb0.4s, a2.s[2]\n"
- "fmla v28.4s, bb0.4s, a3.s[2]\n"
- "ldr b0q, [%[b_ptr]]\n"
+ "fmla v16.4s, bb0.4s, a0.s[2]\n"
+ "fmla v20.4s, bb0.4s, a1.s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, bb0.4s, a2.s[2]\n"
+ "fmla v28.4s, bb0.4s, a3.s[2]\n"
+ "ldr b0q, [%[b_ptr]]\n"
- "fmla v17.4s, bb1.4s, a0.s[2]\n"
- "add %[a_ptr0], %[a_ptr0], #32\n"
- "fmla v21.4s, bb1.4s, a1.s[2]\n"
- "add %[a_ptr1], %[a_ptr1], #32\n"
- "fmla v25.4s, bb1.4s, a2.s[2]\n"
- "add %[a_ptr2], %[a_ptr2], #32\n"
- "fmla v29.4s, bb1.4s, a3.s[2]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v17.4s, bb1.4s, a0.s[2]\n"
+ "add %[a_ptr0], %[a_ptr0], #32\n"
+ "fmla v21.4s, bb1.4s, a1.s[2]\n"
+ "add %[a_ptr1], %[a_ptr1], %[a_incr1]\n"
+ "fmla v25.4s, bb1.4s, a2.s[2]\n"
+ "add %[a_ptr2], %[a_ptr2], %[a_incr2]\n"
+ "fmla v29.4s, bb1.4s, a3.s[2]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v18.4s, bb2.4s, a0.s[2]\n"
- "add %[a_ptr3], %[a_ptr3], #32\n"
- "fmla v22.4s, bb2.4s, a1.s[2]\n"
- "fmla v26.4s, bb2.4s, a2.s[2]\n"
- "fmla v30.4s, bb2.4s, a3.s[2]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v18.4s, bb2.4s, a0.s[2]\n"
+ "add %[a_ptr3], %[a_ptr3], %[a_incr3]\n"
+ "fmla v22.4s, bb2.4s, a1.s[2]\n"
+ "fmla v26.4s, bb2.4s, a2.s[2]\n"
+ "fmla v30.4s, bb2.4s, a3.s[2]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v19.4s, bb3.4s, a0.s[2]\n"
- "fmla v23.4s, bb3.4s, a1.s[2]\n"
- "fmla v27.4s, bb3.4s, a2.s[2]\n"
- "fmla v31.4s, bb3.4s, a3.s[2]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v19.4s, bb3.4s, a0.s[2]\n"
+ "fmla v23.4s, bb3.4s, a1.s[2]\n"
+ "fmla v27.4s, bb3.4s, a2.s[2]\n"
+ "fmla v31.4s, bb3.4s, a3.s[2]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 3
- "fmla v16.4s, b0a.4s, a0.s[3]\n"
- "fmla v20.4s, b0a.4s, a1.s[3]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, b0a.4s, a2.s[3]\n"
- "fmla v28.4s, b0a.4s, a3.s[3]\n"
- "ldr b0aq, [%[b_ptr]]\n"
+ "fmla v16.4s, b0a.4s, a0.s[3]\n"
+ "fmla v20.4s, b0a.4s, a1.s[3]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, b0a.4s, a2.s[3]\n"
+ "fmla v28.4s, b0a.4s, a3.s[3]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
- "fmla v17.4s, b1a.4s, a0.s[3]\n"
- "fmla v21.4s, b1a.4s, a1.s[3]\n"
- "fmla v25.4s, b1a.4s, a2.s[3]\n"
- "fmla v29.4s, b1a.4s, a3.s[3]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
+ "fmla v17.4s, b1a.4s, a0.s[3]\n"
+ "fmla v21.4s, b1a.4s, a1.s[3]\n"
+ "fmla v25.4s, b1a.4s, a2.s[3]\n"
+ "fmla v29.4s, b1a.4s, a3.s[3]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
- "fmla v18.4s, b2a.4s, a0.s[3]\n"
- "fmla v22.4s, b2a.4s, a1.s[3]\n"
- "fmla v26.4s, b2a.4s, a2.s[3]\n"
- "fmla v30.4s, b2a.4s, a3.s[3]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
+ "fmla v18.4s, b2a.4s, a0.s[3]\n"
+ "fmla v22.4s, b2a.4s, a1.s[3]\n"
+ "fmla v26.4s, b2a.4s, a2.s[3]\n"
+ "fmla v30.4s, b2a.4s, a3.s[3]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
- "fmla v19.4s, b3a.4s, a0.s[3]\n"
- "fmla v23.4s, b3a.4s, a1.s[3]\n"
- "ldr a0q, [%[a_ptr0]]\n"
- "fmla v27.4s, b3a.4s, a2.s[3]\n"
- "fmla v31.4s, b3a.4s, a3.s[3]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v19.4s, b3a.4s, a0.s[3]\n"
+ "fmla v23.4s, b3a.4s, a1.s[3]\n"
+ "ldr a0q, [%[a_ptr0]]\n"
+ "fmla v27.4s, b3a.4s, a2.s[3]\n"
+ "fmla v31.4s, b3a.4s, a3.s[3]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 4
- "fmla v16.4s, bb0.4s, a0a.s[0]\n"
- "fmla v20.4s, bb0.4s, a1a.s[0]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, bb0.4s, a2a.s[0]\n"
- "fmla v28.4s, bb0.4s, a3a.s[0]\n"
- "ldr b0q, [%[b_ptr]]\n"
+ "fmla v16.4s, bb0.4s, a0a.s[0]\n"
+ "fmla v20.4s, bb0.4s, a1a.s[0]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, bb0.4s, a2a.s[0]\n"
+ "fmla v28.4s, bb0.4s, a3a.s[0]\n"
+ "ldr b0q, [%[b_ptr]]\n"
- "fmla v17.4s, bb1.4s, a0a.s[0]\n"
- "fmla v21.4s, bb1.4s, a1a.s[0]\n"
- "ldr a1q, [%[a_ptr1]]\n"
- "fmla v25.4s, bb1.4s, a2a.s[0]\n"
- "fmla v29.4s, bb1.4s, a3a.s[0]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v17.4s, bb1.4s, a0a.s[0]\n"
+ "fmla v21.4s, bb1.4s, a1a.s[0]\n"
+ "ldr a1q, [%[a_ptr1]]\n"
+ "fmla v25.4s, bb1.4s, a2a.s[0]\n"
+ "fmla v29.4s, bb1.4s, a3a.s[0]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v18.4s, bb2.4s, a0a.s[0]\n"
- "fmla v22.4s, bb2.4s, a1a.s[0]\n"
- "ldr a2q, [%[a_ptr2]]\n"
- "fmla v26.4s, bb2.4s, a2a.s[0]\n"
- "fmla v30.4s, bb2.4s, a3a.s[0]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v18.4s, bb2.4s, a0a.s[0]\n"
+ "fmla v22.4s, bb2.4s, a1a.s[0]\n"
+ "ldr a2q, [%[a_ptr2]]\n"
+ "fmla v26.4s, bb2.4s, a2a.s[0]\n"
+ "fmla v30.4s, bb2.4s, a3a.s[0]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v19.4s, bb3.4s, a0a.s[0]\n"
- "fmla v23.4s, bb3.4s, a1a.s[0]\n"
- "ldr a3q, [%[a_ptr3]]\n"
- "fmla v27.4s, bb3.4s, a2a.s[0]\n"
- "fmla v31.4s, bb3.4s, a3a.s[0]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v19.4s, bb3.4s, a0a.s[0]\n"
+ "fmla v23.4s, bb3.4s, a1a.s[0]\n"
+ "ldr a3q, [%[a_ptr3]]\n"
+ "fmla v27.4s, bb3.4s, a2a.s[0]\n"
+ "fmla v31.4s, bb3.4s, a3a.s[0]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 5
- "fmla v16.4s, b0a.4s, a0a.s[1]\n"
- "fmla v20.4s, b0a.4s, a1a.s[1]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, b0a.4s, a2a.s[1]\n"
- "fmla v28.4s, b0a.4s, a3a.s[1]\n"
- "ldr b0aq, [%[b_ptr]]\n"
+ "fmla v16.4s, b0a.4s, a0a.s[1]\n"
+ "fmla v20.4s, b0a.4s, a1a.s[1]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, b0a.4s, a2a.s[1]\n"
+ "fmla v28.4s, b0a.4s, a3a.s[1]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
- "fmla v17.4s, b1a.4s, a0a.s[1]\n"
- "fmla v21.4s, b1a.4s, a1a.s[1]\n"
- "fmla v25.4s, b1a.4s, a2a.s[1]\n"
- "fmla v29.4s, b1a.4s, a3a.s[1]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
+ "fmla v17.4s, b1a.4s, a0a.s[1]\n"
+ "fmla v21.4s, b1a.4s, a1a.s[1]\n"
+ "fmla v25.4s, b1a.4s, a2a.s[1]\n"
+ "fmla v29.4s, b1a.4s, a3a.s[1]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
- "fmla v18.4s, b2a.4s, a0a.s[1]\n"
- "fmla v22.4s, b2a.4s, a1a.s[1]\n"
- "fmla v26.4s, b2a.4s, a2a.s[1]\n"
- "fmla v30.4s, b2a.4s, a3a.s[1]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
+ "fmla v18.4s, b2a.4s, a0a.s[1]\n"
+ "fmla v22.4s, b2a.4s, a1a.s[1]\n"
+ "fmla v26.4s, b2a.4s, a2a.s[1]\n"
+ "fmla v30.4s, b2a.4s, a3a.s[1]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
- "fmla v19.4s, b3a.4s, a0a.s[1]\n"
- "fmla v23.4s, b3a.4s, a1a.s[1]\n"
- "fmla v27.4s, b3a.4s, a2a.s[1]\n"
- "fmla v31.4s, b3a.4s, a3a.s[1]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v19.4s, b3a.4s, a0a.s[1]\n"
+ "fmla v23.4s, b3a.4s, a1a.s[1]\n"
+ "fmla v27.4s, b3a.4s, a2a.s[1]\n"
+ "fmla v31.4s, b3a.4s, a3a.s[1]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 6
- "fmla v16.4s, bb0.4s, a0a.s[2]\n"
- "fmla v20.4s, bb0.4s, a1a.s[2]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, bb0.4s, a2a.s[2]\n"
- "fmla v28.4s, bb0.4s, a3a.s[2]\n"
- "ldr b0q, [%[b_ptr]]\n"
+ "fmla v16.4s, bb0.4s, a0a.s[2]\n"
+ "fmla v20.4s, bb0.4s, a1a.s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, bb0.4s, a2a.s[2]\n"
+ "fmla v28.4s, bb0.4s, a3a.s[2]\n"
+ "ldr b0q, [%[b_ptr]]\n"
- "fmla v17.4s, bb1.4s, a0a.s[2]\n"
- "fmla v21.4s, bb1.4s, a1a.s[2]\n"
- "fmla v25.4s, bb1.4s, a2a.s[2]\n"
- "fmla v29.4s, bb1.4s, a3a.s[2]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v17.4s, bb1.4s, a0a.s[2]\n"
+ "fmla v21.4s, bb1.4s, a1a.s[2]\n"
+ "fmla v25.4s, bb1.4s, a2a.s[2]\n"
+ "fmla v29.4s, bb1.4s, a3a.s[2]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v18.4s, bb2.4s, a0a.s[2]\n"
- "fmla v22.4s, bb2.4s, a1a.s[2]\n"
- "fmla v26.4s, bb2.4s, a2a.s[2]\n"
- "fmla v30.4s, bb2.4s, a3a.s[2]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v18.4s, bb2.4s, a0a.s[2]\n"
+ "fmla v22.4s, bb2.4s, a1a.s[2]\n"
+ "fmla v26.4s, bb2.4s, a2a.s[2]\n"
+ "fmla v30.4s, bb2.4s, a3a.s[2]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v19.4s, bb3.4s, a0a.s[2]\n"
- "fmla v23.4s, bb3.4s, a1a.s[2]\n"
- "fmla v27.4s, bb3.4s, a2a.s[2]\n"
- "fmla v31.4s, bb3.4s, a3a.s[2]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v19.4s, bb3.4s, a0a.s[2]\n"
+ "fmla v23.4s, bb3.4s, a1a.s[2]\n"
+ "fmla v27.4s, bb3.4s, a2a.s[2]\n"
+ "fmla v31.4s, bb3.4s, a3a.s[2]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 7
- "fmla v16.4s, b0a.4s, a0a.s[3]\n"
- "fmla v20.4s, b0a.4s, a1a.s[3]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, b0a.4s, a2a.s[3]\n"
- "fmla v28.4s, b0a.4s, a3a.s[3]\n"
- "ldr b0aq, [%[b_ptr]]\n"
+ "fmla v16.4s, b0a.4s, a0a.s[3]\n"
+ "fmla v20.4s, b0a.4s, a1a.s[3]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, b0a.4s, a2a.s[3]\n"
+ "fmla v28.4s, b0a.4s, a3a.s[3]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
- "fmla v17.4s, b1a.4s, a0a.s[3]\n"
- "fmla v21.4s, b1a.4s, a1a.s[3]\n"
- "fmla v25.4s, b1a.4s, a2a.s[3]\n"
- "fmla v29.4s, b1a.4s, a3a.s[3]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
+ "fmla v17.4s, b1a.4s, a0a.s[3]\n"
+ "fmla v21.4s, b1a.4s, a1a.s[3]\n"
+ "fmla v25.4s, b1a.4s, a2a.s[3]\n"
+ "fmla v29.4s, b1a.4s, a3a.s[3]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
- "fmla v18.4s, b2a.4s, a0a.s[3]\n"
- "fmla v22.4s, b2a.4s, a1a.s[3]\n"
- "fmla v26.4s, b2a.4s, a2a.s[3]\n"
- "fmla v30.4s, b2a.4s, a3a.s[3]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
+ "fmla v18.4s, b2a.4s, a0a.s[3]\n"
+ "fmla v22.4s, b2a.4s, a1a.s[3]\n"
+ "fmla v26.4s, b2a.4s, a2a.s[3]\n"
+ "fmla v30.4s, b2a.4s, a3a.s[3]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
- "fmla v19.4s, b3a.4s, a0a.s[3]\n"
- "fmla v23.4s, b3a.4s, a1a.s[3]\n"
- "fmla v27.4s, b3a.4s, a2a.s[3]\n"
- "fmla v31.4s, b3a.4s, a3a.s[3]\n"
- "bne 1b\n"
+ "fmla v19.4s, b3a.4s, a0a.s[3]\n"
+ "fmla v23.4s, b3a.4s, a1a.s[3]\n"
+ "fmla v27.4s, b3a.4s, a2a.s[3]\n"
+ "fmla v31.4s, b3a.4s, a3a.s[3]\n"
+ "bne 1b\n"
// Skip to here
"4:\n"
// Detached final iteration
// Unroll 0
- "fmla v16.4s, bb0.4s, a0.s[0]\n"
- "fmla v20.4s, bb0.4s, a1.s[0]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
- "fmla v24.4s, bb0.4s, a2.s[0]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v28.4s, bb0.4s, a3.s[0]\n"
- "ldr b0q, [%[b_ptr]]\n"
+ "fmla v16.4s, bb0.4s, a0.s[0]\n"
+ "fmla v20.4s, bb0.4s, a1.s[0]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v24.4s, bb0.4s, a2.s[0]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v28.4s, bb0.4s, a3.s[0]\n"
+ "ldr b0q, [%[b_ptr]]\n"
- "fmla v17.4s, bb1.4s, a0.s[0]\n"
- "cbnz %w[oddk], 2f\n" // Deal with odd K before we load a0a
- "fmla v21.4s, bb1.4s, a1.s[0]\n"
- "ldr a0aq, [%[a_ptr0], #16]\n"
- "fmla v25.4s, bb1.4s, a2.s[0]\n"
- "fmla v29.4s, bb1.4s, a3.s[0]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v17.4s, bb1.4s, a0.s[0]\n"
+ "cbnz %w[oddk], 2f\n" // Deal with odd K before we load a0a
+ "fmla v21.4s, bb1.4s, a1.s[0]\n"
+ "ldr a0aq, [%[a_ptr0], #16]\n"
+ "fmla v25.4s, bb1.4s, a2.s[0]\n"
+ "fmla v29.4s, bb1.4s, a3.s[0]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v18.4s, bb2.4s, a0.s[0]\n"
- "fmla v22.4s, bb2.4s, a1.s[0]\n"
- "ldr a1aq, [%[a_ptr1], #16]\n"
- "fmla v26.4s, bb2.4s, a2.s[0]\n"
- "fmla v30.4s, bb2.4s, a3.s[0]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v18.4s, bb2.4s, a0.s[0]\n"
+ "fmla v22.4s, bb2.4s, a1.s[0]\n"
+ "ldr a1aq, [%[a_ptr1], #16]\n"
+ "fmla v26.4s, bb2.4s, a2.s[0]\n"
+ "fmla v30.4s, bb2.4s, a3.s[0]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v19.4s, bb3.4s, a0.s[0]\n"
- "fmla v23.4s, bb3.4s, a1.s[0]\n"
- "ldr a2aq, [%[a_ptr2], #16]\n"
- "fmla v27.4s, bb3.4s, a2.s[0]\n"
- "fmla v31.4s, bb3.4s, a3.s[0]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v19.4s, bb3.4s, a0.s[0]\n"
+ "fmla v23.4s, bb3.4s, a1.s[0]\n"
+ "ldr a2aq, [%[a_ptr2], #16]\n"
+ "fmla v27.4s, bb3.4s, a2.s[0]\n"
+ "fmla v31.4s, bb3.4s, a3.s[0]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 1
- "fmla v16.4s, b0a.4s, a0.s[1]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v20.4s, b0a.4s, a1.s[1]\n"
- "ldr a3aq, [%[a_ptr3], #16]\n"
- "fmla v24.4s, b0a.4s, a2.s[1]\n"
- "fmla v28.4s, b0a.4s, a3.s[1]\n"
- "ldr b0aq, [%[b_ptr]]\n"
+ "fmla v16.4s, b0a.4s, a0.s[1]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v20.4s, b0a.4s, a1.s[1]\n"
+ "ldr a3aq, [%[a_ptr3], #16]\n"
+ "fmla v24.4s, b0a.4s, a2.s[1]\n"
+ "fmla v28.4s, b0a.4s, a3.s[1]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
- "fmla v17.4s, b1a.4s, a0.s[1]\n"
- "add %[a_ptr0], %[a_ptr0], #32\n"
- "fmla v21.4s, b1a.4s, a1.s[1]\n"
- "add %[a_ptr1], %[a_ptr1], #32\n"
- "fmla v25.4s, b1a.4s, a2.s[1]\n"
- "add %[a_ptr2], %[a_ptr2], #32\n"
- "fmla v29.4s, b1a.4s, a3.s[1]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
+ "fmla v17.4s, b1a.4s, a0.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #32\n"
+ "fmla v21.4s, b1a.4s, a1.s[1]\n"
+ "add %[a_ptr1], %[a_ptr1], %[a_incr1]\n"
+ "fmla v25.4s, b1a.4s, a2.s[1]\n"
+ "add %[a_ptr2], %[a_ptr2], %[a_incr2]\n"
+ "fmla v29.4s, b1a.4s, a3.s[1]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
- "fmla v18.4s, b2a.4s, a0.s[1]\n"
- "fmla v22.4s, b2a.4s, a1.s[1]\n"
- "add %[a_ptr3], %[a_ptr3], #32\n"
- "fmla v26.4s, b2a.4s, a2.s[1]\n"
- "fmla v30.4s, b2a.4s, a3.s[1]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
+ "fmla v18.4s, b2a.4s, a0.s[1]\n"
+ "fmla v22.4s, b2a.4s, a1.s[1]\n"
+ "add %[a_ptr3], %[a_ptr3], %[a_incr3]\n"
+ "fmla v26.4s, b2a.4s, a2.s[1]\n"
+ "fmla v30.4s, b2a.4s, a3.s[1]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
- "fmla v19.4s, b3a.4s, a0.s[1]\n"
- "fmla v23.4s, b3a.4s, a1.s[1]\n"
- "fmla v27.4s, b3a.4s, a2.s[1]\n"
- "fmla v31.4s, b3a.4s, a3.s[1]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v19.4s, b3a.4s, a0.s[1]\n"
+ "fmla v23.4s, b3a.4s, a1.s[1]\n"
+ "fmla v27.4s, b3a.4s, a2.s[1]\n"
+ "fmla v31.4s, b3a.4s, a3.s[1]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 2
- "fmla v16.4s, bb0.4s, a0.s[2]\n"
- "fmla v20.4s, bb0.4s, a1.s[2]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, bb0.4s, a2.s[2]\n"
- "fmla v28.4s, bb0.4s, a3.s[2]\n"
- "ldr b0q, [%[b_ptr]]\n"
+ "fmla v16.4s, bb0.4s, a0.s[2]\n"
+ "fmla v20.4s, bb0.4s, a1.s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, bb0.4s, a2.s[2]\n"
+ "fmla v28.4s, bb0.4s, a3.s[2]\n"
+ "ldr b0q, [%[b_ptr]]\n"
- "fmla v17.4s, bb1.4s, a0.s[2]\n"
- "fmla v21.4s, bb1.4s, a1.s[2]\n"
- "fmla v25.4s, bb1.4s, a2.s[2]\n"
- "fmla v29.4s, bb1.4s, a3.s[2]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v17.4s, bb1.4s, a0.s[2]\n"
+ "fmla v21.4s, bb1.4s, a1.s[2]\n"
+ "fmla v25.4s, bb1.4s, a2.s[2]\n"
+ "fmla v29.4s, bb1.4s, a3.s[2]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v18.4s, bb2.4s, a0.s[2]\n"
- "fmla v22.4s, bb2.4s, a1.s[2]\n"
- "fmla v26.4s, bb2.4s, a2.s[2]\n"
- "fmla v30.4s, bb2.4s, a3.s[2]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v18.4s, bb2.4s, a0.s[2]\n"
+ "fmla v22.4s, bb2.4s, a1.s[2]\n"
+ "fmla v26.4s, bb2.4s, a2.s[2]\n"
+ "fmla v30.4s, bb2.4s, a3.s[2]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v19.4s, bb3.4s, a0.s[2]\n"
- "fmla v23.4s, bb3.4s, a1.s[2]\n"
- "fmla v27.4s, bb3.4s, a2.s[2]\n"
- "fmla v31.4s, bb3.4s, a3.s[2]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v19.4s, bb3.4s, a0.s[2]\n"
+ "fmla v23.4s, bb3.4s, a1.s[2]\n"
+ "fmla v27.4s, bb3.4s, a2.s[2]\n"
+ "fmla v31.4s, bb3.4s, a3.s[2]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 3
- "fmla v16.4s, b0a.4s, a0.s[3]\n"
- "fmla v20.4s, b0a.4s, a1.s[3]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, b0a.4s, a2.s[3]\n"
- "fmla v28.4s, b0a.4s, a3.s[3]\n"
- "ldr b0aq, [%[b_ptr]]\n"
+ "fmla v16.4s, b0a.4s, a0.s[3]\n"
+ "fmla v20.4s, b0a.4s, a1.s[3]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, b0a.4s, a2.s[3]\n"
+ "fmla v28.4s, b0a.4s, a3.s[3]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
- "fmla v17.4s, b1a.4s, a0.s[3]\n"
- "fmla v21.4s, b1a.4s, a1.s[3]\n"
- "fmla v25.4s, b1a.4s, a2.s[3]\n"
- "fmla v29.4s, b1a.4s, a3.s[3]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
+ "fmla v17.4s, b1a.4s, a0.s[3]\n"
+ "fmla v21.4s, b1a.4s, a1.s[3]\n"
+ "fmla v25.4s, b1a.4s, a2.s[3]\n"
+ "fmla v29.4s, b1a.4s, a3.s[3]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
- "fmla v18.4s, b2a.4s, a0.s[3]\n"
- "fmla v22.4s, b2a.4s, a1.s[3]\n"
- "fmla v26.4s, b2a.4s, a2.s[3]\n"
- "fmla v30.4s, b2a.4s, a3.s[3]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
+ "fmla v18.4s, b2a.4s, a0.s[3]\n"
+ "fmla v22.4s, b2a.4s, a1.s[3]\n"
+ "fmla v26.4s, b2a.4s, a2.s[3]\n"
+ "fmla v30.4s, b2a.4s, a3.s[3]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
- "fmla v19.4s, b3a.4s, a0.s[3]\n"
- "fmla v23.4s, b3a.4s, a1.s[3]\n"
- "fmla v27.4s, b3a.4s, a2.s[3]\n"
- "fmla v31.4s, b3a.4s, a3.s[3]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v19.4s, b3a.4s, a0.s[3]\n"
+ "fmla v23.4s, b3a.4s, a1.s[3]\n"
+ "fmla v27.4s, b3a.4s, a2.s[3]\n"
+ "fmla v31.4s, b3a.4s, a3.s[3]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 4
- "fmla v16.4s, bb0.4s, a0a.s[0]\n"
- "fmla v20.4s, bb0.4s, a1a.s[0]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, bb0.4s, a2a.s[0]\n"
- "fmla v28.4s, bb0.4s, a3a.s[0]\n"
- "ldr b0q, [%[b_ptr]]\n"
+ "fmla v16.4s, bb0.4s, a0a.s[0]\n"
+ "fmla v20.4s, bb0.4s, a1a.s[0]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, bb0.4s, a2a.s[0]\n"
+ "fmla v28.4s, bb0.4s, a3a.s[0]\n"
+ "ldr b0q, [%[b_ptr]]\n"
- "fmla v17.4s, bb1.4s, a0a.s[0]\n"
- "fmla v21.4s, bb1.4s, a1a.s[0]\n"
- "fmla v25.4s, bb1.4s, a2a.s[0]\n"
- "fmla v29.4s, bb1.4s, a3a.s[0]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v17.4s, bb1.4s, a0a.s[0]\n"
+ "fmla v21.4s, bb1.4s, a1a.s[0]\n"
+ "fmla v25.4s, bb1.4s, a2a.s[0]\n"
+ "fmla v29.4s, bb1.4s, a3a.s[0]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v18.4s, bb2.4s, a0a.s[0]\n"
- "fmla v22.4s, bb2.4s, a1a.s[0]\n"
- "fmla v26.4s, bb2.4s, a2a.s[0]\n"
- "fmla v30.4s, bb2.4s, a3a.s[0]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v18.4s, bb2.4s, a0a.s[0]\n"
+ "fmla v22.4s, bb2.4s, a1a.s[0]\n"
+ "fmla v26.4s, bb2.4s, a2a.s[0]\n"
+ "fmla v30.4s, bb2.4s, a3a.s[0]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v19.4s, bb3.4s, a0a.s[0]\n"
- "fmla v23.4s, bb3.4s, a1a.s[0]\n"
- "fmla v27.4s, bb3.4s, a2a.s[0]\n"
- "fmla v31.4s, bb3.4s, a3a.s[0]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v19.4s, bb3.4s, a0a.s[0]\n"
+ "fmla v23.4s, bb3.4s, a1a.s[0]\n"
+ "fmla v27.4s, bb3.4s, a2a.s[0]\n"
+ "fmla v31.4s, bb3.4s, a3a.s[0]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 5
- "fmla v16.4s, b0a.4s, a0a.s[1]\n"
- "fmla v20.4s, b0a.4s, a1a.s[1]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, b0a.4s, a2a.s[1]\n"
- "fmla v28.4s, b0a.4s, a3a.s[1]\n"
- "ldr b0aq, [%[b_ptr]]\n"
+ "fmla v16.4s, b0a.4s, a0a.s[1]\n"
+ "fmla v20.4s, b0a.4s, a1a.s[1]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, b0a.4s, a2a.s[1]\n"
+ "fmla v28.4s, b0a.4s, a3a.s[1]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
- "fmla v17.4s, b1a.4s, a0a.s[1]\n"
- "fmla v21.4s, b1a.4s, a1a.s[1]\n"
- "fmla v25.4s, b1a.4s, a2a.s[1]\n"
- "fmla v29.4s, b1a.4s, a3a.s[1]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
+ "fmla v17.4s, b1a.4s, a0a.s[1]\n"
+ "fmla v21.4s, b1a.4s, a1a.s[1]\n"
+ "fmla v25.4s, b1a.4s, a2a.s[1]\n"
+ "fmla v29.4s, b1a.4s, a3a.s[1]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
- "fmla v18.4s, b2a.4s, a0a.s[1]\n"
- "fmla v22.4s, b2a.4s, a1a.s[1]\n"
- "fmla v26.4s, b2a.4s, a2a.s[1]\n"
- "fmla v30.4s, b2a.4s, a3a.s[1]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
+ "fmla v18.4s, b2a.4s, a0a.s[1]\n"
+ "fmla v22.4s, b2a.4s, a1a.s[1]\n"
+ "fmla v26.4s, b2a.4s, a2a.s[1]\n"
+ "fmla v30.4s, b2a.4s, a3a.s[1]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
- "fmla v19.4s, b3a.4s, a0a.s[1]\n"
- "fmla v23.4s, b3a.4s, a1a.s[1]\n"
- "fmla v27.4s, b3a.4s, a2a.s[1]\n"
- "fmla v31.4s, b3a.4s, a3a.s[1]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v19.4s, b3a.4s, a0a.s[1]\n"
+ "fmla v23.4s, b3a.4s, a1a.s[1]\n"
+ "fmla v27.4s, b3a.4s, a2a.s[1]\n"
+ "fmla v31.4s, b3a.4s, a3a.s[1]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 6
- "fmla v16.4s, bb0.4s, a0a.s[2]\n"
- "fmla v20.4s, bb0.4s, a1a.s[2]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, bb0.4s, a2a.s[2]\n"
- "fmla v28.4s, bb0.4s, a3a.s[2]\n"
+ "fmla v16.4s, bb0.4s, a0a.s[2]\n"
+ "fmla v20.4s, bb0.4s, a1a.s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, bb0.4s, a2a.s[2]\n"
+ "fmla v28.4s, bb0.4s, a3a.s[2]\n"
- "fmla v17.4s, bb1.4s, a0a.s[2]\n"
- "fmla v21.4s, bb1.4s, a1a.s[2]\n"
- "fmla v25.4s, bb1.4s, a2a.s[2]\n"
- "fmla v29.4s, bb1.4s, a3a.s[2]\n"
+ "fmla v17.4s, bb1.4s, a0a.s[2]\n"
+ "fmla v21.4s, bb1.4s, a1a.s[2]\n"
+ "fmla v25.4s, bb1.4s, a2a.s[2]\n"
+ "fmla v29.4s, bb1.4s, a3a.s[2]\n"
- "fmla v18.4s, bb2.4s, a0a.s[2]\n"
- "fmla v22.4s, bb2.4s, a1a.s[2]\n"
- "fmla v26.4s, bb2.4s, a2a.s[2]\n"
- "fmla v30.4s, bb2.4s, a3a.s[2]\n"
+ "fmla v18.4s, bb2.4s, a0a.s[2]\n"
+ "fmla v22.4s, bb2.4s, a1a.s[2]\n"
+ "fmla v26.4s, bb2.4s, a2a.s[2]\n"
+ "fmla v30.4s, bb2.4s, a3a.s[2]\n"
- "fmla v19.4s, bb3.4s, a0a.s[2]\n"
- "fmla v23.4s, bb3.4s, a1a.s[2]\n"
- "fmla v27.4s, bb3.4s, a2a.s[2]\n"
- "fmla v31.4s, bb3.4s, a3a.s[2]\n"
+ "fmla v19.4s, bb3.4s, a0a.s[2]\n"
+ "fmla v23.4s, bb3.4s, a1a.s[2]\n"
+ "fmla v27.4s, bb3.4s, a2a.s[2]\n"
+ "fmla v31.4s, bb3.4s, a3a.s[2]\n"
// Unroll 7
- "fmla v16.4s, b0a.4s, a0a.s[3]\n"
- "fmla v17.4s, b1a.4s, a0a.s[3]\n"
- "fmla v18.4s, b2a.4s, a0a.s[3]\n"
- "fmla v19.4s, b3a.4s, a0a.s[3]\n"
- "cbnz %w[odds], 6f\n"
+ "fmla v16.4s, b0a.4s, a0a.s[3]\n"
+ "fmla v17.4s, b1a.4s, a0a.s[3]\n"
+ "fmla v18.4s, b2a.4s, a0a.s[3]\n"
+ "fmla v19.4s, b3a.4s, a0a.s[3]\n"
+ "cbnz %w[odds], 6f\n"
- "fmla v20.4s, b0a.4s, a1a.s[3]\n"
- "str q16, [%[c_ptr0]]\n"
- "fmla v21.4s, b1a.4s, a1a.s[3]\n"
- "str q17, [%[c_ptr0], #16]\n"
- "fmla v22.4s, b2a.4s, a1a.s[3]\n"
- "str q18, [%[c_ptr0], #32]\n"
- "fmla v23.4s, b3a.4s, a1a.s[3]\n"
- "str q19, [%[c_ptr0], #48]\n"
+ "fmla v20.4s, b0a.4s, a1a.s[3]\n"
+ "str q16, [%[c_ptr0]]\n"
+ "fmla v21.4s, b1a.4s, a1a.s[3]\n"
+ "str q17, [%[c_ptr0], #16]\n"
+ "fmla v22.4s, b2a.4s, a1a.s[3]\n"
+ "str q18, [%[c_ptr0], #32]\n"
+ "fmla v23.4s, b3a.4s, a1a.s[3]\n"
+ "str q19, [%[c_ptr0], #48]\n"
- "fmla v24.4s, b0a.4s, a2a.s[3]\n"
- "str q20, [%[c_ptr1]]\n"
- "fmla v25.4s, b1a.4s, a2a.s[3]\n"
- "str q21, [%[c_ptr1], #16]\n"
- "fmla v26.4s, b2a.4s, a2a.s[3]\n"
- "str q22, [%[c_ptr1], #32]\n"
- "fmla v27.4s, b3a.4s, a2a.s[3]\n"
- "str q23, [%[c_ptr1], #48]\n"
+ "fmla v24.4s, b0a.4s, a2a.s[3]\n"
+ "str q20, [%[c_ptr1]]\n"
+ "fmla v25.4s, b1a.4s, a2a.s[3]\n"
+ "str q21, [%[c_ptr1], #16]\n"
+ "fmla v26.4s, b2a.4s, a2a.s[3]\n"
+ "str q22, [%[c_ptr1], #32]\n"
+ "fmla v27.4s, b3a.4s, a2a.s[3]\n"
+ "str q23, [%[c_ptr1], #48]\n"
- "fmla v28.4s, b0a.4s, a3a.s[3]\n"
- "str q24, [%[c_ptr2]]\n"
- "fmla v29.4s, b1a.4s, a3a.s[3]\n"
- "str q25, [%[c_ptr2], #16]\n"
- "fmla v30.4s, b2a.4s, a3a.s[3]\n"
- "str q26, [%[c_ptr2], #32]\n"
- "fmla v31.4s, b3a.4s, a3a.s[3]\n"
- "str q27, [%[c_ptr2], #48]\n"
- "b 3f\n"
+ "fmla v28.4s, b0a.4s, a3a.s[3]\n"
+ "str q24, [%[c_ptr2]]\n"
+ "fmla v29.4s, b1a.4s, a3a.s[3]\n"
+ "str q25, [%[c_ptr2], #16]\n"
+ "fmla v30.4s, b2a.4s, a3a.s[3]\n"
+ "str q26, [%[c_ptr2], #32]\n"
+ "fmla v31.4s, b3a.4s, a3a.s[3]\n"
+ "str q27, [%[c_ptr2], #48]\n"
+ "b 3f\n"
// Odd K case: Just do 4 more.
"2:\n"
- "fmla v21.4s, bb1.4s, a1.s[0]\n"
- "add %[a_ptr0], %[a_ptr0], #16\n"
- "fmla v25.4s, bb1.4s, a2.s[0]\n"
- "add %[a_ptr1], %[a_ptr1], #16\n"
- "fmla v29.4s, bb1.4s, a3.s[0]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v21.4s, bb1.4s, a1.s[0]\n"
+ "add %[a_ptr0], %[a_ptr0], #16\n"
+ "fmla v25.4s, bb1.4s, a2.s[0]\n"
+ "add %[a_ptr1], %[a_ptr1], #16\n"
+ "fmla v29.4s, bb1.4s, a3.s[0]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v18.4s, bb2.4s, a0.s[0]\n"
- "add %[a_ptr2], %[a_ptr2], #16\n"
- "fmla v22.4s, bb2.4s, a1.s[0]\n"
- "add %[a_ptr3], %[a_ptr3], #16\n"
- "fmla v26.4s, bb2.4s, a2.s[0]\n"
- "fmla v30.4s, bb2.4s, a3.s[0]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v18.4s, bb2.4s, a0.s[0]\n"
+ "add %[a_ptr2], %[a_ptr2], #16\n"
+ "fmla v22.4s, bb2.4s, a1.s[0]\n"
+ "add %[a_ptr3], %[a_ptr3], #16\n"
+ "fmla v26.4s, bb2.4s, a2.s[0]\n"
+ "fmla v30.4s, bb2.4s, a3.s[0]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v19.4s, bb3.4s, a0.s[0]\n"
- "fmla v23.4s, bb3.4s, a1.s[0]\n"
- "fmla v27.4s, bb3.4s, a2.s[0]\n"
- "fmla v31.4s, bb3.4s, a3.s[0]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v19.4s, bb3.4s, a0.s[0]\n"
+ "fmla v23.4s, bb3.4s, a1.s[0]\n"
+ "fmla v27.4s, bb3.4s, a2.s[0]\n"
+ "fmla v31.4s, bb3.4s, a3.s[0]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 1
- "fmla v16.4s, b0a.4s, a0.s[1]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v20.4s, b0a.4s, a1.s[1]\n"
- "fmla v24.4s, b0a.4s, a2.s[1]\n"
- "fmla v28.4s, b0a.4s, a3.s[1]\n"
- "ldr b0aq, [%[b_ptr]]\n"
+ "fmla v16.4s, b0a.4s, a0.s[1]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v20.4s, b0a.4s, a1.s[1]\n"
+ "fmla v24.4s, b0a.4s, a2.s[1]\n"
+ "fmla v28.4s, b0a.4s, a3.s[1]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
- "fmla v17.4s, b1a.4s, a0.s[1]\n"
- "fmla v21.4s, b1a.4s, a1.s[1]\n"
- "fmla v25.4s, b1a.4s, a2.s[1]\n"
- "fmla v29.4s, b1a.4s, a3.s[1]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
+ "fmla v17.4s, b1a.4s, a0.s[1]\n"
+ "fmla v21.4s, b1a.4s, a1.s[1]\n"
+ "fmla v25.4s, b1a.4s, a2.s[1]\n"
+ "fmla v29.4s, b1a.4s, a3.s[1]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
- "fmla v18.4s, b2a.4s, a0.s[1]\n"
- "fmla v22.4s, b2a.4s, a1.s[1]\n"
- "fmla v26.4s, b2a.4s, a2.s[1]\n"
- "fmla v30.4s, b2a.4s, a3.s[1]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
+ "fmla v18.4s, b2a.4s, a0.s[1]\n"
+ "fmla v22.4s, b2a.4s, a1.s[1]\n"
+ "fmla v26.4s, b2a.4s, a2.s[1]\n"
+ "fmla v30.4s, b2a.4s, a3.s[1]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
- "fmla v19.4s, b3a.4s, a0.s[1]\n"
- "fmla v23.4s, b3a.4s, a1.s[1]\n"
- "fmla v27.4s, b3a.4s, a2.s[1]\n"
- "fmla v31.4s, b3a.4s, a3.s[1]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v19.4s, b3a.4s, a0.s[1]\n"
+ "fmla v23.4s, b3a.4s, a1.s[1]\n"
+ "fmla v27.4s, b3a.4s, a2.s[1]\n"
+ "fmla v31.4s, b3a.4s, a3.s[1]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 2
- "fmla v16.4s, bb0.4s, a0.s[2]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v20.4s, bb0.4s, a1.s[2]\n"
- "fmla v24.4s, bb0.4s, a2.s[2]\n"
- "fmla v28.4s, bb0.4s, a3.s[2]\n"
+ "fmla v16.4s, bb0.4s, a0.s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v20.4s, bb0.4s, a1.s[2]\n"
+ "fmla v24.4s, bb0.4s, a2.s[2]\n"
+ "fmla v28.4s, bb0.4s, a3.s[2]\n"
- "fmla v17.4s, bb1.4s, a0.s[2]\n"
- "fmla v21.4s, bb1.4s, a1.s[2]\n"
- "fmla v25.4s, bb1.4s, a2.s[2]\n"
- "fmla v29.4s, bb1.4s, a3.s[2]\n"
+ "fmla v17.4s, bb1.4s, a0.s[2]\n"
+ "fmla v21.4s, bb1.4s, a1.s[2]\n"
+ "fmla v25.4s, bb1.4s, a2.s[2]\n"
+ "fmla v29.4s, bb1.4s, a3.s[2]\n"
- "fmla v18.4s, bb2.4s, a0.s[2]\n"
- "fmla v22.4s, bb2.4s, a1.s[2]\n"
- "fmla v26.4s, bb2.4s, a2.s[2]\n"
- "fmla v30.4s, bb2.4s, a3.s[2]\n"
+ "fmla v18.4s, bb2.4s, a0.s[2]\n"
+ "fmla v22.4s, bb2.4s, a1.s[2]\n"
+ "fmla v26.4s, bb2.4s, a2.s[2]\n"
+ "fmla v30.4s, bb2.4s, a3.s[2]\n"
- "fmla v19.4s, bb3.4s, a0.s[2]\n"
- "fmla v23.4s, bb3.4s, a1.s[2]\n"
- "fmla v27.4s, bb3.4s, a2.s[2]\n"
- "fmla v31.4s, bb3.4s, a3.s[2]\n"
+ "fmla v19.4s, bb3.4s, a0.s[2]\n"
+ "fmla v23.4s, bb3.4s, a1.s[2]\n"
+ "fmla v27.4s, bb3.4s, a2.s[2]\n"
+ "fmla v31.4s, bb3.4s, a3.s[2]\n"
// Unroll 3
- "fmla v16.4s, b0a.4s, a0.s[3]\n"
- "fmla v17.4s, b1a.4s, a0.s[3]\n"
- "fmla v18.4s, b2a.4s, a0.s[3]\n"
- "fmla v19.4s, b3a.4s, a0.s[3]\n"
- "cbnz %w[odds], 7f\n"
+ "fmla v16.4s, b0a.4s, a0.s[3]\n"
+ "fmla v17.4s, b1a.4s, a0.s[3]\n"
+ "fmla v18.4s, b2a.4s, a0.s[3]\n"
+ "fmla v19.4s, b3a.4s, a0.s[3]\n"
+ "cbnz %w[odds], 7f\n"
- "fmla v20.4s, b0a.4s, a1.s[3]\n"
- "str q16, [%[c_ptr0]]\n"
- "fmla v21.4s, b1a.4s, a1.s[3]\n"
- "str q17, [%[c_ptr0], #16]\n"
- "fmla v22.4s, b2a.4s, a1.s[3]\n"
- "str q18, [%[c_ptr0], #32]\n"
- "fmla v23.4s, b3a.4s, a1.s[3]\n"
- "str q19, [%[c_ptr0], #48]\n"
+ "fmla v20.4s, b0a.4s, a1.s[3]\n"
+ "str q16, [%[c_ptr0]]\n"
+ "fmla v21.4s, b1a.4s, a1.s[3]\n"
+ "str q17, [%[c_ptr0], #16]\n"
+ "fmla v22.4s, b2a.4s, a1.s[3]\n"
+ "str q18, [%[c_ptr0], #32]\n"
+ "fmla v23.4s, b3a.4s, a1.s[3]\n"
+ "str q19, [%[c_ptr0], #48]\n"
- "fmla v24.4s, b0a.4s, a2.s[3]\n"
- "str q20, [%[c_ptr1]]\n"
- "fmla v25.4s, b1a.4s, a2.s[3]\n"
- "str q21, [%[c_ptr1], #16]\n"
- "fmla v26.4s, b2a.4s, a2.s[3]\n"
- "str q22, [%[c_ptr1], #32]\n"
- "fmla v27.4s, b3a.4s, a2.s[3]\n"
- "str q23, [%[c_ptr1], #48]\n"
+ "fmla v24.4s, b0a.4s, a2.s[3]\n"
+ "str q20, [%[c_ptr1]]\n"
+ "fmla v25.4s, b1a.4s, a2.s[3]\n"
+ "str q21, [%[c_ptr1], #16]\n"
+ "fmla v26.4s, b2a.4s, a2.s[3]\n"
+ "str q22, [%[c_ptr1], #32]\n"
+ "fmla v27.4s, b3a.4s, a2.s[3]\n"
+ "str q23, [%[c_ptr1], #48]\n"
- "fmla v28.4s, b0a.4s, a3.s[3]\n"
- "str q24, [%[c_ptr2]]\n"
- "fmla v29.4s, b1a.4s, a3.s[3]\n"
- "str q25, [%[c_ptr2], #16]\n"
- "fmla v30.4s, b2a.4s, a3.s[3]\n"
- "str q26, [%[c_ptr2], #32]\n"
- "fmla v31.4s, b3a.4s, a3.s[3]\n"
- "str q27, [%[c_ptr2], #48]\n"
- "b 3f\n"
+ "fmla v28.4s, b0a.4s, a3.s[3]\n"
+ "str q24, [%[c_ptr2]]\n"
+ "fmla v29.4s, b1a.4s, a3.s[3]\n"
+ "str q25, [%[c_ptr2], #16]\n"
+ "fmla v30.4s, b2a.4s, a3.s[3]\n"
+ "str q26, [%[c_ptr2], #32]\n"
+ "fmla v31.4s, b3a.4s, a3.s[3]\n"
+ "str q27, [%[c_ptr2], #48]\n"
+ "b 3f\n"
// "Odd ones" - lead in from even
"6:\n"
- "fmla v20.4s, b0a.4s, a1a.s[3]\n"
- "fmla v21.4s, b1a.4s, a1a.s[3]\n"
- "ldr b0q, [%[b_ptr]]\n"
- "fmla v22.4s, b2a.4s, a1a.s[3]\n"
- "subs %w[odds], %w[odds], #1\n"
- "fmla v23.4s, b3a.4s, a1a.s[3]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v20.4s, b0a.4s, a1a.s[3]\n"
+ "fmla v21.4s, b1a.4s, a1a.s[3]\n"
+ "ldr b0q, [%[b_ptr]]\n"
+ "fmla v22.4s, b2a.4s, a1a.s[3]\n"
+ "subs %w[odds], %w[odds], #1\n"
+ "fmla v23.4s, b3a.4s, a1a.s[3]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v24.4s, b0a.4s, a2a.s[3]\n"
- "fmla v25.4s, b1a.4s, a2a.s[3]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v26.4s, b2a.4s, a2a.s[3]\n"
- "fmla v27.4s, b3a.4s, a2a.s[3]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v24.4s, b0a.4s, a2a.s[3]\n"
+ "fmla v25.4s, b1a.4s, a2a.s[3]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v26.4s, b2a.4s, a2a.s[3]\n"
+ "fmla v27.4s, b3a.4s, a2a.s[3]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
- "fmla v28.4s, b0a.4s, a3a.s[3]\n"
- "ld1r {a0.4s}, [%[a_ptr0]], #4\n"
- "fmla v29.4s, b1a.4s, a3a.s[3]\n"
- "fmla v30.4s, b2a.4s, a3a.s[3]\n"
- "ld1r {a1.4s}, [%[a_ptr1]], #4\n"
- "fmla v31.4s, b3a.4s, a3a.s[3]\n"
+ "fmla v28.4s, b0a.4s, a3a.s[3]\n"
+ "ld1r {a0.4s}, [%[a_ptr0]], #4\n"
+ "fmla v29.4s, b1a.4s, a3a.s[3]\n"
+ "fmla v30.4s, b2a.4s, a3a.s[3]\n"
+ "ld1r {a1.4s}, [%[a_ptr1]], #4\n"
+ "fmla v31.4s, b3a.4s, a3a.s[3]\n"
- "fmla v16.4s, bb0.4s, a0.4s\n"
- "beq 9f\n"
- "b 8f\n"
+ "fmla v16.4s, bb0.4s, a0.4s\n"
+ "beq 9f\n"
+ "b 8f\n"
// "Odd ones" - lead in from odd
"7:\n"
- "fmla v20.4s, b0a.4s, a1.s[3]\n"
- "subs %w[odds], %w[odds], #1\n"
- "fmla v21.4s, b1a.4s, a1.s[3]\n"
- "ldr b0q, [%[b_ptr]]\n"
- "fmla v22.4s, b2a.4s, a1.s[3]\n"
- "fmla v23.4s, b3a.4s, a1.s[3]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v20.4s, b0a.4s, a1.s[3]\n"
+ "subs %w[odds], %w[odds], #1\n"
+ "fmla v21.4s, b1a.4s, a1.s[3]\n"
+ "ldr b0q, [%[b_ptr]]\n"
+ "fmla v22.4s, b2a.4s, a1.s[3]\n"
+ "fmla v23.4s, b3a.4s, a1.s[3]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v24.4s, b0a.4s, a2.s[3]\n"
- "fmla v25.4s, b1a.4s, a2.s[3]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v26.4s, b2a.4s, a2.s[3]\n"
- "fmla v27.4s, b3a.4s, a2.s[3]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v24.4s, b0a.4s, a2.s[3]\n"
+ "fmla v25.4s, b1a.4s, a2.s[3]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v26.4s, b2a.4s, a2.s[3]\n"
+ "fmla v27.4s, b3a.4s, a2.s[3]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
- "fmla v28.4s, b0a.4s, a3.s[3]\n"
- "ld1r {a0.4s}, [%[a_ptr0]], #4\n"
- "fmla v29.4s, b1a.4s, a3.s[3]\n"
- "fmla v30.4s, b2a.4s, a3.s[3]\n"
- "ld1r {a1.4s}, [%[a_ptr1]], #4\n"
- "fmla v31.4s, b3a.4s, a3.s[3]\n"
+ "fmla v28.4s, b0a.4s, a3.s[3]\n"
+ "ld1r {a0.4s}, [%[a_ptr0]], #4\n"
+ "fmla v29.4s, b1a.4s, a3.s[3]\n"
+ "fmla v30.4s, b2a.4s, a3.s[3]\n"
+ "ld1r {a1.4s}, [%[a_ptr1]], #4\n"
+ "fmla v31.4s, b3a.4s, a3.s[3]\n"
- "fmla v16.4s, bb0.4s, a0.4s\n"
- "beq 9f\n"
+ "fmla v16.4s, bb0.4s, a0.4s\n"
+ "beq 9f\n"
// "Odd ones" - loop
"8:\n"
- "fmla v17.4s, bb1.4s, a0.4s\n"
- "ld1r {a2.4s}, [%[a_ptr2]], #4\n"
- "fmla v18.4s, bb2.4s, a0.4s\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v19.4s, bb3.4s, a0.4s\n"
- "ld1r {a3.4s}, [%[a_ptr3]], #4\n"
+ "fmla v17.4s, bb1.4s, a0.4s\n"
+ "ld1r {a2.4s}, [%[a_ptr2]], #4\n"
+ "fmla v18.4s, bb2.4s, a0.4s\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v19.4s, bb3.4s, a0.4s\n"
+ "ld1r {a3.4s}, [%[a_ptr3]], #4\n"
- "fmla v20.4s, bb0.4s, a1.4s\n"
- "subs %w[odds], %w[odds], #1\n"
- "fmla v21.4s, bb1.4s, a1.4s\n"
- "ld1r {a0.4s}, [%[a_ptr0]], #4\n"
- "fmla v22.4s, bb2.4s, a1.4s\n"
- "fmla v23.4s, bb3.4s, a1.4s\n"
- "ld1r {a1.4s}, [%[a_ptr1]], #4\n"
+ "fmla v20.4s, bb0.4s, a1.4s\n"
+ "subs %w[odds], %w[odds], #1\n"
+ "fmla v21.4s, bb1.4s, a1.4s\n"
+ "ld1r {a0.4s}, [%[a_ptr0]], #4\n"
+ "fmla v22.4s, bb2.4s, a1.4s\n"
+ "fmla v23.4s, bb3.4s, a1.4s\n"
+ "ld1r {a1.4s}, [%[a_ptr1]], #4\n"
- "fmla v24.4s, bb0.4s, a2.4s\n"
- "fmla v28.4s, bb0.4s, a3.4s\n"
- "ldr b0q, [%[b_ptr]]\n"
- "fmla v25.4s, bb1.4s, a2.4s\n"
- "fmla v29.4s, bb1.4s, a3.4s\n"
- "ldr b1q, [%[b_ptr], #16]\n"
+ "fmla v24.4s, bb0.4s, a2.4s\n"
+ "fmla v28.4s, bb0.4s, a3.4s\n"
+ "ldr b0q, [%[b_ptr]]\n"
+ "fmla v25.4s, bb1.4s, a2.4s\n"
+ "fmla v29.4s, bb1.4s, a3.4s\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
- "fmla v26.4s, bb2.4s, a2.4s\n"
- "fmla v30.4s, bb2.4s, a3.4s\n"
- "ldr b2q, [%[b_ptr], #32]\n"
- "fmla v27.4s, bb3.4s, a2.4s\n"
- "fmla v31.4s, bb3.4s, a3.4s\n"
- "ldr b3q, [%[b_ptr], #48]\n"
- "fmla v16.4s, bb0.4s, a0.4s\n"
- "bne 8b\n"
+ "fmla v26.4s, bb2.4s, a2.4s\n"
+ "fmla v30.4s, bb2.4s, a3.4s\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v27.4s, bb3.4s, a2.4s\n"
+ "fmla v31.4s, bb3.4s, a3.4s\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v16.4s, bb0.4s, a0.4s\n"
+ "bne 8b\n"
// "Odd ones" - detached final iteration
"9:\n"
- "fmla v17.4s, bb1.4s, a0.4s\n"
- "ld1r {a2.4s}, [%[a_ptr2]], #4\n"
- "fmla v18.4s, bb2.4s, a0.4s\n"
- "fmla v19.4s, bb3.4s, a0.4s\n"
- "ld1r {a3.4s}, [%[a_ptr3]], #4\n"
+ "fmla v17.4s, bb1.4s, a0.4s\n"
+ "ld1r {a2.4s}, [%[a_ptr2]], #4\n"
+ "fmla v18.4s, bb2.4s, a0.4s\n"
+ "fmla v19.4s, bb3.4s, a0.4s\n"
+ "ld1r {a3.4s}, [%[a_ptr3]], #4\n"
- "fmla v20.4s, bb0.4s, a1.4s\n"
- "str q16, [%[c_ptr0]]\n"
- "fmla v21.4s, bb1.4s, a1.4s\n"
- "str q17, [%[c_ptr0], #16]\n"
- "fmla v22.4s, bb2.4s, a1.4s\n"
- "str q18, [%[c_ptr0], #32]\n"
- "fmla v23.4s, bb3.4s, a1.4s\n"
- "str q19, [%[c_ptr0], #48]\n"
+ "fmla v20.4s, bb0.4s, a1.4s\n"
+ "str q16, [%[c_ptr0]]\n"
+ "fmla v21.4s, bb1.4s, a1.4s\n"
+ "str q17, [%[c_ptr0], #16]\n"
+ "fmla v22.4s, bb2.4s, a1.4s\n"
+ "str q18, [%[c_ptr0], #32]\n"
+ "fmla v23.4s, bb3.4s, a1.4s\n"
+ "str q19, [%[c_ptr0], #48]\n"
- "fmla v24.4s, bb0.4s, a2.4s\n"
- "str q20, [%[c_ptr1]]\n"
- "fmla v25.4s, bb1.4s, a2.4s\n"
- "str q21, [%[c_ptr1], #16]\n"
- "fmla v26.4s, bb2.4s, a2.4s\n"
- "str q22, [%[c_ptr1], #32]\n"
- "fmla v27.4s, bb3.4s, a2.4s\n"
- "str q23, [%[c_ptr1], #48]\n"
+ "fmla v24.4s, bb0.4s, a2.4s\n"
+ "str q20, [%[c_ptr1]]\n"
+ "fmla v25.4s, bb1.4s, a2.4s\n"
+ "str q21, [%[c_ptr1], #16]\n"
+ "fmla v26.4s, bb2.4s, a2.4s\n"
+ "str q22, [%[c_ptr1], #32]\n"
+ "fmla v27.4s, bb3.4s, a2.4s\n"
+ "str q23, [%[c_ptr1], #48]\n"
- "fmla v28.4s, bb0.4s, a3.4s\n"
- "str q24, [%[c_ptr2]]\n"
- "fmla v29.4s, bb1.4s, a3.4s\n"
- "str q25, [%[c_ptr2], #16]\n"
- "fmla v30.4s, bb2.4s, a3.4s\n"
- "str q26, [%[c_ptr2], #32]\n"
- "fmla v31.4s, bb3.4s, a3.4s\n"
- "str q27, [%[c_ptr2], #48]\n"
+ "fmla v28.4s, bb0.4s, a3.4s\n"
+ "str q24, [%[c_ptr2]]\n"
+ "fmla v29.4s, bb1.4s, a3.4s\n"
+ "str q25, [%[c_ptr2], #16]\n"
+ "fmla v30.4s, bb2.4s, a3.4s\n"
+ "str q26, [%[c_ptr2], #32]\n"
+ "fmla v31.4s, bb3.4s, a3.4s\n"
+ "str q27, [%[c_ptr2], #48]\n"
"3:\n"
- "str q28, [%[c_ptr3]]\n"
- "str q29, [%[c_ptr3], #16]\n"
- "str q30, [%[c_ptr3], #32]\n"
- "str q31, [%[c_ptr3], #48]\n"
+ "str q28, [%[c_ptr3]]\n"
+ "add %[c_ptr0], %[c_ptr0], #64\n"
+ "str q29, [%[c_ptr3], #16]\n"
+ "add %[c_ptr1], %[c_ptr1], %[a_incr1], LSL #1\n"
+ "str q30, [%[c_ptr3], #32]\n"
+ "add %[c_ptr2], %[c_ptr2], %[a_incr2], LSL #1\n"
+ "str q31, [%[c_ptr3], #48]\n"
+ "add %[c_ptr3], %[c_ptr3], %[a_incr3], LSL #1\n"
- : [a_ptr0] "+r"(a_ptr0), [a_ptr1] "+r"(a_ptr1), [a_ptr2] "+r"(a_ptr2), [a_ptr3] "+r"(a_ptr3),
- [b_ptr] "+r"(b_ptr), [loops] "+r"(loops), [odds] "+r"(odds)
- : [ldb] "r"(ldbb), [oddk] "r"(oddk), [beta0] "r"(beta0), [betaptr] "r"(&beta),
- [c_ptr0] "r"(c_ptr0), [c_ptr1] "r"(c_ptr1), [c_ptr2] "r"(c_ptr2), [c_ptr3] "r"(c_ptr3)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
- "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
- "cc", "memory");
+ : [a_ptr0] "+r" (a_ptr0), [a_ptr1] "+r" (a_ptr1), [a_ptr2] "+r" (a_ptr2), [a_ptr3] "+r" (a_ptr3),
+ [b_ptr] "+r" (b_ptr), [loops] "+r" (loops), [odds] "+r" (odds),
+ [c_ptr0] "+r" (c_ptr0), [c_ptr1] "+r" (c_ptr1), [c_ptr2] "+r" (c_ptr2), [c_ptr3] "+r" (c_ptr3)
+ : [ldb] "r" (ldbb), [oddk] "r" (oddk), [beta0] "r" (beta0), [betaptr] "r" (&beta),
+ [a_incr1] "r" (a_incr1), [a_incr2] "r" (a_incr2), [a_incr3] "r" (a_incr3)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
+ "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
+ "cc", "memory"
+ );
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp
index c89514f..a73bc76 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,14 +25,13 @@
#ifdef __aarch64__
-namespace arm_gemm
-{
+namespace arm_gemm {
+
// Actual kernel implementations
void a64_sgemv_pretransposed(const float *, int, const float *, float *, float, int, int);
// Pretransposed SGEMV strategy class.
-class sgemv_pretransposed
-{
+class sgemv_pretransposed {
public:
typedef float operand_type;
typedef float result_type;
@@ -47,19 +46,17 @@
* terms of this standard arrangement, so if the A matrix is in fact the
* B matrix from a GEMM call, the sense of the transpose needs to be
* reversed. */
- static const int A_interleave = 32;
- static const int A_block = 1;
- static const bool A_transpose = false;
+ static const int A_interleave = 32;
+ static const int A_block = 1;
+ static const bool A_transpose = false;
/* Kernel blocking parameters */
static const int out_width = 32;
- static const int k_unroll = 1;
+ static const int k_unroll = 1;
kern_type kernel = a64_sgemv_pretransposed;
- sgemv_pretransposed(const CPUInfo *ci)
- {
- }
+ sgemv_pretransposed(const CPUInfo *ci) { }
};
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp
index 2907598..165e0a6 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp
@@ -30,15 +30,13 @@
#include "../../asmlib.hpp"
#include "../../utils.hpp"
-namespace arm_gemm
-{
-void a64_sgemv_pretransposed(const float *A, int lda, const float *X, float *Y, float beta, int M, int N)
-{
- const bool beta0 = (beta == 0.0f);
- const bool beta1 = (beta == 1.0f);
+namespace arm_gemm {
- for(int x = 0; x < N; x += 32)
- {
+void a64_sgemv_pretransposed(const float *A, int lda, const float *X, float *Y, float beta, int M, int N) {
+ const bool beta0 = (beta==0.0f);
+ const bool beta1 = (beta==1.0f);
+
+ for (int x=0; x<N; x+=32) {
float *y_ptr = Y + x;
// How many elements are we processing in this loop?
@@ -53,20 +51,16 @@
register float32x4_t r6 asm("v30");
register float32x4_t r7 asm("v31");
- register float32x4_t x0 asm("v0");
+ register float32x4_t x0 asm("v0");
register float32x4_t x0a asm("v1");
const float *x_ptr = X;
- const float *a_ptr = A + ((x / 32) * lda);
+ const float *a_ptr = A + ((x/32) * lda);
- if(beta0)
- {
- r0 = r1 = r2 = r3 = r4 = r5 = r6 = r7 = vdupq_n_f32(0.0f);
- }
- else
- {
- if(l == 32)
- {
+ if (beta0) {
+ r0=r1=r2=r3=r4=r5=r6=r7=vdupq_n_f32(0.0f);
+ } else {
+ if (l==32) {
// Fastest path - load all 8 vectors
r0 = vld1q_f32(y_ptr);
r1 = vld1q_f32(y_ptr + 4);
@@ -76,29 +70,25 @@
r5 = vld1q_f32(y_ptr + 20);
r6 = vld1q_f32(y_ptr + 24);
r7 = vld1q_f32(y_ptr + 28);
- }
- else
- {
+ } else {
// Slow case - leftovers. Note that we don't care about
// out-of-range vectors and lanes as we will throw them away at
// the end.
- int vecs = l / 4; // How many leftover vectors?
- int oddbits = l % 4; // And how many odd single values?
+ int vecs=l/4; // How many leftover vectors?
+ int oddbits=l%4; // And how many odd single values?
- if(oddbits)
- {
+ if (oddbits) {
// Load the outstanding odd values into a vector first
- float32x4_t oddvec = vdupq_n_f32(0.0f); // This does not really need to be initialized, but the compiler has a hard time with that.
- float *oddbase = y_ptr + l - oddbits;
+ float32x4_t oddvec = vdupq_n_f32(0.0f); // This does not really need to be initialized, but the compiler has a hard time with that.
+ float *oddbase = y_ptr + l - oddbits;
- switch(oddbits)
- {
+ switch (oddbits) {
case 3:
oddvec = vld1q_lane_f32(oddbase + 2, oddvec, 2);
- // fall through
+ // fall through
case 2:
oddvec = vld1q_lane_f32(oddbase + 1, oddvec, 1);
- // fall through
+ // fall through
case 1:
oddvec = vld1q_lane_f32(oddbase, oddvec, 0);
break;
@@ -108,116 +98,60 @@
}
// Now load the whole vectors, putting the oddments in when we run out.
- do
- {
- if(vecs == 0)
- {
- r0 = oddvec;
- break;
- }
+ do {
+ if (vecs==0) { r0 = oddvec; break; }
r0 = vld1q_f32(y_ptr);
- if(--vecs == 0)
- {
- r1 = oddvec;
- break;
- }
+ if (--vecs==0) { r1 = oddvec; break; }
r1 = vld1q_f32(y_ptr + 4);
- if(--vecs == 0)
- {
- r2 = oddvec;
- break;
- }
+ if (--vecs==0) { r2 = oddvec; break; }
r2 = vld1q_f32(y_ptr + 8);
- if(--vecs == 0)
- {
- r3 = oddvec;
- break;
- }
+ if (--vecs==0) { r3 = oddvec; break; }
r3 = vld1q_f32(y_ptr + 12);
- if(--vecs == 0)
- {
- r4 = oddvec;
- break;
- }
+ if (--vecs==0) { r4 = oddvec; break; }
r4 = vld1q_f32(y_ptr + 16);
- if(--vecs == 0)
- {
- r5 = oddvec;
- break;
- }
+ if (--vecs==0) { r5 = oddvec; break; }
r5 = vld1q_f32(y_ptr + 20);
- if(--vecs == 0)
- {
- r6 = oddvec;
- break;
- }
+ if (--vecs==0) { r6 = oddvec; break; }
r6 = vld1q_f32(y_ptr + 24);
r7 = oddvec;
- }
- while(0);
- }
- else
- {
+ } while (0);
+ } else {
// Slightly less slow path - just load the whole vectors
- do
- {
+ do {
// It can't be the case that oddbits==0 AND vecs==0 or we wouldn't be here.
- if(vecs == 0)
- {
- UNREACHABLE("Impossible lack of work to do");
- }
+ if (vecs==0) { UNREACHABLE("Impossible lack of work to do"); }
r0 = vld1q_f32(y_ptr);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
r1 = vld1q_f32(y_ptr + 4);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
r2 = vld1q_f32(y_ptr + 8);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
r3 = vld1q_f32(y_ptr + 12);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
r4 = vld1q_f32(y_ptr + 16);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
r5 = vld1q_f32(y_ptr + 20);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
r6 = vld1q_f32(y_ptr + 24);
- }
- while(0);
+ } while (0);
}
}
- if(!beta1)
- {
+ if (!beta1) {
const float32x4_t vb = vdupq_n_f32(beta);
r0 = vmulq_f32(r0, vb);
@@ -231,34 +165,34 @@
}
}
- if(M >= 8)
- {
- int k = (M / 8) - 1;
- x0 = vld1q_f32(x_ptr);
+ if (M>=8) {
+ int k = (M/8)-1;
+ x0 = vld1q_f32(x_ptr);
- __asm __volatile(
- "ldr q2, [%[a_ptr], #0]\n"
- "ldr q3, [%[a_ptr], #16]\n"
- "ldr q4, [%[a_ptr], #32]\n"
- "ldr q5, [%[a_ptr], #48]\n"
- "ldr q6, [%[a_ptr], #64]\n"
- "ldr q7, [%[a_ptr], #80]\n"
- "ldr q8, [%[a_ptr], #96]\n"
- "ldr q9, [%[a_ptr], #112]\n"
- "ldr q10, [%[a_ptr], #128]\n"
- "ldr q11, [%[a_ptr], #144]\n"
- "ldr q12, [%[a_ptr], #160]\n"
- "ldr q13, [%[a_ptr], #176]\n"
- "ldr q14, [%[a_ptr], #192]\n"
- "ldr q15, [%[a_ptr], #208]\n"
- "ldr q16, [%[a_ptr], #224]\n"
- "ldr q17, [%[a_ptr], #240]\n"
- "ldr q18, [%[a_ptr], #256]\n"
- "ldr q19, [%[a_ptr], #272]\n"
- "ldr q20, [%[a_ptr], #288]\n"
- "ldr q21, [%[a_ptr], #304]\n"
- "ldr q22, [%[a_ptr], #320]\n"
- "ldr q23, [%[a_ptr], #336]\n" ASM_PREFETCH("[%[a_ptr], #384]")
+ __asm __volatile (
+ "ldr q2, [%[a_ptr], #0]\n"
+ "ldr q3, [%[a_ptr], #16]\n"
+ "ldr q4, [%[a_ptr], #32]\n"
+ "ldr q5, [%[a_ptr], #48]\n"
+ "ldr q6, [%[a_ptr], #64]\n"
+ "ldr q7, [%[a_ptr], #80]\n"
+ "ldr q8, [%[a_ptr], #96]\n"
+ "ldr q9, [%[a_ptr], #112]\n"
+ "ldr q10, [%[a_ptr], #128]\n"
+ "ldr q11, [%[a_ptr], #144]\n"
+ "ldr q12, [%[a_ptr], #160]\n"
+ "ldr q13, [%[a_ptr], #176]\n"
+ "ldr q14, [%[a_ptr], #192]\n"
+ "ldr q15, [%[a_ptr], #208]\n"
+ "ldr q16, [%[a_ptr], #224]\n"
+ "ldr q17, [%[a_ptr], #240]\n"
+ "ldr q18, [%[a_ptr], #256]\n"
+ "ldr q19, [%[a_ptr], #272]\n"
+ "ldr q20, [%[a_ptr], #288]\n"
+ "ldr q21, [%[a_ptr], #304]\n"
+ "ldr q22, [%[a_ptr], #320]\n"
+ "ldr q23, [%[a_ptr], #336]\n"
+ ASM_PREFETCH("[%[a_ptr], #384]")
ASM_PREFETCH("[%[a_ptr], #448]")
ASM_PREFETCH("[%[a_ptr], #512]")
ASM_PREFETCH("[%[a_ptr], #576]")
@@ -284,363 +218,377 @@
ASM_PREFETCH("[%[a_ptr], #1856]")
ASM_PREFETCH("[%[a_ptr], #1920]")
ASM_PREFETCH("[%[a_ptr], #1984]")
- "add %[a_ptr], %[a_ptr], #352\n"
+ "add %[a_ptr], %[a_ptr], #352\n"
- "cbz %w[k], 2f\n"
+ "cbz %w[k], 2f\n"
"1:\n"
// Unroll 0
- "fmla %[r0].4s, v2.4s, %[x0].s[0]\n"
- "ldr %q[x0a], [%[x_ptr], #16]\n"
- "fmla %[r1].4s, v3.4s, %[x0].s[0]\n"
- "ldr q3, [%[a_ptr], #0]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla %[r2].4s, v4.4s, %[x0].s[0]\n"
- "ldr q4, [%[a_ptr], #16]\n"
- "fmla %[r3].4s, v5.4s, %[x0].s[0]\n"
- "ldr q5, [%[a_ptr], #32]\n"
- "add %[x_ptr], %[x_ptr], #32\n" ASM_PREFETCH("[%[a_ptr], #1664]")
- "fmla %[r4].4s, v6.4s, %[x0].s[0]\n"
- "ldr q6, [%[a_ptr], #48]\n"
- "fmla %[r5].4s, v7.4s, %[x0].s[0]\n"
- "ldr q7, [%[a_ptr], #64]\n"
- "fmla %[r6].4s, v8.4s, %[x0].s[0]\n"
- "ldr q8, [%[a_ptr], #80]\n"
- "fmla %[r7].4s, v9.4s, %[x0].s[0]\n"
- "ldr q9, [%[a_ptr], #96]\n" ASM_PREFETCH("[%[a_ptr], #1728]")
+ "fmla %[r0].4s, v2.4s, %[x0].s[0]\n"
+ "ldr %q[x0a], [%[x_ptr], #16]\n"
+ "fmla %[r1].4s, v3.4s, %[x0].s[0]\n"
+ "ldr q3, [%[a_ptr], #0]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla %[r2].4s, v4.4s, %[x0].s[0]\n"
+ "ldr q4, [%[a_ptr], #16]\n"
+ "fmla %[r3].4s, v5.4s, %[x0].s[0]\n"
+ "ldr q5, [%[a_ptr], #32]\n"
+ "add %[x_ptr], %[x_ptr], #32\n"
+ ASM_PREFETCH("[%[a_ptr], #1664]")
+ "fmla %[r4].4s, v6.4s, %[x0].s[0]\n"
+ "ldr q6, [%[a_ptr], #48]\n"
+ "fmla %[r5].4s, v7.4s, %[x0].s[0]\n"
+ "ldr q7, [%[a_ptr], #64]\n"
+ "fmla %[r6].4s, v8.4s, %[x0].s[0]\n"
+ "ldr q8, [%[a_ptr], #80]\n"
+ "fmla %[r7].4s, v9.4s, %[x0].s[0]\n"
+ "ldr q9, [%[a_ptr], #96]\n"
+ ASM_PREFETCH("[%[a_ptr], #1728]")
// Unroll 1
- "fmla %[r0].4s, v10.4s, %[x0].s[1]\n"
- "ldr q10, [%[a_ptr], #112]\n"
- "fmla %[r1].4s, v11.4s, %[x0].s[1]\n"
- "ldr q11, [%[a_ptr], #128]\n"
- "fmla %[r2].4s, v12.4s, %[x0].s[1]\n"
- "ldr q12, [%[a_ptr], #144]\n"
- "fmla %[r3].4s, v13.4s, %[x0].s[1]\n"
- "ldr q13, [%[a_ptr], #160]\n" ASM_PREFETCH("[%[a_ptr], #1792]")
- "fmla %[r4].4s, v14.4s, %[x0].s[1]\n"
- "ldr q14, [%[a_ptr], #176]\n"
- "fmla %[r5].4s, v15.4s, %[x0].s[1]\n"
- "ldr q15, [%[a_ptr], #192]\n"
- "fmla %[r6].4s, v16.4s, %[x0].s[1]\n"
- "ldr q16, [%[a_ptr], #208]\n"
- "fmla %[r7].4s, v17.4s, %[x0].s[1]\n"
- "ldr q17, [%[a_ptr], #224]\n" ASM_PREFETCH("[%[a_ptr], #1856]")
+ "fmla %[r0].4s, v10.4s, %[x0].s[1]\n"
+ "ldr q10, [%[a_ptr], #112]\n"
+ "fmla %[r1].4s, v11.4s, %[x0].s[1]\n"
+ "ldr q11, [%[a_ptr], #128]\n"
+ "fmla %[r2].4s, v12.4s, %[x0].s[1]\n"
+ "ldr q12, [%[a_ptr], #144]\n"
+ "fmla %[r3].4s, v13.4s, %[x0].s[1]\n"
+ "ldr q13, [%[a_ptr], #160]\n"
+ ASM_PREFETCH("[%[a_ptr], #1792]")
+ "fmla %[r4].4s, v14.4s, %[x0].s[1]\n"
+ "ldr q14, [%[a_ptr], #176]\n"
+ "fmla %[r5].4s, v15.4s, %[x0].s[1]\n"
+ "ldr q15, [%[a_ptr], #192]\n"
+ "fmla %[r6].4s, v16.4s, %[x0].s[1]\n"
+ "ldr q16, [%[a_ptr], #208]\n"
+ "fmla %[r7].4s, v17.4s, %[x0].s[1]\n"
+ "ldr q17, [%[a_ptr], #224]\n"
+ ASM_PREFETCH("[%[a_ptr], #1856]")
// Unroll 2
- "fmla %[r0].4s, v18.4s, %[x0].s[2]\n"
- "ldr q18, [%[a_ptr], #240]\n"
- "fmla %[r1].4s, v19.4s, %[x0].s[2]\n"
- "ldr q19, [%[a_ptr], #256]\n"
- "fmla %[r2].4s, v20.4s, %[x0].s[2]\n"
- "ldr q20, [%[a_ptr], #272]\n"
- "fmla %[r3].4s, v21.4s, %[x0].s[2]\n"
- "ldr q21, [%[a_ptr], #288]\n" ASM_PREFETCH("[%[a_ptr], #1920]")
- "fmla %[r4].4s, v22.4s, %[x0].s[2]\n"
- "ldr q22, [%[a_ptr], #304]\n"
- "fmla %[r5].4s, v23.4s, %[x0].s[2]\n"
- "ldr q23, [%[a_ptr], #320]\n"
- "fmla %[r6].4s, v3.4s, %[x0].s[2]\n"
- "ldr q2, [%[a_ptr], #336]\n"
- "ldr q3, [%[a_ptr], #352]\n"
- "fmla %[r7].4s, v4.4s, %[x0].s[2]\n"
- "ldr q4, [%[a_ptr], #368]\n" ASM_PREFETCH("[%[a_ptr], #1984]")
+ "fmla %[r0].4s, v18.4s, %[x0].s[2]\n"
+ "ldr q18, [%[a_ptr], #240]\n"
+ "fmla %[r1].4s, v19.4s, %[x0].s[2]\n"
+ "ldr q19, [%[a_ptr], #256]\n"
+ "fmla %[r2].4s, v20.4s, %[x0].s[2]\n"
+ "ldr q20, [%[a_ptr], #272]\n"
+ "fmla %[r3].4s, v21.4s, %[x0].s[2]\n"
+ "ldr q21, [%[a_ptr], #288]\n"
+ ASM_PREFETCH("[%[a_ptr], #1920]")
+ "fmla %[r4].4s, v22.4s, %[x0].s[2]\n"
+ "ldr q22, [%[a_ptr], #304]\n"
+ "fmla %[r5].4s, v23.4s, %[x0].s[2]\n"
+ "ldr q23, [%[a_ptr], #320]\n"
+ "fmla %[r6].4s, v3.4s, %[x0].s[2]\n"
+ "ldr q2, [%[a_ptr], #336]\n"
+ "ldr q3, [%[a_ptr], #352]\n"
+ "fmla %[r7].4s, v4.4s, %[x0].s[2]\n"
+ "ldr q4, [%[a_ptr], #368]\n"
+ ASM_PREFETCH("[%[a_ptr], #1984]")
// Unroll 3
- "fmla %[r0].4s, v5.4s, %[x0].s[3]\n"
- "ldr q5, [%[a_ptr], #384]\n"
- "fmla %[r1].4s, v6.4s, %[x0].s[3]\n"
- "ldr q6, [%[a_ptr], #400]\n"
- "fmla %[r2].4s, v7.4s, %[x0].s[3]\n"
- "ldr q7, [%[a_ptr], #416]\n"
- "fmla %[r3].4s, v8.4s, %[x0].s[3]\n" ASM_PREFETCH("[%[a_ptr], #2048]")
- "ldr q8, [%[a_ptr], #432]\n"
- "fmla %[r4].4s, v9.4s, %[x0].s[3]\n"
- "ldr q9, [%[a_ptr], #448]\n"
- "fmla %[r5].4s, v10.4s, %[x0].s[3]\n"
- "ldr q10, [%[a_ptr], #464]\n"
- "fmla %[r6].4s, v11.4s, %[x0].s[3]\n"
- "ldr q11, [%[a_ptr], #480]\n"
- "fmla %[r7].4s, v12.4s, %[x0].s[3]\n"
- "ldr q12, [%[a_ptr], #496]\n" ASM_PREFETCH("[%[a_ptr], #2112]")
+ "fmla %[r0].4s, v5.4s, %[x0].s[3]\n"
+ "ldr q5, [%[a_ptr], #384]\n"
+ "fmla %[r1].4s, v6.4s, %[x0].s[3]\n"
+ "ldr q6, [%[a_ptr], #400]\n"
+ "fmla %[r2].4s, v7.4s, %[x0].s[3]\n"
+ "ldr q7, [%[a_ptr], #416]\n"
+ "fmla %[r3].4s, v8.4s, %[x0].s[3]\n"
+ ASM_PREFETCH("[%[a_ptr], #2048]")
+ "ldr q8, [%[a_ptr], #432]\n"
+ "fmla %[r4].4s, v9.4s, %[x0].s[3]\n"
+ "ldr q9, [%[a_ptr], #448]\n"
+ "fmla %[r5].4s, v10.4s, %[x0].s[3]\n"
+ "ldr q10, [%[a_ptr], #464]\n"
+ "fmla %[r6].4s, v11.4s, %[x0].s[3]\n"
+ "ldr q11, [%[a_ptr], #480]\n"
+ "fmla %[r7].4s, v12.4s, %[x0].s[3]\n"
+ "ldr q12, [%[a_ptr], #496]\n"
+ ASM_PREFETCH("[%[a_ptr], #2112]")
// Unroll 4
- "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n"
- "ldr %q[x0], [%[x_ptr]]\n"
- "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n"
- "ldr q14, [%[a_ptr], #512]\n"
- "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n"
- "ldr q15, [%[a_ptr], #528]\n"
- "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n" ASM_PREFETCH("[%[a_ptr], #2176]")
- "ldr q16, [%[a_ptr], #544]\n"
- "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n"
- "ldr q17, [%[a_ptr], #560]\n"
- "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n"
- "ldr q18, [%[a_ptr], #576]\n"
- "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n"
- "ldr q19, [%[a_ptr], #592]\n"
- "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n"
- "ldr q20, [%[a_ptr], #608]\n" ASM_PREFETCH("[%[a_ptr], #2240]")
+ "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n"
+ "ldr %q[x0], [%[x_ptr]]\n"
+ "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n"
+ "ldr q14, [%[a_ptr], #512]\n"
+ "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n"
+ "ldr q15, [%[a_ptr], #528]\n"
+ "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n"
+ ASM_PREFETCH("[%[a_ptr], #2176]")
+ "ldr q16, [%[a_ptr], #544]\n"
+ "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n"
+ "ldr q17, [%[a_ptr], #560]\n"
+ "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n"
+ "ldr q18, [%[a_ptr], #576]\n"
+ "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n"
+ "ldr q19, [%[a_ptr], #592]\n"
+ "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n"
+ "ldr q20, [%[a_ptr], #608]\n"
+ ASM_PREFETCH("[%[a_ptr], #2240]")
// Unroll 5
- "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n"
- "ldr q21, [%[a_ptr], #624]\n"
- "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n"
- "ldr q22, [%[a_ptr], #640]\n"
- "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n"
- "ldr q23, [%[a_ptr], #656]\n"
- "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n"
- "ldr q2, [%[a_ptr], #672]\n" ASM_PREFETCH("[%[a_ptr], #2304]")
- "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n"
- "ldr q3, [%[a_ptr], #688]\n"
- "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n"
- "ldr q4, [%[a_ptr], #704]\n"
- "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n"
- "ldr q5, [%[a_ptr], #720]\n"
- "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n"
- "ldr q6, [%[a_ptr], #736]\n" ASM_PREFETCH("[%[a_ptr], #2368]")
+ "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n"
+ "ldr q21, [%[a_ptr], #624]\n"
+ "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n"
+ "ldr q22, [%[a_ptr], #640]\n"
+ "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n"
+ "ldr q23, [%[a_ptr], #656]\n"
+ "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n"
+ "ldr q2, [%[a_ptr], #672]\n"
+ ASM_PREFETCH("[%[a_ptr], #2304]")
+ "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n"
+ "ldr q3, [%[a_ptr], #688]\n"
+ "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n"
+ "ldr q4, [%[a_ptr], #704]\n"
+ "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n"
+ "ldr q5, [%[a_ptr], #720]\n"
+ "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n"
+ "ldr q6, [%[a_ptr], #736]\n"
+ ASM_PREFETCH("[%[a_ptr], #2368]")
// Unroll 6
- "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n"
- "ldr q7, [%[a_ptr], #752]\n"
- "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n"
- "ldr q8, [%[a_ptr], #768]\n"
- "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n"
- "ldr q9, [%[a_ptr], #784]\n"
- "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n"
- "ldr q10, [%[a_ptr], #800]\n" ASM_PREFETCH("[%[a_ptr], #2432]")
- "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n"
- "ldr q11, [%[a_ptr], #816]\n"
- "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n"
- "ldr q12, [%[a_ptr], #832]\n"
- "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n"
- "ldr q13, [%[a_ptr], #848]\n"
- "ldr q14, [%[a_ptr], #864]\n"
- "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n"
- "ldr q15, [%[a_ptr], #880]\n" ASM_PREFETCH("[%[a_ptr], #2496]")
+ "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n"
+ "ldr q7, [%[a_ptr], #752]\n"
+ "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n"
+ "ldr q8, [%[a_ptr], #768]\n"
+ "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n"
+ "ldr q9, [%[a_ptr], #784]\n"
+ "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n"
+ "ldr q10, [%[a_ptr], #800]\n"
+ ASM_PREFETCH("[%[a_ptr], #2432]")
+ "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n"
+ "ldr q11, [%[a_ptr], #816]\n"
+ "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n"
+ "ldr q12, [%[a_ptr], #832]\n"
+ "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n"
+ "ldr q13, [%[a_ptr], #848]\n"
+ "ldr q14, [%[a_ptr], #864]\n"
+ "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n"
+ "ldr q15, [%[a_ptr], #880]\n"
+ ASM_PREFETCH("[%[a_ptr], #2496]")
// Unroll 7
- "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n"
- "ldr q16, [%[a_ptr], #896]\n"
- "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n"
- "ldr q17, [%[a_ptr], #912]\n"
- "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n"
- "ldr q18, [%[a_ptr], #928]\n"
- "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n" ASM_PREFETCH("[%[a_ptr], #2560]")
- "ldr q19, [%[a_ptr], #944]\n"
- "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n"
- "ldr q20, [%[a_ptr], #960]\n"
- "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n"
- "ldr q21, [%[a_ptr], #976]\n"
- "add %[a_ptr], %[a_ptr], #1024\n"
- "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n"
- "ldr q22, [%[a_ptr], #-32]\n"
- "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n"
- "ldr q23, [%[a_ptr], #-16]\n" ASM_PREFETCH("[%[a_ptr], #1600]")
- "bne 1b\n"
+ "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n"
+ "ldr q16, [%[a_ptr], #896]\n"
+ "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n"
+ "ldr q17, [%[a_ptr], #912]\n"
+ "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n"
+ "ldr q18, [%[a_ptr], #928]\n"
+ "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n"
+ ASM_PREFETCH("[%[a_ptr], #2560]")
+ "ldr q19, [%[a_ptr], #944]\n"
+ "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n"
+ "ldr q20, [%[a_ptr], #960]\n"
+ "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n"
+ "ldr q21, [%[a_ptr], #976]\n"
+ "add %[a_ptr], %[a_ptr], #1024\n"
+ "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n"
+ "ldr q22, [%[a_ptr], #-32]\n"
+ "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n"
+ "ldr q23, [%[a_ptr], #-16]\n"
+ ASM_PREFETCH("[%[a_ptr], #1600]")
+ "bne 1b\n"
// Detached final iteration
"2:\n"
// Unroll 0
- "fmla %[r0].4s, v2.4s, %[x0].s[0]\n"
- "ldr %q[x0a], [%[x_ptr], #16]\n"
- "fmla %[r1].4s, v3.4s, %[x0].s[0]\n"
- "ldr q3, [%[a_ptr], #0]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla %[r2].4s, v4.4s, %[x0].s[0]\n"
- "ldr q4, [%[a_ptr], #16]\n"
- "fmla %[r3].4s, v5.4s, %[x0].s[0]\n"
- "ldr q5, [%[a_ptr], #32]\n"
- "add %[x_ptr], %[x_ptr], #32\n"
- "fmla %[r4].4s, v6.4s, %[x0].s[0]\n"
- "ldr q6, [%[a_ptr], #48]\n"
- "fmla %[r5].4s, v7.4s, %[x0].s[0]\n"
- "ldr q7, [%[a_ptr], #64]\n"
- "fmla %[r6].4s, v8.4s, %[x0].s[0]\n"
- "ldr q8, [%[a_ptr], #80]\n"
- "fmla %[r7].4s, v9.4s, %[x0].s[0]\n"
- "ldr q9, [%[a_ptr], #96]\n"
+ "fmla %[r0].4s, v2.4s, %[x0].s[0]\n"
+ "ldr %q[x0a], [%[x_ptr], #16]\n"
+ "fmla %[r1].4s, v3.4s, %[x0].s[0]\n"
+ "ldr q3, [%[a_ptr], #0]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla %[r2].4s, v4.4s, %[x0].s[0]\n"
+ "ldr q4, [%[a_ptr], #16]\n"
+ "fmla %[r3].4s, v5.4s, %[x0].s[0]\n"
+ "ldr q5, [%[a_ptr], #32]\n"
+ "add %[x_ptr], %[x_ptr], #32\n"
+ "fmla %[r4].4s, v6.4s, %[x0].s[0]\n"
+ "ldr q6, [%[a_ptr], #48]\n"
+ "fmla %[r5].4s, v7.4s, %[x0].s[0]\n"
+ "ldr q7, [%[a_ptr], #64]\n"
+ "fmla %[r6].4s, v8.4s, %[x0].s[0]\n"
+ "ldr q8, [%[a_ptr], #80]\n"
+ "fmla %[r7].4s, v9.4s, %[x0].s[0]\n"
+ "ldr q9, [%[a_ptr], #96]\n"
// Unroll 1
- "fmla %[r0].4s, v10.4s, %[x0].s[1]\n"
- "ldr q10, [%[a_ptr], #112]\n"
- "fmla %[r1].4s, v11.4s, %[x0].s[1]\n"
- "ldr q11, [%[a_ptr], #128]\n"
- "fmla %[r2].4s, v12.4s, %[x0].s[1]\n"
- "ldr q12, [%[a_ptr], #144]\n"
- "fmla %[r3].4s, v13.4s, %[x0].s[1]\n"
- "ldr q13, [%[a_ptr], #160]\n"
- "fmla %[r4].4s, v14.4s, %[x0].s[1]\n"
- "ldr q14, [%[a_ptr], #176]\n"
- "fmla %[r5].4s, v15.4s, %[x0].s[1]\n"
- "ldr q15, [%[a_ptr], #192]\n"
- "fmla %[r6].4s, v16.4s, %[x0].s[1]\n"
- "ldr q16, [%[a_ptr], #208]\n"
- "fmla %[r7].4s, v17.4s, %[x0].s[1]\n"
- "ldr q17, [%[a_ptr], #224]\n"
+ "fmla %[r0].4s, v10.4s, %[x0].s[1]\n"
+ "ldr q10, [%[a_ptr], #112]\n"
+ "fmla %[r1].4s, v11.4s, %[x0].s[1]\n"
+ "ldr q11, [%[a_ptr], #128]\n"
+ "fmla %[r2].4s, v12.4s, %[x0].s[1]\n"
+ "ldr q12, [%[a_ptr], #144]\n"
+ "fmla %[r3].4s, v13.4s, %[x0].s[1]\n"
+ "ldr q13, [%[a_ptr], #160]\n"
+ "fmla %[r4].4s, v14.4s, %[x0].s[1]\n"
+ "ldr q14, [%[a_ptr], #176]\n"
+ "fmla %[r5].4s, v15.4s, %[x0].s[1]\n"
+ "ldr q15, [%[a_ptr], #192]\n"
+ "fmla %[r6].4s, v16.4s, %[x0].s[1]\n"
+ "ldr q16, [%[a_ptr], #208]\n"
+ "fmla %[r7].4s, v17.4s, %[x0].s[1]\n"
+ "ldr q17, [%[a_ptr], #224]\n"
// Unroll 2
- "fmla %[r0].4s, v18.4s, %[x0].s[2]\n"
- "ldr q18, [%[a_ptr], #240]\n"
- "fmla %[r1].4s, v19.4s, %[x0].s[2]\n"
- "ldr q19, [%[a_ptr], #256]\n"
- "fmla %[r2].4s, v20.4s, %[x0].s[2]\n"
- "ldr q20, [%[a_ptr], #272]\n"
- "fmla %[r3].4s, v21.4s, %[x0].s[2]\n"
- "ldr q21, [%[a_ptr], #288]\n"
- "fmla %[r4].4s, v22.4s, %[x0].s[2]\n"
- "ldr q22, [%[a_ptr], #304]\n"
- "fmla %[r5].4s, v23.4s, %[x0].s[2]\n"
- "ldr q23, [%[a_ptr], #320]\n"
- "fmla %[r6].4s, v3.4s, %[x0].s[2]\n"
- "ldr q2, [%[a_ptr], #336]\n"
- "ldr q3, [%[a_ptr], #352]\n"
- "fmla %[r7].4s, v4.4s, %[x0].s[2]\n"
- "ldr q4, [%[a_ptr], #368]\n"
+ "fmla %[r0].4s, v18.4s, %[x0].s[2]\n"
+ "ldr q18, [%[a_ptr], #240]\n"
+ "fmla %[r1].4s, v19.4s, %[x0].s[2]\n"
+ "ldr q19, [%[a_ptr], #256]\n"
+ "fmla %[r2].4s, v20.4s, %[x0].s[2]\n"
+ "ldr q20, [%[a_ptr], #272]\n"
+ "fmla %[r3].4s, v21.4s, %[x0].s[2]\n"
+ "ldr q21, [%[a_ptr], #288]\n"
+ "fmla %[r4].4s, v22.4s, %[x0].s[2]\n"
+ "ldr q22, [%[a_ptr], #304]\n"
+ "fmla %[r5].4s, v23.4s, %[x0].s[2]\n"
+ "ldr q23, [%[a_ptr], #320]\n"
+ "fmla %[r6].4s, v3.4s, %[x0].s[2]\n"
+ "ldr q2, [%[a_ptr], #336]\n"
+ "ldr q3, [%[a_ptr], #352]\n"
+ "fmla %[r7].4s, v4.4s, %[x0].s[2]\n"
+ "ldr q4, [%[a_ptr], #368]\n"
// Unroll 3
- "fmla %[r0].4s, v5.4s, %[x0].s[3]\n"
- "ldr q5, [%[a_ptr], #384]\n"
- "fmla %[r1].4s, v6.4s, %[x0].s[3]\n"
- "ldr q6, [%[a_ptr], #400]\n"
- "fmla %[r2].4s, v7.4s, %[x0].s[3]\n"
- "ldr q7, [%[a_ptr], #416]\n"
- "fmla %[r3].4s, v8.4s, %[x0].s[3]\n"
- "ldr q8, [%[a_ptr], #432]\n"
- "fmla %[r4].4s, v9.4s, %[x0].s[3]\n"
- "ldr q9, [%[a_ptr], #448]\n"
- "fmla %[r5].4s, v10.4s, %[x0].s[3]\n"
- "ldr q10, [%[a_ptr], #464]\n"
- "fmla %[r6].4s, v11.4s, %[x0].s[3]\n"
- "ldr q11, [%[a_ptr], #480]\n"
- "fmla %[r7].4s, v12.4s, %[x0].s[3]\n"
- "ldr q12, [%[a_ptr], #496]\n"
+ "fmla %[r0].4s, v5.4s, %[x0].s[3]\n"
+ "ldr q5, [%[a_ptr], #384]\n"
+ "fmla %[r1].4s, v6.4s, %[x0].s[3]\n"
+ "ldr q6, [%[a_ptr], #400]\n"
+ "fmla %[r2].4s, v7.4s, %[x0].s[3]\n"
+ "ldr q7, [%[a_ptr], #416]\n"
+ "fmla %[r3].4s, v8.4s, %[x0].s[3]\n"
+ "ldr q8, [%[a_ptr], #432]\n"
+ "fmla %[r4].4s, v9.4s, %[x0].s[3]\n"
+ "ldr q9, [%[a_ptr], #448]\n"
+ "fmla %[r5].4s, v10.4s, %[x0].s[3]\n"
+ "ldr q10, [%[a_ptr], #464]\n"
+ "fmla %[r6].4s, v11.4s, %[x0].s[3]\n"
+ "ldr q11, [%[a_ptr], #480]\n"
+ "fmla %[r7].4s, v12.4s, %[x0].s[3]\n"
+ "ldr q12, [%[a_ptr], #496]\n"
// Unroll 4
- "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n"
- "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n"
- "ldr q14, [%[a_ptr], #512]\n"
- "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n"
- "ldr q15, [%[a_ptr], #528]\n"
- "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n"
- "ldr q16, [%[a_ptr], #544]\n"
- "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n"
- "ldr q17, [%[a_ptr], #560]\n"
- "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n"
- "ldr q18, [%[a_ptr], #576]\n"
- "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n"
- "ldr q19, [%[a_ptr], #592]\n"
- "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n"
- "ldr q20, [%[a_ptr], #608]\n"
+ "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n"
+ "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n"
+ "ldr q14, [%[a_ptr], #512]\n"
+ "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n"
+ "ldr q15, [%[a_ptr], #528]\n"
+ "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n"
+ "ldr q16, [%[a_ptr], #544]\n"
+ "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n"
+ "ldr q17, [%[a_ptr], #560]\n"
+ "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n"
+ "ldr q18, [%[a_ptr], #576]\n"
+ "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n"
+ "ldr q19, [%[a_ptr], #592]\n"
+ "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n"
+ "ldr q20, [%[a_ptr], #608]\n"
// Unroll 5
- "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n"
- "ldr q21, [%[a_ptr], #624]\n"
- "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n"
- "ldr q22, [%[a_ptr], #640]\n"
- "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n"
- "ldr q23, [%[a_ptr], #656]\n"
- "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n"
- "add %[a_ptr], %[a_ptr], #672\n"
- "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n"
- "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n"
- "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n"
- "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n"
+ "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n"
+ "ldr q21, [%[a_ptr], #624]\n"
+ "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n"
+ "ldr q22, [%[a_ptr], #640]\n"
+ "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n"
+ "ldr q23, [%[a_ptr], #656]\n"
+ "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n"
+ "add %[a_ptr], %[a_ptr], #672\n"
+ "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n"
+ "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n"
+ "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n"
+ "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n"
// Unroll 6
- "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n"
- "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n"
- "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n"
- "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n"
- "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n"
- "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n"
- "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n"
- "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n"
+ "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n"
+ "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n"
+ "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n"
+ "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n"
+ "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n"
+ "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n"
+ "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n"
+ "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n"
// Unroll 7
- "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n"
- "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n"
- "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n"
- "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n"
- "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n"
- "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n"
- "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n"
- "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n"
- :
- [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr),
- [x0] "+w"(x0), [x0a] "+w"(x0a), [k] "+r"(k),
- [r0] "+w"(r0), [r1] "+w"(r1), [r2] "+w"(r2), [r3] "+w"(r3),
- [r4] "+w"(r4), [r5] "+w"(r5), [r6] "+w"(r6), [r7] "+w"(r7)
- :
- : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
- "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "cc", "memory");
+ "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n"
+ "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n"
+ "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n"
+ "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n"
+ "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n"
+ "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n"
+ "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n"
+ "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n"
+ :
+ [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr),
+ [x0] "+w" (x0), [x0a] "+w" (x0a), [k] "+r" (k),
+ [r0] "+w" (r0), [r1] "+w" (r1), [r2] "+w" (r2), [r3] "+w" (r3),
+ [r4] "+w" (r4), [r5] "+w" (r5), [r6] "+w" (r6), [r7] "+w" (r7)
+ :
+ : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
+ "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "cc", "memory");
}
// Deal with ragged M
- if(M % 8)
- {
- int l = (M % 8) - 1;
+ if (M % 8) {
+ int l=(M%8)-1;
- __asm __volatile(
- "ldr q2, [%[a_ptr], #0]\n"
- "ldr q3, [%[a_ptr], #16]\n"
- "ldr q4, [%[a_ptr], #32]\n"
- "ldr q5, [%[a_ptr], #48]\n"
- "ldr q6, [%[a_ptr], #64]\n"
- "ldr q7, [%[a_ptr], #80]\n"
- "ldr q8, [%[a_ptr], #96]\n"
- "ldr q9, [%[a_ptr], #112]\n"
- "ldr %s[x0], [%[x_ptr]]\n"
- "add %[a_ptr], %[a_ptr], #128\n"
- "add %[x_ptr], %[x_ptr], #4\n"
+ __asm __volatile (
+ "ldr q2, [%[a_ptr], #0]\n"
+ "ldr q3, [%[a_ptr], #16]\n"
+ "ldr q4, [%[a_ptr], #32]\n"
+ "ldr q5, [%[a_ptr], #48]\n"
+ "ldr q6, [%[a_ptr], #64]\n"
+ "ldr q7, [%[a_ptr], #80]\n"
+ "ldr q8, [%[a_ptr], #96]\n"
+ "ldr q9, [%[a_ptr], #112]\n"
+ "ldr %s[x0], [%[x_ptr]]\n"
+ "add %[a_ptr], %[a_ptr], #128\n"
+ "add %[x_ptr], %[x_ptr], #4\n"
- "cbz %w[l], 2f\n"
+ "cbz %w[l], 2f\n"
"1:\n"
- "fmla %[r0].4s, v2.4s, %[x0].s[0]\n"
- "ldr q2, [%[a_ptr], #0]\n"
- "subs %w[l], %w[l], #1\n"
- "fmla %[r1].4s, v3.4s, %[x0].s[0]\n"
- "ldr q3, [%[a_ptr], #16]\n"
- "fmla %[r2].4s, v4.4s, %[x0].s[0]\n"
- "ldr q4, [%[a_ptr], #32]\n"
- "fmla %[r3].4s, v5.4s, %[x0].s[0]\n"
- "ldr q5, [%[a_ptr], #48]\n"
- "fmla %[r4].4s, v6.4s, %[x0].s[0]\n"
- "ldr q6, [%[a_ptr], #64]\n"
- "fmla %[r5].4s, v7.4s, %[x0].s[0]\n"
- "ldr q7, [%[a_ptr], #80]\n"
- "fmla %[r6].4s, v8.4s, %[x0].s[0]\n"
- "ldr q8, [%[a_ptr], #96]\n"
- "fmla %[r7].4s, v9.4s, %[x0].s[0]\n"
- "ldr q9, [%[a_ptr], #112]\n"
- "ldr %s[x0], [%[x_ptr]]\n"
- "add %[a_ptr], %[a_ptr], #128\n"
- "add %[x_ptr], %[x_ptr], #4\n"
- "bne 1b\n"
+ "fmla %[r0].4s, v2.4s, %[x0].s[0]\n"
+ "ldr q2, [%[a_ptr], #0]\n"
+ "subs %w[l], %w[l], #1\n"
+ "fmla %[r1].4s, v3.4s, %[x0].s[0]\n"
+ "ldr q3, [%[a_ptr], #16]\n"
+ "fmla %[r2].4s, v4.4s, %[x0].s[0]\n"
+ "ldr q4, [%[a_ptr], #32]\n"
+ "fmla %[r3].4s, v5.4s, %[x0].s[0]\n"
+ "ldr q5, [%[a_ptr], #48]\n"
+ "fmla %[r4].4s, v6.4s, %[x0].s[0]\n"
+ "ldr q6, [%[a_ptr], #64]\n"
+ "fmla %[r5].4s, v7.4s, %[x0].s[0]\n"
+ "ldr q7, [%[a_ptr], #80]\n"
+ "fmla %[r6].4s, v8.4s, %[x0].s[0]\n"
+ "ldr q8, [%[a_ptr], #96]\n"
+ "fmla %[r7].4s, v9.4s, %[x0].s[0]\n"
+ "ldr q9, [%[a_ptr], #112]\n"
+ "ldr %s[x0], [%[x_ptr]]\n"
+ "add %[a_ptr], %[a_ptr], #128\n"
+ "add %[x_ptr], %[x_ptr], #4\n"
+ "bne 1b\n"
"2:\n"
- "fmla %[r0].4s, v2.4s, %[x0].s[0]\n"
- "fmla %[r1].4s, v3.4s, %[x0].s[0]\n"
- "fmla %[r2].4s, v4.4s, %[x0].s[0]\n"
- "fmla %[r3].4s, v5.4s, %[x0].s[0]\n"
- "fmla %[r4].4s, v6.4s, %[x0].s[0]\n"
- "fmla %[r5].4s, v7.4s, %[x0].s[0]\n"
- "fmla %[r6].4s, v8.4s, %[x0].s[0]\n"
- "fmla %[r7].4s, v9.4s, %[x0].s[0]\n"
- :
- [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr),
- [x0] "+w"(x0), [l] "+r"(l),
- [r0] "+w"(r0), [r1] "+w"(r1), [r2] "+w"(r2), [r3] "+w"(r3),
- [r4] "+w"(r4), [r5] "+w"(r5), [r6] "+w"(r6), [r7] "+w"(r7)
- :
- : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", "memory");
+ "fmla %[r0].4s, v2.4s, %[x0].s[0]\n"
+ "fmla %[r1].4s, v3.4s, %[x0].s[0]\n"
+ "fmla %[r2].4s, v4.4s, %[x0].s[0]\n"
+ "fmla %[r3].4s, v5.4s, %[x0].s[0]\n"
+ "fmla %[r4].4s, v6.4s, %[x0].s[0]\n"
+ "fmla %[r5].4s, v7.4s, %[x0].s[0]\n"
+ "fmla %[r6].4s, v8.4s, %[x0].s[0]\n"
+ "fmla %[r7].4s, v9.4s, %[x0].s[0]\n"
+ :
+ [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr),
+ [x0] "+w" (x0), [l] "+r" (l),
+ [r0] "+w" (r0), [r1] "+w" (r1), [r2] "+w" (r2), [r3] "+w" (r3),
+ [r4] "+w" (r4), [r5] "+w" (r5), [r6] "+w" (r6), [r7] "+w" (r7)
+ :
+ : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", "memory");
}
- if(l == 32)
- {
+ if (l==32) {
// Fast path
vst1q_f32(y_ptr, r0);
vst1q_f32(y_ptr + 4, r1);
@@ -650,82 +598,48 @@
vst1q_f32(y_ptr + 20, r5);
vst1q_f32(y_ptr + 24, r6);
vst1q_f32(y_ptr + 28, r7);
- }
- else
- {
- int vecs = l / 4;
- int oddbits = l % 4;
+ } else {
+ int vecs=l/4;
+ int oddbits=l%4;
- if(oddbits)
- {
+ if (oddbits) {
// As above - slowest path deals with vectors plus odd bits
float32x4_t oddvec;
- do
- {
- if(vecs == 0)
- {
- oddvec = r0;
- break;
- }
+ do {
+ if (vecs==0) { oddvec=r0; break; }
vst1q_f32(y_ptr, r0);
- if(--vecs == 0)
- {
- oddvec = r1;
- break;
- }
+ if (--vecs==0) { oddvec=r1; break; }
vst1q_f32(y_ptr + 4, r1);
- if(--vecs == 0)
- {
- oddvec = r2;
- break;
- }
+ if (--vecs==0) { oddvec=r2; break; }
vst1q_f32(y_ptr + 8, r2);
- if(--vecs == 0)
- {
- oddvec = r3;
- break;
- }
+ if (--vecs==0) { oddvec=r3; break; }
vst1q_f32(y_ptr + 12, r3);
- if(--vecs == 0)
- {
- oddvec = r4;
- break;
- }
+ if (--vecs==0) { oddvec=r4; break; }
vst1q_f32(y_ptr + 16, r4);
- if(--vecs == 0)
- {
- oddvec = r5;
- break;
- }
+ if (--vecs==0) { oddvec=r5; break; }
vst1q_f32(y_ptr + 20, r5);
- if(--vecs == 0)
- {
- oddvec = r6;
- break;
- }
+ if (--vecs==0) { oddvec=r6; break; }
vst1q_f32(y_ptr + 24, r6);
- oddvec = r7;
- }
- while(0);
+ oddvec=r7;
+ } while (0);
float *oddbase = y_ptr + l - oddbits;
- switch(oddbits)
- {
+ switch(oddbits) {
case 3:
vst1q_lane_f32(oddbase + 2, oddvec, 2);
- // fall through
+ // fall through
case 2:
vst1q_lane_f32(oddbase + 1, oddvec, 1);
- // fall through
+ // fall through
case 1:
vst1q_lane_f32(oddbase, oddvec, 0);
break;
@@ -734,56 +648,31 @@
// oddbits must be 1, 2 or 3.
UNREACHABLE("Impossible case in switch.");
}
- }
- else
- {
+ } else {
// As above - medium path deals with vectors only
- do
- {
- if(vecs == 0)
- {
- UNREACHABLE("vecs and oddbits can't both be 0");
- }
+ do {
+ if (vecs==0) { UNREACHABLE("vecs and oddbits can't both be 0"); }
vst1q_f32(y_ptr, r0);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
vst1q_f32(y_ptr + 4, r1);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
vst1q_f32(y_ptr + 8, r2);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
vst1q_f32(y_ptr + 12, r3);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
vst1q_f32(y_ptr + 16, r4);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
vst1q_f32(y_ptr + 20, r5);
- if(--vecs == 0)
- {
- break;
- }
+ if (--vecs==0) { break; }
vst1q_f32(y_ptr + 24, r6);
- }
- while(0);
+ } while (0);
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp
index 5b9bd72..18c5c3a 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,14 +25,13 @@
#ifdef __aarch64__
-namespace arm_gemm
-{
+namespace arm_gemm {
+
// Actual kernel implementations
void a64_sgemv_trans(const float *, const float *, float *, float, int, int, int);
// Transposed SGEMV strategy class.
-class sgemv_trans
-{
+class sgemv_trans {
public:
typedef float operand_type;
typedef float result_type;
@@ -41,13 +40,11 @@
/* Kernel blocking parameters */
static const int out_width = 96;
- static const int k_unroll = 1;
+ static const int k_unroll = 1;
- kern_type kernel = a64_sgemv_trans;
+ kern_type kernel=a64_sgemv_trans;
- sgemv_trans(const CPUInfo *ci)
- {
- }
+ sgemv_trans(const CPUInfo *ci) { }
};
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp
index 8fa403b..64ef9d8 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp
@@ -42,464 +42,472 @@
// higher performance, but that's left to the outer loop. In this kernel we
// process all of M at the same time.
+
// How far ahead to prefetch for the first and subsequent prefetches.
// These values work for A72 on JunoR2...
#define FIRST_PFD 9
#define PFD 6
-namespace arm_gemm
-{
-void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float beta, int lda, int M, int N)
-{
+namespace arm_gemm {
+
+void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float beta, int lda, int M, int N) {
const float *a_ptr_base = Astart;
- float *y_ptr = Ystart;
+ float *y_ptr = Ystart;
register const float32x4_t vb asm("v1") = vdupq_n_f32(beta);
- int firstpfd = FIRST_PFD;
- if(firstpfd > M)
- {
- firstpfd = (M - 1);
+ int firstpfd=FIRST_PFD;
+ if (firstpfd > M) {
+ firstpfd = (M-1);
}
int pfd = PFD;
- if(pfd > M)
- {
- pfd = (M - 1);
+ if (pfd > M) {
+ pfd = (M-1);
}
ptrdiff_t jump = lda * sizeof(int);
- for(; N >= 96; N -= 96)
- {
- int k = M - 1;
+ for (;N>=96;N-=96) {
+ int k = M-1;
- const float *a_ptr = a_ptr_base;
- const float *x_ptr = Xstart;
- const float *pf_ptr = a_ptr;
+ const float *a_ptr = a_ptr_base;
+ const float *x_ptr = Xstart;
+ const float *pf_ptr = a_ptr;
const float *firstpf_ptr = a_ptr;
- const float *pf_limit = a_ptr + (M * lda);
+ const float *pf_limit = a_ptr + (M * lda);
- for(int i = 0; i < firstpfd; i++)
- {
+ for (int i=0; i<firstpfd; i++) {
prefetch_1x(firstpf_ptr);
firstpf_ptr += lda;
}
- for(int i = 0; i < pfd; i++)
- {
+ for (int i=0; i<pfd; i++) {
prefetch_5x(pf_ptr + 16);
pf_ptr += lda;
}
a_ptr_base += 96;
- __asm __volatile(
- "movi v8.4s,#0x0\n"
- "ldr w0, [%[x_ptr]]\n"
- "movi v9.4s,#0x0\n"
- "ldr q2, [%[a_ptr], #0]\n"
- "movi v10.4s,#0x0\n"
- "ldr q3, [%[a_ptr], #0x10]\n"
- "movi v11.4s,#0x0\n"
- "ldr q4, [%[a_ptr], #0x20]\n"
- "movi v12.4s,#0x0\n"
- "ldr q5, [%[a_ptr], #0x30]\n"
- "movi v13.4s,#0x0\n"
- "ldr q6, [%[a_ptr], #0x40]\n"
- "movi v14.4s,#0x0\n"
- "ldr q7, [%[a_ptr], #0x50]\n"
- "movi v15.4s,#0x0\n" ASM_PREFETCH("[%[firstpf_ptr]]")
- "movi v16.4s, #0x0\n"
- "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #64]")
- "movi v18.4s, #0x0\n"
- "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #128]")
- "movi v20.4s, #0x0\n"
- "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #192]")
- "movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #256]")
- "movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #320]")
- "movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n"
- "add %[pf_ptr], %[pf_ptr], %[jump]\n"
- "movi v28.4s, #0x0\n"
- "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n"
- "movi v29.4s, #0x0\n"
- "movi v30.4s, #0x0\n"
- "movi v31.4s, #0x0\n"
+ __asm __volatile (
+ "movi v8.4s,#0x0\n"
+ "ldr w0, [%[x_ptr]]\n"
+ "movi v9.4s,#0x0\n"
+ "ldr q2, [%[a_ptr], #0]\n"
+ "movi v10.4s,#0x0\n"
+ "ldr q3, [%[a_ptr], #0x10]\n"
+ "movi v11.4s,#0x0\n"
+ "ldr q4, [%[a_ptr], #0x20]\n"
+ "movi v12.4s,#0x0\n"
+ "ldr q5, [%[a_ptr], #0x30]\n"
+ "movi v13.4s,#0x0\n"
+ "ldr q6, [%[a_ptr], #0x40]\n"
+ "movi v14.4s,#0x0\n"
+ "ldr q7, [%[a_ptr], #0x50]\n"
+ "movi v15.4s,#0x0\n"
+ ASM_PREFETCH("[%[firstpf_ptr]]")
+ "movi v16.4s, #0x0\n"
+ "movi v17.4s, #0x0\n"
+ ASM_PREFETCH("[%[pf_ptr], #64]")
+ "movi v18.4s, #0x0\n"
+ "movi v19.4s, #0x0\n"
+ ASM_PREFETCH("[%[pf_ptr], #128]")
+ "movi v20.4s, #0x0\n"
+ "movi v21.4s, #0x0\n"
+ ASM_PREFETCH("[%[pf_ptr], #192]")
+ "movi v22.4s, #0x0\n"
+ "movi v23.4s, #0x0\n"
+ ASM_PREFETCH("[%[pf_ptr], #256]")
+ "movi v24.4s, #0x0\n"
+ "movi v25.4s, #0x0\n"
+ ASM_PREFETCH("[%[pf_ptr], #320]")
+ "movi v26.4s, #0x0\n"
+ "movi v27.4s, #0x0\n"
+ "add %[pf_ptr], %[pf_ptr], %[jump]\n"
+ "movi v28.4s, #0x0\n"
+ "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v31.4s, #0x0\n"
// Skip everything if there are no iterations of the main loop to do.
- "cbz %w[k], 10f\n"
+ "cbz %w[k], 10f\n"
// Loop with all prefetches. Exit this loop when firstpf_ptr
// hits pf_limit.
"1:\n"
- "dup v0.4s, w0\n"
- "ldr w0, [%[x_ptr], #4]\n"
- "add %[x_ptr], %[x_ptr], #0x4\n"
- "fmla v8.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x60]\n"
- "fmla v9.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x70]\n" ASM_PREFETCH("[%[firstpf_ptr]]")
- "fmla v10.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x80]\n"
- "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n"
- "fmla v11.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x90]\n"
- "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]")
- "fmla v12.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0xa0]\n"
- "fmla v13.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]")
- "fmla v14.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0xc0]\n"
- "fmla v15.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0xd0]\n"
- "fmla v16.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0xe0]\n"
- "fmla v17.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]")
- "fmla v18.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0x100]\n"
- "fmla v19.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0x110]\n"
- "fmla v20.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x120]\n"
- "fmla v21.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]")
- "fmla v22.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x140]\n"
- "fmla v23.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x150]\n"
- "fmla v24.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0x160]\n"
- "fmla v25.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]")
- "add %[a_ptr], %[a_ptr], %[jump]\n"
- "fmla v26.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x00]\n"
- "fmla v27.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x10]\n"
- "fmla v28.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x20]\n"
- "fmla v29.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]")
- "fmla v30.4s, v6.4s, v0.4s\n"
- "add %[pf_ptr], %[pf_ptr], %[jump]\n"
- "ldr q6, [%[a_ptr], #0x40]\n"
- "fmla v31.4s, v7.4s, v0.4s\n"
- "cmp %[firstpf_ptr], %[pf_limit]\n"
- "ldr q7, [%[a_ptr], #0x50]\n"
- "blt 1b\n"
+ "dup v0.4s, w0\n"
+ "ldr w0, [%[x_ptr], #4]\n"
+ "add %[x_ptr], %[x_ptr], #0x4\n"
+ "fmla v8.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x60]\n"
+ "fmla v9.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x70]\n"
+ ASM_PREFETCH("[%[firstpf_ptr]]")
+ "fmla v10.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x80]\n"
+ "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n"
+ "fmla v11.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x90]\n"
+ "sub %w[k], %w[k], #1\n"
+ ASM_PREFETCH("[%[x_ptr], #128]")
+ "fmla v12.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0xa0]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0xb0]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x40]")
+ "fmla v14.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0xc0]\n"
+ "fmla v15.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0xd0]\n"
+ "fmla v16.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0xe0]\n"
+ "fmla v17.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0xf0]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x80]")
+ "fmla v18.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x100]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x110]\n"
+ "fmla v20.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x120]\n"
+ "fmla v21.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x130]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0xc0]")
+ "fmla v22.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x140]\n"
+ "fmla v23.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x150]\n"
+ "fmla v24.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x160]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x170]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x100]")
+ "add %[a_ptr], %[a_ptr], %[jump]\n"
+ "fmla v26.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x00]\n"
+ "fmla v27.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x10]\n"
+ "fmla v28.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x20]\n"
+ "fmla v29.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x30]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x140]")
+ "fmla v30.4s, v6.4s, v0.4s\n"
+ "add %[pf_ptr], %[pf_ptr], %[jump]\n"
+ "ldr q6, [%[a_ptr], #0x40]\n"
+ "fmla v31.4s, v7.4s, v0.4s\n"
+ "cmp %[firstpf_ptr], %[pf_limit]\n"
+ "ldr q7, [%[a_ptr], #0x50]\n"
+ "blt 1b\n"
// Check that there are still "main" prefetches to do.
- "cmp %[pf_ptr], %[pf_limit]\n"
- "bge 9f\n"
+ "cmp %[pf_ptr], %[pf_limit]\n"
+ "bge 9f\n"
// Just the main prefetches, exit this loop when pf_ptr hits pf_limit.
"8:\n"
- "dup v0.4s, w0\n"
- "ldr w0, [%[x_ptr], #4]\n"
- "add %[x_ptr], %[x_ptr], #0x4\n"
- "fmla v8.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x60]\n"
- "fmla v9.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x70]\n"
- "fmla v10.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x80]\n"
- "fmla v11.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x90]\n"
- "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]")
- "fmla v12.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0xa0]\n"
- "fmla v13.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]")
- "fmla v14.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0xc0]\n"
- "fmla v15.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0xd0]\n"
- "fmla v16.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0xe0]\n"
- "fmla v17.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]")
- "fmla v18.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0x100]\n"
- "fmla v19.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0x110]\n"
- "fmla v20.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x120]\n"
- "fmla v21.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]")
- "fmla v22.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x140]\n"
- "fmla v23.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x150]\n"
- "fmla v24.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0x160]\n"
- "fmla v25.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]")
- "add %[a_ptr], %[a_ptr], %[jump]\n"
- "fmla v26.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x00]\n"
- "fmla v27.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x10]\n"
- "fmla v28.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x20]\n"
- "fmla v29.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]")
- "fmla v30.4s, v6.4s, v0.4s\n"
- "add %[pf_ptr], %[pf_ptr], %[jump]\n"
- "ldr q6, [%[a_ptr], #0x40]\n"
- "fmla v31.4s, v7.4s, v0.4s\n"
- "cmp %[pf_ptr], %[pf_limit]\n"
- "ldr q7, [%[a_ptr], #0x50]\n"
- "blt 8b\n"
+ "dup v0.4s, w0\n"
+ "ldr w0, [%[x_ptr], #4]\n"
+ "add %[x_ptr], %[x_ptr], #0x4\n"
+ "fmla v8.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x60]\n"
+ "fmla v9.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x70]\n"
+ "fmla v10.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x80]\n"
+ "fmla v11.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x90]\n"
+ "sub %w[k], %w[k], #1\n"
+ ASM_PREFETCH("[%[x_ptr], #128]")
+ "fmla v12.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0xa0]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0xb0]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x40]")
+ "fmla v14.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0xc0]\n"
+ "fmla v15.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0xd0]\n"
+ "fmla v16.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0xe0]\n"
+ "fmla v17.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0xf0]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x80]")
+ "fmla v18.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x100]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x110]\n"
+ "fmla v20.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x120]\n"
+ "fmla v21.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x130]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0xc0]")
+ "fmla v22.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x140]\n"
+ "fmla v23.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x150]\n"
+ "fmla v24.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x160]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x170]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x100]")
+ "add %[a_ptr], %[a_ptr], %[jump]\n"
+ "fmla v26.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x00]\n"
+ "fmla v27.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x10]\n"
+ "fmla v28.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x20]\n"
+ "fmla v29.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x30]\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x140]")
+ "fmla v30.4s, v6.4s, v0.4s\n"
+ "add %[pf_ptr], %[pf_ptr], %[jump]\n"
+ "ldr q6, [%[a_ptr], #0x40]\n"
+ "fmla v31.4s, v7.4s, v0.4s\n"
+ "cmp %[pf_ptr], %[pf_limit]\n"
+ "ldr q7, [%[a_ptr], #0x50]\n"
+ "blt 8b\n"
// Check that there is still work to do.
"9:\n"
- "cmp %w[k], #0\n"
- "beq 10f\n"
+ "cmp %w[k], #0\n"
+ "beq 10f\n"
// Loop without prefetches, exit when k hits 0.
"2:\n"
- "dup v0.4s, w0\n"
- "ldr w0, [%[x_ptr], #4]\n"
- "add %[x_ptr], %[x_ptr], #0x4\n"
- "fmla v8.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x60]\n"
- "fmla v9.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x70]\n"
- "fmla v10.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x80]\n"
- "fmla v11.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x90]\n"
- "subs %w[k], %w[k], #1\n"
- "fmla v12.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0xa0]\n"
- "fmla v13.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0xb0]\n"
- "fmla v14.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0xc0]\n"
- "fmla v15.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0xd0]\n"
- "fmla v16.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0xe0]\n"
- "fmla v17.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0xf0]\n"
- "fmla v18.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0x100]\n"
- "fmla v19.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0x110]\n"
- "fmla v20.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x120]\n"
- "fmla v21.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x130]\n"
- "fmla v22.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x140]\n"
- "fmla v23.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x150]\n"
- "fmla v24.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0x160]\n"
- "fmla v25.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0x170]\n"
- "add %[a_ptr], %[a_ptr], %[jump]\n"
- "fmla v26.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x00]\n"
- "fmla v27.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x10]\n"
- "fmla v28.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x20]\n"
- "fmla v29.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x30]\n"
- "fmla v30.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0x40]\n"
- "fmla v31.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0x50]\n"
- "bne 2b\n"
+ "dup v0.4s, w0\n"
+ "ldr w0, [%[x_ptr], #4]\n"
+ "add %[x_ptr], %[x_ptr], #0x4\n"
+ "fmla v8.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x60]\n"
+ "fmla v9.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x70]\n"
+ "fmla v10.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x80]\n"
+ "fmla v11.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x90]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla v12.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0xa0]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0xb0]\n"
+ "fmla v14.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0xc0]\n"
+ "fmla v15.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0xd0]\n"
+ "fmla v16.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0xe0]\n"
+ "fmla v17.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0xf0]\n"
+ "fmla v18.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x100]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x110]\n"
+ "fmla v20.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x120]\n"
+ "fmla v21.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x130]\n"
+ "fmla v22.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x140]\n"
+ "fmla v23.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x150]\n"
+ "fmla v24.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x160]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x170]\n"
+ "add %[a_ptr], %[a_ptr], %[jump]\n"
+ "fmla v26.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x00]\n"
+ "fmla v27.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x10]\n"
+ "fmla v28.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x20]\n"
+ "fmla v29.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x30]\n"
+ "fmla v30.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x40]\n"
+ "fmla v31.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x50]\n"
+ "bne 2b\n"
"10:\n"
// Final iteration
- "dup v0.4s, w0\n"
- "fmla v8.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x60]\n"
- "fmla v9.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x70]\n"
- "fmla v10.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x80]\n"
- "fmla v11.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x90]\n"
- "fmla v12.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0xa0]\n"
- "fmla v13.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0xb0]\n"
- "fmla v14.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0xc0]\n"
- "fmla v15.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0xd0]\n"
- "fmla v16.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0xe0]\n"
- "fmla v17.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0xf0]\n"
- "fmla v18.4s, v6.4s, v0.4s\n"
+ "dup v0.4s, w0\n"
+ "fmla v8.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x60]\n"
+ "fmla v9.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x70]\n"
+ "fmla v10.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x80]\n"
+ "fmla v11.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x90]\n"
+ "fmla v12.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0xa0]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0xb0]\n"
+ "fmla v14.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0xc0]\n"
+ "fmla v15.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0xd0]\n"
+ "fmla v16.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0xe0]\n"
+ "fmla v17.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0xf0]\n"
+ "fmla v18.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0x100]\n"
- "fmla v19.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0x110]\n"
- "fmla v20.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[a_ptr], #0x120]\n"
- "fmla v21.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[a_ptr], #0x130]\n"
- "fmla v22.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[a_ptr], #0x140]\n"
- "fmla v23.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[a_ptr], #0x150]\n"
- "fmla v24.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[a_ptr], #0x160]\n"
- "fmla v25.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[a_ptr], #0x170]\n"
- "fmla v26.4s, v2.4s, v0.4s\n"
- "ldr q2, [%[y_ptr]]\n"
- "fmla v27.4s, v3.4s, v0.4s\n"
- "ldr q3, [%[y_ptr], #0x10]\n"
- "fmla v28.4s, v4.4s, v0.4s\n"
- "ldr q4, [%[y_ptr], #0x20]\n"
- "fmla v29.4s, v5.4s, v0.4s\n"
- "ldr q5, [%[y_ptr], #0x30]\n"
- "fmla v30.4s, v6.4s, v0.4s\n"
- "ldr q6, [%[y_ptr], #0x40]\n"
- "fmla v31.4s, v7.4s, v0.4s\n"
- "ldr q7, [%[y_ptr], #0x50]\n"
+ "ldr q6, [%[a_ptr], #0x100]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x110]\n"
+ "fmla v20.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x120]\n"
+ "fmla v21.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x130]\n"
+ "fmla v22.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x140]\n"
+ "fmla v23.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x150]\n"
+ "fmla v24.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x160]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x170]\n"
+ "fmla v26.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[y_ptr]]\n"
+ "fmla v27.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[y_ptr], #0x10]\n"
+ "fmla v28.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[y_ptr], #0x20]\n"
+ "fmla v29.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[y_ptr], #0x30]\n"
+ "fmla v30.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[y_ptr], #0x40]\n"
+ "fmla v31.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[y_ptr], #0x50]\n"
- "fmla v8.4s, v2.4s, %[vb].4s\n"
- "ldr q2, [%[y_ptr], #0x60]\n"
- "fmla v9.4s, v3.4s, %[vb].4s\n"
- "ldr q3, [%[y_ptr], #0x70]\n"
- "fmla v10.4s, v4.4s, %[vb].4s\n"
- "ldr q4, [%[y_ptr], #0x80]\n"
- "fmla v11.4s, v5.4s, %[vb].4s\n"
- "ldr q5, [%[y_ptr], #0x90]\n"
- "fmla v12.4s, v6.4s, %[vb].4s\n"
- "ldr q6, [%[y_ptr], #0xa0]\n"
- "str q8, [%[y_ptr], #0x00]\n"
- "fmla v13.4s, v7.4s, %[vb].4s\n"
- "ldr q7, [%[y_ptr], #0xb0]\n"
- "str q9, [%[y_ptr], #0x10]\n"
- "fmla v14.4s, v2.4s, %[vb].4s\n"
- "ldr q2, [%[y_ptr], #0xc0]\n"
- "str q10, [%[y_ptr], #0x20]\n"
- "fmla v15.4s, v3.4s, %[vb].4s\n"
- "ldr q3, [%[y_ptr], #0xd0]\n"
- "str q11, [%[y_ptr], #0x30]\n"
- "fmla v16.4s, v4.4s, %[vb].4s\n"
- "ldr q4, [%[y_ptr], #0xe0]\n"
- "str q12, [%[y_ptr], #0x40]\n"
- "fmla v17.4s, v5.4s, %[vb].4s\n"
- "ldr q5, [%[y_ptr], #0xf0]\n"
- "str q13, [%[y_ptr], #0x50]\n"
- "fmla v18.4s, v6.4s, %[vb].4s\n"
- "ldr q6, [%[y_ptr], #0x100]\n"
- "str q14, [%[y_ptr], #0x60]\n"
- "fmla v19.4s, v7.4s, %[vb].4s\n"
- "ldr q7, [%[y_ptr], #0x110]\n"
- "str q15, [%[y_ptr], #0x70]\n"
- "fmla v20.4s, v2.4s, %[vb].4s\n"
- "ldr q2, [%[y_ptr], #0x120]\n"
- "str q16, [%[y_ptr], #0x80]\n"
- "fmla v21.4s, v3.4s, %[vb].4s\n"
- "ldr q3, [%[y_ptr], #0x130]\n"
- "str q17, [%[y_ptr], #0x90]\n"
- "fmla v22.4s, v4.4s, %[vb].4s\n"
- "ldr q4, [%[y_ptr], #0x140]\n"
- "str q18, [%[y_ptr], #0xa0]\n"
- "fmla v23.4s, v5.4s, %[vb].4s\n"
- "ldr q5, [%[y_ptr], #0x150]\n"
- "str q19, [%[y_ptr], #0xb0]\n"
- "fmla v24.4s, v6.4s, %[vb].4s\n"
- "ldr q6, [%[y_ptr], #0x160]\n"
- "str q20, [%[y_ptr], #0xc0]\n"
- "fmla v25.4s, v7.4s, %[vb].4s\n"
- "ldr q7, [%[y_ptr], #0x170]\n"
- "str q21, [%[y_ptr], #0xd0]\n"
- "fmla v26.4s, v2.4s, %[vb].4s\n"
- "str q22, [%[y_ptr], #0xe0]\n"
- "fmla v27.4s, v3.4s, %[vb].4s\n"
- "str q23, [%[y_ptr], #0xf0]\n"
- "fmla v28.4s, v4.4s, %[vb].4s\n"
- "str q24, [%[y_ptr], #0x100]\n"
- "fmla v29.4s, v5.4s, %[vb].4s\n"
- "str q25, [%[y_ptr], #0x110]\n"
- "fmla v30.4s, v6.4s, %[vb].4s\n"
- "str q26, [%[y_ptr], #0x120]\n"
- "fmla v31.4s, v7.4s, %[vb].4s\n"
- "str q27, [%[y_ptr], #0x130]\n"
+ "fmla v8.4s, v2.4s, %[vb].4s\n"
+ "ldr q2, [%[y_ptr], #0x60]\n"
+ "fmla v9.4s, v3.4s, %[vb].4s\n"
+ "ldr q3, [%[y_ptr], #0x70]\n"
+ "fmla v10.4s, v4.4s, %[vb].4s\n"
+ "ldr q4, [%[y_ptr], #0x80]\n"
+ "fmla v11.4s, v5.4s, %[vb].4s\n"
+ "ldr q5, [%[y_ptr], #0x90]\n"
+ "fmla v12.4s, v6.4s, %[vb].4s\n"
+ "ldr q6, [%[y_ptr], #0xa0]\n"
+ "str q8, [%[y_ptr], #0x00]\n"
+ "fmla v13.4s, v7.4s, %[vb].4s\n"
+ "ldr q7, [%[y_ptr], #0xb0]\n"
+ "str q9, [%[y_ptr], #0x10]\n"
+ "fmla v14.4s, v2.4s, %[vb].4s\n"
+ "ldr q2, [%[y_ptr], #0xc0]\n"
+ "str q10, [%[y_ptr], #0x20]\n"
+ "fmla v15.4s, v3.4s, %[vb].4s\n"
+ "ldr q3, [%[y_ptr], #0xd0]\n"
+ "str q11, [%[y_ptr], #0x30]\n"
+ "fmla v16.4s, v4.4s, %[vb].4s\n"
+ "ldr q4, [%[y_ptr], #0xe0]\n"
+ "str q12, [%[y_ptr], #0x40]\n"
+ "fmla v17.4s, v5.4s, %[vb].4s\n"
+ "ldr q5, [%[y_ptr], #0xf0]\n"
+ "str q13, [%[y_ptr], #0x50]\n"
+ "fmla v18.4s, v6.4s, %[vb].4s\n"
+ "ldr q6, [%[y_ptr], #0x100]\n"
+ "str q14, [%[y_ptr], #0x60]\n"
+ "fmla v19.4s, v7.4s, %[vb].4s\n"
+ "ldr q7, [%[y_ptr], #0x110]\n"
+ "str q15, [%[y_ptr], #0x70]\n"
+ "fmla v20.4s, v2.4s, %[vb].4s\n"
+ "ldr q2, [%[y_ptr], #0x120]\n"
+ "str q16, [%[y_ptr], #0x80]\n"
+ "fmla v21.4s, v3.4s, %[vb].4s\n"
+ "ldr q3, [%[y_ptr], #0x130]\n"
+ "str q17, [%[y_ptr], #0x90]\n"
+ "fmla v22.4s, v4.4s, %[vb].4s\n"
+ "ldr q4, [%[y_ptr], #0x140]\n"
+ "str q18, [%[y_ptr], #0xa0]\n"
+ "fmla v23.4s, v5.4s, %[vb].4s\n"
+ "ldr q5, [%[y_ptr], #0x150]\n"
+ "str q19, [%[y_ptr], #0xb0]\n"
+ "fmla v24.4s, v6.4s, %[vb].4s\n"
+ "ldr q6, [%[y_ptr], #0x160]\n"
+ "str q20, [%[y_ptr], #0xc0]\n"
+ "fmla v25.4s, v7.4s, %[vb].4s\n"
+ "ldr q7, [%[y_ptr], #0x170]\n"
+ "str q21, [%[y_ptr], #0xd0]\n"
+ "fmla v26.4s, v2.4s, %[vb].4s\n"
+ "str q22, [%[y_ptr], #0xe0]\n"
+ "fmla v27.4s, v3.4s, %[vb].4s\n"
+ "str q23, [%[y_ptr], #0xf0]\n"
+ "fmla v28.4s, v4.4s, %[vb].4s\n"
+ "str q24, [%[y_ptr], #0x100]\n"
+ "fmla v29.4s, v5.4s, %[vb].4s\n"
+ "str q25, [%[y_ptr], #0x110]\n"
+ "fmla v30.4s, v6.4s, %[vb].4s\n"
+ "str q26, [%[y_ptr], #0x120]\n"
+ "fmla v31.4s, v7.4s, %[vb].4s\n"
+ "str q27, [%[y_ptr], #0x130]\n"
- "stp q28, q29, [%[y_ptr], #0x140]\n"
- "stp q30, q31, [%[y_ptr], #0x160]\n"
- "add %[y_ptr], %[y_ptr], #0x180\n"
+ "stp q28, q29, [%[y_ptr], #0x140]\n"
+ "stp q30, q31, [%[y_ptr], #0x160]\n"
+ "add %[y_ptr], %[y_ptr], #0x180\n"
- : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr)
- : [jump] "r"(jump), [vb] "w"(vb), [pf_limit] "r"(pf_limit)
- : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
+ : [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr), [y_ptr] "+r" (y_ptr), [k] "+r" (k), [pf_ptr] "+r" (pf_ptr), [firstpf_ptr] "+r" (firstpf_ptr)
+ : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit)
+ : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
- "v27", "v28", "v29", "v30", "v31", "cc");
+ "v27", "v28", "v29", "v30", "v31", "cc"
+ );
}
- if(N > 0)
- {
+ if (N>0) {
// Handle N tail - up to 95 stragglers.
// This is 0-23 vectors, plus optionally an 64-bit vector and/or a
// single value for the remainder.
// Independent pointers into the matrix for the odd 2 and odd 1.
// Double up as flag to indicate whether they are needed.
- const float *odd2_aptr = NULL;
- const float *odd1_aptr = NULL;
+ const float *odd2_aptr=NULL;
+ const float *odd1_aptr=NULL;
// Figure out how much work we need to do.
- int numvecs = N / 4;
- int rem = N % 4;
- int k = M;
+ int numvecs = N/4;
+ int rem = N%4;
+ int k=M;
// Set up pointers for the odd 2/1 if needed.
- if(rem >= 2)
- {
+ if (rem >= 2) {
odd2_aptr = a_ptr_base + (numvecs * 4);
}
- if(rem & 1)
- {
- odd1_aptr = a_ptr_base + (numvecs * 4) + (odd2_aptr == NULL ? 0 : 2);
+ if (rem & 1) {
+ odd1_aptr = a_ptr_base + (numvecs * 4) + (odd2_aptr==NULL ? 0 : 2);
}
- const float *a_ptr = a_ptr_base;
+ const float *a_ptr = a_ptr_base;
const float *firstpf_ptr = a_ptr_base;
- const float *pf_ptr = a_ptr_base;
- const float *pf_limit = a_ptr + (M * lda);
+ const float *pf_ptr = a_ptr_base;
+ const float *pf_limit = a_ptr + (M * lda);
const float *x_ptr = Xstart;
- int vecs = 0; // Working variable to count how many vectors to work on.
- int dopf = 1; // Track whether we are doing prefetches.
+ int vecs=0; // Working variable to count how many vectors to work on.
+ int dopf=1; // Track whether we are doing prefetches.
// Figure out how many cache lines we need to prefetch each time.
int numpfs = (N + 15) / 16;
// Do initial prefetches
- for(int i = 0; i < firstpfd + 1; i++)
- {
+ for (int i=0; i<firstpfd+1; i++) {
prefetch_1x(firstpf_ptr);
firstpf_ptr += lda;
}
// Do "main" prefetches - adapt number to the number we actually need.
- if(numpfs > 1)
- {
- for(int i = 0; i < pfd + 1; i++)
- {
- switch(numpfs)
- {
+ if (numpfs > 1) {
+ for (int i=0; i<pfd+1; i++) {
+ switch (numpfs) {
case 2:
prefetch_1x(pf_ptr + 16);
break;
@@ -525,387 +533,392 @@
}
pf_ptr += lda;
}
- }
- else
- {
+ } else {
// Just disable additional prefetches
- dopf = 0;
+ dopf=0;
}
// Do the real work
- __asm __volatile(
+ __asm __volatile (
// Initialize all the vectors - not worth skipping this if only
// some are needed.
- "movi v8.4s,#0x0\n"
- "ldr w0, [%[x_ptr]]\n"
- "movi v9.4s,#0x0\n"
- "movi v10.4s,#0x0\n"
- "movi v11.4s,#0x0\n"
- "movi v12.4s,#0x0\n"
- "movi v13.4s,#0x0\n"
- "movi v14.4s,#0x0\n"
- "movi v15.4s,#0x0\n"
- "movi v16.4s, #0x0\n"
- "movi v17.4s, #0x0\n"
- "movi v18.4s, #0x0\n"
- "movi v19.4s, #0x0\n"
- "movi v20.4s, #0x0\n"
- "movi v21.4s, #0x0\n"
- "movi v22.4s, #0x0\n"
- "movi v23.4s, #0x0\n"
- "movi v24.4s, #0x0\n"
- "movi v25.4s, #0x0\n"
- "movi v26.4s, #0x0\n"
- "movi v27.4s, #0x0\n"
- "movi v28.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- "movi v30.4s, #0x0\n"
- "movi v6.2s, #0x0\n"
- "movi v5.2s, #0x0\n"
+ "movi v8.4s,#0x0\n"
+ "ldr w0, [%[x_ptr]]\n"
+ "movi v9.4s,#0x0\n"
+ "movi v10.4s,#0x0\n"
+ "movi v11.4s,#0x0\n"
+ "movi v12.4s,#0x0\n"
+ "movi v13.4s,#0x0\n"
+ "movi v14.4s,#0x0\n"
+ "movi v15.4s,#0x0\n"
+ "movi v16.4s, #0x0\n"
+ "movi v17.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ "movi v19.4s, #0x0\n"
+ "movi v20.4s, #0x0\n"
+ "movi v21.4s, #0x0\n"
+ "movi v22.4s, #0x0\n"
+ "movi v23.4s, #0x0\n"
+ "movi v24.4s, #0x0\n"
+ "movi v25.4s, #0x0\n"
+ "movi v26.4s, #0x0\n"
+ "movi v27.4s, #0x0\n"
+ "movi v28.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v6.2s, #0x0\n"
+ "movi v5.2s, #0x0\n"
- "1:\n" ASM_PREFETCH("[%[firstpf_ptr]]\n")
+ "1:\n"
+ ASM_PREFETCH("[%[firstpf_ptr]]\n")
"11:\n"
- "dup v0.4s, w0\n"
- "ldr w0, [%[x_ptr], #4]\n"
- "add %[x_ptr], %[x_ptr], #4\n"
+ "dup v0.4s, w0\n"
+ "ldr w0, [%[x_ptr], #4]\n"
+ "add %[x_ptr], %[x_ptr], #4\n"
- "cbz %w[numvecs], 2f\n"
- "mov %w[vecs], %w[numvecs]\n"
+ "cbz %w[numvecs], 2f\n"
+ "mov %w[vecs], %w[numvecs]\n"
// Vector 0
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x00]\n"
- "fmla v8.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x00]\n"
+ "fmla v8.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 1
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x10]\n"
- "fmla v9.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x10]\n"
+ "fmla v9.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 2
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x20]\n"
- "fmla v10.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x20]\n"
+ "fmla v10.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 3
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x30]\n"
- "fmla v11.4s, v7.4s, v0.4s\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x30]\n"
+ "fmla v11.4s, v7.4s, v0.4s\n"
// Prefetch
- "cbz %w[dopf], 3f\n" ASM_PREFETCH("[%[pf_ptr], #0x40]")
+ "cbz %w[dopf], 3f\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x40]")
"3:\n"
- "beq 2f\n"
+ "beq 2f\n"
// Vector 4
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x40]\n"
- "fmla v12.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x40]\n"
+ "fmla v12.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 5
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x50]\n"
- "fmla v13.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x50]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 6
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x60]\n"
- "fmla v14.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x60]\n"
+ "fmla v14.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 7
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x70]\n"
- "fmla v15.4s, v7.4s, v0.4s\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x70]\n"
+ "fmla v15.4s, v7.4s, v0.4s\n"
// Prefetch
- "cbz %w[dopf], 4f\n" ASM_PREFETCH("[%[pf_ptr], #0x80]")
+ "cbz %w[dopf], 4f\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x80]")
"4:\n"
- "beq 2f\n"
+ "beq 2f\n"
// Vector 8
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x80]\n"
- "fmla v16.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x80]\n"
+ "fmla v16.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 9
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x90]\n"
- "fmla v17.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x90]\n"
+ "fmla v17.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 10
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0xa0]\n"
- "fmla v18.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xa0]\n"
+ "fmla v18.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 11
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0xb0]\n"
- "fmla v19.4s, v7.4s, v0.4s\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xb0]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
// Prefetch
- "cbz %w[dopf], 5f\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]")
+ "cbz %w[dopf], 5f\n"
+ ASM_PREFETCH("[%[pf_ptr], #0xc0]")
"5:\n"
- "beq 2f\n"
+ "beq 2f\n"
// Vector 12
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0xc0]\n"
- "fmla v20.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xc0]\n"
+ "fmla v20.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 13
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0xd0]\n"
- "fmla v21.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xd0]\n"
+ "fmla v21.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 14
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0xe0]\n"
- "fmla v22.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xe0]\n"
+ "fmla v22.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 15
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0xf0]\n"
- "fmla v23.4s, v7.4s, v0.4s\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xf0]\n"
+ "fmla v23.4s, v7.4s, v0.4s\n"
// Prefetch
- "cbz %w[dopf], 6f\n" ASM_PREFETCH("[%[pf_ptr], #0x100]")
+ "cbz %w[dopf], 6f\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x100]")
"6:\n"
- "beq 2f\n"
+ "beq 2f\n"
// Vector 16
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x100]\n"
- "fmla v24.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x100]\n"
+ "fmla v24.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 17
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x110]\n"
- "fmla v25.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x110]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 18
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x120]\n"
- "fmla v26.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x120]\n"
+ "fmla v26.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 19
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x130]\n"
- "fmla v27.4s, v7.4s, v0.4s\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x130]\n"
+ "fmla v27.4s, v7.4s, v0.4s\n"
// Prefetch
- "cbz %w[dopf], 7f\n" ASM_PREFETCH("[%[pf_ptr], #0x140]")
+ "cbz %w[dopf], 7f\n"
+ ASM_PREFETCH("[%[pf_ptr], #0x140]")
"7:\n"
- "beq 2f\n"
+ "beq 2f\n"
// Vector 20
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x140]\n"
- "fmla v28.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x140]\n"
+ "fmla v28.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 21
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x150]\n"
- "fmla v29.4s, v7.4s, v0.4s\n"
- "beq 2f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x150]\n"
+ "fmla v29.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
// Vector 22
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7,[%[a_ptr], #0x160]\n"
- "fmla v30.4s, v7.4s, v0.4s\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x160]\n"
+ "fmla v30.4s, v7.4s, v0.4s\n"
"2:\n"
- "add %[a_ptr], %[a_ptr], %[jump]\n"
+ "add %[a_ptr], %[a_ptr], %[jump]\n"
// Do the odd 2-vector, if needed
- "cbz %[odd2_aptr], 8f\n"
- "ldr d7, [%[odd2_aptr]]\n"
- "fmla v6.2s, v7.2s, v0.2s\n"
- "add %[odd2_aptr], %[odd2_aptr], %[jump]\n"
+ "cbz %[odd2_aptr], 8f\n"
+ "ldr d7, [%[odd2_aptr]]\n"
+ "fmla v6.2s, v7.2s, v0.2s\n"
+ "add %[odd2_aptr], %[odd2_aptr], %[jump]\n"
"8:\n"
// Do the odd 1-vector, if needed
- "cbz %[odd1_aptr], 9f\n"
- "ldr s7, [%[odd1_aptr]]\n"
- "fmla v5.2s, v7.2s, v0.2s\n"
- "add %[odd1_aptr], %[odd1_aptr], %[jump]\n"
+ "cbz %[odd1_aptr], 9f\n"
+ "ldr s7, [%[odd1_aptr]]\n"
+ "fmla v5.2s, v7.2s, v0.2s\n"
+ "add %[odd1_aptr], %[odd1_aptr], %[jump]\n"
// Get out if needed.
"9:\n"
- "subs %w[k], %w[k], #1\n"
- "beq 10f\n"
+ "subs %w[k], %w[k], #1\n"
+ "beq 10f\n"
// Update the "main" prefetch pointer, if it strays beyond the limit turn off "dopf"
- "add %[pf_ptr], %[pf_ptr], %[jump]\n"
- "cmp %[pf_ptr], %[pf_limit]\n"
- "csel %w[dopf], %w[dopf], WZR, LT\n"
+ "add %[pf_ptr], %[pf_ptr], %[jump]\n"
+ "cmp %[pf_ptr], %[pf_limit]\n"
+ "csel %w[dopf], %w[dopf], WZR, LT\n"
// Update the "leading" prefetch pointer, don't do the first
// instruction of the loop if it's over the limit.
- "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n"
- "cmp %[firstpf_ptr], %[pf_limit]\n"
- "blt 1b\n"
- "b 11b\n"
+ "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n"
+ "cmp %[firstpf_ptr], %[pf_limit]\n"
+ "blt 1b\n"
+ "b 11b\n"
// Now write out the outputs
"10:\n"
- "cbz %w[numvecs], 12f\n"
- "mov %w[vecs], %w[numvecs]\n"
+ "cbz %w[numvecs], 12f\n"
+ "mov %w[vecs], %w[numvecs]\n"
// Vector 0
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v8.4s, v7.4s, %[vb].4s\n"
- "str q8, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v8.4s, v7.4s, %[vb].4s\n"
+ "str q8, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 1
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v9.4s, v7.4s, %[vb].4s\n"
- "str q9, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v9.4s, v7.4s, %[vb].4s\n"
+ "str q9, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 2
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v10.4s, v7.4s, %[vb].4s\n"
- "str q10, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v10.4s, v7.4s, %[vb].4s\n"
+ "str q10, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 3
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v11.4s, v7.4s, %[vb].4s\n"
- "str q11, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v11.4s, v7.4s, %[vb].4s\n"
+ "str q11, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 4
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v12.4s, v7.4s, %[vb].4s\n"
- "str q12, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v12.4s, v7.4s, %[vb].4s\n"
+ "str q12, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 5
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v13.4s, v7.4s, %[vb].4s\n"
- "str q13, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v13.4s, v7.4s, %[vb].4s\n"
+ "str q13, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 6
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v14.4s, v7.4s, %[vb].4s\n"
- "str q14, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v14.4s, v7.4s, %[vb].4s\n"
+ "str q14, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 7
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v15.4s, v7.4s, %[vb].4s\n"
- "str q15, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v15.4s, v7.4s, %[vb].4s\n"
+ "str q15, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 8
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v16.4s, v7.4s, %[vb].4s\n"
- "str q16, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v16.4s, v7.4s, %[vb].4s\n"
+ "str q16, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 9
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v17.4s, v7.4s, %[vb].4s\n"
- "str q17, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v17.4s, v7.4s, %[vb].4s\n"
+ "str q17, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 10
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v18.4s, v7.4s, %[vb].4s\n"
- "str q18, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v18.4s, v7.4s, %[vb].4s\n"
+ "str q18, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 11
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v19.4s, v7.4s, %[vb].4s\n"
- "str q19, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v19.4s, v7.4s, %[vb].4s\n"
+ "str q19, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 12
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v20.4s, v7.4s, %[vb].4s\n"
- "str q20, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v20.4s, v7.4s, %[vb].4s\n"
+ "str q20, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 13
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v21.4s, v7.4s, %[vb].4s\n"
- "str q21, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v21.4s, v7.4s, %[vb].4s\n"
+ "str q21, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 14
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v22.4s, v7.4s, %[vb].4s\n"
- "str q22, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v22.4s, v7.4s, %[vb].4s\n"
+ "str q22, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 15
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v23.4s, v7.4s, %[vb].4s\n"
- "str q23, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v23.4s, v7.4s, %[vb].4s\n"
+ "str q23, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 16
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v24.4s, v7.4s, %[vb].4s\n"
- "str q24, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v24.4s, v7.4s, %[vb].4s\n"
+ "str q24, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 17
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v25.4s, v7.4s, %[vb].4s\n"
- "str q25, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v25.4s, v7.4s, %[vb].4s\n"
+ "str q25, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 18
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v26.4s, v7.4s, %[vb].4s\n"
- "str q26, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v26.4s, v7.4s, %[vb].4s\n"
+ "str q26, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 19
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v27.4s, v7.4s, %[vb].4s\n"
- "str q27, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v27.4s, v7.4s, %[vb].4s\n"
+ "str q27, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 20
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v28.4s, v7.4s, %[vb].4s\n"
- "str q28, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v28.4s, v7.4s, %[vb].4s\n"
+ "str q28, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 21
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v29.4s, v7.4s, %[vb].4s\n"
- "str q29, [%[y_ptr]], #0x10\n"
- "beq 12f\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v29.4s, v7.4s, %[vb].4s\n"
+ "str q29, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
// Vector 22
- "subs %w[vecs], %w[vecs], #1\n"
- "ldr q7, [%[y_ptr]]\n"
- "fmla v30.4s, v7.4s, %[vb].4s\n"
- "str q30, [%[y_ptr]], #0x10\n"
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v30.4s, v7.4s, %[vb].4s\n"
+ "str q30, [%[y_ptr]], #0x10\n"
// Odd 2
"12:\n"
- "cbz %[odd2_aptr], 13f\n"
- "ldr d7, [%[y_ptr]]\n"
- "fmla v6.2s, v7.2s, %[vb].2s\n"
- "str d6, [%[y_ptr]], #0x8\n"
+ "cbz %[odd2_aptr], 13f\n"
+ "ldr d7, [%[y_ptr]]\n"
+ "fmla v6.2s, v7.2s, %[vb].2s\n"
+ "str d6, [%[y_ptr]], #0x8\n"
// Odd 1
"13:\n"
- "cbz %[odd1_aptr], 14f\n"
- "ldr s7, [%[y_ptr]]\n"
- "fmla v5.2s, v7.2s, %[vb].2s\n"
- "str s5, [%[y_ptr]]\n"
+ "cbz %[odd1_aptr], 14f\n"
+ "ldr s7, [%[y_ptr]]\n"
+ "fmla v5.2s, v7.2s, %[vb].2s\n"
+ "str s5, [%[y_ptr]]\n"
"14:\n"
- : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k),
- [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr),
- [odd1_aptr] "+r"(odd1_aptr), [odd2_aptr] "+r"(odd2_aptr),
- [dopf] "+r"(dopf), [vecs] "+r"(vecs)
- : [jump] "r"(jump), [vb] "w"(vb), [pf_limit] "r"(pf_limit), [numvecs] "r"(numvecs)
- : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
+ : [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr), [y_ptr] "+r" (y_ptr), [k] "+r" (k),
+ [pf_ptr] "+r" (pf_ptr), [firstpf_ptr] "+r" (firstpf_ptr),
+ [odd1_aptr] "+r" (odd1_aptr), [odd2_aptr] "+r" (odd2_aptr),
+ [dopf] "+r" (dopf), [vecs] "+r" (vecs)
+ : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit), [numvecs] "r" (numvecs)
+ : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
- "v27", "v28", "v29", "v30", "v31", "cc");
+ "v27", "v28", "v29", "v30", "v31", "cc"
+ );
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/mergeresults.hpp b/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
index 4a6da3d..04d1343 100644
--- a/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
+++ b/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
@@ -30,38 +30,43 @@
#include "asmlib.hpp"
#include "utils.hpp"
-namespace arm_gemm
-{
-template <unsigned int width, unsigned int height, typename Tin, typename Tout>
-inline void MergeResults(Tout *out, const Tin *in, int ldc, int y0, int ymax, int x0, int xmax, const Tout alpha, const Tout beta)
-{
- int full_y_blocks = (ymax - y0) / height;
- int y_remainder = (ymax - y0) % height;
- int y_blocks = full_y_blocks + (y_remainder ? 1 : 0);
+namespace arm_gemm {
- int full_x_blocks = (xmax - x0) / width;
- int x_remainder = (xmax - x0) % width;
- int x_blocks = full_x_blocks + (x_remainder ? 1 : 0);
+template<unsigned int twidth, unsigned int height, bool sve=false, typename Tin, typename Tout>
+inline void MergeResults(Tout * out, const Tin * in, int ldc, int y0, int ymax, int x0, int xmax, const Tout alpha, const Tout beta) {
+ // For SVE cases, multiply the width up by the vector length.
+ // Use the *input* type to determine this, since this will be what the kernel operated on.
+ const int width = twidth * (sve ? get_vector_length<Tin>() : 1);
- for(int y_block = 0; y_block < y_blocks; y_block++)
- {
+ const int full_y_blocks = (ymax - y0) / height;
+ const int y_remainder = (ymax - y0) % height;
+ const int y_blocks = full_y_blocks + (y_remainder ? 1 : 0);
+
+ const int full_x_blocks = (xmax - x0) / width;
+ const int x_remainder = (xmax - x0) % width;
+ const int x_blocks = full_x_blocks + (x_remainder ? 1 : 0);
+
+ for (int y_block = 0; y_block < y_blocks; y_block++) {
int ybase = y0 + (y_block * height);
int fill_rows = (y_block < full_y_blocks) ? height : y_remainder;
- for(int x_block = 0; x_block < x_blocks; x_block++)
- {
+ for (int x_block = 0; x_block < x_blocks; x_block++) {
int xbase = x0 + (x_block * width);
int fill_cols = (x_block < full_x_blocks) ? width : x_remainder;
- for(int row = 0; row < fill_rows; row++)
- {
- for(int col = 0; col < fill_cols; col++)
- {
+ for (int row=0; row < fill_rows; row++) {
+ for (int col=0; col < fill_cols; col++) {
Tout &p = out[(ybase + row) * ldc + xbase + col];
- p = (p * beta) + (alpha * in[row * width + col]);
+ // Special case for beta==0 - don't read the input;
+ // (0 * x == 0) is not always true for FP types.
+ if (beta == static_cast<Tout>(0)) {
+ p = (alpha * in[row * width + col]);
+ } else {
+ p = (p * beta) + (alpha * in[row * width + col]);
+ }
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
index b44e564..f4485bc 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
@@ -27,9 +27,8 @@
#include <arm_neon.h>
-template <>
-inline void MergeResults<8, 6>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta)
-{
+template<>
+inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) {
const float *inptr = in;
prefetch_6x(inptr);
prefetch_6x(inptr + 96);
@@ -37,8 +36,7 @@
float32x4_t av = vdupq_n_f32(alpha);
float32x4_t bv = vdupq_n_f32(beta);
- for(int y = y0; y < ymax; y += 8)
- {
+ for (int y=y0; y<ymax; y+=8) {
float *outptr0 = out + (y * ldout) + x0;
float *outptr1 = outptr0 + ldout;
float *outptr2 = outptr1 + ldout;
@@ -53,17 +51,14 @@
prefetch_2x(outptr4);
prefetch_2x(outptr5);
- for(int i = x0; i < xmax; i += 8)
- {
+ for (int i=x0; i<xmax; i+=8) {
float dummyres[8];
/* Make sure we throw away results if Y isn't a multiple of 8.
* We do this by pointing the result pointer at a dummy buffer
* we later discard. */
- if((y + 5) >= ymax)
- {
- switch((y + 5) - ymax)
- {
+ if ((y+5) >= ymax) {
+ switch ((y + 5) - ymax) {
case 4:
outptr1 = dummyres;
case 3:
@@ -81,84 +76,168 @@
}
}
- /* For ragged X, manually copy over the valid results. */
- if((i + 7) >= xmax)
- {
- for(int xi = 0; xi < 8; xi++)
- {
- if((i + xi) < xmax)
- {
- *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta);
- outptr0++;
- *outptr1 = (alpha * inptr[xi + 8]) + (*outptr1 * beta);
- outptr1++;
- *outptr2 = (alpha * inptr[xi + 16]) + (*outptr2 * beta);
- outptr2++;
- *outptr3 = (alpha * inptr[xi + 24]) + (*outptr3 * beta);
- outptr3++;
- *outptr4 = (alpha * inptr[xi + 32]) + (*outptr4 * beta);
- outptr4++;
- *outptr5 = (alpha * inptr[xi + 40]) + (*outptr5 * beta);
- outptr5++;
+ if (beta == 0.0f) {
+ /* If beta=0, don't read the original input at all. */
+
+ /* For ragged X, manually copy over the valid results. */
+ if ((i+7) >= xmax) {
+ for (int xi=0; xi<8; xi++) {
+ if ((i+xi) < xmax) {
+ *outptr0 = (alpha * inptr[xi]);
+ outptr0++;
+ *outptr1 = (alpha * inptr[xi + 8]);
+ outptr1++;
+ *outptr2 = (alpha * inptr[xi + 16]);
+ outptr2++;
+ *outptr3 = (alpha * inptr[xi + 24]);
+ outptr3++;
+ *outptr4 = (alpha * inptr[xi + 32]);
+ outptr4++;
+ *outptr5 = (alpha * inptr[xi + 40]);
+ outptr5++;
+ }
}
+ inptr += 48;
+ } else {
+ /* Optimized routine to copy an entire block */
+ __asm __volatile (
+ // Rows 0-1
+ "VLD1.32 {d0-d3}, [%[inptr]]!\n"
+ "VLD1.32 {d4-d7}, [%[inptr]]!\n"
+
+ "VMUL.f32 q4, q0, %q[av]\n"
+ ASM_PREFETCH("[%[inptr], #352]")
+ "VMUL.f32 q5, q1, %q[av]\n"
+ "VST1.32 {d8-d11}, [%[outptr0]]!\n"
+ ASM_PREFETCH("[%[inptr], #416]")
+ "VMUL.f32 q6, q2, %q[av]\n"
+ ASM_PREFETCH("[%[inptr], #480]")
+ "VMUL.f32 q7, q3, %q[av]\n"
+ "VST1.32 {d12-d15}, [%[outptr1]]!\n"
+
+ // Rows 2-3
+ "VLD1.32 {d0-d3}, [%[inptr]]!\n"
+ "VLD1.32 {d4-d7}, [%[inptr]]!\n"
+
+ "VMUL.f32 q4, q0, %q[av]\n"
+ ASM_PREFETCH("[%[outptr0], #96]")
+ "VMUL.f32 q5, q1, %q[av]\n"
+ "VST1.32 {d8-d11}, [%[outptr2]]!\n"
+ ASM_PREFETCH("[%[outptr1], #96]")
+ "VMUL.f32 q6, q2, %q[av]\n"
+ ASM_PREFETCH("[%[outptr2], #96]")
+ "VMUL.f32 q7, q3, %q[av]\n"
+ "VST1.32 {d12-d15}, [%[outptr3]]!\n"
+
+ // Rows 4-5
+ "VLD1.32 {d0-d3}, [%[inptr]]!\n"
+ "VLD1.32 {d4-d7}, [%[inptr]]!\n"
+
+ "VMUL.f32 q4, q0, %q[av]\n"
+ ASM_PREFETCH("[%[outptr3], #96]")
+ "VMUL.f32 q5, q1, %q[av]\n"
+ "VST1.32 {d8-d11}, [%[outptr4]]!\n"
+ ASM_PREFETCH("[%[outptr4], #96]")
+ "VMUL.f32 q6, q2, %q[av]\n"
+ ASM_PREFETCH("[%[outptr5], #128]")
+ "VMUL.f32 q7, q3, %q[av]\n"
+ "VST1.32 {d12-d15}, [%[outptr5]]!\n"
+ : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+ [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [inptr] "+r" (inptr)
+ : [av] "w" (av), [bv] "w" (bv)
+ : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"
+ );
}
- inptr += 48;
- }
- else
- {
- /* Optimized routine to copy an entire block */
- __asm __volatile(
- // Rows 0-1
- "VLD1.32 {d8-d11}, [%[outptr0]]\n"
- "VMUL.f32 q4, q4, %q[bv]\n"
- "VLD1.32 {d12-d15}, [%[outptr1]]\n"
- "VMUL.f32 q5, q5, %q[bv]\n"
- "VLD1.32 {d0-d3}, [%[inptr]]!\n"
- "VMUL.f32 q6, q6, %q[bv]\n"
- "VLD1.32 {d4-d7}, [%[inptr]]!\n"
- "VMUL.f32 q7, q7, %q[bv]\n"
+ } else {
+ /* Non-zero beta: Read output and apply beta. */
- "VMLA.f32 q4, q0, %q[av]\n" ASM_PREFETCH("[%[inptr], #352]")
- "VMLA.f32 q5, q1, %q[av]\n"
- "VST1.32 {d8-d11}, [%[outptr0]]!\n" ASM_PREFETCH("[%[inptr], #416]") "VMLA.f32 q6, q2, %q[av]\n" ASM_PREFETCH("[%[inptr], #480]")
- "VMLA.f32 q7, q3, %q[av]\n"
- "VST1.32 {d12-d15}, [%[outptr1]]!\n"
+ /* For ragged X, manually copy over the valid results. */
+ if ((i+7) >= xmax) {
+ for (int xi=0; xi<8; xi++) {
+ if ((i+xi) < xmax) {
+ *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta);
+ outptr0++;
+ *outptr1 = (alpha * inptr[xi + 8]) + (*outptr1 * beta);
+ outptr1++;
+ *outptr2 = (alpha * inptr[xi + 16]) + (*outptr2 * beta);
+ outptr2++;
+ *outptr3 = (alpha * inptr[xi + 24]) + (*outptr3 * beta);
+ outptr3++;
+ *outptr4 = (alpha * inptr[xi + 32]) + (*outptr4 * beta);
+ outptr4++;
+ *outptr5 = (alpha * inptr[xi + 40]) + (*outptr5 * beta);
+ outptr5++;
+ }
+ }
+ inptr += 48;
+ } else {
+ /* Optimized routine to copy an entire block */
+ __asm __volatile (
+ // Rows 0-1
+ "VLD1.32 {d8-d11}, [%[outptr0]]\n"
+ "VMUL.f32 q4, q4, %q[bv]\n"
+ "VLD1.32 {d12-d15}, [%[outptr1]]\n"
+ "VMUL.f32 q5, q5, %q[bv]\n"
+ "VLD1.32 {d0-d3}, [%[inptr]]!\n"
+ "VMUL.f32 q6, q6, %q[bv]\n"
+ "VLD1.32 {d4-d7}, [%[inptr]]!\n"
+ "VMUL.f32 q7, q7, %q[bv]\n"
- // Rows 2-3
- "VLD1.32 {d8-d11}, [%[outptr2]]\n"
- "VMUL.f32 q4, q4, %q[bv]\n"
- "VLD1.32 {d12-d15}, [%[outptr3]]\n"
- "VMUL.f32 q5, q5, %q[bv]\n"
- "VLD1.32 {d0-d3}, [%[inptr]]!\n"
- "VMUL.f32 q6, q6, %q[bv]\n"
- "VLD1.32 {d4-d7}, [%[inptr]]!\n"
- "VMUL.f32 q7, q7, %q[bv]\n"
+ "VMLA.f32 q4, q0, %q[av]\n"
+ ASM_PREFETCH("[%[inptr], #352]")
+ "VMLA.f32 q5, q1, %q[av]\n"
+ "VST1.32 {d8-d11}, [%[outptr0]]!\n"
+ ASM_PREFETCH("[%[inptr], #416]")
+ "VMLA.f32 q6, q2, %q[av]\n"
+ ASM_PREFETCH("[%[inptr], #480]")
+ "VMLA.f32 q7, q3, %q[av]\n"
+ "VST1.32 {d12-d15}, [%[outptr1]]!\n"
- "VMLA.f32 q4, q0, %q[av]\n" ASM_PREFETCH("[%[outptr0], #96]")
- "VMLA.f32 q5, q1, %q[av]\n"
- "VST1.32 {d8-d11}, [%[outptr2]]!\n" ASM_PREFETCH("[%[outptr1], #96]") "VMLA.f32 q6, q2, %q[av]\n" ASM_PREFETCH("[%[outptr2], #96]")
- "VMLA.f32 q7, q3, %q[av]\n"
- "VST1.32 {d12-d15}, [%[outptr3]]!\n"
+ // Rows 2-3
+ "VLD1.32 {d8-d11}, [%[outptr2]]\n"
+ "VMUL.f32 q4, q4, %q[bv]\n"
+ "VLD1.32 {d12-d15}, [%[outptr3]]\n"
+ "VMUL.f32 q5, q5, %q[bv]\n"
+ "VLD1.32 {d0-d3}, [%[inptr]]!\n"
+ "VMUL.f32 q6, q6, %q[bv]\n"
+ "VLD1.32 {d4-d7}, [%[inptr]]!\n"
+ "VMUL.f32 q7, q7, %q[bv]\n"
- // Rows 4-5
- "VLD1.32 {d8-d11}, [%[outptr4]]\n"
- "VMUL.f32 q4, q4, %q[bv]\n"
- "VLD1.32 {d12-d15}, [%[outptr5]]\n"
- "VMUL.f32 q5, q5, %q[bv]\n"
- "VLD1.32 {d0-d3}, [%[inptr]]!\n"
- "VMUL.f32 q6, q6, %q[bv]\n"
- "VLD1.32 {d4-d7}, [%[inptr]]!\n"
- "VMUL.f32 q7, q7, %q[bv]\n"
+ "VMLA.f32 q4, q0, %q[av]\n"
+ ASM_PREFETCH("[%[outptr0], #96]")
+ "VMLA.f32 q5, q1, %q[av]\n"
+ "VST1.32 {d8-d11}, [%[outptr2]]!\n"
+ ASM_PREFETCH("[%[outptr1], #96]")
+ "VMLA.f32 q6, q2, %q[av]\n"
+ ASM_PREFETCH("[%[outptr2], #96]")
+ "VMLA.f32 q7, q3, %q[av]\n"
+ "VST1.32 {d12-d15}, [%[outptr3]]!\n"
- "VMLA.f32 q4, q0, %q[av]\n" ASM_PREFETCH("[%[outptr3], #96]")
- "VMLA.f32 q5, q1, %q[av]\n"
- "VST1.32 {d8-d11}, [%[outptr4]]!\n" ASM_PREFETCH("[%[outptr4], #96]") "VMLA.f32 q6, q2, %q[av]\n" ASM_PREFETCH("[%[outptr5], #128]")
- "VMLA.f32 q7, q3, %q[av]\n"
- "VST1.32 {d12-d15}, [%[outptr5]]!\n"
- : [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3),
- [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [inptr] "+r"(inptr)
- : [av] "w"(av), [bv] "w"(bv)
- : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
+ // Rows 4-5
+ "VLD1.32 {d8-d11}, [%[outptr4]]\n"
+ "VMUL.f32 q4, q4, %q[bv]\n"
+ "VLD1.32 {d12-d15}, [%[outptr5]]\n"
+ "VMUL.f32 q5, q5, %q[bv]\n"
+ "VLD1.32 {d0-d3}, [%[inptr]]!\n"
+ "VMUL.f32 q6, q6, %q[bv]\n"
+ "VLD1.32 {d4-d7}, [%[inptr]]!\n"
+ "VMUL.f32 q7, q7, %q[bv]\n"
+
+ "VMLA.f32 q4, q0, %q[av]\n"
+ ASM_PREFETCH("[%[outptr3], #96]")
+ "VMLA.f32 q5, q1, %q[av]\n"
+ "VST1.32 {d8-d11}, [%[outptr4]]!\n"
+ ASM_PREFETCH("[%[outptr4], #96]")
+ "VMLA.f32 q6, q2, %q[av]\n"
+ ASM_PREFETCH("[%[outptr5], #128]")
+ "VMLA.f32 q7, q3, %q[av]\n"
+ "VST1.32 {d12-d15}, [%[outptr5]]!\n"
+ : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+ [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [inptr] "+r" (inptr)
+ : [av] "w" (av), [bv] "w" (bv)
+ : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"
+ );
+ }
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
index 3b59a43..be23978 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
@@ -25,9 +25,8 @@
#ifdef __aarch64__
-template <>
-inline void MergeResults<12, 8>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta)
-{
+template<>
+inline void MergeResults<12, 8, false>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) {
const float *inptr = in;
prefetch_6x(inptr);
prefetch_6x(inptr + 96);
@@ -35,8 +34,7 @@
float32x4_t av = vdupq_n_f32(alpha);
float32x4_t bv = vdupq_n_f32(beta);
- for(int y = y0; y < ymax; y += 8)
- {
+ for (int y=y0; y<ymax; y+=8) {
float *outptr0 = out + (y * ldout) + x0;
float *outptr1 = outptr0 + ldout;
float *outptr2 = outptr1 + ldout;
@@ -55,17 +53,14 @@
prefetch_2x(outptr6);
prefetch_2x(outptr7);
- for(int i = x0; i < xmax; i += 12)
- {
+ for (int i=x0; i<xmax; i+=12) {
float dummyres[12];
/* Make sure we throw away results if Y isn't a multiple of 8.
* We do this by pointing the result pointer at a dummy buffer
* we later discard. */
- if((y + 7) >= ymax)
- {
- switch((y + 7) - ymax)
- {
+ if ((y+7) >= ymax) {
+ switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
case 5:
@@ -87,147 +82,259 @@
}
}
- /* For ragged X, manually copy over the valid results. */
- if((i + 11) >= xmax)
- {
- for(int xi = 0; xi < 12; xi++)
- {
- if((i + xi) < xmax)
- {
- *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta);
- outptr0++;
- *outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta);
- outptr1++;
- *outptr2 = (alpha * inptr[xi + 24]) + (*outptr2 * beta);
- outptr2++;
- *outptr3 = (alpha * inptr[xi + 36]) + (*outptr3 * beta);
- outptr3++;
- *outptr4 = (alpha * inptr[xi + 48]) + (*outptr4 * beta);
- outptr4++;
- *outptr5 = (alpha * inptr[xi + 60]) + (*outptr5 * beta);
- outptr5++;
- *outptr6 = (alpha * inptr[xi + 72]) + (*outptr6 * beta);
- outptr6++;
- *outptr7 = (alpha * inptr[xi + 84]) + (*outptr7 * beta);
- outptr7++;
+ if (beta==0.0f) {
+ /* If beta==0, don't read the original input at all. */
+
+ /* For ragged X, manually copy over the valid results. */
+ if ((i+11) >= xmax) {
+ for (int xi=0; xi<12; xi++) {
+ if ((i+xi) < xmax) {
+ *outptr0 = (alpha * inptr[xi]);
+ outptr0++;
+ *outptr1 = (alpha * inptr[xi + 12]);
+ outptr1++;
+ *outptr2 = (alpha * inptr[xi + 24]);
+ outptr2++;
+ *outptr3 = (alpha * inptr[xi + 36]);
+ outptr3++;
+ *outptr4 = (alpha * inptr[xi + 48]);
+ outptr4++;
+ *outptr5 = (alpha * inptr[xi + 60]);
+ outptr5++;
+ *outptr6 = (alpha * inptr[xi + 72]);
+ outptr6++;
+ *outptr7 = (alpha * inptr[xi + 84]);
+ outptr7++;
+ }
}
+ inptr += 96;
+ } else {
+ /* Optimized routine to copy an entire block */
+ __asm __volatile (
+ // Rows 0-1
+ "LDP q0, q1, [%[inptr]]\n"
+ "FMUL v16.4s, v0.4s, %[av].4s\n"
+ "LDP q2, q3, [%[inptr], #32]\n"
+ "FMUL v17.4s, v1.4s, %[av].4s\n"
+ "LDP q4, q5, [%[inptr], #64]\n"
+ "FMUL v18.4s, v2.4s, %[av].4s\n"
+ "STP q16, q17, [%[outptr0]], #32\n"
+ ASM_PREFETCH("[%[inptr], #768]")
+ "FMUL v19.4s, v3.4s, %[av].4s\n"
+ "STR q18, [%[outptr0]], #16\n"
+ "FMUL v20.4s, v4.4s, %[av].4s\n"
+ "STP q19, q20, [%[outptr1]], #32\n"
+ ASM_PREFETCH("[%[inptr], #832]")
+ "FMUL v21.4s, v5.4s, %[av].4s\n"
+ "STR q21, [%[outptr1]], #16\n"
+
+ // Rows 2-3
+ "LDP q0, q1, [%[inptr], #96]\n"
+ "FMUL v16.4s, v0.4s, %[av].4s\n"
+ "LDP q2, q3, [%[inptr], #128]\n"
+ "FMUL v17.4s, v1.4s, %[av].4s\n"
+ "LDP q4, q5, [%[inptr], #160]\n"
+ "FMUL v18.4s, v2.4s, %[av].4s\n"
+ "STP q16, q17, [%[outptr2]], #32\n"
+ ASM_PREFETCH("[%[inptr], #896]")
+ "FMUL v19.4s, v3.4s, %[av].4s\n"
+ "STR q18, [%[outptr2]], #16\n"
+ "FMUL v20.4s, v4.4s, %[av].4s\n"
+ "STP q19, q20, [%[outptr3]], #32\n"
+ ASM_PREFETCH("[%[inptr], #1024]")
+ "FMUL v21.4s, v5.4s, %[av].4s\n"
+ "STR q21, [%[outptr3]], #16\n"
+
+ // Rows 4-5
+ "LDP q0, q1, [%[inptr], #192]\n"
+ "FMUL v16.4s, v0.4s, %[av].4s\n"
+ "LDP q2, q3, [%[inptr], #224]\n"
+ "FMUL v17.4s, v1.4s, %[av].4s\n"
+ "LDP q4, q5, [%[inptr], #256]\n"
+ "FMUL v18.4s, v2.4s, %[av].4s\n"
+ "STP q16, q17, [%[outptr4]], #32\n"
+ ASM_PREFETCH("[%[inptr], #960]")
+ "FMUL v19.4s, v3.4s, %[av].4s\n"
+ "STR q18, [%[outptr4]], #16\n"
+ "FMUL v20.4s, v4.4s, %[av].4s\n"
+ "STP q19, q20, [%[outptr5]], #32\n"
+ ASM_PREFETCH("[%[inptr], #1088]")
+ "FMUL v21.4s, v5.4s, %[av].4s\n"
+ "STR q21, [%[outptr5]], #16\n"
+
+ // Rows 6-7
+ "LDP q0, q1, [%[inptr], #288]\n"
+ "FMUL v16.4s, v0.4s, %[av].4s\n"
+ "LDP q2, q3, [%[inptr], #320]\n"
+ "FMUL v17.4s, v1.4s, %[av].4s\n"
+ "LDP q4, q5, [%[inptr], #352]\n"
+ "FMUL v18.4s, v2.4s, %[av].4s\n"
+ "STP q16, q17, [%[outptr6]], #32\n"
+ "FMUL v19.4s, v3.4s, %[av].4s\n"
+ "STR q18, [%[outptr6]], #16\n"
+ "FMUL v20.4s, v4.4s, %[av].4s\n"
+ "STP q19, q20, [%[outptr7]], #32\n"
+ "FMUL v21.4s, v5.4s, %[av].4s\n"
+ "STR q21, [%[outptr7]], #16\n"
+ "ADD %[inptr], %[inptr], #384\n"
+ : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+ [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7),
+ [inptr] "+r" (inptr)
+ : [av] "w" (av), [bv] "w" (bv)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21"
+ );
}
- inptr += 96;
- }
- else
- {
- /* Optimized routine to copy an entire block */
- __asm __volatile(
- // Rows 0-1
- "LDP q16, q17, [%[outptr0]]\n"
- "FMUL v16.4s, v16.4s, %[bv].4s\n"
- "LDR q18, [%[outptr0], #32]\n"
- "FMUL v17.4s, v17.4s, %[bv].4s\n"
- "LDP q19, q20, [%[outptr1]]\n"
- "FMUL v18.4s, v18.4s, %[bv].4s\n"
- "LDR q21, [%[outptr1], #32]\n" ASM_PREFETCH("[%[inptr], #768]")
- "FMUL v19.4s, v19.4s, %[bv].4s\n"
- "LDP q0, q1, [%[inptr]]\n"
- "FMUL v20.4s, v20.4s, %[bv].4s\n"
- "LDP q2, q3, [%[inptr], #32]\n"
- "FMUL v21.4s, v21.4s, %[bv].4s\n"
- "LDP q4, q5, [%[inptr], #64]\n"
- "FMLA v16.4s, v0.4s, %[av].4s\n" ASM_PREFETCH("[%[inptr], #832]")
- "FMLA v17.4s, v1.4s, %[av].4s\n"
- "STP q16, q17, [%[outptr0]], #32\n"
- "FMLA v18.4s, v2.4s, %[av].4s\n"
- "STR q18, [%[outptr0]], #16\n"
- "FMLA v19.4s, v3.4s, %[av].4s\n" ASM_PREFETCH("[%[inptr], #896]")
- "FMLA v20.4s, v4.4s, %[av].4s\n"
- "STP q19, q20, [%[outptr1]], #32\n"
- "FMLA v21.4s, v5.4s, %[av].4s\n"
- "STR q21, [%[outptr1]], #16\n"
+ } else {
+ /* For ragged X, manually copy over the valid results. */
+ if ((i+11) >= xmax) {
+ for (int xi=0; xi<12; xi++) {
+ if ((i+xi) < xmax) {
+ *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta);
+ outptr0++;
+ *outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta);
+ outptr1++;
+ *outptr2 = (alpha * inptr[xi + 24]) + (*outptr2 * beta);
+ outptr2++;
+ *outptr3 = (alpha * inptr[xi + 36]) + (*outptr3 * beta);
+ outptr3++;
+ *outptr4 = (alpha * inptr[xi + 48]) + (*outptr4 * beta);
+ outptr4++;
+ *outptr5 = (alpha * inptr[xi + 60]) + (*outptr5 * beta);
+ outptr5++;
+ *outptr6 = (alpha * inptr[xi + 72]) + (*outptr6 * beta);
+ outptr6++;
+ *outptr7 = (alpha * inptr[xi + 84]) + (*outptr7 * beta);
+ outptr7++;
+ }
+ }
+ inptr += 96;
+ } else {
+ /* Optimized routine to copy an entire block */
+ __asm __volatile (
+ // Rows 0-1
+ "LDP q16, q17, [%[outptr0]]\n"
+ "FMUL v16.4s, v16.4s, %[bv].4s\n"
+ "LDR q18, [%[outptr0], #32]\n"
+ "FMUL v17.4s, v17.4s, %[bv].4s\n"
+ "LDP q19, q20, [%[outptr1]]\n"
+ "FMUL v18.4s, v18.4s, %[bv].4s\n"
+ "LDR q21, [%[outptr1], #32]\n"
+ ASM_PREFETCH("[%[inptr], #768]")
+ "FMUL v19.4s, v19.4s, %[bv].4s\n"
+ "LDP q0, q1, [%[inptr]]\n"
+ "FMUL v20.4s, v20.4s, %[bv].4s\n"
+ "LDP q2, q3, [%[inptr], #32]\n"
+ "FMUL v21.4s, v21.4s, %[bv].4s\n"
+ "LDP q4, q5, [%[inptr], #64]\n"
+ "FMLA v16.4s, v0.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[inptr], #832]")
+ "FMLA v17.4s, v1.4s, %[av].4s\n"
+ "STP q16, q17, [%[outptr0]], #32\n"
+ "FMLA v18.4s, v2.4s, %[av].4s\n"
+ "STR q18, [%[outptr0]], #16\n"
+ "FMLA v19.4s, v3.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[inptr], #896]")
+ "FMLA v20.4s, v4.4s, %[av].4s\n"
+ "STP q19, q20, [%[outptr1]], #32\n"
+ "FMLA v21.4s, v5.4s, %[av].4s\n"
+ "STR q21, [%[outptr1]], #16\n"
- // Rows 2-3
- "LDP q16, q17, [%[outptr2]]\n"
- "FMUL v16.4s, v16.4s, %[bv].4s\n"
- "LDR q18, [%[outptr2], #32]\n"
- "FMUL v17.4s, v17.4s, %[bv].4s\n"
- "LDP q19, q20, [%[outptr3]]\n"
- "FMUL v18.4s, v18.4s, %[bv].4s\n"
- "LDR q21, [%[outptr3], #32]\n" ASM_PREFETCH("[%[inptr], #960]")
- "FMUL v19.4s, v19.4s, %[bv].4s\n"
- "LDP q0, q1, [%[inptr], #96]\n"
- "FMUL v20.4s, v20.4s, %[bv].4s\n"
- "LDP q2, q3, [%[inptr], #128]\n"
- "FMUL v21.4s, v21.4s, %[bv].4s\n"
- "LDP q4, q5, [%[inptr], #160]\n"
- "FMLA v16.4s, v0.4s, %[av].4s\n" ASM_PREFETCH("[%[inptr], #1024]")
- "FMLA v17.4s, v1.4s, %[av].4s\n"
- "STP q16, q17, [%[outptr2]], #32\n"
- "FMLA v18.4s, v2.4s, %[av].4s\n"
- "STR q18, [%[outptr2]], #16\n"
- "FMLA v19.4s, v3.4s, %[av].4s\n" ASM_PREFETCH("[%[inptr], #1088]")
- "FMLA v20.4s, v4.4s, %[av].4s\n"
- "STP q19, q20, [%[outptr3]], #32\n"
- "FMLA v21.4s, v5.4s, %[av].4s\n"
- "STR q21, [%[outptr3]], #16\n"
+ // Rows 2-3
+ "LDP q16, q17, [%[outptr2]]\n"
+ "FMUL v16.4s, v16.4s, %[bv].4s\n"
+ "LDR q18, [%[outptr2], #32]\n"
+ "FMUL v17.4s, v17.4s, %[bv].4s\n"
+ "LDP q19, q20, [%[outptr3]]\n"
+ "FMUL v18.4s, v18.4s, %[bv].4s\n"
+ "LDR q21, [%[outptr3], #32]\n"
+ ASM_PREFETCH("[%[inptr], #960]")
+ "FMUL v19.4s, v19.4s, %[bv].4s\n"
+ "LDP q0, q1, [%[inptr], #96]\n"
+ "FMUL v20.4s, v20.4s, %[bv].4s\n"
+ "LDP q2, q3, [%[inptr], #128]\n"
+ "FMUL v21.4s, v21.4s, %[bv].4s\n"
+ "LDP q4, q5, [%[inptr], #160]\n"
+ "FMLA v16.4s, v0.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[inptr], #1024]")
+ "FMLA v17.4s, v1.4s, %[av].4s\n"
+ "STP q16, q17, [%[outptr2]], #32\n"
+ "FMLA v18.4s, v2.4s, %[av].4s\n"
+ "STR q18, [%[outptr2]], #16\n"
+ "FMLA v19.4s, v3.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[inptr], #1088]")
+ "FMLA v20.4s, v4.4s, %[av].4s\n"
+ "STP q19, q20, [%[outptr3]], #32\n"
+ "FMLA v21.4s, v5.4s, %[av].4s\n"
+ "STR q21, [%[outptr3]], #16\n"
- // Rows 4-5
- ASM_PREFETCH("[%[outptr0], #80]")
- "LDP q16, q17, [%[outptr4]]\n"
- "FMUL v16.4s, v16.4s, %[bv].4s\n"
- "LDR q18, [%[outptr4], #32]\n"
- "FMUL v17.4s, v17.4s, %[bv].4s\n"
- "LDP q19, q20, [%[outptr5]]\n"
- "FMUL v18.4s, v18.4s, %[bv].4s\n"
- "LDR q21, [%[outptr5], #32]\n" ASM_PREFETCH("[%[outptr1], #80]")
- "FMUL v19.4s, v19.4s, %[bv].4s\n"
- "LDP q0, q1, [%[inptr], #192]\n"
- "FMUL v20.4s, v20.4s, %[bv].4s\n"
- "LDP q2, q3, [%[inptr], #224]\n"
- "FMUL v21.4s, v21.4s, %[bv].4s\n"
- "LDP q4, q5, [%[inptr], #256]\n"
- "FMLA v16.4s, v0.4s, %[av].4s\n" ASM_PREFETCH("[%[outptr2], #80]")
- "FMLA v17.4s, v1.4s, %[av].4s\n"
- "STP q16, q17, [%[outptr4]], #32\n"
- "FMLA v18.4s, v2.4s, %[av].4s\n"
- "STR q18, [%[outptr4]], #16\n"
- "FMLA v19.4s, v3.4s, %[av].4s\n" ASM_PREFETCH("[%[outptr3], #80]")
- "FMLA v20.4s, v4.4s, %[av].4s\n"
- "STP q19, q20, [%[outptr5]], #32\n"
- "FMLA v21.4s, v5.4s, %[av].4s\n"
- "STR q21, [%[outptr5]], #16\n"
+ // Rows 4-5
+ ASM_PREFETCH("[%[outptr0], #80]")
+ "LDP q16, q17, [%[outptr4]]\n"
+ "FMUL v16.4s, v16.4s, %[bv].4s\n"
+ "LDR q18, [%[outptr4], #32]\n"
+ "FMUL v17.4s, v17.4s, %[bv].4s\n"
+ "LDP q19, q20, [%[outptr5]]\n"
+ "FMUL v18.4s, v18.4s, %[bv].4s\n"
+ "LDR q21, [%[outptr5], #32]\n"
+ ASM_PREFETCH("[%[outptr1], #80]")
+ "FMUL v19.4s, v19.4s, %[bv].4s\n"
+ "LDP q0, q1, [%[inptr], #192]\n"
+ "FMUL v20.4s, v20.4s, %[bv].4s\n"
+ "LDP q2, q3, [%[inptr], #224]\n"
+ "FMUL v21.4s, v21.4s, %[bv].4s\n"
+ "LDP q4, q5, [%[inptr], #256]\n"
+ "FMLA v16.4s, v0.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[outptr2], #80]")
+ "FMLA v17.4s, v1.4s, %[av].4s\n"
+ "STP q16, q17, [%[outptr4]], #32\n"
+ "FMLA v18.4s, v2.4s, %[av].4s\n"
+ "STR q18, [%[outptr4]], #16\n"
+ "FMLA v19.4s, v3.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[outptr3], #80]")
+ "FMLA v20.4s, v4.4s, %[av].4s\n"
+ "STP q19, q20, [%[outptr5]], #32\n"
+ "FMLA v21.4s, v5.4s, %[av].4s\n"
+ "STR q21, [%[outptr5]], #16\n"
- // Rows 6-7
- ASM_PREFETCH("[%[outptr4], #80]")
- "LDP q16, q17, [%[outptr6]]\n"
- "FMUL v16.4s, v16.4s, %[bv].4s\n"
- "LDR q18, [%[outptr6], #32]\n"
- "FMUL v17.4s, v17.4s, %[bv].4s\n"
- "LDP q19, q20, [%[outptr7]]\n"
- "FMUL v18.4s, v18.4s, %[bv].4s\n"
- "LDR q21, [%[outptr7], #32]\n" ASM_PREFETCH("[%[outptr5], #80]")
- "FMUL v19.4s, v19.4s, %[bv].4s\n"
- "LDP q0, q1, [%[inptr], #288]\n"
- "FMUL v20.4s, v20.4s, %[bv].4s\n"
- "LDP q2, q3, [%[inptr], #320]\n"
- "FMUL v21.4s, v21.4s, %[bv].4s\n"
- "LDP q4, q5, [%[inptr], #352]\n"
- "FMLA v16.4s, v0.4s, %[av].4s\n" ASM_PREFETCH("[%[outptr6], #128]")
- "FMLA v17.4s, v1.4s, %[av].4s\n"
- "STP q16, q17, [%[outptr6]], #32\n"
- "FMLA v18.4s, v2.4s, %[av].4s\n"
- "STR q18, [%[outptr6]], #16\n"
- "FMLA v19.4s, v3.4s, %[av].4s\n" ASM_PREFETCH("[%[outptr7], #128]")
- "FMLA v20.4s, v4.4s, %[av].4s\n"
- "STP q19, q20, [%[outptr7]], #32\n"
- "FMLA v21.4s, v5.4s, %[av].4s\n"
- "STR q21, [%[outptr7]], #16\n"
- "ADD %[inptr], %[inptr], #384\n"
- : [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3),
- [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7),
- [inptr] "+r"(inptr)
- : [av] "w"(av), [bv] "w"(bv)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21");
+ // Rows 6-7
+ ASM_PREFETCH("[%[outptr4], #80]")
+ "LDP q16, q17, [%[outptr6]]\n"
+ "FMUL v16.4s, v16.4s, %[bv].4s\n"
+ "LDR q18, [%[outptr6], #32]\n"
+ "FMUL v17.4s, v17.4s, %[bv].4s\n"
+ "LDP q19, q20, [%[outptr7]]\n"
+ "FMUL v18.4s, v18.4s, %[bv].4s\n"
+ "LDR q21, [%[outptr7], #32]\n"
+ ASM_PREFETCH("[%[outptr5], #80]")
+ "FMUL v19.4s, v19.4s, %[bv].4s\n"
+ "LDP q0, q1, [%[inptr], #288]\n"
+ "FMUL v20.4s, v20.4s, %[bv].4s\n"
+ "LDP q2, q3, [%[inptr], #320]\n"
+ "FMUL v21.4s, v21.4s, %[bv].4s\n"
+ "LDP q4, q5, [%[inptr], #352]\n"
+ "FMLA v16.4s, v0.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[outptr6], #128]")
+ "FMLA v17.4s, v1.4s, %[av].4s\n"
+ "STP q16, q17, [%[outptr6]], #32\n"
+ "FMLA v18.4s, v2.4s, %[av].4s\n"
+ "STR q18, [%[outptr6]], #16\n"
+ "FMLA v19.4s, v3.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[outptr7], #128]")
+ "FMLA v20.4s, v4.4s, %[av].4s\n"
+ "STP q19, q20, [%[outptr7]], #32\n"
+ "FMLA v21.4s, v5.4s, %[av].4s\n"
+ "STR q21, [%[outptr7]], #16\n"
+ "ADD %[inptr], %[inptr], #384\n"
+ : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+ [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7),
+ [inptr] "+r" (inptr)
+ : [av] "w" (av), [bv] "w" (bv)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21"
+ );
+ }
}
}
}
}
-#endif // __aarch64__
\ No newline at end of file
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
index 9708fe1..9e5eb88 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
@@ -28,9 +28,8 @@
#include <arm_neon.h>
-template <>
-inline void MergeResults<12, 8>(__fp16 *out, const float *in, int ldout, int y0, int ymax, int x0, int xmax, const __fp16 alpha, const __fp16 beta)
-{
+template<>
+inline void MergeResults<12,8,false>(__fp16 *out, const float *in, int ldout, int y0, int ymax, int x0, int xmax, const __fp16 alpha, const __fp16 beta) {
const float *inptr = in;
prefetch_6x(inptr);
prefetch_6x(inptr + 24);
@@ -38,8 +37,7 @@
float32x4_t av = vdupq_n_f32(alpha);
float32x4_t bv = vdupq_n_f32(beta);
- for(int y = y0; y < ymax; y += 8)
- {
+ for (int y=y0; y<ymax; y+=8) {
__fp16 *outptr0 = out + (y * ldout) + x0;
__fp16 *outptr1 = outptr0 + ldout;
__fp16 *outptr2 = outptr1 + ldout;
@@ -58,17 +56,14 @@
prefetch_2x(outptr6);
prefetch_2x(outptr7);
- for(int i = x0; i < xmax; i += 12)
- {
+ for (int i=x0; i<xmax; i+=12) {
__fp16 dummyres[12];
/* Make sure we throw away results if Y isn't a multiple of 8.
* We do this by pointing the result pointer at a dummy buffer
* we later discard. */
- if((y + 7) >= ymax)
- {
- switch((y + 7) - ymax)
- {
+ if ((y+7) >= ymax) {
+ switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
case 5:
@@ -90,182 +85,335 @@
}
}
- /* For ragged X, manually copy over the valid results. */
- if((i + 11) >= xmax)
- {
- for(int xi = 0; xi < 12; xi++)
- {
- if((i + xi) < xmax)
- {
- *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta);
- outptr0++;
- *outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta);
- outptr1++;
- *outptr2 = (alpha * inptr[xi + 24]) + (*outptr2 * beta);
- outptr2++;
- *outptr3 = (alpha * inptr[xi + 36]) + (*outptr3 * beta);
- outptr3++;
- *outptr4 = (alpha * inptr[xi + 48]) + (*outptr4 * beta);
- outptr4++;
- *outptr5 = (alpha * inptr[xi + 60]) + (*outptr5 * beta);
- outptr5++;
- *outptr6 = (alpha * inptr[xi + 72]) + (*outptr6 * beta);
- outptr6++;
- *outptr7 = (alpha * inptr[xi + 84]) + (*outptr7 * beta);
- outptr7++;
+ if (beta == ((__fp16)0.0f)) {
+ /* If beta==0, don't read the output. */
+ /* For ragged X, manually copy over the valid results. */
+ if ((i+11) >= xmax) {
+ for (int xi=0; xi<12; xi++) {
+ if ((i+xi) < xmax) {
+ *outptr0 = (alpha * inptr[xi]);
+ outptr0++;
+ *outptr1 = (alpha * inptr[xi + 12]);
+ outptr1++;
+ *outptr2 = (alpha * inptr[xi + 24]);
+ outptr2++;
+ *outptr3 = (alpha * inptr[xi + 36]);
+ outptr3++;
+ *outptr4 = (alpha * inptr[xi + 48]);
+ outptr4++;
+ *outptr5 = (alpha * inptr[xi + 60]);
+ outptr5++;
+ *outptr6 = (alpha * inptr[xi + 72]);
+ outptr6++;
+ *outptr7 = (alpha * inptr[xi + 84]);
+ outptr7++;
+ }
}
+ inptr += 96;
+ } else {
+ /* Optimized routine to copy an entire block */
+ __asm __volatile (
+ // Rows 0-1
+ "LDP q0, q1, [%[inptr]]\n"
+ "LDP q2, q3, [%[inptr], #32]\n"
+ "LDP q4, q5, [%[inptr], #64]\n"
+ "FMUL v16.4s, v0.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[inptr], #768]")
+ "FMUL v17.4s, v1.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[inptr], #832]")
+ "FCVTN v16.4h, v16.4s\n"
+ ASM_PREFETCH("[%[inptr], #896]")
+ "FCVTN2 v16.8h, v17.4s\n"
+ ASM_PREFETCH("[%[inptr], #960]")
+ "FMUL v18.4s, v2.4s, %[av].4s\n"
+ "STR q16, [%[outptr0]], #16\n"
+ "FCVTN v18.4h, v18.4s\n"
+ "STR d18, [%[outptr0]], #8\n"
+ "FMUL v19.4s, v3.4s, %[av].4s\n"
+ "FMUL v20.4s, v4.4s, %[av].4s\n"
+ "FCVTN v19.4h, v19.4s\n"
+ "FCVTN2 v19.8h, v20.4s\n"
+ "STR q19, [%[outptr1]], #16\n"
+ "FMUL v21.4s, v5.4s, %[av].4s\n"
+ "FCVTN v21.4h, v21.4s\n"
+ "STR d21, [%[outptr1]], #8\n"
+
+ // Rows 2-3
+ "LDP q0, q1, [%[inptr], #96]\n"
+ "LDP q2, q3, [%[inptr], #128]\n"
+ "LDP q4, q5, [%[inptr], #160]\n"
+ "FMUL v16.4s, v0.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[inptr], #1024]")
+ "FMUL v17.4s, v1.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[inptr], #1088]")
+ "FCVTN v16.4h, v16.4s\n"
+ ASM_PREFETCH("[%[outptr0], #64]")
+ "FCVTN2 v16.8h, v17.4s\n"
+ ASM_PREFETCH("[%[outptr1], #64]")
+ "FMUL v18.4s, v2.4s, %[av].4s\n"
+ "STR q16, [%[outptr2]], #16\n"
+ "FCVTN v18.4h, v18.4s\n"
+ "STR d18, [%[outptr2]], #8\n"
+ "FMUL v19.4s, v3.4s, %[av].4s\n"
+ "FMUL v20.4s, v4.4s, %[av].4s\n"
+ "FCVTN v19.4h, v19.4s\n"
+ "FCVTN2 v19.8h, v20.4s\n"
+ "STR q19, [%[outptr3]], #16\n"
+ "FMUL v21.4s, v5.4s, %[av].4s\n"
+ "FCVTN v21.4h, v21.4s\n"
+ "STR d21, [%[outptr3]], #8\n"
+
+ // Rows 4-5
+ "LDP q0, q1, [%[inptr], #192]\n"
+ "LDP q2, q3, [%[inptr], #224]\n"
+ "LDP q4, q5, [%[inptr], #256]\n"
+ "FMUL v16.4s, v0.4s, %[av].4s\n"
+ "FMUL v17.4s, v1.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[outptr2], #64]")
+ "FCVTN v16.4h, v16.4s\n"
+ ASM_PREFETCH("[%[outptr3], #64]")
+ "FCVTN2 v16.8h, v17.4s\n"
+ ASM_PREFETCH("[%[outptr4], #88]")
+ "FMUL v18.4s, v2.4s, %[av].4s\n"
+ "STR q16, [%[outptr4]], #16\n"
+ "FCVTN v18.4h, v18.4s\n"
+ "STR d18, [%[outptr4]], #8\n"
+ "FMUL v19.4s, v3.4s, %[av].4s\n"
+ "FMUL v20.4s, v4.4s, %[av].4s\n"
+ "FCVTN v19.4h, v19.4s\n"
+ "FCVTN2 v19.8h, v20.4s\n"
+ "STR q19, [%[outptr5]], #16\n"
+ "FMUL v21.4s, v5.4s, %[av].4s\n"
+ "FCVTN v21.4h, v21.4s\n"
+ "STR d21, [%[outptr5]], #8\n"
+
+ // Rows 6-7
+ "LDP q0, q1, [%[inptr], #288]\n"
+ "LDP q2, q3, [%[inptr], #320]\n"
+ "LDP q4, q5, [%[inptr], #352]\n"
+ "FMUL v16.4s, v0.4s, %[av].4s\n"
+ "FMUL v17.4s, v1.4s, %[av].4s\n"
+ ASM_PREFETCH("[%[outptr5], #64]")
+ "FCVTN v16.4h, v16.4s\n"
+ ASM_PREFETCH("[%[outptr6], #88]")
+ "FCVTN2 v16.8h, v17.4s\n"
+ ASM_PREFETCH("[%[outptr7], #88]")
+ "FMUL v18.4s, v2.4s, %[av].4s\n"
+ "STR q16, [%[outptr6]], #16\n"
+ "FCVTN v18.4h, v18.4s\n"
+ "STR d18, [%[outptr6]], #8\n"
+ "FMUL v19.4s, v3.4s, %[av].4s\n"
+ "FMUL v20.4s, v4.4s, %[av].4s\n"
+ "FCVTN v19.4h, v19.4s\n"
+ "FCVTN2 v19.8h, v20.4s\n"
+ "STR q19, [%[outptr7]], #16\n"
+ "FMUL v21.4s, v5.4s, %[av].4s\n"
+ "FCVTN v21.4h, v21.4s\n"
+ "STR d21, [%[outptr7]], #8\n"
+ "ADD %[inptr], %[inptr], #384\n"
+ : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+ [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7),
+ [inptr] "+r" (inptr)
+ : [av] "w" (av), [bv] "w" (bv)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21"
+ );
}
- inptr += 96;
- }
- else
- {
- /* Optimized routine to copy an entire block */
- __asm __volatile(
- // Rows 0-1
- "LDR q16, [%[outptr0]]\n"
- "FCVTL2 v17.4s, v16.8h\n"
- "LDR d18, [%[outptr0], #16]\n"
- "FCVTL v16.4s, v16.4h\n"
- "LDR q19, [%[outptr1]]\n"
- "FMUL v17.4s, v17.4s, %[bv].4s\n"
- "LDR d21, [%[outptr1], #16]\n"
- "FMUL v16.4s, v16.4s, %[bv].4s\n"
- "LDP q0, q1, [%[inptr]]\n"
- "FCVTL v18.4s, v18.4h\n"
- "LDP q2, q3, [%[inptr], #32]\n"
- "FCVTL2 v20.4s, v19.8h\n"
- "LDP q4, q5, [%[inptr], #64]\n"
- "FCVTL v19.4s, v19.4h\n" ASM_PREFETCH("[%[inptr], #768]") "FCVTL v21.4s, v21.4h\n" ASM_PREFETCH("[%[inptr], #832]") "FMUL v18.4s, v18.4s, %[bv].4s\n" ASM_PREFETCH("[%[inptr], #896]")
- "FMUL v20.4s, v20.4s, %[bv].4s\n" ASM_PREFETCH("[%[inptr], #960]")
- "FMUL v19.4s, v19.4s, %[bv].4s\n"
- "FMUL v21.4s, v21.4s, %[bv].4s\n"
- "FMLA v16.4s, v0.4s, %[av].4s\n"
- "FMLA v17.4s, v1.4s, %[av].4s\n"
- "FCVTN v16.4h, v16.4s\n"
- "FCVTN2 v16.8h, v17.4s\n"
- "FMLA v18.4s, v2.4s, %[av].4s\n"
- "STR q16, [%[outptr0]], #16\n"
- "FCVTN v18.4h, v18.4s\n"
- "STR d18, [%[outptr0]], #8\n"
- "FMLA v19.4s, v3.4s, %[av].4s\n"
- "FMLA v20.4s, v4.4s, %[av].4s\n"
- "FCVTN v19.4h, v19.4s\n"
- "FCVTN2 v19.8h, v20.4s\n"
- "STR q19, [%[outptr1]], #16\n"
- "FMLA v21.4s, v5.4s, %[av].4s\n"
- "FCVTN v21.4h, v21.4s\n"
- "STR d21, [%[outptr1]], #8\n"
+ } else {
+ /* For ragged X, manually copy over the valid results. */
+ if ((i+11) >= xmax) {
+ for (int xi=0; xi<12; xi++) {
+ if ((i+xi) < xmax) {
+ *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta);
+ outptr0++;
+ *outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta);
+ outptr1++;
+ *outptr2 = (alpha * inptr[xi + 24]) + (*outptr2 * beta);
+ outptr2++;
+ *outptr3 = (alpha * inptr[xi + 36]) + (*outptr3 * beta);
+ outptr3++;
+ *outptr4 = (alpha * inptr[xi + 48]) + (*outptr4 * beta);
+ outptr4++;
+ *outptr5 = (alpha * inptr[xi + 60]) + (*outptr5 * beta);
+ outptr5++;
+ *outptr6 = (alpha * inptr[xi + 72]) + (*outptr6 * beta);
+ outptr6++;
+ *outptr7 = (alpha * inptr[xi + 84]) + (*outptr7 * beta);
+ outptr7++;
+ }
+ }
+ inptr += 96;
+ } else {
+ /* Optimized routine to copy an entire block */
+ __asm __volatile (
+ // Rows 0-1
+ "LDR q16, [%[outptr0]]\n"
+ "FCVTL2 v17.4s, v16.8h\n"
+ "LDR d18, [%[outptr0], #16]\n"
+ "FCVTL v16.4s, v16.4h\n"
+ "LDR q19, [%[outptr1]]\n"
+ "FMUL v17.4s, v17.4s, %[bv].4s\n"
+ "LDR d21, [%[outptr1], #16]\n"
+ "FMUL v16.4s, v16.4s, %[bv].4s\n"
+ "LDP q0, q1, [%[inptr]]\n"
+ "FCVTL v18.4s, v18.4h\n"
+ "LDP q2, q3, [%[inptr], #32]\n"
+ "FCVTL2 v20.4s, v19.8h\n"
+ "LDP q4, q5, [%[inptr], #64]\n"
+ "FCVTL v19.4s, v19.4h\n"
+ ASM_PREFETCH("[%[inptr], #768]")
+ "FCVTL v21.4s, v21.4h\n"
+ ASM_PREFETCH("[%[inptr], #832]")
+ "FMUL v18.4s, v18.4s, %[bv].4s\n"
+ ASM_PREFETCH("[%[inptr], #896]")
+ "FMUL v20.4s, v20.4s, %[bv].4s\n"
+ ASM_PREFETCH("[%[inptr], #960]")
+ "FMUL v19.4s, v19.4s, %[bv].4s\n"
+ "FMUL v21.4s, v21.4s, %[bv].4s\n"
+ "FMLA v16.4s, v0.4s, %[av].4s\n"
+ "FMLA v17.4s, v1.4s, %[av].4s\n"
+ "FCVTN v16.4h, v16.4s\n"
+ "FCVTN2 v16.8h, v17.4s\n"
+ "FMLA v18.4s, v2.4s, %[av].4s\n"
+ "STR q16, [%[outptr0]], #16\n"
+ "FCVTN v18.4h, v18.4s\n"
+ "STR d18, [%[outptr0]], #8\n"
+ "FMLA v19.4s, v3.4s, %[av].4s\n"
+ "FMLA v20.4s, v4.4s, %[av].4s\n"
+ "FCVTN v19.4h, v19.4s\n"
+ "FCVTN2 v19.8h, v20.4s\n"
+ "STR q19, [%[outptr1]], #16\n"
+ "FMLA v21.4s, v5.4s, %[av].4s\n"
+ "FCVTN v21.4h, v21.4s\n"
+ "STR d21, [%[outptr1]], #8\n"
- // Rows 2-3
- "LDR q16, [%[outptr2]]\n"
- "FCVTL2 v17.4s, v16.8h\n"
- "LDR d18, [%[outptr2], #16]\n"
- "FCVTL v16.4s, v16.4h\n"
- "LDR q19, [%[outptr3]]\n"
- "FMUL v17.4s, v17.4s, %[bv].4s\n"
- "LDR d21, [%[outptr3], #16]\n"
- "FMUL v16.4s, v16.4s, %[bv].4s\n"
- "LDP q0, q1, [%[inptr], #96]\n"
- "FCVTL v18.4s, v18.4h\n"
- "LDP q2, q3, [%[inptr], #128]\n"
- "FCVTL2 v20.4s, v19.8h\n"
- "LDP q4, q5, [%[inptr], #160]\n"
- "FCVTL v19.4s, v19.4h\n" ASM_PREFETCH("[%[inptr], #1024]") "FCVTL v21.4s, v21.4h\n" ASM_PREFETCH("[%[inptr], #1088]") "FMUL v18.4s, v18.4s, %[bv].4s\n" ASM_PREFETCH("[%[outptr0], #64]")
- "FMUL v20.4s, v20.4s, %[bv].4s\n" ASM_PREFETCH("[%[outptr1], #64]")
- "FMUL v19.4s, v19.4s, %[bv].4s\n"
- "FMUL v21.4s, v21.4s, %[bv].4s\n"
- "FMLA v16.4s, v0.4s, %[av].4s\n"
- "FMLA v17.4s, v1.4s, %[av].4s\n"
- "FCVTN v16.4h, v16.4s\n"
- "FCVTN2 v16.8h, v17.4s\n"
- "FMLA v18.4s, v2.4s, %[av].4s\n"
- "STR q16, [%[outptr2]], #16\n"
- "FCVTN v18.4h, v18.4s\n"
- "STR d18, [%[outptr2]], #8\n"
- "FMLA v19.4s, v3.4s, %[av].4s\n"
- "FMLA v20.4s, v4.4s, %[av].4s\n"
- "FCVTN v19.4h, v19.4s\n"
- "FCVTN2 v19.8h, v20.4s\n"
- "STR q19, [%[outptr3]], #16\n"
- "FMLA v21.4s, v5.4s, %[av].4s\n"
- "FCVTN v21.4h, v21.4s\n"
- "STR d21, [%[outptr3]], #8\n"
+ // Rows 2-3
+ "LDR q16, [%[outptr2]]\n"
+ "FCVTL2 v17.4s, v16.8h\n"
+ "LDR d18, [%[outptr2], #16]\n"
+ "FCVTL v16.4s, v16.4h\n"
+ "LDR q19, [%[outptr3]]\n"
+ "FMUL v17.4s, v17.4s, %[bv].4s\n"
+ "LDR d21, [%[outptr3], #16]\n"
+ "FMUL v16.4s, v16.4s, %[bv].4s\n"
+ "LDP q0, q1, [%[inptr], #96]\n"
+ "FCVTL v18.4s, v18.4h\n"
+ "LDP q2, q3, [%[inptr], #128]\n"
+ "FCVTL2 v20.4s, v19.8h\n"
+ "LDP q4, q5, [%[inptr], #160]\n"
+ "FCVTL v19.4s, v19.4h\n"
+ ASM_PREFETCH("[%[inptr], #1024]")
+ "FCVTL v21.4s, v21.4h\n"
+ ASM_PREFETCH("[%[inptr], #1088]")
+ "FMUL v18.4s, v18.4s, %[bv].4s\n"
+ ASM_PREFETCH("[%[outptr0], #64]")
+ "FMUL v20.4s, v20.4s, %[bv].4s\n"
+ ASM_PREFETCH("[%[outptr1], #64]")
+ "FMUL v19.4s, v19.4s, %[bv].4s\n"
+ "FMUL v21.4s, v21.4s, %[bv].4s\n"
+ "FMLA v16.4s, v0.4s, %[av].4s\n"
+ "FMLA v17.4s, v1.4s, %[av].4s\n"
+ "FCVTN v16.4h, v16.4s\n"
+ "FCVTN2 v16.8h, v17.4s\n"
+ "FMLA v18.4s, v2.4s, %[av].4s\n"
+ "STR q16, [%[outptr2]], #16\n"
+ "FCVTN v18.4h, v18.4s\n"
+ "STR d18, [%[outptr2]], #8\n"
+ "FMLA v19.4s, v3.4s, %[av].4s\n"
+ "FMLA v20.4s, v4.4s, %[av].4s\n"
+ "FCVTN v19.4h, v19.4s\n"
+ "FCVTN2 v19.8h, v20.4s\n"
+ "STR q19, [%[outptr3]], #16\n"
+ "FMLA v21.4s, v5.4s, %[av].4s\n"
+ "FCVTN v21.4h, v21.4s\n"
+ "STR d21, [%[outptr3]], #8\n"
- // Rows 4-5
- "LDR q16, [%[outptr4]]\n"
- "FCVTL2 v17.4s, v16.8h\n"
- "LDR d18, [%[outptr4], #16]\n"
- "FCVTL v16.4s, v16.4h\n"
- "LDR q19, [%[outptr5]]\n"
- "FMUL v17.4s, v17.4s, %[bv].4s\n"
- "LDR d21, [%[outptr5], #16]\n"
- "FMUL v16.4s, v16.4s, %[bv].4s\n"
- "LDP q0, q1, [%[inptr], #192]\n"
- "FCVTL v18.4s, v18.4h\n"
- "LDP q2, q3, [%[inptr], #224]\n"
- "FCVTL2 v20.4s, v19.8h\n"
- "LDP q4, q5, [%[inptr], #256]\n"
- "FCVTL v19.4s, v19.4h\n" ASM_PREFETCH("[%[outptr2], #64]") "FCVTL v21.4s, v21.4h\n" ASM_PREFETCH("[%[outptr3], #64]") "FMUL v18.4s, v18.4s, %[bv].4s\n" ASM_PREFETCH("[%[outptr4], #88]")
- "FMUL v20.4s, v20.4s, %[bv].4s\n"
- "FMUL v19.4s, v19.4s, %[bv].4s\n"
- "FMUL v21.4s, v21.4s, %[bv].4s\n"
- "FMLA v16.4s, v0.4s, %[av].4s\n"
- "FMLA v17.4s, v1.4s, %[av].4s\n"
- "FCVTN v16.4h, v16.4s\n"
- "FCVTN2 v16.8h, v17.4s\n"
- "FMLA v18.4s, v2.4s, %[av].4s\n"
- "STR q16, [%[outptr4]], #16\n"
- "FCVTN v18.4h, v18.4s\n"
- "STR d18, [%[outptr4]], #8\n"
- "FMLA v19.4s, v3.4s, %[av].4s\n"
- "FMLA v20.4s, v4.4s, %[av].4s\n"
- "FCVTN v19.4h, v19.4s\n"
- "FCVTN2 v19.8h, v20.4s\n"
- "STR q19, [%[outptr5]], #16\n"
- "FMLA v21.4s, v5.4s, %[av].4s\n"
- "FCVTN v21.4h, v21.4s\n"
- "STR d21, [%[outptr5]], #8\n"
+ // Rows 4-5
+ "LDR q16, [%[outptr4]]\n"
+ "FCVTL2 v17.4s, v16.8h\n"
+ "LDR d18, [%[outptr4], #16]\n"
+ "FCVTL v16.4s, v16.4h\n"
+ "LDR q19, [%[outptr5]]\n"
+ "FMUL v17.4s, v17.4s, %[bv].4s\n"
+ "LDR d21, [%[outptr5], #16]\n"
+ "FMUL v16.4s, v16.4s, %[bv].4s\n"
+ "LDP q0, q1, [%[inptr], #192]\n"
+ "FCVTL v18.4s, v18.4h\n"
+ "LDP q2, q3, [%[inptr], #224]\n"
+ "FCVTL2 v20.4s, v19.8h\n"
+ "LDP q4, q5, [%[inptr], #256]\n"
+ "FCVTL v19.4s, v19.4h\n"
+ ASM_PREFETCH("[%[outptr2], #64]")
+ "FCVTL v21.4s, v21.4h\n"
+ ASM_PREFETCH("[%[outptr3], #64]")
+ "FMUL v18.4s, v18.4s, %[bv].4s\n"
+ ASM_PREFETCH("[%[outptr4], #88]")
+ "FMUL v20.4s, v20.4s, %[bv].4s\n"
+ "FMUL v19.4s, v19.4s, %[bv].4s\n"
+ "FMUL v21.4s, v21.4s, %[bv].4s\n"
+ "FMLA v16.4s, v0.4s, %[av].4s\n"
+ "FMLA v17.4s, v1.4s, %[av].4s\n"
+ "FCVTN v16.4h, v16.4s\n"
+ "FCVTN2 v16.8h, v17.4s\n"
+ "FMLA v18.4s, v2.4s, %[av].4s\n"
+ "STR q16, [%[outptr4]], #16\n"
+ "FCVTN v18.4h, v18.4s\n"
+ "STR d18, [%[outptr4]], #8\n"
+ "FMLA v19.4s, v3.4s, %[av].4s\n"
+ "FMLA v20.4s, v4.4s, %[av].4s\n"
+ "FCVTN v19.4h, v19.4s\n"
+ "FCVTN2 v19.8h, v20.4s\n"
+ "STR q19, [%[outptr5]], #16\n"
+ "FMLA v21.4s, v5.4s, %[av].4s\n"
+ "FCVTN v21.4h, v21.4s\n"
+ "STR d21, [%[outptr5]], #8\n"
- // Rows 6-7
- "LDR q16, [%[outptr6]]\n"
- "FCVTL2 v17.4s, v16.8h\n"
- "LDR d18, [%[outptr6], #16]\n"
- "FCVTL v16.4s, v16.4h\n"
- "LDR q19, [%[outptr7]]\n"
- "FMUL v17.4s, v17.4s, %[bv].4s\n"
- "LDR d21, [%[outptr7], #16]\n"
- "FMUL v16.4s, v16.4s, %[bv].4s\n"
- "LDP q0, q1, [%[inptr], #288]\n"
- "FCVTL v18.4s, v18.4h\n"
- "LDP q2, q3, [%[inptr], #320]\n"
- "FCVTL2 v20.4s, v19.8h\n"
- "LDP q4, q5, [%[inptr], #352]\n"
- "FCVTL v19.4s, v19.4h\n" ASM_PREFETCH("[%[outptr5], #64]") "FCVTL v21.4s, v21.4h\n" ASM_PREFETCH("[%[outptr6], #88]") "FMUL v18.4s, v18.4s, %[bv].4s\n" ASM_PREFETCH("[%[outptr7], #88]")
- "FMUL v20.4s, v20.4s, %[bv].4s\n"
- "FMUL v19.4s, v19.4s, %[bv].4s\n"
- "FMUL v21.4s, v21.4s, %[bv].4s\n"
- "FMLA v16.4s, v0.4s, %[av].4s\n"
- "FMLA v17.4s, v1.4s, %[av].4s\n"
- "FCVTN v16.4h, v16.4s\n"
- "FCVTN2 v16.8h, v17.4s\n"
- "FMLA v18.4s, v2.4s, %[av].4s\n"
- "STR q16, [%[outptr6]], #16\n"
- "FCVTN v18.4h, v18.4s\n"
- "STR d18, [%[outptr6]], #8\n"
- "FMLA v19.4s, v3.4s, %[av].4s\n"
- "FMLA v20.4s, v4.4s, %[av].4s\n"
- "FCVTN v19.4h, v19.4s\n"
- "FCVTN2 v19.8h, v20.4s\n"
- "STR q19, [%[outptr7]], #16\n"
- "FMLA v21.4s, v5.4s, %[av].4s\n"
- "FCVTN v21.4h, v21.4s\n"
- "STR d21, [%[outptr7]], #8\n"
- "ADD %[inptr], %[inptr], #384\n"
- : [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3),
- [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7),
- [inptr] "+r"(inptr)
- : [av] "w"(av), [bv] "w"(bv)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21");
+ // Rows 6-7
+ "LDR q16, [%[outptr6]]\n"
+ "FCVTL2 v17.4s, v16.8h\n"
+ "LDR d18, [%[outptr6], #16]\n"
+ "FCVTL v16.4s, v16.4h\n"
+ "LDR q19, [%[outptr7]]\n"
+ "FMUL v17.4s, v17.4s, %[bv].4s\n"
+ "LDR d21, [%[outptr7], #16]\n"
+ "FMUL v16.4s, v16.4s, %[bv].4s\n"
+ "LDP q0, q1, [%[inptr], #288]\n"
+ "FCVTL v18.4s, v18.4h\n"
+ "LDP q2, q3, [%[inptr], #320]\n"
+ "FCVTL2 v20.4s, v19.8h\n"
+ "LDP q4, q5, [%[inptr], #352]\n"
+ "FCVTL v19.4s, v19.4h\n"
+ ASM_PREFETCH("[%[outptr5], #64]")
+ "FCVTL v21.4s, v21.4h\n"
+ ASM_PREFETCH("[%[outptr6], #88]")
+ "FMUL v18.4s, v18.4s, %[bv].4s\n"
+ ASM_PREFETCH("[%[outptr7], #88]")
+ "FMUL v20.4s, v20.4s, %[bv].4s\n"
+ "FMUL v19.4s, v19.4s, %[bv].4s\n"
+ "FMUL v21.4s, v21.4s, %[bv].4s\n"
+ "FMLA v16.4s, v0.4s, %[av].4s\n"
+ "FMLA v17.4s, v1.4s, %[av].4s\n"
+ "FCVTN v16.4h, v16.4s\n"
+ "FCVTN2 v16.8h, v17.4s\n"
+ "FMLA v18.4s, v2.4s, %[av].4s\n"
+ "STR q16, [%[outptr6]], #16\n"
+ "FCVTN v18.4h, v18.4s\n"
+ "STR d18, [%[outptr6]], #8\n"
+ "FMLA v19.4s, v3.4s, %[av].4s\n"
+ "FMLA v20.4s, v4.4s, %[av].4s\n"
+ "FCVTN v19.4h, v19.4s\n"
+ "FCVTN2 v19.8h, v20.4s\n"
+ "STR q19, [%[outptr7]], #16\n"
+ "FMLA v21.4s, v5.4s, %[av].4s\n"
+ "FCVTN v21.4h, v21.4s\n"
+ "STR d21, [%[outptr7]], #8\n"
+ "ADD %[inptr], %[inptr], #384\n"
+ : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+ [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7),
+ [inptr] "+r" (inptr)
+ : [av] "w" (av), [bv] "w" (bv)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21"
+ );
+ }
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
index 08cfc00..3ed43b1 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
@@ -23,12 +23,12 @@
*/
#pragma once
-#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
+// AArch64 only, and either the FP16_KERNELS option set or the target explicitly supports FP16 vectors.
+#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
-template <>
+template<>
inline void MergeResults<24, 8>(__fp16 *out, const __fp16 *in, const int ldout, const int y0, const int ymax,
- const int x0, const int xmax, const __fp16 alpha, const __fp16 beta)
-{
+ const int x0, const int xmax, const __fp16 alpha, const __fp16 beta) {
const __fp16 *inptr = in;
prefetch_6x(inptr);
prefetch_6x(inptr + 48);
@@ -36,8 +36,7 @@
float16x8_t va = vdupq_n_f16(alpha);
float16x8_t vb = vdupq_n_f16(beta);
- for(int y = y0; y < ymax; y += 8)
- {
+ for (int y=y0; y<ymax; y+=8) {
__fp16 *outptr0 = out + (y * ldout) + x0;
__fp16 *outptr1 = outptr0 + ldout;
__fp16 *outptr2 = outptr1 + ldout;
@@ -56,17 +55,14 @@
prefetch_2x(outptr6);
prefetch_2x(outptr7);
- for(int i = x0; i < xmax; i += 24)
- {
+ for (int i=x0; i<xmax; i+=24) {
__fp16 dummyres[24];
/* Make sure we throw away results if Y isn't a multiple of 8.
* We do this by pointing the result pointer at a dummy buffer
* we later discard. */
- if((y + 7) >= ymax)
- {
- switch((y + 7) - ymax)
- {
+ if ((y+7) >= ymax) {
+ switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
case 5:
@@ -85,149 +81,277 @@
default:
UNREACHABLE("Impossible.");
+
}
}
- /* For ragged X, manually copy over the valid results. */
- if((i + 23) >= xmax)
- {
- for(int xi = 0; xi < 24; xi++)
- {
- if((i + xi) < xmax)
- {
- *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta);
- outptr0++;
- *outptr1 = (alpha * inptr[xi + 24]) + (*outptr1 * beta);
- outptr1++;
- *outptr2 = (alpha * inptr[xi + 48]) + (*outptr2 * beta);
- outptr2++;
- *outptr3 = (alpha * inptr[xi + 72]) + (*outptr3 * beta);
- outptr3++;
- *outptr4 = (alpha * inptr[xi + 96]) + (*outptr4 * beta);
- outptr4++;
- *outptr5 = (alpha * inptr[xi + 120]) + (*outptr5 * beta);
- outptr5++;
- *outptr6 = (alpha * inptr[xi + 144]) + (*outptr6 * beta);
- outptr6++;
- *outptr7 = (alpha * inptr[xi + 168]) + (*outptr7 * beta);
- outptr7++;
+ if (beta == (__fp16)0.0f) {
+ /* If beta===0, don't read the output. */
+
+ /* For ragged X, manually copy over the valid results. */
+ if ((i+23) >= xmax) {
+ for (int xi=0; xi<24; xi++) {
+ if ((i+xi) < xmax) {
+ *outptr0 = (alpha * inptr[xi]);
+ outptr0++;
+ *outptr1 = (alpha * inptr[xi + 24]);
+ outptr1++;
+ *outptr2 = (alpha * inptr[xi + 48]);
+ outptr2++;
+ *outptr3 = (alpha * inptr[xi + 72]);
+ outptr3++;
+ *outptr4 = (alpha * inptr[xi + 96]);
+ outptr4++;
+ *outptr5 = (alpha * inptr[xi + 120]);
+ outptr5++;
+ *outptr6 = (alpha * inptr[xi + 144]);
+ outptr6++;
+ *outptr7 = (alpha * inptr[xi + 168]);
+ outptr7++;
+ }
}
+ inptr += 192;
+ } else {
+ /* Optimized routine to copy an entire block */
+ __asm __volatile (
+#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ ".arch armv8.2-a+fp16\n"
+#endif
+ // Rows 0-1
+ ASM_PREFETCH("[%[inptr], #768]")
+ "LDP q0, q1, [%[inptr]]\n"
+ "LDP q2, q3, [%[inptr], #32]\n"
+ "LDP q4, q5, [%[inptr], #64]\n"
+ "FMUL v16.8h, v0.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[inptr], #832]")
+ "FMUL v17.8h, v1.8h, %[va].8h\n"
+ "STP q16, q17, [%[outptr0]], #32\n"
+ "FMUL v18.8h, v2.8h, %[va].8h\n"
+ "STR q18, [%[outptr0]], #16\n"
+ "FMUL v19.8h, v3.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[inptr], #896]")
+ "FMUL v20.8h, v4.8h, %[va].8h\n"
+ "STP q19, q20, [%[outptr1]], #32\n"
+ "FMUL v21.8h, v5.8h, %[va].8h\n"
+ "STR q21, [%[outptr1]], #16\n"
+ ASM_PREFETCH("[%[inptr], #960]")
+
+ // Rows 2-3
+ ASM_PREFETCH("[%[inptr], #1024]")
+ "LDP q0, q1, [%[inptr], #96]\n"
+ "LDP q2, q3, [%[inptr], #128]\n"
+ "LDP q4, q5, [%[inptr], #160]\n"
+ "FMUL v16.8h, v0.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[inptr], #1088]")
+ "FMUL v17.8h, v1.8h, %[va].8h\n"
+ "STP q16, q17, [%[outptr2]], #32\n"
+ "FMUL v18.8h, v2.8h, %[va].8h\n"
+ "STR q18, [%[outptr2]], #16\n"
+ "FMUL v19.8h, v3.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr0], #80]")
+ "FMUL v20.8h, v4.8h, %[va].8h\n"
+ "STP q19, q20, [%[outptr3]], #32\n"
+ "FMUL v21.8h, v5.8h, %[va].8h\n"
+ "STR q21, [%[outptr3]], #16\n"
+ ASM_PREFETCH("[%[outptr1], #80]")
+
+ // Rows 4-5
+ ASM_PREFETCH("[%[outptr2], #80]")
+ "LDP q0, q1, [%[inptr], #192]\n"
+ "LDP q2, q3, [%[inptr], #224]\n"
+ "LDP q4, q5, [%[inptr], #256]\n"
+ "FMUL v16.8h, v0.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr3], #80]")
+ "FMUL v17.8h, v1.8h, %[va].8h\n"
+ "STP q16, q17, [%[outptr4]], #32\n"
+ "FMUL v18.8h, v2.8h, %[va].8h\n"
+ "STR q18, [%[outptr4]], #16\n"
+ "FMUL v19.8h, v3.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr4], #80]")
+ "FMUL v20.8h, v4.8h, %[va].8h\n"
+ "STP q19, q20, [%[outptr5]], #32\n"
+ "FMUL v21.8h, v5.8h, %[va].8h\n"
+ "STR q21, [%[outptr5]], #16\n"
+
+ // Rows 6-7
+ ASM_PREFETCH("[%[outptr5], #80]")
+ "LDP q0, q1, [%[inptr], #288]\n"
+ "LDP q2, q3, [%[inptr], #320]\n"
+ "LDP q4, q5, [%[inptr], #352]\n"
+ "FMUL v16.8h, v0.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr6], #128]")
+ "FMUL v17.8h, v1.8h, %[va].8h\n"
+ "STP q16, q17, [%[outptr6]], #32\n"
+ "FMUL v18.8h, v2.8h, %[va].8h\n"
+ "STR q18, [%[outptr6]], #16\n"
+ "FMUL v19.8h, v3.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr7], #128]")
+ "FMUL v20.8h, v4.8h, %[va].8h\n"
+ "STP q19, q20, [%[outptr7]], #32\n"
+ "FMUL v21.8h, v5.8h, %[va].8h\n"
+ "STR q21, [%[outptr7]], #16\n"
+ "ADD %[inptr], %[inptr], #384\n"
+ : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+ [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7),
+ [inptr] "+r" (inptr)
+ : [va] "w" (va), [vb] "w" (vb)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21"
+ );
}
- inptr += 192;
- }
- else
- {
- /* Optimized routine to copy an entire block */
- __asm __volatile(
- ".arch armv8.2-a+fp16\n"
- // Rows 0-1
- "LDP q16, q17, [%[outptr0]]\n"
- "FMUL v16.8h, v16.8h, %[vb].8h\n"
- "LDR q18, [%[outptr0], #32]\n"
- "FMUL v17.8h, v17.8h, %[vb].8h\n"
- "LDP q19, q20, [%[outptr1]]\n"
- "FMUL v18.8h, v18.8h, %[vb].8h\n" ASM_PREFETCH("[%[inptr], #768]")
- "LDR q21, [%[outptr1], #32]\n"
- "FMUL v19.8h, v19.8h, %[vb].8h\n"
- "LDP q0, q1, [%[inptr]]\n"
- "FMUL v20.8h, v20.8h, %[vb].8h\n"
- "LDP q2, q3, [%[inptr], #32]\n"
- "FMUL v21.8h, v21.8h, %[vb].8h\n"
- "LDP q4, q5, [%[inptr], #64]\n"
- "FMLA v16.8h, v0.8h, %[va].8h\n" ASM_PREFETCH("[%[inptr], #832]")
- "FMLA v17.8h, v1.8h, %[va].8h\n"
- "STP q16, q17, [%[outptr0]], #32\n"
- "FMLA v18.8h, v2.8h, %[va].8h\n"
- "STR q18, [%[outptr0]], #16\n"
- "FMLA v19.8h, v3.8h, %[va].8h\n" ASM_PREFETCH("[%[inptr], #896]")
- "FMLA v20.8h, v4.8h, %[va].8h\n"
- "STP q19, q20, [%[outptr1]], #32\n"
- "FMLA v21.8h, v5.8h, %[va].8h\n"
- "STR q21, [%[outptr1]], #16\n" ASM_PREFETCH("[%[inptr], #960]")
+ } else {
+ /* For ragged X, manually copy over the valid results. */
+ if ((i+23) >= xmax) {
+ for (int xi=0; xi<24; xi++) {
+ if ((i+xi) < xmax) {
+ *outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta);
+ outptr0++;
+ *outptr1 = (alpha * inptr[xi + 24]) + (*outptr1 * beta);
+ outptr1++;
+ *outptr2 = (alpha * inptr[xi + 48]) + (*outptr2 * beta);
+ outptr2++;
+ *outptr3 = (alpha * inptr[xi + 72]) + (*outptr3 * beta);
+ outptr3++;
+ *outptr4 = (alpha * inptr[xi + 96]) + (*outptr4 * beta);
+ outptr4++;
+ *outptr5 = (alpha * inptr[xi + 120]) + (*outptr5 * beta);
+ outptr5++;
+ *outptr6 = (alpha * inptr[xi + 144]) + (*outptr6 * beta);
+ outptr6++;
+ *outptr7 = (alpha * inptr[xi + 168]) + (*outptr7 * beta);
+ outptr7++;
+ }
+ }
+ inptr += 192;
+ } else {
+ /* Optimized routine to copy an entire block */
+ __asm __volatile (
+#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ ".arch armv8.2-a+fp16\n"
+#endif
+ // Rows 0-1
+ "LDP q16, q17, [%[outptr0]]\n"
+ "FMUL v16.8h, v16.8h, %[vb].8h\n"
+ "LDR q18, [%[outptr0], #32]\n"
+ "FMUL v17.8h, v17.8h, %[vb].8h\n"
+ "LDP q19, q20, [%[outptr1]]\n"
+ "FMUL v18.8h, v18.8h, %[vb].8h\n"
+ ASM_PREFETCH("[%[inptr], #768]")
+ "LDR q21, [%[outptr1], #32]\n"
+ "FMUL v19.8h, v19.8h, %[vb].8h\n"
+ "LDP q0, q1, [%[inptr]]\n"
+ "FMUL v20.8h, v20.8h, %[vb].8h\n"
+ "LDP q2, q3, [%[inptr], #32]\n"
+ "FMUL v21.8h, v21.8h, %[vb].8h\n"
+ "LDP q4, q5, [%[inptr], #64]\n"
+ "FMLA v16.8h, v0.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[inptr], #832]")
+ "FMLA v17.8h, v1.8h, %[va].8h\n"
+ "STP q16, q17, [%[outptr0]], #32\n"
+ "FMLA v18.8h, v2.8h, %[va].8h\n"
+ "STR q18, [%[outptr0]], #16\n"
+ "FMLA v19.8h, v3.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[inptr], #896]")
+ "FMLA v20.8h, v4.8h, %[va].8h\n"
+ "STP q19, q20, [%[outptr1]], #32\n"
+ "FMLA v21.8h, v5.8h, %[va].8h\n"
+ "STR q21, [%[outptr1]], #16\n"
+ ASM_PREFETCH("[%[inptr], #960]")
- // Rows 2-3
- "LDP q16, q17, [%[outptr2]]\n"
- "FMUL v16.8h, v16.8h, %[vb].8h\n"
- "LDR q18, [%[outptr2], #32]\n"
- "FMUL v17.8h, v17.8h, %[vb].8h\n"
- "LDP q19, q20, [%[outptr3]]\n"
- "FMUL v18.8h, v18.8h, %[vb].8h\n" ASM_PREFETCH("[%[inptr], #1024]")
- "LDR q21, [%[outptr3], #32]\n"
- "FMUL v19.8h, v19.8h, %[vb].8h\n"
- "LDP q0, q1, [%[inptr], #96]\n"
- "FMUL v20.8h, v20.8h, %[vb].8h\n"
- "LDP q2, q3, [%[inptr], #128]\n"
- "FMUL v21.8h, v21.8h, %[vb].8h\n"
- "LDP q4, q5, [%[inptr], #160]\n"
- "FMLA v16.8h, v0.8h, %[va].8h\n" ASM_PREFETCH("[%[inptr], #1088]")
- "FMLA v17.8h, v1.8h, %[va].8h\n"
- "STP q16, q17, [%[outptr2]], #32\n"
- "FMLA v18.8h, v2.8h, %[va].8h\n"
- "STR q18, [%[outptr2]], #16\n"
- "FMLA v19.8h, v3.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr0], #80]")
- "FMLA v20.8h, v4.8h, %[va].8h\n"
- "STP q19, q20, [%[outptr3]], #32\n"
- "FMLA v21.8h, v5.8h, %[va].8h\n"
- "STR q21, [%[outptr3]], #16\n" ASM_PREFETCH("[%[outptr1], #80]")
+ // Rows 2-3
+ "LDP q16, q17, [%[outptr2]]\n"
+ "FMUL v16.8h, v16.8h, %[vb].8h\n"
+ "LDR q18, [%[outptr2], #32]\n"
+ "FMUL v17.8h, v17.8h, %[vb].8h\n"
+ "LDP q19, q20, [%[outptr3]]\n"
+ "FMUL v18.8h, v18.8h, %[vb].8h\n"
+ ASM_PREFETCH("[%[inptr], #1024]")
+ "LDR q21, [%[outptr3], #32]\n"
+ "FMUL v19.8h, v19.8h, %[vb].8h\n"
+ "LDP q0, q1, [%[inptr], #96]\n"
+ "FMUL v20.8h, v20.8h, %[vb].8h\n"
+ "LDP q2, q3, [%[inptr], #128]\n"
+ "FMUL v21.8h, v21.8h, %[vb].8h\n"
+ "LDP q4, q5, [%[inptr], #160]\n"
+ "FMLA v16.8h, v0.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[inptr], #1088]")
+ "FMLA v17.8h, v1.8h, %[va].8h\n"
+ "STP q16, q17, [%[outptr2]], #32\n"
+ "FMLA v18.8h, v2.8h, %[va].8h\n"
+ "STR q18, [%[outptr2]], #16\n"
+ "FMLA v19.8h, v3.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr0], #80]")
+ "FMLA v20.8h, v4.8h, %[va].8h\n"
+ "STP q19, q20, [%[outptr3]], #32\n"
+ "FMLA v21.8h, v5.8h, %[va].8h\n"
+ "STR q21, [%[outptr3]], #16\n"
+ ASM_PREFETCH("[%[outptr1], #80]")
- // Rows 4-5
- "LDP q16, q17, [%[outptr4]]\n"
- "FMUL v16.8h, v16.8h, %[vb].8h\n"
- "LDR q18, [%[outptr4], #32]\n"
- "FMUL v17.8h, v17.8h, %[vb].8h\n"
- "LDP q19, q20, [%[outptr5]]\n"
- "FMUL v18.8h, v18.8h, %[vb].8h\n" ASM_PREFETCH("[%[outptr2], #80]")
- "LDR q21, [%[outptr5], #32]\n"
- "FMUL v19.8h, v19.8h, %[vb].8h\n"
- "LDP q0, q1, [%[inptr], #192]\n"
- "FMUL v20.8h, v20.8h, %[vb].8h\n"
- "LDP q2, q3, [%[inptr], #224]\n"
- "FMUL v21.8h, v21.8h, %[vb].8h\n"
- "LDP q4, q5, [%[inptr], #256]\n"
- "FMLA v16.8h, v0.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr3], #80]")
- "FMLA v17.8h, v1.8h, %[va].8h\n"
- "STP q16, q17, [%[outptr4]], #32\n"
- "FMLA v18.8h, v2.8h, %[va].8h\n"
- "STR q18, [%[outptr4]], #16\n"
- "FMLA v19.8h, v3.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr4], #80]")
- "FMLA v20.8h, v4.8h, %[va].8h\n"
- "STP q19, q20, [%[outptr5]], #32\n"
- "FMLA v21.8h, v5.8h, %[va].8h\n"
- "STR q21, [%[outptr5]], #16\n"
+ // Rows 4-5
+ "LDP q16, q17, [%[outptr4]]\n"
+ "FMUL v16.8h, v16.8h, %[vb].8h\n"
+ "LDR q18, [%[outptr4], #32]\n"
+ "FMUL v17.8h, v17.8h, %[vb].8h\n"
+ "LDP q19, q20, [%[outptr5]]\n"
+ "FMUL v18.8h, v18.8h, %[vb].8h\n"
+ ASM_PREFETCH("[%[outptr2], #80]")
+ "LDR q21, [%[outptr5], #32]\n"
+ "FMUL v19.8h, v19.8h, %[vb].8h\n"
+ "LDP q0, q1, [%[inptr], #192]\n"
+ "FMUL v20.8h, v20.8h, %[vb].8h\n"
+ "LDP q2, q3, [%[inptr], #224]\n"
+ "FMUL v21.8h, v21.8h, %[vb].8h\n"
+ "LDP q4, q5, [%[inptr], #256]\n"
+ "FMLA v16.8h, v0.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr3], #80]")
+ "FMLA v17.8h, v1.8h, %[va].8h\n"
+ "STP q16, q17, [%[outptr4]], #32\n"
+ "FMLA v18.8h, v2.8h, %[va].8h\n"
+ "STR q18, [%[outptr4]], #16\n"
+ "FMLA v19.8h, v3.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr4], #80]")
+ "FMLA v20.8h, v4.8h, %[va].8h\n"
+ "STP q19, q20, [%[outptr5]], #32\n"
+ "FMLA v21.8h, v5.8h, %[va].8h\n"
+ "STR q21, [%[outptr5]], #16\n"
- // Rows 6-7
- "LDP q16, q17, [%[outptr6]]\n"
- "FMUL v16.8h, v16.8h, %[vb].8h\n"
- "LDR q18, [%[outptr6], #32]\n"
- "FMUL v17.8h, v17.8h, %[vb].8h\n"
- "LDP q19, q20, [%[outptr7]]\n" ASM_PREFETCH("[%[outptr5], #80]")
- "FMUL v18.8h, v18.8h, %[vb].8h\n"
- "LDR q21, [%[outptr7], #32]\n"
- "FMUL v19.8h, v19.8h, %[vb].8h\n"
- "LDP q0, q1, [%[inptr], #288]\n"
- "FMUL v20.8h, v20.8h, %[vb].8h\n"
- "LDP q2, q3, [%[inptr], #320]\n"
- "FMUL v21.8h, v21.8h, %[vb].8h\n"
- "LDP q4, q5, [%[inptr], #352]\n"
- "FMLA v16.8h, v0.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr6], #128]")
- "FMLA v17.8h, v1.8h, %[va].8h\n"
- "STP q16, q17, [%[outptr6]], #32\n"
- "FMLA v18.8h, v2.8h, %[va].8h\n"
- "STR q18, [%[outptr6]], #16\n"
- "FMLA v19.8h, v3.8h, %[va].8h\n" ASM_PREFETCH("[%[outptr7], #128]")
- "FMLA v20.8h, v4.8h, %[va].8h\n"
- "STP q19, q20, [%[outptr7]], #32\n"
- "FMLA v21.8h, v5.8h, %[va].8h\n"
- "STR q21, [%[outptr7]], #16\n"
- "ADD %[inptr], %[inptr], #384\n"
- : [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3),
- [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7),
- [inptr] "+r"(inptr)
- : [va] "w"(va), [vb] "w"(vb)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21");
+ // Rows 6-7
+ "LDP q16, q17, [%[outptr6]]\n"
+ "FMUL v16.8h, v16.8h, %[vb].8h\n"
+ "LDR q18, [%[outptr6], #32]\n"
+ "FMUL v17.8h, v17.8h, %[vb].8h\n"
+ "LDP q19, q20, [%[outptr7]]\n"
+ ASM_PREFETCH("[%[outptr5], #80]")
+ "FMUL v18.8h, v18.8h, %[vb].8h\n"
+ "LDR q21, [%[outptr7], #32]\n"
+ "FMUL v19.8h, v19.8h, %[vb].8h\n"
+ "LDP q0, q1, [%[inptr], #288]\n"
+ "FMUL v20.8h, v20.8h, %[vb].8h\n"
+ "LDP q2, q3, [%[inptr], #320]\n"
+ "FMUL v21.8h, v21.8h, %[vb].8h\n"
+ "LDP q4, q5, [%[inptr], #352]\n"
+ "FMLA v16.8h, v0.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr6], #128]")
+ "FMLA v17.8h, v1.8h, %[va].8h\n"
+ "STP q16, q17, [%[outptr6]], #32\n"
+ "FMLA v18.8h, v2.8h, %[va].8h\n"
+ "STR q18, [%[outptr6]], #16\n"
+ "FMLA v19.8h, v3.8h, %[va].8h\n"
+ ASM_PREFETCH("[%[outptr7], #128]")
+ "FMLA v20.8h, v4.8h, %[va].8h\n"
+ "STP q19, q20, [%[outptr7]], #32\n"
+ "FMLA v21.8h, v5.8h, %[va].8h\n"
+ "STR q21, [%[outptr7]], #16\n"
+ "ADD %[inptr], %[inptr], #384\n"
+ : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+ [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [outptr6] "+r" (outptr6), [outptr7] "+r" (outptr7),
+ [inptr] "+r" (inptr)
+ : [va] "w" (va), [vb] "w" (vb)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v16", "v17", "v18", "v19", "v20", "v21"
+ );
+ }
}
}
}
}
-#endif // __aarch64__ && __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
+#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
index 79dd1f0..ee32ce7 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
@@ -25,18 +25,16 @@
#ifdef __aarch64__
-template <>
-inline void MergeResults<12, 8>(int32_t *out, const int32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const int32_t alpha, const int32_t beta)
-{
+template<>
+inline void MergeResults<12, 8, false>(int32_t *out, const int32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const int32_t alpha, const int32_t beta) {
const int32_t *inptr = in;
prefetch_6x(inptr);
prefetch_6x(inptr + 96);
int32x4_t alpha_value = vdupq_n_s32(alpha);
- int32x4_t beta_value = vdupq_n_s32(beta);
+ int32x4_t beta_value = vdupq_n_s32(beta);
- for(int y = y0; y < ymax; y += 8)
- {
+ for (int y=y0; y<ymax; y+=8) {
int32_t *outptr0 = out + (y * ldout) + x0;
int32_t *outptr1 = outptr0 + ldout;
int32_t *outptr2 = outptr1 + ldout;
@@ -55,17 +53,14 @@
prefetch_2x(outptr6);
prefetch_2x(outptr7);
- for(int i = x0; i < xmax; i += 12)
- {
+ for (int i=x0; i<xmax; i+=12) {
int32_t dummyres[12];
/* Make sure we throw away results if Y isn't a multiple of 8.
* We do this by pointing the result pointer at a dummy buffer
* we later discard. */
- if((y + 7) >= ymax)
- {
- switch((y + 7) - ymax)
- {
+ if ((y+7) >= ymax) {
+ switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
case 5:
@@ -88,12 +83,9 @@
}
/* For ragged X, manually copy over the valid results. */
- if((i + 11) >= xmax)
- {
- for(int xi = 0; xi < 12; xi++)
- {
- if((i + xi) < xmax)
- {
+ if ((i+11) >= xmax) {
+ for (int xi=0; xi<12; xi++) {
+ if ((i+xi) < xmax) {
*outptr0 = (alpha * inptr[xi]) + (*outptr0 * beta);
outptr0++;
*outptr1 = (alpha * inptr[xi + 12]) + (*outptr1 * beta);
@@ -113,177 +105,175 @@
}
}
inptr += 96;
- }
- else
- {
+ } else {
/* Optimized routine to copy an entire block */
- __asm __volatile(
- // Row 0
- ASM_PREFETCH("[%x[outptr1], #192]")
- "ldr q3, [%x[outptr0]]\n"
- "ldr q4, [%x[outptr0], #0x10]\n"
- "ldr q5, [%x[outptr0], #0x20]\n"
- "mul v3.4s, v3.4s, %[beta_value].4s\n"
- "ldr q6, [%x[inptr]]\n"
- "mul v4.4s, v4.4s, %[beta_value].4s\n"
- "ldr q7, [%x[inptr], #0x10]\n"
- "mul v5.4s, v5.4s, %[beta_value].4s\n"
- "ldr q8, [%x[inptr], #0x20]\n"
- "mla v3.4s, v6.4s, %[alpha_value].4s\n"
- "ldr q0, [%x[outptr1]]\n"
- "mla v4.4s, v7.4s, %[alpha_value].4s\n"
- "ldr q1, [%x[outptr1], #0x10]\n"
- "mla v5.4s, v8.4s, %[alpha_value].4s\n"
- "ldr q2, [%x[outptr1], #0x20]\n"
+ __asm __volatile (
+ // Row 0
+ ASM_PREFETCH("[%x[outptr1], #192]")
+ "ldr q3, [%x[outptr0]]\n"
+ "ldr q4, [%x[outptr0], #0x10]\n"
+ "ldr q5, [%x[outptr0], #0x20]\n"
+ "mul v3.4s, v3.4s, %[beta_value].4s\n"
+ "ldr q6, [%x[inptr]]\n"
+ "mul v4.4s, v4.4s, %[beta_value].4s\n"
+ "ldr q7, [%x[inptr], #0x10]\n"
+ "mul v5.4s, v5.4s, %[beta_value].4s\n"
+ "ldr q8, [%x[inptr], #0x20]\n"
+ "mla v3.4s, v6.4s, %[alpha_value].4s\n"
+ "ldr q0, [%x[outptr1]]\n"
+ "mla v4.4s, v7.4s, %[alpha_value].4s\n"
+ "ldr q1, [%x[outptr1], #0x10]\n"
+ "mla v5.4s, v8.4s, %[alpha_value].4s\n"
+ "ldr q2, [%x[outptr1], #0x20]\n"
- // Row 1
- ASM_PREFETCH("[%x[outptr2], #192]")
- "mul v0.4s, v0.4s, %[beta_value].4s\n"
- "ldr q6, [%x[inptr], #0x30]\n"
- "str q3, [%x[outptr0]], #0x10\n"
- "mul v1.4s, v1.4s, %[beta_value].4s\n"
- "ldr q7, [%x[inptr], #0x40]\n"
- "str q4, [%x[outptr0]], #0x10\n"
- "mul v2.4s, v2.4s, %[beta_value].4s\n"
- "ldr q8, [%x[inptr], #0x50]\n"
- "str q5, [%x[outptr0]], #0x10\n"
- "mla v0.4s, v6.4s, %[alpha_value].4s\n"
- "ldr q3, [%x[outptr2]]\n"
- "mla v1.4s, v7.4s, %[alpha_value].4s\n"
- "ldr q4, [%x[outptr2], #0x10]\n"
- "mla v2.4s, v8.4s, %[alpha_value].4s\n"
- "ldr q5, [%x[outptr2], #0x20]\n"
+ // Row 1
+ ASM_PREFETCH("[%x[outptr2], #192]")
+ "mul v0.4s, v0.4s, %[beta_value].4s\n"
+ "ldr q6, [%x[inptr], #0x30]\n"
+ "str q3, [%x[outptr0]], #0x10\n"
+ "mul v1.4s, v1.4s, %[beta_value].4s\n"
+ "ldr q7, [%x[inptr], #0x40]\n"
+ "str q4, [%x[outptr0]], #0x10\n"
+ "mul v2.4s, v2.4s, %[beta_value].4s\n"
+ "ldr q8, [%x[inptr], #0x50]\n"
+ "str q5, [%x[outptr0]], #0x10\n"
+ "mla v0.4s, v6.4s, %[alpha_value].4s\n"
+ "ldr q3, [%x[outptr2]]\n"
+ "mla v1.4s, v7.4s, %[alpha_value].4s\n"
+ "ldr q4, [%x[outptr2], #0x10]\n"
+ "mla v2.4s, v8.4s, %[alpha_value].4s\n"
+ "ldr q5, [%x[outptr2], #0x20]\n"
- // Row 2
- ASM_PREFETCH("[%x[outptr3], #192]")
- "mul v3.4s, v3.4s, %[beta_value].4s\n"
- "ldr q6, [%x[inptr], #0x60]\n"
- "str q0, [%x[outptr1]], #0x10\n"
- "mul v4.4s, v4.4s, %[beta_value].4s\n"
- "ldr q7, [%x[inptr], #0x70]\n"
- "str q1, [%x[outptr1]], #0x10\n"
- "mul v5.4s, v5.4s, %[beta_value].4s\n"
- "ldr q8, [%x[inptr], #0x80]\n"
- "str q2, [%x[outptr1]], #0x10\n"
- "mla v3.4s, v6.4s, %[alpha_value].4s\n"
- "ldr q0, [%x[outptr3]]\n"
- "mla v4.4s, v7.4s, %[alpha_value].4s\n"
- "ldr q1, [%x[outptr3], #0x10]\n"
- "mla v5.4s, v8.4s, %[alpha_value].4s\n"
- "ldr q2, [%x[outptr3], #0x20]\n"
+ // Row 2
+ ASM_PREFETCH("[%x[outptr3], #192]")
+ "mul v3.4s, v3.4s, %[beta_value].4s\n"
+ "ldr q6, [%x[inptr], #0x60]\n"
+ "str q0, [%x[outptr1]], #0x10\n"
+ "mul v4.4s, v4.4s, %[beta_value].4s\n"
+ "ldr q7, [%x[inptr], #0x70]\n"
+ "str q1, [%x[outptr1]], #0x10\n"
+ "mul v5.4s, v5.4s, %[beta_value].4s\n"
+ "ldr q8, [%x[inptr], #0x80]\n"
+ "str q2, [%x[outptr1]], #0x10\n"
+ "mla v3.4s, v6.4s, %[alpha_value].4s\n"
+ "ldr q0, [%x[outptr3]]\n"
+ "mla v4.4s, v7.4s, %[alpha_value].4s\n"
+ "ldr q1, [%x[outptr3], #0x10]\n"
+ "mla v5.4s, v8.4s, %[alpha_value].4s\n"
+ "ldr q2, [%x[outptr3], #0x20]\n"
- // Row 3
- ASM_PREFETCH("[%x[outptr4], #192]")
- "mul v0.4s, v0.4s, %[beta_value].4s\n"
- "ldr q6, [%x[inptr], #0x90]\n"
- "str q3, [%x[outptr2]], #0x10\n"
- "mul v1.4s, v1.4s, %[beta_value].4s\n"
- "ldr q7, [%x[inptr], #0xa0]\n"
- "str q4, [%x[outptr2]], #0x10\n"
- "mul v2.4s, v2.4s, %[beta_value].4s\n"
- "ldr q8, [%x[inptr], #0xb0]\n"
- "str q5, [%x[outptr2]], #0x10\n"
- "mla v0.4s, v6.4s, %[alpha_value].4s\n"
- "ldr q3, [%x[outptr4]]\n"
- "mla v1.4s, v7.4s, %[alpha_value].4s\n"
- "ldr q4, [%x[outptr4], #0x10]\n"
- "mla v2.4s, v8.4s, %[alpha_value].4s\n"
- "ldr q5, [%x[outptr4], #0x20]\n"
+ // Row 3
+ ASM_PREFETCH("[%x[outptr4], #192]")
+ "mul v0.4s, v0.4s, %[beta_value].4s\n"
+ "ldr q6, [%x[inptr], #0x90]\n"
+ "str q3, [%x[outptr2]], #0x10\n"
+ "mul v1.4s, v1.4s, %[beta_value].4s\n"
+ "ldr q7, [%x[inptr], #0xa0]\n"
+ "str q4, [%x[outptr2]], #0x10\n"
+ "mul v2.4s, v2.4s, %[beta_value].4s\n"
+ "ldr q8, [%x[inptr], #0xb0]\n"
+ "str q5, [%x[outptr2]], #0x10\n"
+ "mla v0.4s, v6.4s, %[alpha_value].4s\n"
+ "ldr q3, [%x[outptr4]]\n"
+ "mla v1.4s, v7.4s, %[alpha_value].4s\n"
+ "ldr q4, [%x[outptr4], #0x10]\n"
+ "mla v2.4s, v8.4s, %[alpha_value].4s\n"
+ "ldr q5, [%x[outptr4], #0x20]\n"
- // Row 4
- ASM_PREFETCH("[%x[outptr5], #192]")
- "mul v3.4s, v3.4s, %[beta_value].4s\n"
- "ldr q6, [%x[inptr], #0xc0]\n"
- "str q0, [%x[outptr3]], #0x10\n"
- "mul v4.4s, v4.4s, %[beta_value].4s\n"
- "ldr q7, [%x[inptr], #0xd0]\n"
- "str q1, [%x[outptr3]], #0x10\n"
- "mul v5.4s, v5.4s, %[beta_value].4s\n"
- "ldr q8, [%x[inptr], #0xe0]\n"
- "str q2, [%x[outptr3]], #0x10\n"
- "mla v3.4s, v6.4s, %[alpha_value].4s\n"
- "ldr q0, [%x[outptr5]]\n"
- "mla v4.4s, v7.4s, %[alpha_value].4s\n"
- "ldr q1, [%x[outptr5], #0x10]\n"
- "mla v5.4s, v8.4s, %[alpha_value].4s\n"
- "ldr q2, [%x[outptr5], #0x20]\n"
+ // Row 4
+ ASM_PREFETCH("[%x[outptr5], #192]")
+ "mul v3.4s, v3.4s, %[beta_value].4s\n"
+ "ldr q6, [%x[inptr], #0xc0]\n"
+ "str q0, [%x[outptr3]], #0x10\n"
+ "mul v4.4s, v4.4s, %[beta_value].4s\n"
+ "ldr q7, [%x[inptr], #0xd0]\n"
+ "str q1, [%x[outptr3]], #0x10\n"
+ "mul v5.4s, v5.4s, %[beta_value].4s\n"
+ "ldr q8, [%x[inptr], #0xe0]\n"
+ "str q2, [%x[outptr3]], #0x10\n"
+ "mla v3.4s, v6.4s, %[alpha_value].4s\n"
+ "ldr q0, [%x[outptr5]]\n"
+ "mla v4.4s, v7.4s, %[alpha_value].4s\n"
+ "ldr q1, [%x[outptr5], #0x10]\n"
+ "mla v5.4s, v8.4s, %[alpha_value].4s\n"
+ "ldr q2, [%x[outptr5], #0x20]\n"
- // Row 5
- ASM_PREFETCH("[%x[outptr6], #192]")
- "mul v0.4s, v0.4s, %[beta_value].4s\n"
- "ldr q6, [%x[inptr], #0xf0]\n"
- "str q3, [%x[outptr4]], #0x10\n"
- "mul v1.4s, v1.4s, %[beta_value].4s\n"
- "ldr q7, [%x[inptr], #0x100]\n"
- "str q4, [%x[outptr4]], #0x10\n"
- "mul v2.4s, v2.4s, %[beta_value].4s\n"
- "ldr q8, [%x[inptr], #0x110]\n"
- "str q5, [%x[outptr4]], #0x10\n"
- "mla v0.4s, v6.4s, %[alpha_value].4s\n"
- "ldr q3, [%x[outptr6]]\n"
- "mla v1.4s, v7.4s, %[alpha_value].4s\n"
- "ldr q4, [%x[outptr6], #0x10]\n"
- "mla v2.4s, v8.4s, %[alpha_value].4s\n"
- "ldr q5, [%x[outptr6], #0x20]\n"
+ // Row 5
+ ASM_PREFETCH("[%x[outptr6], #192]")
+ "mul v0.4s, v0.4s, %[beta_value].4s\n"
+ "ldr q6, [%x[inptr], #0xf0]\n"
+ "str q3, [%x[outptr4]], #0x10\n"
+ "mul v1.4s, v1.4s, %[beta_value].4s\n"
+ "ldr q7, [%x[inptr], #0x100]\n"
+ "str q4, [%x[outptr4]], #0x10\n"
+ "mul v2.4s, v2.4s, %[beta_value].4s\n"
+ "ldr q8, [%x[inptr], #0x110]\n"
+ "str q5, [%x[outptr4]], #0x10\n"
+ "mla v0.4s, v6.4s, %[alpha_value].4s\n"
+ "ldr q3, [%x[outptr6]]\n"
+ "mla v1.4s, v7.4s, %[alpha_value].4s\n"
+ "ldr q4, [%x[outptr6], #0x10]\n"
+ "mla v2.4s, v8.4s, %[alpha_value].4s\n"
+ "ldr q5, [%x[outptr6], #0x20]\n"
- // Row 6
- ASM_PREFETCH("[%x[outptr7], #192]")
- "mul v3.4s, v3.4s, %[beta_value].4s\n"
- "ldr q6, [%x[inptr], #0x120]\n"
- "str q0, [%x[outptr5]], #0x10\n"
- "mul v4.4s, v4.4s, %[beta_value].4s\n"
- "ldr q7, [%x[inptr], #0x130]\n"
- "str q1, [%x[outptr5]], #0x10\n"
- "mul v5.4s, v5.4s, %[beta_value].4s\n"
- "ldr q8, [%x[inptr], #0x140]\n"
- "str q2, [%x[outptr5]], #0x10\n"
- "mla v3.4s, v6.4s, %[alpha_value].4s\n"
- "ldr q0, [%x[outptr7]]\n"
- "mla v4.4s, v7.4s, %[alpha_value].4s\n"
- "ldr q1, [%x[outptr7], #0x10]\n"
- "mla v5.4s, v8.4s, %[alpha_value].4s\n"
- "ldr q2, [%x[outptr7], #0x20]\n"
+ // Row 6
+ ASM_PREFETCH("[%x[outptr7], #192]")
+ "mul v3.4s, v3.4s, %[beta_value].4s\n"
+ "ldr q6, [%x[inptr], #0x120]\n"
+ "str q0, [%x[outptr5]], #0x10\n"
+ "mul v4.4s, v4.4s, %[beta_value].4s\n"
+ "ldr q7, [%x[inptr], #0x130]\n"
+ "str q1, [%x[outptr5]], #0x10\n"
+ "mul v5.4s, v5.4s, %[beta_value].4s\n"
+ "ldr q8, [%x[inptr], #0x140]\n"
+ "str q2, [%x[outptr5]], #0x10\n"
+ "mla v3.4s, v6.4s, %[alpha_value].4s\n"
+ "ldr q0, [%x[outptr7]]\n"
+ "mla v4.4s, v7.4s, %[alpha_value].4s\n"
+ "ldr q1, [%x[outptr7], #0x10]\n"
+ "mla v5.4s, v8.4s, %[alpha_value].4s\n"
+ "ldr q2, [%x[outptr7], #0x20]\n"
- // Row 7
- "mul v0.4s, v0.4s, %[beta_value].4s\n"
- "ldr q6, [%x[inptr], #0x150]\n"
- "str q3, [%x[outptr6]], #0x10\n"
- "mul v1.4s, v1.4s, %[beta_value].4s\n"
- "ldr q7, [%x[inptr], #0x160]\n"
- "str q4, [%x[outptr6]], #0x10\n"
- "mul v2.4s, v2.4s, %[beta_value].4s\n"
- "ldr q8, [%x[inptr], #0x170]\n"
- "str q5, [%x[outptr6]], #0x10\n"
- "mla v0.4s, v6.4s, %[alpha_value].4s\n"
- "mla v1.4s, v7.4s, %[alpha_value].4s\n"
- "mla v2.4s, v8.4s, %[alpha_value].4s\n"
- "str q0, [%x[outptr7]], #0x10\n"
- "str q1, [%x[outptr7]], #0x10\n"
- "str q2, [%x[outptr7]], #0x10\n"
+ // Row 7
+ "mul v0.4s, v0.4s, %[beta_value].4s\n"
+ "ldr q6, [%x[inptr], #0x150]\n"
+ "str q3, [%x[outptr6]], #0x10\n"
+ "mul v1.4s, v1.4s, %[beta_value].4s\n"
+ "ldr q7, [%x[inptr], #0x160]\n"
+ "str q4, [%x[outptr6]], #0x10\n"
+ "mul v2.4s, v2.4s, %[beta_value].4s\n"
+ "ldr q8, [%x[inptr], #0x170]\n"
+ "str q5, [%x[outptr6]], #0x10\n"
+ "mla v0.4s, v6.4s, %[alpha_value].4s\n"
+ "mla v1.4s, v7.4s, %[alpha_value].4s\n"
+ "mla v2.4s, v8.4s, %[alpha_value].4s\n"
+ "str q0, [%x[outptr7]], #0x10\n"
+ "str q1, [%x[outptr7]], #0x10\n"
+ "str q2, [%x[outptr7]], #0x10\n"
- "add %x[inptr], %x[inptr], #0x180\n"
- : [outptr0] "+r"(outptr0),
- [outptr1] "+r"(outptr1),
- [outptr2] "+r"(outptr2),
- [outptr3] "+r"(outptr3),
- [outptr4] "+r"(outptr4),
- [outptr5] "+r"(outptr5),
- [outptr6] "+r"(outptr6),
- [outptr7] "+r"(outptr7),
- [inptr] "+r"(inptr)
- : [alpha_value] "w"(alpha_value),
- [beta_value] "w"(beta_value)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8");
+ "add %x[inptr], %x[inptr], #0x180\n"
+ : [outptr0] "+r" (outptr0),
+ [outptr1] "+r" (outptr1),
+ [outptr2] "+r" (outptr2),
+ [outptr3] "+r" (outptr3),
+ [outptr4] "+r" (outptr4),
+ [outptr5] "+r" (outptr5),
+ [outptr6] "+r" (outptr6),
+ [outptr7] "+r" (outptr7),
+ [inptr] "+r" (inptr)
+ : [alpha_value] "w" (alpha_value),
+ [beta_value] "w" (beta_value)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"
+ );
}
}
}
}
-template <>
-inline void MergeResults<12, 8>(uint32_t *out, const uint32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const uint32_t alpha, const uint32_t beta)
-{
- // Since the above code uses only MUL and MLA instructions discard the "unsignedness" and proceed safely.
- MergeResults<12, 8>(reinterpret_cast<int32_t *>(out), reinterpret_cast<const int32_t *>(in), ldout, y0, ymax, x0, xmax, static_cast<const int32_t>(alpha), static_cast<const int32_t>(beta));
+template<>
+inline void MergeResults<12, 8>(uint32_t *out, const uint32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const uint32_t alpha, const uint32_t beta) {
+ // Since the above code uses only MUL and MLA instructions discard the "unsignedness" and proceed safely.
+ MergeResults<12, 8>(reinterpret_cast<int32_t*>(out), reinterpret_cast<const int32_t*>(in), ldout, y0, ymax, x0, xmax, static_cast<const int32_t>(alpha), static_cast<const int32_t>(beta));
}
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/profiler.hpp b/src/core/NEON/kernels/arm_gemm/profiler.hpp
index ada0c95..1b944c4 100644
--- a/src/core/NEON/kernels/arm_gemm/profiler.hpp
+++ b/src/core/NEON/kernels/arm_gemm/profiler.hpp
@@ -31,75 +31,65 @@
#include <mutex>
#endif
-namespace arm_gemm
-{
+namespace arm_gemm {
+
#ifndef NO_MULTI_THREADING
extern std::mutex report_mutex;
#endif
-class profiler
-{
+class profiler {
private:
- static const int maxevents = 100000;
- unsigned long times[maxevents] = {};
- unsigned long units[maxevents] = {};
- int events[maxevents] = {};
- int currentevent = 0;
- int countfd = 0;
+ static const int maxevents = 100000;
+ unsigned long times[maxevents] = { };
+ unsigned long units[maxevents] = { };
+ int events[maxevents] = { };
+ int currentevent=0;
+ int countfd=0;
- class ScopedProfilerClass
- {
+ class ScopedProfilerClass {
private:
profiler &_parent;
- bool legal = false;
+ bool legal=false;
public:
- ScopedProfilerClass(profiler &prof, int i, unsigned long u)
- : _parent(prof)
- {
- if(prof.currentevent == maxevents)
+ ScopedProfilerClass(profiler &prof, int i, unsigned long u) : _parent(prof) {
+ if (prof.currentevent==maxevents)
return;
- prof.events[prof.currentevent] = i;
- prof.units[prof.currentevent] = u;
- legal = true;
+ prof.events[prof.currentevent]=i;
+ prof.units[prof.currentevent]=u;
+ legal=true;
start_counter(prof.countfd);
}
- ~ScopedProfilerClass()
- {
- if(!legal)
- return;
+ ~ScopedProfilerClass() {
+ if (!legal) return;
- long long cycs = stop_counter(_parent.countfd);
+ long long cycs = stop_counter(_parent.countfd);
_parent.times[_parent.currentevent++] = cycs;
}
};
public:
- profiler()
- {
- countfd = open_cycle_counter();
+ profiler() {
+ countfd=open_cycle_counter();
}
- ~profiler()
- {
+ ~profiler() {
close(countfd);
- int tots[5];
+ int tots[5];
unsigned long counts[5];
unsigned long tunits[5];
- const char *descs[] = { "Prepare A", "Prepare B", "Kernel", "Merge" };
+ const char * descs[] = { "Prepare A", "Prepare B", "Kernel", "Merge" };
- for(int i = 1; i < 5; i++)
- {
- tots[i] = 0;
+ for (int i=1; i<5; i++) {
+ tots[i] = 0;
counts[i] = 0;
tunits[i] = 0;
}
- for(int i = 0; i < currentevent; i++)
- {
- // printf("%10s: %ld\n", descs[events[i]-1], times[i]);
+ for (int i=0; i<currentevent; i++) {
+// printf("%10s: %ld\n", descs[events[i]-1], times[i]);
tots[events[i]]++;
counts[events[i]] += times[i];
tunits[events[i]] += units[i];
@@ -113,31 +103,26 @@
#endif
printf("%20s %9s %9s %9s %12s %9s\n", "", "Events", "Total", "Average", "Bytes/MACs", "Per cycle");
- for(int i = 1; i < 5; i++)
- {
- printf("%20s: %9d %9ld %9ld %12lu %9.2f\n", descs[i - 1], tots[i], counts[i], counts[i] / tots[i], tunits[i], (float)tunits[i] / counts[i]);
+ for (int i=1; i<5; i++) {
+ printf("%20s: %9d %9ld %9ld %12lu %9.2f\n",descs[i-1],tots[i],counts[i],counts[i]/tots[i],tunits[i],(float)tunits[i]/counts[i]);
}
}
template <typename T>
- void operator()(int i, unsigned long u, T func)
- {
- if(currentevent == maxevents)
- {
+ void operator() (int i, unsigned long u, T func) {
+ if (currentevent==maxevents) {
func();
- }
- else
- {
+ } else {
events[currentevent] = i;
- units[currentevent] = u;
+ units[currentevent] = u;
start_counter(countfd);
func();
- long long cycs = stop_counter(countfd);
+ long long cycs = stop_counter(countfd);
times[currentevent++] = cycs;
}
}
- ScopedProfilerClass ScopedProfiler(int i, unsigned long u)
- {
+
+ ScopedProfilerClass ScopedProfiler(int i, unsigned long u) {
return ScopedProfilerClass(*this, i, u);
}
};
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
new file mode 100644
index 0000000..44124a7
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2017-2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+namespace arm_gemm {
+
+/*
+ * Define "standard" transforms for the blocked GEMMs with fixed vector
+ * length.
+ *
+ * This assumes that A is interleaved 'height' ways, B is interleaved
+ * 'width' ways and transposed, and that the merge needs to work in 'height'
+ * x 'width' blocks.
+ *
+ * The optional 'block' parameter is for kernels using dot-product type
+ * instructions like UDOT and SDOT.
+ */
+template<typename TOperand, typename TResult, unsigned int height, unsigned int width, unsigned int block=1>
+class StdTransformsFixed
+{
+public:
+ template<typename TIn>
+ void PrepareA(TOperand *out, const TIn *in, const int stride, const int y0,
+ const int ymax, const int k0, const int kmax, bool transposed) {
+ if (transposed) {
+ Transform<height, block, true>(out, in, stride, y0, ymax, k0, kmax);
+ } else {
+ Transform<height, block, false>(out, in, stride, y0, ymax, k0, kmax);
+ }
+ }
+
+ template<typename TIn>
+ void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
+ const int xmax, const int k0, const int kmax, bool transposed) {
+ if (transposed) {
+ Transform<width, block, false>(out, in, stride, x0, xmax, k0, kmax);
+ } else {
+ Transform<width, block, true>(out, in, stride, x0, xmax, k0, kmax);
+ }
+ }
+
+ template<typename TOut>
+ void Merge(TOut *out, const TResult *in, int stride, int y0, int ymax, int x0, int xmax, const TOut alpha, const TOut beta) {
+ MergeResults<width, height>(out, in, stride, y0, ymax, x0, xmax, alpha, beta);
+ }
+};
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/transform.hpp b/src/core/NEON/kernels/arm_gemm/transform.hpp
index c80bb59..77d0d87 100644
--- a/src/core/NEON/kernels/arm_gemm/transform.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transform.hpp
@@ -34,64 +34,55 @@
* Need to cope with the work requested in either dimension not actually
* being a multiple of the block sizes.
*/
-template <unsigned IntBy, unsigned int BlockBy, bool Transposed, size_t TOutSize, size_t TInSize>
-struct TransformImpl
-{
+template <unsigned int tIntBy, unsigned int BlockBy, bool Transposed, size_t TOutSize, size_t TInSize, bool sve>
+struct TransformImpl {
template <typename TOut, typename TIn>
- static void Transform(TOut *out, const TIn *const in, const int stride,
- const int y0, const int ymax, const int x0, const int xmax)
- {
+ static void Transform(TOut* out, const TIn* const in, const int stride,
+ const int y0, const int ymax, const int x0, const int xmax) {
+ // For SVE cases we multiply the interleave factor by the vector length.
+ const unsigned int IntBy = tIntBy * (sve ? get_vector_length<TOut>() : 1);
+
const int n_whole_y_blocks = (ymax - y0) / IntBy;
- const int y_remainders = (ymax - y0) % IntBy;
- const int n_y_blocks = n_whole_y_blocks + (y_remainders ? 1 : 0);
+ const int y_remainders = (ymax - y0) % IntBy;
+ const int n_y_blocks = n_whole_y_blocks + (y_remainders ? 1 : 0);
const int n_whole_x_blocks = (xmax - x0) / BlockBy;
- const int x_remainders = (xmax - x0) % BlockBy;
- const int n_x_blocks = n_whole_x_blocks + (x_remainders ? 1 : 0);
+ const int x_remainders = (xmax - x0) % BlockBy;
+ const int n_x_blocks = n_whole_x_blocks + (x_remainders ? 1 : 0);
// "Y" loop: advance down the rows of the source IntBy rows at a time.
// Set up fill_rows to show the number rows to copy from, and blank_rows
// for the number of blank rows to add.
- for(int y_block = 0; y_block < n_y_blocks; y_block++)
- {
- int fill_rows = (y_block < n_whole_y_blocks) ? IntBy : y_remainders;
+ for (int y_block=0 ; y_block < n_y_blocks; y_block++) {
+ int fill_rows = (y_block < n_whole_y_blocks) ? IntBy : y_remainders;
int blank_rows = IntBy - fill_rows;
int y_base = y0 + (y_block * IntBy);
// So now advance along this block of rows, BlockBy columns at a time.
- for(int x_block = 0; x_block < n_x_blocks; x_block++)
- {
- int fill_cols = (x_block < n_whole_x_blocks) ? BlockBy : x_remainders;
+ for (int x_block=0 ; x_block < n_x_blocks; x_block++) {
+ int fill_cols = (x_block < n_whole_x_blocks) ? BlockBy : x_remainders;
int blank_cols = BlockBy - fill_cols;
int x_base = x0 + (x_block * BlockBy);
- for(int row = 0; row < fill_rows; row++)
- {
- for(int col = 0; col < fill_cols; col++)
- {
+ for (int row = 0; row < fill_rows; row++) {
+ for (int col = 0; col < fill_cols; col++) {
// In-range copy. If it's transposed, we reverse the sense of rows and columns here.
- if(Transposed)
- {
+ if (Transposed) {
*out++ = static_cast<TOut>(in[(x_base + col) * stride + y_base + row]);
- }
- else
- {
+ } else {
*out++ = static_cast<TOut>(in[(y_base + row) * stride + x_base + col]);
}
}
// "col" tail - row is in range but column is out of range.
- for(int col = 0; col < blank_cols; col++)
- {
+ for (int col=0; col < blank_cols; col++) {
*out++ = static_cast<TOut>(0);
}
}
// "row" tail - row is out of range so fill with zeros always.
- for(int row = 0; row < blank_rows; row++)
- {
- for(int col = 0; col < (fill_cols + blank_cols); col++)
- {
+ for (int row = 0; row < blank_rows; row++) {
+ for (int col=0; col < (fill_cols + blank_cols); col++) {
*out++ = static_cast<TOut>(0);
}
}
@@ -100,22 +91,22 @@
}
template <typename T>
- static inline void Transform(T *out, const T *const in, const int stride,
- const int k0, const int kmax, const int x0, const int xmax)
- {
+ static inline void Transform(T* out, const T* const in, const int stride,
+ const int k0, const int kmax, const int x0, const int xmax) {
Transform<T, T>(out, in, stride, k0, kmax, x0, xmax);
}
};
/*****************************************************************************/
-template <unsigned int IntBy, unsigned int BlockBy, bool Transposed, typename TOut, typename TIn>
+template <unsigned int IntBy, unsigned int BlockBy, bool Transposed, bool sve=false, typename TOut, typename TIn>
void Transform(
- TOut *out, const TIn *const in, const int stride,
- const int k0, const int kmax, const int x0, const int xmax)
-{
- // Redirect to a specialised implementation predicated on argument size.
- TransformImpl<IntBy, BlockBy, Transposed, sizeof(TOut), sizeof(TIn)>::Transform(
- out, in, stride, k0, kmax, x0, xmax);
+ TOut* out, const TIn* const in, const int stride,
+ const int k0, const int kmax, const int x0, const int xmax
+) {
+ // Redirect to a specialised implementation predicated on argument size.
+ TransformImpl<IntBy, BlockBy, Transposed, sizeof(TOut), sizeof(TIn), sve>::Transform(
+ out, in, stride, k0, kmax, x0, xmax
+ );
}
/*****************************************************************************/
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
index 501d6bf..492abe5 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
@@ -29,17 +29,15 @@
#include "../asmlib.hpp"
-template <>
-template <typename T>
-inline void TransformImpl<6, 1, false, 4, 4>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax)
-{
- uint32_t *outptr = reinterpret_cast<uint32_t *>(out);
- const uint32_t *inptr = reinterpret_cast<const uint32_t *>(in);
+template<>
+template<typename T>
+inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
+ uint32_t *outptr = reinterpret_cast<uint32_t *>(out);
+ const uint32_t *inptr = reinterpret_cast<const uint32_t *>(in);
uint32_t zerobuff[8];
- for(int y = y0; y < ymax; y += 6)
- {
+ for (int y=y0; y<ymax; y+=6) {
const uint32_t *inptr0 = inptr + y * ldin + k0;
const uint32_t *inptr1 = inptr0 + ldin;
const uint32_t *inptr2 = inptr1 + ldin;
@@ -54,14 +52,11 @@
//prefetch_2x(inptr4);
//prefetch_2x(inptr5);
- int x = (kmax - k0);
- for(; x > 7; x -= 8)
- {
+ int x=(kmax-k0);
+ for (;x>7;x-=8) {
/* Cope with ragged cases by copying from a buffer of zeroes instead */
- if((y + 5) >= ymax)
- {
- switch((y + 5) - ymax)
- {
+ if ((y + 5) >= ymax) {
+ switch ((y + 5) - ymax) {
/* Everything falls through in here */
case 4:
inptr1 = zerobuff;
@@ -80,67 +75,73 @@
}
}
- __asm __volatile(
+
+ __asm __volatile (
// Load up 8 elements (2 vectors) from each of 8 sources.
- "VLD1.32 {d0-d3}, [%[inptr0]]!\n" // q0=A0A1A2A3
- "VLD1.32 {d4-d7}, [%[inptr1]]!\n" // q2=B0B1B2B3
- "VLD1.32 {d8-d11}, [%[inptr2]]!\n" // q4=C0C1C2C3
- "VZIP.32 q0, q4\n" // q0=A0C0A1C1, q4 = A2C2A3C3
- "VLD1.32 {d12-d15}, [%[inptr3]]!\n" // q6=D0D1D2D3
- "VZIP.32 q2, q6\n" // q2=B0D0B1D1, q6 = B2D2B3D3
- "VLD1.32 {d16-d19}, [%[inptr4]]!\n"
- "VLD1.32 {d20-d23}, [%[inptr5]]!\n"
- "VZIP.32 q8, q10\n" // q8=E0F0E1F1, q10 = E2F2E3F3
+ "VLD1.32 {d0-d3}, [%[inptr0]]!\n" // q0=A0A1A2A3
+ "VLD1.32 {d4-d7}, [%[inptr1]]!\n" // q2=B0B1B2B3
+ "VLD1.32 {d8-d11}, [%[inptr2]]!\n" // q4=C0C1C2C3
+ "VZIP.32 q0, q4\n" // q0=A0C0A1C1, q4 = A2C2A3C3
+ "VLD1.32 {d12-d15}, [%[inptr3]]!\n" // q6=D0D1D2D3
+ "VZIP.32 q2, q6\n" // q2=B0D0B1D1, q6 = B2D2B3D3
+ "VLD1.32 {d16-d19}, [%[inptr4]]!\n"
+ "VLD1.32 {d20-d23}, [%[inptr5]]!\n"
+ "VZIP.32 q8, q10\n" // q8=E0F0E1F1, q10 = E2F2E3F3
ASM_PREFETCH("[%[inptr0], #128]")
- "VZIP.32 q0, q2\n" // q0 = A0B0C0D0, q2 = A1B1C1D1
+ "VZIP.32 q0, q2\n" // q0 = A0B0C0D0, q2 = A1B1C1D1
// Store first elements
- "VST1.32 {d0-d1}, [%[outptr]]!\n"
- "VST1.32 {d16}, [%[outptr]]!\n"
+ "VST1.32 {d0-d1}, [%[outptr]]!\n"
+ "VST1.32 {d16}, [%[outptr]]!\n"
- "VZIP.32 q4, q6\n" // q4 = A2B2C2D2, q6 = A3B3C3D3
+ "VZIP.32 q4, q6\n" // q4 = A2B2C2D2, q6 = A3B3C3D3
// Store second elements
- "VST1.32 {d4-d5}, [%[outptr]]!\n"
- "VZIP.32 q1, q5\n" ASM_PREFETCH("[%[inptr1], #128]")
- "VST1.32 {d17}, [%[outptr]]!\n"
- "VZIP.32 q3, q7\n"
+ "VST1.32 {d4-d5}, [%[outptr]]!\n"
+ "VZIP.32 q1, q5\n"
+ ASM_PREFETCH("[%[inptr1], #128]")
+ "VST1.32 {d17}, [%[outptr]]!\n"
+ "VZIP.32 q3, q7\n"
// Store third elements
- "VZIP.32 q9, q11\n"
- "VST1.32 {d8-d9}, [%[outptr]]!\n"
- "VZIP.32 q1, q3\n" ASM_PREFETCH("[%[inptr2], #128]")
- "VST1.32 {d20}, [%[outptr]]!\n"
+ "VZIP.32 q9, q11\n"
+ "VST1.32 {d8-d9}, [%[outptr]]!\n"
+ "VZIP.32 q1, q3\n"
+ ASM_PREFETCH("[%[inptr2], #128]")
+ "VST1.32 {d20}, [%[outptr]]!\n"
// Store fourth elements
- "VZIP.32 q5, q7\n"
- "VST1.32 {d12-d13}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr3], #128]")
- "VST1.32 {d21}, [%[outptr]]!\n"
+ "VZIP.32 q5, q7\n"
+ "VST1.32 {d12-d13}, [%[outptr]]!\n"
+ ASM_PREFETCH("[%[inptr3], #128]")
+ "VST1.32 {d21}, [%[outptr]]!\n"
// Fifth
- "VST1.32 {d2-d3}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr4], #128]")
- "VST1.32 {d18}, [%[outptr]]!\n"
+ "VST1.32 {d2-d3}, [%[outptr]]!\n"
+ ASM_PREFETCH("[%[inptr4], #128]")
+ "VST1.32 {d18}, [%[outptr]]!\n"
// Sixth
- "VST1.32 {d6-d7}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr5], #128]")
- "VST1.32 {d19}, [%[outptr]]!\n"
+ "VST1.32 {d6-d7}, [%[outptr]]!\n"
+ ASM_PREFETCH("[%[inptr5], #128]")
+ "VST1.32 {d19}, [%[outptr]]!\n"
// Seventh
- "VST1.32 {d10-d11}, [%[outptr]]!\n"
- "VST1.32 {d22}, [%[outptr]]!\n"
+ "VST1.32 {d10-d11}, [%[outptr]]!\n"
+ "VST1.32 {d22}, [%[outptr]]!\n"
// Eighth
- "VST1.32 {d14-d15}, [%[outptr]]!\n"
- "VST1.32 {d23}, [%[outptr]]!\n"
+ "VST1.32 {d14-d15}, [%[outptr]]!\n"
+ "VST1.32 {d23}, [%[outptr]]!\n"
- : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3),
- [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [outptr] "+r"(outptr)
+ : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3),
+ [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [outptr] "+r" (outptr)
:
- : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12");
+ : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12"
+ );
}
- for(; x > 0; x--)
- {
+ for (;x>0;x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
@@ -151,4 +152,4 @@
}
}
-#endif // __arm__
+#endif // __arm__
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp
index ea32c96..587bec3 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp
@@ -30,87 +30,98 @@
// Generic unblocked transposed 8x32-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<8, 1, true, 4, 4>::Transform(
- T *out, const T *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- // Redirect to a 16x uint16_t specialisation
- TransformImpl<16, 1, true, 2, 2>::Transform(
- reinterpret_cast<uint16_t *>(out),
- reinterpret_cast<const uint16_t *const>(in),
- stride * 2, x0 * 2, xmax * 2, k0, kmax);
+inline void TransformImpl<8, 1, true, 4, 4, false>::Transform(
+ T* out, const T* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ // Redirect to a 16x uint16_t specialisation
+ TransformImpl<16, 1, true, 2, 2, false>::Transform(
+ reinterpret_cast<uint16_t *>(out),
+ reinterpret_cast<const uint16_t *>(in),
+ stride*2, x0*2, xmax*2, k0, kmax
+ );
}
// Generic 12x16-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<16, 1, true, 2, 2>::Transform(
- T *out, const T *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- // Redirect to a uint16_t specialisation
- Transform(
- reinterpret_cast<uint16_t *>(out),
- reinterpret_cast<const uint16_t *const>(in),
- stride, x0, xmax, k0, kmax);
+inline void TransformImpl<16, 1, true, 2, 2, false>::Transform(
+ T* out, const T* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ // Redirect to a uint16_t specialisation
+ Transform(
+ reinterpret_cast<uint16_t *>(out),
+ reinterpret_cast<const uint16_t *>(in),
+ stride, x0, xmax, k0, kmax
+ );
}
// Specialised 16 x uint16_t version
template <>
-inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out)
-{
- __asm volatile(
- "VLD1.32 {d0-d3}, [%[in0]]!\n"
- "VST1.32 {d0-d3}, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]")
- : [in0] "+r"(in0),
- [out] "+r"(out)
- :
- : "q0", "q1", "memory");
+inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) {
+ __asm volatile (
+ "VLD1.32 {d0-d3}, [%[in0]]!\n"
+ "VST1.32 {d0-d3}, [%[out]]\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ : [in0] "+r" (in0),
+ [out] "+r" (out)
+ :
+ : "q0", "q1", "memory"
+ );
}
template <>
-inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out)
-{
- __asm volatile(
- "VLD1.32 {d0-d3}, [%[in0]]!\n"
- "VST1.32 {d0-d3}, [%[out]]!\n" ASM_PREFETCH("[%[in0], #192]")
- "VLD1.32 {d0-d3}, [%[in1]]!\n"
- "VST1.32 {d0-d3}, [%[out]]\n" ASM_PREFETCH("[%[in1], #192]") "SUB %[out], %[out], #32\n"
- : [in0] "+r"(in0),
- [in1] "+r"(in1),
- [out] "+r"(out)
- :
- : "q0", "q1", "memory");
+inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) {
+ __asm volatile (
+ "VLD1.32 {d0-d3}, [%[in0]]!\n"
+ "VST1.32 {d0-d3}, [%[out]]!\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ "VLD1.32 {d0-d3}, [%[in1]]!\n"
+ "VST1.32 {d0-d3}, [%[out]]\n"
+ ASM_PREFETCH("[%[in1], #192]")
+ "SUB %[out], %[out], #32\n"
+ : [in0] "+r" (in0),
+ [in1] "+r" (in1),
+ [out] "+r" (out)
+ :
+ : "q0", "q1", "memory"
+ );
}
template <>
-inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out)
-{
- __asm __volatile(
- "VLD1.32 {d0-d3}, [%[in0]]!\n"
- "VST1.32 {d0-d3}, [%[out]]!\n" ASM_PREFETCH("[%[in0], #192]")
- "VLD1.32 {d0-d3}, [%[in1]]!\n"
- "VST1.32 {d0-d3}, [%[out]]!\n" ASM_PREFETCH("[%[in1], #192]")
- "VLD1.32 {d0-d3}, [%[in2]]!\n"
- "VST1.32 {d0-d3}, [%[out]]!\n" ASM_PREFETCH("[%[in2], #192]")
- "VLD1.32 {d0-d3}, [%[in3]]!\n"
- "VST1.32 {d0-d3}, [%[out]]\n" ASM_PREFETCH("[%[in3], #192]") "SUB %[out], %[out], #96\n"
- : [in0] "+r"(in0),
- [in1] "+r"(in1),
- [in2] "+r"(in2),
- [in3] "+r"(in3),
- [out] "+r"(out)
- :
- : "q0", "q1", "memory");
+inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) {
+ __asm __volatile (
+ "VLD1.32 {d0-d3}, [%[in0]]!\n"
+ "VST1.32 {d0-d3}, [%[out]]!\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ "VLD1.32 {d0-d3}, [%[in1]]!\n"
+ "VST1.32 {d0-d3}, [%[out]]!\n"
+ ASM_PREFETCH("[%[in1], #192]")
+ "VLD1.32 {d0-d3}, [%[in2]]!\n"
+ "VST1.32 {d0-d3}, [%[out]]!\n"
+ ASM_PREFETCH("[%[in2], #192]")
+ "VLD1.32 {d0-d3}, [%[in3]]!\n"
+ "VST1.32 {d0-d3}, [%[out]]\n"
+ ASM_PREFETCH("[%[in3], #192]")
+ "SUB %[out], %[out], #96\n"
+ : [in0] "+r" (in0),
+ [in1] "+r" (in1),
+ [in2] "+r" (in2),
+ [in3] "+r" (in3),
+ [out] "+r" (out)
+ :
+ : "q0", "q1", "memory"
+ );
}
template <>
template <>
-inline void TransformImpl<16, 1, true, 2, 2>::Transform(
- uint16_t *out, const uint16_t *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- TransposeInterleaveCommon<16, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax);
+inline void TransformImpl<16, 1, true, 2, 2, false>::Transform(
+ uint16_t* out, const uint16_t* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ TransposeInterleaveCommon<16, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax);
}
#endif // __arm__
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
index 8d61f15..8ea0483 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
@@ -30,17 +30,15 @@
#include "../asmlib.hpp"
#include "../utils.hpp"
-template <>
-template <typename T>
-void TransformImpl<4, 16, false, 1, 1>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax)
-{
- uint8_t *outptr = (uint8_t *)out;
- const uint8_t *inptr = (uint8_t *)in;
+template<>
+template<typename T>
+void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
+ uint8_t *outptr = (uint8_t *)out;
+ const uint8_t *inptr = (uint8_t *)in;
uint8_t zerobuff[16];
- for(int y = y0; y < ymax; y += 4)
- {
+ for (int y=y0; y<ymax; y+=4) {
const uint8_t *inptr0 = inptr + y * ldin + k0;
const uint8_t *inptr1 = inptr0 + ldin;
const uint8_t *inptr2 = inptr1 + ldin;
@@ -51,14 +49,11 @@
prefetch_2x(inptr2);
prefetch_2x(inptr3);
- int x = (kmax - k0);
- for(; x > 15; x -= 16)
- {
+ int x=(kmax-k0);
+ for (;x>15;x-=16) {
/* Cope with ragged cases by copying from a buffer of zeroes instead */
- if((y + 3) >= ymax)
- {
- switch((y + 3) - ymax)
- {
+ if ((y + 3) >= ymax) {
+ switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
@@ -73,23 +68,28 @@
}
}
- __asm __volatile(
- "LDR q0, [%[inptr0]], #16\n" ASM_PREFETCH("[%[inptr0], #176]") "LDR q1, [%[inptr1]], #16\n" ASM_PREFETCH("[%[inptr1], #176]")
- "STP q0, q1, [%[outptr]], #32\n"
- "LDR q0, [%[inptr2]], #16\n" ASM_PREFETCH("[%[inptr2], #176]") "LDR q1, [%[inptr3]], #16\n" ASM_PREFETCH("[%[inptr3], #176]") "STP q0, q1, [%[outptr]], #32\n"
- : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3),
- [outptr] "+r"(outptr)
+ __asm __volatile (
+ "LDR q0, [%[inptr0]], #16\n"
+ ASM_PREFETCH("[%[inptr0], #176]")
+ "LDR q1, [%[inptr1]], #16\n"
+ ASM_PREFETCH("[%[inptr1], #176]")
+ "STP q0, q1, [%[outptr]], #32\n"
+ "LDR q0, [%[inptr2]], #16\n"
+ ASM_PREFETCH("[%[inptr2], #176]")
+ "LDR q1, [%[inptr3]], #16\n"
+ ASM_PREFETCH("[%[inptr3], #176]")
+ "STP q0, q1, [%[outptr]], #32\n"
+ : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3),
+ [outptr] "+r" (outptr)
:
- : "v0", "v1");
+ : "v0", "v1"
+ );
}
- if(x > 0)
- {
+ if (x>0) {
/* Need to duplicate this here, in case we didn't run the main loop. */
- if((y + 3) >= ymax)
- {
- switch((y + 3) - ymax)
- {
+ if ((y + 3) >= ymax) {
+ switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
@@ -105,16 +105,11 @@
}
/* We have to write out 16 values, copy as many legal values as there are and pad with 0 */
- auto f = [&outptr, x](const uint8_t *&p)
- {
- for(int i = 0; i < 16; i++)
- {
- if(i < x)
- {
+ auto f = [&outptr, x](const uint8_t *&p) {
+ for (int i=0; i<16; i++) {
+ if (i < x) {
*outptr++ = *p++;
- }
- else
- {
+ } else {
*outptr++ = 0;
}
}
@@ -128,4 +123,4 @@
}
}
-#endif // __aarch64__
\ No newline at end of file
+#endif // __aarch64__
\ No newline at end of file
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
index 3cbc881..91ee492 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
@@ -29,17 +29,15 @@
#include "../asmlib.hpp"
-template <>
-template <typename T>
-void TransformImpl<8, 1, false, 2, 2>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax)
-{
- uint16_t *outptr = (uint16_t *)out;
- const uint16_t *inptr = (const uint16_t *)in;
+template<>
+template<typename T>
+void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
+ uint16_t *outptr = (uint16_t *)out;
+ const uint16_t *inptr = (const uint16_t *)in;
uint16_t zerobuff[24];
- for(int y = y0; y < ymax; y += 8)
- {
+ for (int y=y0; y<ymax; y+=8) {
const uint16_t *inptr0 = inptr + y * ldin + k0;
const uint16_t *inptr1 = inptr0 + ldin;
const uint16_t *inptr2 = inptr1 + ldin;
@@ -58,14 +56,11 @@
prefetch_2x(inptr6);
prefetch_2x(inptr7);
- int x = (kmax - k0);
- for(; x > 7; x -= 8)
- {
+ int x=(kmax-k0);
+ for (;x>7;x-=8) {
/* Cope with ragged cases by copying from a buffer of zeroes instead */
- if((y + 7) >= ymax)
- {
- switch((y + 7) - ymax)
- {
+ if ((y + 7) >= ymax) {
+ switch ((y + 7) - ymax) {
/* Everything falls through in here */
case 6:
inptr1 = zerobuff;
@@ -89,72 +84,74 @@
}
int skippf = (x & 31);
- __asm __volatile(
+ __asm __volatile (
// Load up 8 elements (1 vector) from each of 8 sources.
- "CBNZ %w[skippf], 1f\n" ASM_PREFETCH("[%[inptr0], #128]")
+ "CBNZ %w[skippf], 1f\n"
+ ASM_PREFETCH("[%[inptr0], #128]")
ASM_PREFETCH("[%[inptr1], #128]")
ASM_PREFETCH("[%[inptr2], #128]")
ASM_PREFETCH("[%[inptr3], #128]")
"1:\n"
- "LDR q0, [%[inptr0]], #16\n" // q0=A0A1A2A3A4A5A6A7
- "LDR q4, [%[inptr4]], #16\n" // q8=E0E1E2E3E4E5E6E7
- "LDR q2, [%[inptr2]], #16\n" // q4=C0C1C2C3...
- "LDR q6, [%[inptr6]], #16\n"
- "ZIP1 v8.8h, v0.8h, v4.8h\n" // q8=A0E0A1E1A2E2A3E3
- "ZIP2 v16.8h, v0.8h, v4.8h\n" // q16=A4E4A5E5A6E6A7E7
- "ZIP1 v9.8h, v2.8h, v6.8h\n" // q9=C0G0C1G1C2G2C3G3
- "ZIP2 v17.8h, v2.8h, v6.8h\n" // q17=C4G4C5G5C6G6C7G7
- "LDR q1, [%[inptr1]], #16\n" // q1=B0B1B2B3B4B5B6B7
- "LDR q5, [%[inptr5]], #16\n"
- "LDR q3, [%[inptr3]], #16\n" // q3=D0D1D2D3....
- "LDR q7, [%[inptr7]], #16\n"
- "ZIP1 v10.8h, v1.8h, v5.8h\n" // q18=B0F0B1F1B2F2B3F3
- "ZIP2 v18.8h, v1.8h, v5.8h\n" // q18=B4F4B5F5B6F6B7F7
- "ZIP1 v11.8h, v3.8h, v7.8h\n" // q19=D0H0D1H1D2H2D3H3
- "ZIP2 v19.8h, v3.8h, v7.8h\n" // q19=D4H4D5H5D6H6D7H7
+ "LDR q0, [%[inptr0]], #16\n" // q0=A0A1A2A3A4A5A6A7
+ "LDR q4, [%[inptr4]], #16\n" // q8=E0E1E2E3E4E5E6E7
+ "LDR q2, [%[inptr2]], #16\n" // q4=C0C1C2C3...
+ "LDR q6, [%[inptr6]], #16\n"
+ "ZIP1 v8.8h, v0.8h, v4.8h\n" // q8=A0E0A1E1A2E2A3E3
+ "ZIP2 v16.8h, v0.8h, v4.8h\n" // q16=A4E4A5E5A6E6A7E7
+ "ZIP1 v9.8h, v2.8h, v6.8h\n" // q9=C0G0C1G1C2G2C3G3
+ "ZIP2 v17.8h, v2.8h, v6.8h\n" // q17=C4G4C5G5C6G6C7G7
+ "LDR q1, [%[inptr1]], #16\n" // q1=B0B1B2B3B4B5B6B7
+ "LDR q5, [%[inptr5]], #16\n"
+ "LDR q3, [%[inptr3]], #16\n" // q3=D0D1D2D3....
+ "LDR q7, [%[inptr7]], #16\n"
+ "ZIP1 v10.8h, v1.8h, v5.8h\n" // q18=B0F0B1F1B2F2B3F3
+ "ZIP2 v18.8h, v1.8h, v5.8h\n" // q18=B4F4B5F5B6F6B7F7
+ "ZIP1 v11.8h, v3.8h, v7.8h\n" // q19=D0H0D1H1D2H2D3H3
+ "ZIP2 v19.8h, v3.8h, v7.8h\n" // q19=D4H4D5H5D6H6D7H7
- "ZIP1 v12.8h, v8.8h, v9.8h\n" // q20=A0C0E0G0A1C1E1G1
- "ZIP2 v20.8h, v8.8h, v9.8h\n"
- "ZIP1 v13.8h, v10.8h, v11.8h\n" // q21=B0D0F0H0B1I1F1H1
- "ZIP2 v21.8h, v10.8h, v11.8h\n"
+ "ZIP1 v12.8h, v8.8h, v9.8h\n" // q20=A0C0E0G0A1C1E1G1
+ "ZIP2 v20.8h, v8.8h, v9.8h\n"
+ "ZIP1 v13.8h, v10.8h, v11.8h\n" // q21=B0D0F0H0B1I1F1H1
+ "ZIP2 v21.8h, v10.8h, v11.8h\n"
- "CBNZ %w[skippf], 2f\n" ASM_PREFETCH("[%[inptr4], #112]")
+ "CBNZ %w[skippf], 2f\n"
+ ASM_PREFETCH("[%[inptr4], #112]")
ASM_PREFETCH("[%[inptr5], #112]")
ASM_PREFETCH("[%[inptr6], #112]")
ASM_PREFETCH("[%[inptr7], #112]")
"2:\n"
- "ZIP1 v22.8h, v16.8h, v17.8h\n"
- "ZIP2 v30.8h, v16.8h, v17.8h\n"
- "ZIP1 v23.8h, v18.8h, v19.8h\n"
- "ZIP2 v31.8h, v18.8h, v19.8h\n"
+ "ZIP1 v22.8h, v16.8h, v17.8h\n"
+ "ZIP2 v30.8h, v16.8h, v17.8h\n"
+ "ZIP1 v23.8h, v18.8h, v19.8h\n"
+ "ZIP2 v31.8h, v18.8h, v19.8h\n"
- "ZIP1 v14.8h, v12.8h, v13.8h\n" // q22=A0B0C0D0E0F0G0H0
- "ZIP2 v15.8h, v12.8h, v13.8h\n" // q23=A1B1C1D1E1F1G1H1
- "STP q14, q15, [%[outptr]], #32\n" // Write back first two elements
+ "ZIP1 v14.8h, v12.8h, v13.8h\n" // q22=A0B0C0D0E0F0G0H0
+ "ZIP2 v15.8h, v12.8h, v13.8h\n" // q23=A1B1C1D1E1F1G1H1
+ "STP q14, q15, [%[outptr]], #32\n" // Write back first two elements
- "ZIP1 v0.8h, v20.8h, v21.8h\n"
- "ZIP2 v1.8h, v20.8h, v21.8h\n"
- "STP q0, q1, [%[outptr]], #32\n" // Write back next two elements
+ "ZIP1 v0.8h, v20.8h, v21.8h\n"
+ "ZIP2 v1.8h, v20.8h, v21.8h\n"
+ "STP q0, q1, [%[outptr]], #32\n" // Write back next two elements
- "ZIP1 v2.8h, v22.8h, v23.8h\n"
- "ZIP2 v3.8h, v22.8h, v23.8h\n"
- "STP q2, q3, [%[outptr]], #32\n" // Write back next two elements
+ "ZIP1 v2.8h, v22.8h, v23.8h\n"
+ "ZIP2 v3.8h, v22.8h, v23.8h\n"
+ "STP q2, q3, [%[outptr]], #32\n" // Write back next two elements
- "ZIP1 v4.8h, v30.8h, v31.8h\n"
- "ZIP2 v5.8h, v30.8h, v31.8h\n"
- "STP q4, q5, [%[outptr]], #32\n" // Write back last two elements
- : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3),
- [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr)
- : [skippf] "r"(skippf)
+ "ZIP1 v4.8h, v30.8h, v31.8h\n"
+ "ZIP2 v5.8h, v30.8h, v31.8h\n"
+ "STP q4, q5, [%[outptr]], #32\n" // Write back last two elements
+ : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3),
+ [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr)
+ : [skippf] "r" (skippf)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
- "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
- "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
+ "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+ );
}
- for(; x > 0; x--)
- {
+ for (;x>0;x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
index 47e4fa2..7a32f33 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
@@ -29,17 +29,15 @@
#include "../asmlib.hpp"
-template <>
-template <typename T>
-inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax)
-{
- uint32_t *outptr = (uint32_t *)out;
- const uint32_t *inptr = (uint32_t *)in;
+template<>
+template<typename T>
+inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
+ uint32_t *outptr = (uint32_t *)out;
+ const uint32_t *inptr = (uint32_t *)in;
uint32_t zerobuff[8];
- for(int y = y0; y < ymax; y += 8)
- {
+ for (int y=y0; y<ymax; y+=8) {
const uint32_t *inptr0 = inptr + y * ldin + k0;
const uint32_t *inptr1 = inptr0 + ldin;
const uint32_t *inptr2 = inptr1 + ldin;
@@ -58,14 +56,11 @@
prefetch_2x(inptr6);
prefetch_2x(inptr7);
- int x = (kmax - k0);
- for(; x > 7; x -= 8)
- {
+ int x=(kmax-k0);
+ for (;x>7;x-=8) {
/* Cope with ragged cases by copying from a buffer of zeroes instead */
- if((y + 7) >= ymax)
- {
- switch((y + 7) - ymax)
- {
+ if ((y + 7) >= ymax) {
+ switch ((y + 7) - ymax) {
/* Everything falls through in here */
case 6:
inptr1 = zerobuff;
@@ -88,19 +83,20 @@
}
}
- __asm __volatile(
+ __asm __volatile (
// Load up 8 elements (2 vectors) from each of 8 sources.
"LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3
"LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3
"LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3
- "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1
+ "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1
ASM_PREFETCH("[%[inptr0], #128]")
"LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3
- "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1
+ "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1
"LDP q8, q9, [%[inptr4]], #32\n"
"LDP q10, q11, [%[inptr5]], #32\n"
"LDP q12, q13, [%[inptr6]], #32\n"
- "ZIP1 v18.4s, v8.4s, v12.4s\n" ASM_PREFETCH("[%[inptr1], #128]")
+ "ZIP1 v18.4s, v8.4s, v12.4s\n"
+ ASM_PREFETCH("[%[inptr1], #128]")
"LDP q14, q15, [%[inptr7]], #32\n"
"ZIP1 v19.4s, v10.4s, v14.4s\n"
@@ -110,7 +106,8 @@
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
- "ZIP2 v16.4s, v0.4s, v4.4s\n" ASM_PREFETCH("[%[inptr3], #128]")
+ "ZIP2 v16.4s, v0.4s, v4.4s\n"
+ ASM_PREFETCH("[%[inptr3], #128]")
"ZIP2 v17.4s, v2.4s, v6.4s\n"
"STP q20, q21, [%[outptr]], #32\n" // Write back the first element of each source
@@ -118,12 +115,14 @@
"ZIP2 v19.4s, v10.4s, v14.4s\n"
"STP q22, q23, [%[outptr]], #32\n" // Write back the second element of each source
- "ZIP1 v20.4s, v16.4s, v17.4s\n" ASM_PREFETCH("[%[inptr4], #128]")
+ "ZIP1 v20.4s, v16.4s, v17.4s\n"
+ ASM_PREFETCH("[%[inptr4], #128]")
"ZIP1 v21.4s, v18.4s, v19.4s\n"
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
- "ZIP1 v16.4s, v1.4s, v5.4s\n" ASM_PREFETCH("[%[inptr5], #128]")
+ "ZIP1 v16.4s, v1.4s, v5.4s\n"
+ ASM_PREFETCH("[%[inptr5], #128]")
"ZIP1 v17.4s, v3.4s, v7.4s\n"
"STP q20, q21, [%[outptr]], #32\n" // Third element
@@ -133,14 +132,16 @@
"ZIP1 v20.4s, v16.4s, v17.4s\n"
"ZIP1 v21.4s, v18.4s, v19.4s\n"
- "ZIP2 v22.4s, v16.4s, v17.4s\n" ASM_PREFETCH("[%[inptr6], #128]")
+ "ZIP2 v22.4s, v16.4s, v17.4s\n"
+ ASM_PREFETCH("[%[inptr6], #128]")
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"ZIP2 v16.4s, v1.4s, v5.4s\n"
"ZIP2 v17.4s, v3.4s, v7.4s\n"
"STP q20, q21, [%[outptr]], #32\n" // Fifth element
- "ZIP2 v18.4s, v9.4s, v13.4s\n" ASM_PREFETCH("[%[inptr7], #128]")
+ "ZIP2 v18.4s, v9.4s, v13.4s\n"
+ ASM_PREFETCH("[%[inptr7], #128]")
"ZIP2 v19.4s, v11.4s, v15.4s\n"
"STP q22, q23, [%[outptr]], #32\n" // Sixth element
@@ -151,15 +152,15 @@
"ZIP2 v22.4s, v16.4s, v17.4s\n"
"ZIP2 v23.4s, v18.4s, v19.4s\n"
"STP q22, q23, [%[outptr]], #32\n" // Eighth element
- : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3),
- [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr)
+ : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3),
+ [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
- "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
+ );
}
- for(; x > 0; x--)
- {
+ for (;x>0;x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
@@ -172,4 +173,4 @@
}
}
-#endif // __aarch64__
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
index 1d2d496..773d56d 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
@@ -29,17 +29,15 @@
#include "../asmlib.hpp"
-template <>
-template <>
-inline void TransformImpl<8, 1, false, 4, 2>::Transform(float *out, const __fp16 *in, int ldin, int y0, int ymax, int k0, int kmax)
-{
- float *outptr = out;
- const __fp16 *inptr = in;
+template<>
+template<>
+inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const __fp16 *in, int ldin, int y0, int ymax, int k0, int kmax) {
+ float *outptr = out;
+ const __fp16 *inptr = in;
__fp16 zerobuff[8];
- for(int y = y0; y < ymax; y += 8)
- {
+ for (int y=y0; y<ymax; y+=8) {
const __fp16 *inptr0 = inptr + y * ldin + k0;
const __fp16 *inptr1 = inptr0 + ldin;
const __fp16 *inptr2 = inptr1 + ldin;
@@ -58,14 +56,11 @@
prefetch_2x(inptr6);
prefetch_2x(inptr7);
- int x = (kmax - k0);
- for(; x > 7; x -= 8)
- {
+ int x=(kmax-k0);
+ for (;x>7;x-=8) {
/* Cope with ragged cases by copying from a buffer of zeroes instead */
- if((y + 7) >= ymax)
- {
- switch((y + 7) - ymax)
- {
+ if ((y + 7) >= ymax) {
+ switch ((y + 7) - ymax) {
/* Everything falls through in here */
case 6:
inptr1 = zerobuff;
@@ -88,95 +83,100 @@
}
}
- __asm __volatile(
+ __asm __volatile (
// Load up 8 elements (2 vectors) from each of 8 sources.
- "LDR q0, [%[inptr0]], #16\n"
- "LDR q2, [%[inptr1]], #16\n"
- "FCVTL2 v1.4s, v0.8h\n"
- "FCVTL v0.4s, v0.4h\n"
- "LDR q4, [%[inptr2]], #16\n" // q4=C0C1C2C3
- "FCVTL2 v3.4s, v2.8h\n"
- "FCVTL v2.4s, v2.4h\n"
- "FCVTL2 v5.4s, v4.8h\n"
- "FCVTL v4.4s, v4.4h\n"
- "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1
+ "LDR q0, [%[inptr0]], #16\n"
+ "LDR q2, [%[inptr1]], #16\n"
+ "FCVTL2 v1.4s, v0.8h\n"
+ "FCVTL v0.4s, v0.4h\n"
+ "LDR q4, [%[inptr2]], #16\n" // q4=C0C1C2C3
+ "FCVTL2 v3.4s, v2.8h\n"
+ "FCVTL v2.4s, v2.4h\n"
+ "FCVTL2 v5.4s, v4.8h\n"
+ "FCVTL v4.4s, v4.4h\n"
+ "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1
ASM_PREFETCH("[%[inptr0], #128]")
- "LDR q6, [%[inptr3]], #16\n" // q6=D0D1D2D3
- "FCVTL2 v7.4s, v6.8h\n"
- "FCVTL v6.4s, v6.4h\n"
- "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1
- "LDR q8, [%[inptr4]], #16\n"
- "LDR q10, [%[inptr5]], #16\n"
- "FCVTL2 v9.4s, v8.8h\n"
- "FCVTL v8.4s, v8.4h\n" ASM_PREFETCH("[%[inptr1], #128]")
- "LDR q12, [%[inptr6]], #16\n"
- "FCVTL2 v11.4s, v10.8h\n"
- "FCVTL v10.4s, v10.4h\n"
- "FCVTL2 v13.4s, v12.8h\n"
- "FCVTL v12.4s, v12.4h\n"
- "ZIP1 v18.4s, v8.4s, v12.4s\n"
- "LDR q14, [%[inptr7]], #16\n"
- "FCVTL2 v15.4s, v14.8h\n"
- "FCVTL v14.4s, v14.4h\n"
- "ZIP1 v19.4s, v10.4s, v14.4s\n"
+ "LDR q6, [%[inptr3]], #16\n" // q6=D0D1D2D3
+ "FCVTL2 v7.4s, v6.8h\n"
+ "FCVTL v6.4s, v6.4h\n"
+ "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1
+ "LDR q8, [%[inptr4]], #16\n"
+ "LDR q10, [%[inptr5]], #16\n"
+ "FCVTL2 v9.4s, v8.8h\n"
+ "FCVTL v8.4s, v8.4h\n"
+ ASM_PREFETCH("[%[inptr1], #128]")
+ "LDR q12, [%[inptr6]], #16\n"
+ "FCVTL2 v11.4s, v10.8h\n"
+ "FCVTL v10.4s, v10.4h\n"
+ "FCVTL2 v13.4s, v12.8h\n"
+ "FCVTL v12.4s, v12.4h\n"
+ "ZIP1 v18.4s, v8.4s, v12.4s\n"
+ "LDR q14, [%[inptr7]], #16\n"
+ "FCVTL2 v15.4s, v14.8h\n"
+ "FCVTL v14.4s, v14.4h\n"
+ "ZIP1 v19.4s, v10.4s, v14.4s\n"
ASM_PREFETCH("[%[inptr2], #128]")
- "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0
- "ZIP1 v21.4s, v18.4s, v19.4s\n"
- "ZIP2 v22.4s, v16.4s, v17.4s\n"
- "ZIP2 v23.4s, v18.4s, v19.4s\n" ASM_PREFETCH("[%[inptr3], #128]")
+ "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0
+ "ZIP1 v21.4s, v18.4s, v19.4s\n"
+ "ZIP2 v22.4s, v16.4s, v17.4s\n"
+ "ZIP2 v23.4s, v18.4s, v19.4s\n"
+ ASM_PREFETCH("[%[inptr3], #128]")
- "ZIP2 v16.4s, v0.4s, v4.4s\n"
- "ZIP2 v17.4s, v2.4s, v6.4s\n"
- "STP q20, q21, [%[outptr]], #32\n" // Write back the first element of each source
+ "ZIP2 v16.4s, v0.4s, v4.4s\n"
+ "ZIP2 v17.4s, v2.4s, v6.4s\n"
+ "STP q20, q21, [%[outptr]], #32\n" // Write back the first element of each source
- "ZIP2 v18.4s, v8.4s, v12.4s\n" ASM_PREFETCH("[%[inptr4], #128]")
- "ZIP2 v19.4s, v10.4s, v14.4s\n"
- "STP q22, q23, [%[outptr]], #32\n" // Write back the second element of each source
+ "ZIP2 v18.4s, v8.4s, v12.4s\n"
+ ASM_PREFETCH("[%[inptr4], #128]")
+ "ZIP2 v19.4s, v10.4s, v14.4s\n"
+ "STP q22, q23, [%[outptr]], #32\n" // Write back the second element of each source
- "ZIP1 v20.4s, v16.4s, v17.4s\n"
- "ZIP1 v21.4s, v18.4s, v19.4s\n" ASM_PREFETCH("[%[inptr5], #128]")
- "ZIP2 v22.4s, v16.4s, v17.4s\n"
- "ZIP2 v23.4s, v18.4s, v19.4s\n"
+ "ZIP1 v20.4s, v16.4s, v17.4s\n"
+ "ZIP1 v21.4s, v18.4s, v19.4s\n"
+ ASM_PREFETCH("[%[inptr5], #128]")
+ "ZIP2 v22.4s, v16.4s, v17.4s\n"
+ "ZIP2 v23.4s, v18.4s, v19.4s\n"
- "ZIP1 v16.4s, v1.4s, v5.4s\n"
- "ZIP1 v17.4s, v3.4s, v7.4s\n" ASM_PREFETCH("[%[inptr6], #128]")
- "STP q20, q21, [%[outptr]], #32\n" // Third element
+ "ZIP1 v16.4s, v1.4s, v5.4s\n"
+ "ZIP1 v17.4s, v3.4s, v7.4s\n"
+ ASM_PREFETCH("[%[inptr6], #128]")
+ "STP q20, q21, [%[outptr]], #32\n" // Third element
- "ZIP1 v18.4s, v9.4s, v13.4s\n"
- "ZIP1 v19.4s, v11.4s, v15.4s\n"
- "STP q22, q23, [%[outptr]], #32\n" // Fourth element
+ "ZIP1 v18.4s, v9.4s, v13.4s\n"
+ "ZIP1 v19.4s, v11.4s, v15.4s\n"
+ "STP q22, q23, [%[outptr]], #32\n" // Fourth element
ASM_PREFETCH("[%[inptr7], #128]")
- "ZIP1 v20.4s, v16.4s, v17.4s\n"
- "ZIP1 v21.4s, v18.4s, v19.4s\n"
- "ZIP2 v22.4s, v16.4s, v17.4s\n"
- "ZIP2 v23.4s, v18.4s, v19.4s\n"
+ "ZIP1 v20.4s, v16.4s, v17.4s\n"
+ "ZIP1 v21.4s, v18.4s, v19.4s\n"
+ "ZIP2 v22.4s, v16.4s, v17.4s\n"
+ "ZIP2 v23.4s, v18.4s, v19.4s\n"
- "ZIP2 v16.4s, v1.4s, v5.4s\n"
- "ZIP2 v17.4s, v3.4s, v7.4s\n"
- "STP q20, q21, [%[outptr]], #32\n" // Fifth element
+ "ZIP2 v16.4s, v1.4s, v5.4s\n"
+ "ZIP2 v17.4s, v3.4s, v7.4s\n"
+ "STP q20, q21, [%[outptr]], #32\n" // Fifth element
- "ZIP2 v18.4s, v9.4s, v13.4s\n"
- "ZIP2 v19.4s, v11.4s, v15.4s\n"
- "STP q22, q23, [%[outptr]], #32\n" // Sixth element
+ "ZIP2 v18.4s, v9.4s, v13.4s\n"
+ "ZIP2 v19.4s, v11.4s, v15.4s\n"
+ "STP q22, q23, [%[outptr]], #32\n" // Sixth element
- "ZIP1 v20.4s, v16.4s, v17.4s\n"
- "ZIP1 v21.4s, v18.4s, v19.4s\n"
- "STP q20, q21, [%[outptr]], #32\n" // Seventh element
+ "ZIP1 v20.4s, v16.4s, v17.4s\n"
+ "ZIP1 v21.4s, v18.4s, v19.4s\n"
+ "STP q20, q21, [%[outptr]], #32\n" // Seventh element
- "ZIP2 v22.4s, v16.4s, v17.4s\n"
- "ZIP2 v23.4s, v18.4s, v19.4s\n"
- "STP q22, q23, [%[outptr]], #32\n" // Eighth element
- : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3),
- [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr)
+ "ZIP2 v22.4s, v16.4s, v17.4s\n"
+ "ZIP2 v23.4s, v18.4s, v19.4s\n"
+ "STP q22, q23, [%[outptr]], #32\n" // Eighth element
+ : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3),
+ [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
- "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
+ );
}
- for(; x > 0; x--)
- {
+ for (;x>0;x--) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp
index fd6a253..ec54ce0 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp
@@ -30,106 +30,116 @@
// Generic unblocked transposed 6x32-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<6, 1, true, 4, 4>::Transform(
- T *out, const T *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- // Redirect to a 12 x uint16_t specialisation
- TransformImpl<12, 1, true, 2, 2>::Transform(
- reinterpret_cast<uint16_t *>(out),
- reinterpret_cast<const uint16_t *const>(in),
- stride * 2, x0 * 2, xmax * 2, k0, kmax);
+inline void TransformImpl<6, 1, true, 4, 4, false>::Transform(
+ T* out, const T* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ // Redirect to a 12 x uint16_t specialisation
+ TransformImpl<12, 1, true, 2, 2, false>::Transform(
+ reinterpret_cast<uint16_t *>(out),
+ reinterpret_cast<const uint16_t *>(in),
+ stride*2, x0*2, xmax*2, k0, kmax
+ );
}
// Generic 12x16-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<12, 1, true, 2, 2>::Transform(
- T *out, const T *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- // Redirect to a uint16_t specialisation
- Transform(
- reinterpret_cast<uint16_t *>(out),
- reinterpret_cast<const uint16_t *const>(in),
- stride, x0, xmax, k0, kmax);
+inline void TransformImpl<12, 1, true, 2, 2, false>::Transform(
+ T* out, const T* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ // Redirect to a uint16_t specialisation
+ Transform(
+ reinterpret_cast<uint16_t *>(out),
+ reinterpret_cast<const uint16_t *>(in),
+ stride, x0, xmax, k0, kmax
+ );
}
// Specialised 12 x uint16_t version
template <>
-inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out)
-{
- __asm volatile(
- "LDR q0, [%[in0]]\n"
- "STR q0, [%[out]]\n"
- "LDR d1, [%[in0], #0x10]\n"
- "STR d1, [%[out], #0x10]\n"
- "ADD %x[in0], %x[in0], #0x18\n" ASM_PREFETCH("[%[in0], #192]")
- : [in0] "+r"(in0),
- [out] "+r"(out)
- :
- : "v0", "v1", "memory");
+inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) {
+ __asm volatile (
+ "LDR q0, [%[in0]]\n"
+ "STR q0, [%[out]]\n"
+ "LDR d1, [%[in0], #0x10]\n"
+ "STR d1, [%[out], #0x10]\n"
+ "ADD %x[in0], %x[in0], #0x18\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ : [in0] "+r" (in0),
+ [out] "+r" (out)
+ :
+ : "v0", "v1", "memory"
+ );
}
template <>
-inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out)
-{
- __asm volatile(
- "LDR q0, [%[in0]]\n"
- "LDR d1, [%[in0], #0x10]\n"
- "ADD %x[in0], %x[in0], #0x18\n" ASM_PREFETCH("[%[in0], #192]")
+inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) {
+ __asm volatile (
+ "LDR q0, [%[in0]]\n"
+ "LDR d1, [%[in0], #0x10]\n"
+ "ADD %x[in0], %x[in0], #0x18\n"
+ ASM_PREFETCH("[%[in0], #192]")
- "LDR x21, [%[in1]]\n"
- "LDR q2, [%[in1], #0x08]\n"
- "INS v1.d[1], x21\n"
- "ADD %x[in1], %x[in1], #0x18\n"
- "STP q0, q1, [%[out]]\n"
- "STR q2, [%x[out], #0x20]\n" ASM_PREFETCH("[%[in1], #192]")
- : [in0] "+r"(in0),
- [in1] "+r"(in1),
- [out] "+r"(out)
- :
- : "x21", "v0", "v1", "v2", "memory");
+ "LDR x21, [%[in1]]\n"
+ "LDR q2, [%[in1], #0x08]\n"
+ "INS v1.d[1], x21\n"
+ "ADD %x[in1], %x[in1], #0x18\n"
+ "STP q0, q1, [%[out]]\n"
+ "STR q2, [%x[out], #0x20]\n"
+ ASM_PREFETCH("[%[in1], #192]")
+ : [in0] "+r" (in0),
+ [in1] "+r" (in1),
+ [out] "+r" (out)
+ :
+ : "x21", "v0", "v1", "v2", "memory"
+ );
}
template <>
-inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out)
-{
- __asm __volatile(
- "LDR q0, [%x[in0]], #0x10\n"
- "STR q0, [%x[out]]\n"
- "LDR d1, [%x[in0]], #0x08\n" ASM_PREFETCH("[%[in0], #192]")
- "STR d1, [%x[out], #0x10]\n"
+inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) {
+ __asm __volatile (
+ "LDR q0, [%x[in0]], #0x10\n"
+ "STR q0, [%x[out]]\n"
+ "LDR d1, [%x[in0]], #0x08\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ "STR d1, [%x[out], #0x10]\n"
- "LDR q0, [%x[in1]], #0x10\n"
- "STR q0, [%x[out], #0x18]\n"
- "LDR d1, [%x[in1]], #0x08\n" ASM_PREFETCH("[%[in1], #192]")
- "STR d1, [%x[out], #0x28]\n"
+ "LDR q0, [%x[in1]], #0x10\n"
+ "STR q0, [%x[out], #0x18]\n"
+ "LDR d1, [%x[in1]], #0x08\n"
+ ASM_PREFETCH("[%[in1], #192]")
+ "STR d1, [%x[out], #0x28]\n"
- "LDR q0, [%x[in2]], #0x10\n"
- "STR q0, [%x[out], #0x30]\n"
- "LDR d1, [%x[in2]], #0x08\n" ASM_PREFETCH("[%[in2], #192]")
- "STR d1, [%x[out], #0x40]\n"
+ "LDR q0, [%x[in2]], #0x10\n"
+ "STR q0, [%x[out], #0x30]\n"
+ "LDR d1, [%x[in2]], #0x08\n"
+ ASM_PREFETCH("[%[in2], #192]")
+ "STR d1, [%x[out], #0x40]\n"
- "LDR q0, [%x[in3]], #0x10\n"
- "STR q0, [%x[out], #0x48]\n"
- "LDR d1, [%x[in3]], #0x08\n" ASM_PREFETCH("[%[in3], #192]") "STR d1, [%x[out], #0x58]\n"
- : [in0] "+r"(in0),
- [in1] "+r"(in1),
- [in2] "+r"(in2),
- [in3] "+r"(in3),
- [out] "+r"(out)
- :
- : "v0", "v1", "memory");
+ "LDR q0, [%x[in3]], #0x10\n"
+ "STR q0, [%x[out], #0x48]\n"
+ "LDR d1, [%x[in3]], #0x08\n"
+ ASM_PREFETCH("[%[in3], #192]")
+ "STR d1, [%x[out], #0x58]\n"
+ : [in0] "+r" (in0),
+ [in1] "+r" (in1),
+ [in2] "+r" (in2),
+ [in3] "+r" (in3),
+ [out] "+r" (out)
+ :
+ : "v0", "v1", "memory"
+ );
}
template <>
template <>
-inline void TransformImpl<12, 1, true, 2, 2>::Transform(
- uint16_t *out, const uint16_t *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- TransposeInterleaveCommon<12, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax);
+inline void TransformImpl<12, 1, true, 2, 2, false>::Transform(
+ uint16_t* out, const uint16_t* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ TransposeInterleaveCommon<12, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax);
}
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp
index b79f32f..46b4bf5 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp
@@ -28,86 +28,93 @@
#include "transpose_interleave_common.hpp"
template <>
-inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x1(const __fp16 *&in0, float *out)
-{
- __asm __volatile(
+inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x1(const __fp16 *&in0, float *out) {
+ __asm __volatile (
"LDR q0, [%[in0]], #16\n"
- "FCVTL2 v1.4s, v0.8h\n"
- "FCVTL v0.4s, v0.4h\n"
- "STP q0, q1, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]")
- "LDR d2, [%[in0]], #8\n"
- "FCVTL v2.4s, v2.4h\n"
- "STR q2, [%[out], #32]\n"
- : [in0] "+r"(in0), [out] "+r"(out)
- :
- : "v0", "v1", "v2", "memory");
-}
-
-template <>
-inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x2(const __fp16 *&in0, const __fp16 *&in1, float *out)
-{
- __asm __volatile(
- "LDR q0, [%[in0]], #16\n"
- "FCVTL2 v1.4s, v0.8h\n"
- "FCVTL v0.4s, v0.4h\n"
- "STP q0, q1, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]")
- "LDR d2, [%[in0]], #8\n"
- "FCVTL v2.4s, v2.4h\n"
- "LDR q3, [%[in1]], #16\n"
- "FCVTL2 v4.4s, v3.8h\n"
- "FCVTL v3.4s, v3.4h\n"
- "STP q2, q3, [%[out], #32]\n" ASM_PREFETCH("[%[in1], #192]")
- "LDR d5, [%[in1]], #16\n"
- "FCVTL v5.4s, v5.4h\n"
- "STP q4, q5, [%[out], #64]\n"
- : [in0] "+r"(in0), [in1] "+r"(in1), [out] "+r"(out)
- :
- : "v0", "v1", "v2", "v3", "v4", "v5", "memory");
-}
-
-template <>
-inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x4(const __fp16 *&in0, const __fp16 *&in1, const __fp16 *&in2, const __fp16 *&in3, float *out)
-{
- __asm __volatile(
- "LDR q0, [%[in0]], #16\n"
- "FCVTL2 v1.4s, v0.8h\n"
- "FCVTL v0.4s, v0.4h\n"
+ "FCVTL2 v1.4s, v0.8h\n"
+ "FCVTL v0.4s, v0.4h\n"
"STP q0, q1, [%[out]]\n"
- "LDR d2, [%[in0]], #8\n" ASM_PREFETCH("[%[in0], #192]")
- "FCVTL v2.4s, v2.4h\n"
- "LDR q3, [%[in1]], #16\n"
- "FCVTL2 v4.4s, v3.8h\n"
- "FCVTL v3.4s, v3.4h\n"
- "STP q2, q3, [%[out], #32]\n"
- "LDR d5, [%[in1]], #8\n"
- "FCVTL v5.4s, v5.4h\n" ASM_PREFETCH("[%[in1], #192]")
- "STP q4, q5, [%[out], #64]\n"
- "LDR q6, [%[in2]], #16\n"
- "FCVTL2 v7.4s, v6.8h\n"
- "FCVTL v6.4s, v6.4h\n"
- "STP q6, q7, [%[out], #96]\n"
- "LDR d8, [%[in2]], #8\n"
- "FCVTL v8.4s, v8.4h\n" ASM_PREFETCH("[%[in2], #192]")
- "LDR q9, [%[in3]], #16\n"
- "FCVTL2 v10.4s, v9.8h\n"
- "FCVTL v9.4s, v9.4h\n"
- "STP q8, q9, [%[out], #128]\n"
- "LDR d11, [%[in3]], #8\n"
- "FCVTL v11.4s, v11.4h\n"
- "STP q10, q11, [%[out], #160]\n" ASM_PREFETCH("[%[in3], #192]")
+ ASM_PREFETCH("[%[in0], #192]")
+ "LDR d2, [%[in0]], #8\n"
+ "FCVTL v2.4s, v2.4h\n"
+ "STR q2, [%[out], #32]\n"
+ : [in0] "+r" (in0), [out] "+r" (out)
+ :
+ : "v0", "v1", "v2", "memory"
+ );
+}
- : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), [in3] "+r"(in3), [out] "+r"(out)
- :
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory");
+template <>
+inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x2(const __fp16 *&in0, const __fp16 *&in1, float *out) {
+ __asm __volatile (
+ "LDR q0, [%[in0]], #16\n"
+ "FCVTL2 v1.4s, v0.8h\n"
+ "FCVTL v0.4s, v0.4h\n"
+ "STP q0, q1, [%[out]]\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ "LDR d2, [%[in0]], #8\n"
+ "FCVTL v2.4s, v2.4h\n"
+ "LDR q3, [%[in1]], #16\n"
+ "FCVTL2 v4.4s, v3.8h\n"
+ "FCVTL v3.4s, v3.4h\n"
+ "STP q2, q3, [%[out], #32]\n"
+ ASM_PREFETCH("[%[in1], #192]")
+ "LDR d5, [%[in1]], #16\n"
+ "FCVTL v5.4s, v5.4h\n"
+ "STP q4, q5, [%[out], #64]\n"
+ : [in0] "+r" (in0), [in1] "+r" (in1), [out] "+r" (out)
+ :
+ : "v0", "v1", "v2", "v3", "v4", "v5", "memory"
+ );
+}
+
+template <>
+inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x4(const __fp16 *&in0, const __fp16 *&in1, const __fp16 *&in2, const __fp16 *&in3, float *out) {
+ __asm __volatile (
+ "LDR q0, [%[in0]], #16\n"
+ "FCVTL2 v1.4s, v0.8h\n"
+ "FCVTL v0.4s, v0.4h\n"
+ "STP q0, q1, [%[out]]\n"
+ "LDR d2, [%[in0]], #8\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ "FCVTL v2.4s, v2.4h\n"
+ "LDR q3, [%[in1]], #16\n"
+ "FCVTL2 v4.4s, v3.8h\n"
+ "FCVTL v3.4s, v3.4h\n"
+ "STP q2, q3, [%[out], #32]\n"
+ "LDR d5, [%[in1]], #8\n"
+ "FCVTL v5.4s, v5.4h\n"
+ ASM_PREFETCH("[%[in1], #192]")
+ "STP q4, q5, [%[out], #64]\n"
+ "LDR q6, [%[in2]], #16\n"
+ "FCVTL2 v7.4s, v6.8h\n"
+ "FCVTL v6.4s, v6.4h\n"
+ "STP q6, q7, [%[out], #96]\n"
+ "LDR d8, [%[in2]], #8\n"
+ "FCVTL v8.4s, v8.4h\n"
+ ASM_PREFETCH("[%[in2], #192]")
+ "LDR q9, [%[in3]], #16\n"
+ "FCVTL2 v10.4s, v9.8h\n"
+ "FCVTL v9.4s, v9.4h\n"
+ "STP q8, q9, [%[out], #128]\n"
+ "LDR d11, [%[in3]], #8\n"
+ "FCVTL v11.4s, v11.4h\n"
+ "STP q10, q11, [%[out], #160]\n"
+ ASM_PREFETCH("[%[in3], #192]")
+
+ : [in0] "+r" (in0), [in1] "+r" (in1), [in2] "+r" (in2), [in3] "+r" (in3), [out] "+r" (out)
+ :
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory"
+ );
}
template <>
template <>
-inline void TransformImpl<12, 1, true, 4, 2>::Transform(
- float *out, const __fp16 *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- TransposeInterleaveCommon<12, __fp16, float>::Transform(out, in, stride, x0, xmax, k0, kmax);
+inline void TransformImpl<12, 1, true, 4, 2, false>::Transform(
+ float* out, const __fp16* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ TransposeInterleaveCommon<12, __fp16, float>::Transform(out, in, stride, x0, xmax, k0, kmax);
}
#endif // __aarch64__ && __ARM_FP16_ARGS
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp
index 5434599..80420dd 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp
@@ -30,92 +30,101 @@
// Generic unblocked transposed 12x32-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<12, 1, true, 4, 4>::Transform(
- T *out, const T *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- // Redirect to a 24 x uint16_t specialisation
- TransformImpl<24, 1, true, 2, 2>::Transform(
- reinterpret_cast<uint16_t *>(out),
- reinterpret_cast<const uint16_t *const>(in),
- stride * 2, x0 * 2, xmax * 2, k0, kmax);
+inline void TransformImpl<12, 1, true, 4, 4, false>::Transform(
+ T* out, const T* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ // Redirect to a 24 x uint16_t specialisation
+ TransformImpl<24, 1, true, 2, 2, false>::Transform(
+ reinterpret_cast<uint16_t *>(out),
+ reinterpret_cast<const uint16_t *>(in),
+ stride*2, x0*2, xmax*2, k0, kmax
+ );
}
// Generic 24x16-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<24, 1, true, 2, 2>::Transform(
- T *out, const T *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- // Redirect to a uint16_t specialisation
- Transform(
- reinterpret_cast<uint16_t *>(out),
- reinterpret_cast<const uint16_t *const>(in),
- stride, x0, xmax, k0, kmax);
+inline void TransformImpl<24, 1, true, 2, 2, false>::Transform(
+ T* out, const T* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ // Redirect to a uint16_t specialisation
+ Transform(
+ reinterpret_cast<uint16_t *>(out),
+ reinterpret_cast<const uint16_t *>(in),
+ stride, x0, xmax, k0, kmax
+ );
}
// Specialised 24 x uint16_t version
template <>
-inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out)
-{
- __asm __volatile(
- "LDP q0, q1, [%[in0]], #32\n"
- "STP q0, q1, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]")
- "LDR q2, [%[in0]], #16\n"
- "STR q2, [%[out], #32]\n"
- : [in0] "+r"(in0), [out] "+r"(out)
- :
- : "v0", "v1", "v2", "memory");
-}
-
-template <>
-inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out)
-{
- __asm __volatile(
- "LDP q0, q1, [%[in0]], #32\n"
- "STP q0, q1, [%[out]]\n" ASM_PREFETCH("[%[in0], #192]")
- "LDR q2, [%[in0]], #16\n"
- "LDP q3, q4, [%[in1]], #32\n"
- "STP q2, q3, [%[out], #32]\n" ASM_PREFETCH("[%[in1], #192]")
- "LDR q5, [%[in1]], #16\n"
- "STP q4, q5, [%[out], #64]\n"
- : [in0] "+r"(in0), [in1] "+r"(in1), [out] "+r"(out)
- :
- : "v0", "v1", "v2", "v3", "v4", "v5", "memory");
-}
-
-template <>
-inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out)
-{
- __asm __volatile(
+inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) {
+ __asm __volatile (
"LDP q0, q1, [%[in0]], #32\n"
"STP q0, q1, [%[out]]\n"
- "LDR q2, [%[in0]], #16\n" ASM_PREFETCH("[%[in0], #192]")
- "LDP q3, q4, [%[in1]], #32\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ "LDR q2, [%[in0]], #16\n"
+ "STR q2, [%[out], #32]\n"
+ : [in0] "+r" (in0), [out] "+r" (out)
+ :
+ : "v0", "v1", "v2", "memory"
+ );
+}
+
+template <>
+inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1,uint16_t *out) {
+ __asm __volatile (
+ "LDP q0, q1, [%[in0]], #32\n"
+ "STP q0, q1, [%[out]]\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ "LDR q2, [%[in0]], #16\n"
+ "LDP q3, q4, [%[in1]], #32\n"
"STP q2, q3, [%[out], #32]\n"
- "LDR q5, [%[in1]], #16\n" ASM_PREFETCH("[%[in1], #192]")
+ ASM_PREFETCH("[%[in1], #192]")
+ "LDR q5, [%[in1]], #16\n"
"STP q4, q5, [%[out], #64]\n"
- "LDP q6, q7, [%[in2]], #32\n"
+ : [in0] "+r" (in0), [in1] "+r" (in1), [out] "+r" (out)
+ :
+ : "v0", "v1", "v2", "v3", "v4", "v5", "memory"
+ );
+}
+
+template <>
+inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) {
+ __asm __volatile (
+ "LDP q0, q1, [%[in0]], #32\n"
+ "STP q0, q1, [%[out]]\n"
+ "LDR q2, [%[in0]], #16\n"
+ ASM_PREFETCH("[%[in0], #192]")
+ "LDP q3, q4, [%[in1]], #32\n"
+ "STP q2, q3, [%[out], #32]\n"
+ "LDR q5, [%[in1]], #16\n"
+ ASM_PREFETCH("[%[in1], #192]")
+ "STP q4, q5, [%[out], #64]\n"
+ "LDP q6, q7, [%[in2]], #32\n"
"STP q6, q7, [%[out], #96]\n"
- "LDR q8, [%[in2]], #16\n" ASM_PREFETCH("[%[in2], #192]")
- "LDP q9, q10, [%[in3]], #32\n"
+ "LDR q8, [%[in2]], #16\n"
+ ASM_PREFETCH("[%[in2], #192]")
+ "LDP q9, q10, [%[in3]], #32\n"
"STP q8, q9, [%[out], #128]\n"
- "LDR q11, [%[in3]], #16\n"
- "STP q10, q11, [%[out], #160]\n" ASM_PREFETCH("[%[in3], #192]")
+ "LDR q11, [%[in3]], #16\n"
+ "STP q10, q11, [%[out], #160]\n"
+ ASM_PREFETCH("[%[in3], #192]")
- : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), [in3] "+r"(in3), [out] "+r"(out)
- :
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory");
+ : [in0] "+r" (in0), [in1] "+r" (in1), [in2] "+r" (in2), [in3] "+r" (in3), [out] "+r" (out)
+ :
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory"
+ );
}
template <>
template <>
-inline void TransformImpl<24, 1, true, 2, 2>::Transform(
- uint16_t *out, const uint16_t *const in, const int stride,
- const int x0, const int xmax, const int k0, const int kmax)
-{
- TransposeInterleaveCommon<24, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax);
+inline void TransformImpl<24, 1, true, 2, 2, false>::Transform(
+ uint16_t* out, const uint16_t* const in, const int stride,
+ const int x0, const int xmax, const int k0, const int kmax
+) {
+ TransposeInterleaveCommon<24, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax);
}
-#endif // __arch64__
+#endif // __arch64__
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp b/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp
index 3218ca1..63e85c1 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp
@@ -24,137 +24,117 @@
#pragma once
template <unsigned int IntBy, typename TIn, typename TOut>
-struct TransposeInterleaveCommon
-{
- // Override the moveblock_1xY methods to improve performance
- static inline void moveblock_1x1(const TIn *&in0, TOut *out)
- {
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in0++);
+struct TransposeInterleaveCommon {
+ // Override the moveblock_1xY methods to improve performance
+ static inline void moveblock_1x1(const TIn *&in0, TOut *out) {
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in0++);
+ }
+ }
+
+ static inline void moveblock_1x2(const TIn *&in0, const TIn *&in1, TOut *out) {
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in0++);
+ }
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in1++);
+ }
+ }
+
+ static inline void moveblock_1x4(const TIn *&in0, const TIn *&in1, const TIn *&in2, const TIn *&in3, TOut *out) {
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in0++);
+ }
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in1++);
+ }
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in2++);
+ }
+ for (unsigned int i = 0; i < IntBy; i++) {
+ *out++ = static_cast<TOut>(*in3++);
+ }
+ }
+
+ static inline void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax) {
+ const auto ldin = stride;
+
+ TOut *outarray = out;
+ const TIn *inarray = in;
+ TOut *outptr_base = outarray;
+ const TIn *inptr_base = inarray + x0 + (k0 * ldin);
+ int ldout = (kmax - k0) * IntBy;
+
+ int k=(kmax-k0);
+ for ( ; k>3; k-=4) {
+ TOut *outptr = outptr_base;
+ const TIn *inptr = inptr_base;
+ const TIn *inptr1 = inptr + ldin;
+ const TIn *inptr2 = inptr1 + ldin;
+ const TIn *inptr3 = inptr2 + ldin;
+
+ prefetch_3x(inptr);
+ prefetch_3x(inptr1);
+ prefetch_3x(inptr2);
+ prefetch_3x(inptr3);
+
+ outptr_base += IntBy * 4;
+ inptr_base += ldin * 4;
+
+ for (int x = (xmax-x0) / IntBy; x > 0 ; x--) {
+ moveblock_1x4(inptr, inptr1, inptr2, inptr3, outptr);
+ outptr += ldout;
}
}
- static inline void moveblock_1x2(const TIn *&in0, const TIn *&in1, TOut *out)
- {
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in0++);
- }
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in1++);
- }
- }
+ if (k) {
+ TOut *outptr = outptr_base;
+ const TIn *inptr = inptr_base;
+ const TIn *inptr1 = inptr + ldin;
+ const TIn *inptr2 = inptr1 + ldin;
- static inline void moveblock_1x4(const TIn *&in0, const TIn *&in1, const TIn *&in2, const TIn *&in3, TOut *out)
- {
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in0++);
- }
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in1++);
- }
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in2++);
- }
- for(unsigned int i = 0; i < IntBy; i++)
- {
- *out++ = static_cast<TOut>(*in3++);
- }
- }
+ prefetch_3x(inptr);
+ prefetch_3x(inptr1);
+ prefetch_3x(inptr2);
- static inline void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax)
- {
- const auto ldin = stride;
+ for (int x = (xmax-x0) / IntBy; x > 0 ; x--) {
+ switch(k) {
+ case 3:
+ moveblock_1x2(inptr, inptr1, outptr);
+ moveblock_1x1(inptr2, outptr + IntBy * 2);
+ break;
- TOut *outarray = out;
- const TIn *inarray = in;
- TOut *outptr_base = outarray;
- const TIn *inptr_base = inarray + x0 + (k0 * ldin);
- int ldout = (kmax - k0) * IntBy;
+ case 2:
+ moveblock_1x2(inptr, inptr1, outptr);
+ break;
- int k = (kmax - k0);
- for(; k > 3; k -= 4)
- {
- TOut *outptr = outptr_base;
- const TIn *inptr = inptr_base;
- const TIn *inptr1 = inptr + ldin;
- const TIn *inptr2 = inptr1 + ldin;
- const TIn *inptr3 = inptr2 + ldin;
+ case 1:
+ moveblock_1x1(inptr, outptr);
+ break;
- prefetch_3x(inptr);
- prefetch_3x(inptr1);
- prefetch_3x(inptr2);
- prefetch_3x(inptr3);
-
- outptr_base += IntBy * 4;
- inptr_base += ldin * 4;
-
- for(int x = (xmax - x0) / IntBy; x > 0; x--)
- {
- moveblock_1x4(inptr, inptr1, inptr2, inptr3, outptr);
- outptr += ldout;
+ default:
+ UNREACHABLE("Impossible.");
}
+
+ outptr += ldout;
}
+ }
- if(k)
- {
- TOut *outptr = outptr_base;
- const TIn *inptr = inptr_base;
- const TIn *inptr1 = inptr + ldin;
- const TIn *inptr2 = inptr1 + ldin;
+ // Cope with ragged X cases
+ const unsigned int overflow = (xmax - x0) % IntBy;
+ if (overflow) {
+ const TIn *inptr_base = inarray + (xmax - overflow) + (k0 * ldin);
+ TOut *outptr = outarray + ((xmax - x0) / IntBy) * ldout;
- prefetch_3x(inptr);
- prefetch_3x(inptr1);
- prefetch_3x(inptr2);
+ for (int k=(kmax-k0); k>0; k--) {
+ const TIn *inptr = inptr_base;
+ inptr_base += ldin;
- for(int x = (xmax - x0) / IntBy; x > 0; x--)
- {
- switch(k)
- {
- case 3:
- moveblock_1x2(inptr, inptr1, outptr);
- moveblock_1x1(inptr2, outptr + IntBy * 2);
- break;
-
- case 2:
- moveblock_1x2(inptr, inptr1, outptr);
- break;
-
- case 1:
- moveblock_1x1(inptr, outptr);
- break;
-
- default:
- UNREACHABLE("Impossible.");
- }
-
- outptr += ldout;
- }
- }
-
- // Cope with ragged X cases
- const unsigned int overflow = (xmax - x0) % IntBy;
- if(overflow)
- {
- const TIn *inptr_base = inarray + (xmax - overflow) + (k0 * ldin);
- TOut *outptr = outarray + ((xmax - x0) / IntBy) * ldout;
-
- for(int k = (kmax - k0); k > 0; k--)
- {
- const TIn *inptr = inptr_base;
- inptr_base += ldin;
-
- for(unsigned int x = 0; x < IntBy; x++)
- {
- TOut val = (x < overflow) ? static_cast<TOut>(*inptr++) : static_cast<TOut>(0);
- *outptr++ = val;
- }
+ for (unsigned int x=0; x < IntBy; x++) {
+ TOut val = (x < overflow) ? static_cast<TOut>(*inptr++) : static_cast<TOut>(0);
+ *outptr++ = val;
}
}
}
+}
};
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp
index 6c5b92a..b77bc7a 100644
--- a/src/core/NEON/kernels/arm_gemm/utils.hpp
+++ b/src/core/NEON/kernels/arm_gemm/utils.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,27 +25,29 @@
#pragma once
// Macro for unreachable code (e.g. impossible default cases on switch)
-#define UNREACHABLE(why) __builtin_unreachable()
+#define UNREACHABLE(why) __builtin_unreachable()
// Paranoid option for the above with assert
// #define UNREACHABLE(why) assert(0 && why)
-inline int iceildiv(const int a, const int b)
-{
- return (a + b - 1) / b;
+inline int iceildiv(const int a, const int b) {
+ return (a + b - 1) / b;
}
template <typename T>
-inline T roundup(const T a, const T b)
-{
- T rem = a % b;
+inline T roundup(const T a, const T b) {
+ T rem = a % b;
- if(rem)
- {
- return a + b - rem;
- }
- else
- {
- return a;
- }
+ if (rem) {
+ return a + b - rem;
+ } else {
+ return a;
+ }
+}
+
+template <typename T>
+inline unsigned long get_vector_length() {
+ const unsigned long length = 16;
+
+ return length / sizeof(T);
}
diff --git a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
new file mode 100644
index 0000000..c9037ab
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/WindowIterator.h"
+
+using namespace arm_compute;
+
+INEGEMMWrapperKernel::INEGEMMWrapperKernel()
+ : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _window3d(), _window_shape()
+{
+}
+
+INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c)
+{
+ Params p;
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(a);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(b);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(c);
+
+ p.M = c->info()->tensor_shape().y();
+ p.N = c->info()->tensor_shape().x();
+ p.K = a->info()->tensor_shape().x();
+ p.multis = b->info()->tensor_shape().z();
+ p.batches = c->info()->tensor_shape().total_size_upper(2) / p.multis;
+
+ return p;
+}
+
+void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta)
+{
+ _params = extract_parameters(a, b, c);
+ _a = a;
+ _b = b;
+ _c = c;
+
+ _window3d = configure_internal(alpha, beta);
+ _window_shape = _window3d.shape();
+
+ // Convert the 3D window into a 1D window in order to allow the scheduler to arbitrary split it.
+ Window collapsed;
+ collapsed.set(0, Window::Dimension(0, _window3d.num_iterations_total()));
+
+ INEKernel::configure(collapsed);
+}
+
+void INEGEMMWrapperKernel::run(const Window &window, const ThreadInfo &info)
+{
+ const Coordinates start_offset = index2coords(_window_shape, window.x().start());
+ const Coordinates end_offset = index2coords(_window_shape, window.x().end() - 1);
+
+ run_internal(_window3d, start_offset, end_offset, info);
+}
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp b/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp
new file mode 100644
index 0000000..715fe70
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp
@@ -0,0 +1,141 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
+
+#include "NEGEMMInterleavedStrategies.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/WindowIterator.h"
+
+namespace arm_compute
+{
+template <typename To, typename Tr, bool use_dot>
+void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker,
+ const BlockSizes &block_sizes, const INEGEMMWrapperKernel::Params ¶ms, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads)
+{
+ using strategy = typename Kernel<To>::strategy;
+
+ _prepared_a = prepared_a;
+ _transformed_b = transformed_b;
+ _tmp_c = tmp_c;
+ _c = c;
+ _block_walker = block_walker;
+ _block_sizes = block_sizes;
+ _params = params;
+ _b_is_pretransposed = b_is_pretransposed;
+ _alpha = alpha;
+ _beta = beta;
+
+ auto_init_if_empty(*_tmp_c->info(), c->info()->clone()->set_tensor_shape(TensorShape{ _block_sizes.x_block * strategy::out_height(), max_num_threads }));
+}
+
+template <typename To, typename Tr, bool use_dot>
+void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::transform(const MatrixMultiplyWorkload &wl, const ThreadInfo &info, const Window &batch_window, const Coordinates &start_offset,
+ const Coordinates &end_offset)
+{
+ using strategy = typename Kernel<To>::strategy;
+
+ strategy strat(info.cpu_info);
+ TensorAccessor<To> prepared_a(*_prepared_a);
+ TensorAccessor<To> transformed_b(*_transformed_b);
+ TensorAccessor<Tr> c(*_c);
+ TensorAccessor<Tr> tmp_c(*_tmp_c);
+
+ int prev_batch = -1;
+ To *a_ptr = nullptr;
+ auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id)
+ {
+ const unsigned int y = id.x();
+ const unsigned int batch = id.y();
+ const unsigned int ymax = std::min(_params.M, y + strategy::out_height());
+
+ // If it's the first block of a new batch then reset the pointer to A.
+ if(prev_batch != static_cast<int>(batch))
+ {
+ const unsigned int first_m = id.x();
+ a_ptr = prepared_a(0, first_m, batch);
+ prev_batch = batch;
+ }
+
+ // Call matrix multiply assembly routine to process the block:
+ strat.kernel(a_ptr, transformed_b(wl._offset_transformed_b), tmp_c(0, info.thread_id), 1, wl._bblocks, wl._kern_k);
+ a_ptr += strategy::out_height() * wl._kern_k;
+
+ // Merge the result with the other blocks' results:
+ strat.transforms.Merge(c(0, 0, batch, wl._multi), tmp_c(0, info.thread_id), c.stride(1), y, ymax, wl._x0, wl._xmax, _alpha, (wl._k0 == 0 ? _beta : static_cast<Tr>(1)));
+ });
+ auto on_new_row_size = [&](unsigned int start, unsigned int end)
+ {
+ //Nothing to do
+ };
+ window_iterator.iterate_2D(on_new_row_size);
+}
+
+template <typename To, typename Tr, bool use_dot>
+void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::create_workloads(std::vector<MatrixMultiplyWorkload> &workloads)
+{
+ using strategy = typename Kernel<To>::strategy;
+
+ unsigned int offset_transformed_b = 0;
+ execute_window_loop(_block_walker, [&](const Coordinates & id)
+ {
+ const unsigned int x0 = id.x();
+ const unsigned int k0 = id.y();
+ const unsigned int multi = id.z();
+
+ const unsigned int xmax = std::min(x0 + _block_walker.x().step(), _params.N);
+ const unsigned int kmax = std::min(k0 + _block_walker.y().step(), _params.K);
+
+ // Figure out how many "K" the kernel will actually process.
+ const int kern_k = ceil_to_multiple(kmax - k0, strategy::k_unroll());
+ const int bblocks = DIV_CEIL(xmax - x0, strategy::out_width());
+
+ workloads.push_back(MatrixMultiplyWorkload(offset_transformed_b, x0, xmax, k0, kmax, multi, kern_k, bblocks));
+
+ if(_b_is_pretransposed)
+ {
+ offset_transformed_b += bblocks * strategy::out_width() * kern_k;
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ });
+}
+
+template class NEGEMMInterleavedMatrixMultiplyWrapperTemplate<float, float>;
+#ifdef __aarch64__
+template class NEGEMMInterleavedMatrixMultiplyWrapperTemplate<uint8_t, uint32_t>;
+template class NEGEMMInterleavedMatrixMultiplyWrapperTemplate<int8_t, int32_t>;
+template class NEGEMMInterleavedMatrixMultiplyWrapperTemplate<uint8_t, uint32_t, true>;
+template class NEGEMMInterleavedMatrixMultiplyWrapperTemplate<int8_t, int32_t, true>;
+#endif /* __aarch64__ */
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class NEGEMMInterleavedMatrixMultiplyWrapperTemplate<float16_t, float16_t>;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+} // namespace arm_compute
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp
new file mode 100644
index 0000000..f33a14f
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp
@@ -0,0 +1,170 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
+
+#include "NEGEMMInterleavedStrategies.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+
+namespace arm_compute
+{
+namespace
+{
+// Call the lambda function for each workload generated by the passed window.
+template <typename To, bool use_dot, typename Lambda>
+void for_each_element_in_window(const Window &window, const ITensor *b, ITensor *transformed_b, unsigned int N, unsigned int K, Lambda &&lambda)
+{
+ using strategy = typename Kernel<To, use_dot>::strategy;
+
+ unsigned int offset_transformed_b = transformed_b->info()->offset_first_element_in_bytes();
+ execute_window_loop(window, [&](const Coordinates & coordinates)
+ {
+ const unsigned int x0 = coordinates.x();
+ const unsigned int k0 = coordinates.y();
+ const unsigned int multi = coordinates.z();
+
+ const unsigned int offset_b = b->info()->offset_element_in_bytes(Coordinates(0, 0, multi));
+ const unsigned int xmax = std::min(x0 + window.x().step(), N);
+ const unsigned int kmax = std::min(k0 + window.y().step(), K);
+
+ /* Figure out the size of each block. */
+ unsigned int x_size = (xmax - x0);
+ unsigned int k_size = (kmax - k0);
+
+ /* Round sizes up as needed. */
+ x_size = ceil_to_multiple(x_size, strategy::out_width());
+ k_size = ceil_to_multiple(k_size, strategy::k_unroll());
+
+ lambda(PrepareBWorkload(offset_b, offset_transformed_b, x0, xmax, k0, kmax));
+
+ //Each workload represents one block:
+ offset_transformed_b += (x_size * k_size * sizeof(To));
+ });
+}
+
+// Calculate the size of transformed_b:
+template <typename To, bool use_dot>
+unsigned int get_B_pretransposed_array_size(unsigned int N, unsigned int K, const BlockSizes &bs)
+{
+ using strategy = typename Kernel<To, use_dot>::strategy;
+
+ // How many full blocks do N / K contain ?
+ size_t num_full_k = K / bs.k_block;
+ size_t num_full_x = N / bs.x_block;
+
+ ARM_COMPUTE_ERROR_ON(bs.x_block % strategy::out_width() != 0);
+ ARM_COMPUTE_ERROR_ON(bs.k_block % strategy::k_unroll() != 0);
+
+ size_t normal_x_size = bs.x_block;
+ size_t normal_k_size = bs.k_block;
+
+ // Round up the leftovers to be a multiple of the strategy processing size:
+ size_t left_over_x_size = ceil_to_multiple(N % bs.x_block, strategy::out_width());
+ size_t left_over_k_size = ceil_to_multiple(K % bs.k_block, strategy::k_unroll());
+
+ // Calculate the total size of the buffer:
+ size_t total = num_full_k * normal_k_size * (num_full_x * normal_x_size + left_over_x_size);
+ total += left_over_k_size * (left_over_x_size + num_full_x * normal_x_size);
+ total *= sizeof(To);
+ return total;
+}
+
+} // namespace
+
+template <typename To, bool use_dot>
+BlockSizes NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::block_sizes() const
+{
+ return _block_sizes;
+}
+
+template <typename To, bool use_dot>
+void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::configure(const ITensor *b, ITensor *transformed_b, bool transpose_b, const CPUInfo &ci, const INEGEMMWrapperKernel::Params ¶ms)
+{
+ using strategy = typename Kernel<To, use_dot>::strategy;
+
+ const unsigned int multis = b->info()->tensor_shape().z();
+ _Nsize = b->info()->tensor_shape().x();
+ _Ksize = b->info()->tensor_shape().y();
+ _b = b;
+ _transformed_b = transformed_b;
+ _transpose_b = transpose_b;
+
+ _block_sizes = calculate_block_sizes<strategy>(ci, params.M, params.N, params.K);
+
+ auto_init_if_empty(*transformed_b->info(), b->info()->clone()->set_tensor_shape(TensorShape{ get_B_pretransposed_array_size<To, use_dot>(_Nsize, _Ksize, _block_sizes) }));
+
+ Window window;
+ window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_Nsize, _block_sizes.x_block), _block_sizes.x_block));
+ window.set(Window::DimY, Window::Dimension(0, ceil_to_multiple(_Ksize, _block_sizes.k_block), _block_sizes.k_block));
+ window.set(Window::DimZ, Window::Dimension(0, multis));
+
+ INEKernel::configure(window);
+}
+
+template <typename To, bool use_dot>
+void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::transform(const PrepareBWorkload &wl, const ThreadInfo &info)
+{
+ using strategy = typename Kernel<To, use_dot>::strategy;
+
+ strategy strat(info.cpu_info);
+ strat.transforms.PrepareB(reinterpret_cast<To *>(_transformed_b->buffer() + wl._offset_transformed_b),
+ reinterpret_cast<To *>(_b->buffer() + wl._offset_b),
+ _b->info()->strides_in_bytes().y() / sizeof(To),
+ wl._x0, wl._xmax, wl._k0, wl._kmax, _transpose_b);
+}
+
+template <typename To, bool use_dot>
+void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::create_workloads(std::vector<PrepareBWorkload> &workloads)
+{
+ for_each_element_in_window<To, use_dot>(window(), _b, _transformed_b, _Nsize, _Ksize, [&workloads](PrepareBWorkload && wl)
+ {
+ workloads.push_back(std::move(wl));
+ });
+}
+
+template <typename To, bool use_dot>
+void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(window, INEKernel::window());
+ for_each_element_in_window<To, use_dot>(window, _b, _transformed_b, _Nsize, _Ksize, [&](PrepareBWorkload && wl)
+ {
+ this->transform(wl, info);
+ });
+}
+
+template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<float>;
+#ifdef __aarch64__
+template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<uint8_t>;
+template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<int8_t>;
+template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<uint8_t, true>;
+template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<int8_t, true>;
+#endif /* __aarch64__ */
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class NEGEMMInterleavedPrepareBWrapperKernelTemplate<float16_t>;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+} // namespace arm_compute
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
new file mode 100644
index 0000000..26a8ade
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_NEGEMMINTERLEAVEDSTRATEGIES_H__
+#define __ARM_COMPUTE_NEGEMMINTERLEAVEDSTRATEGIES_H__
+
+#include "../arm_gemm/utils.hpp"
+#include "arm_gemm.hpp"
+
+#include "../arm_gemm/mergeresults.hpp"
+#include "../arm_gemm/transform.hpp"
+
+#include "../arm_gemm/kernels/a32_sgemm_8x6.hpp"
+#include "../arm_gemm/kernels/a64_gemm_s8_12x8.hpp"
+#include "../arm_gemm/kernels/a64_gemm_s8_4x4.hpp"
+#include "../arm_gemm/kernels/a64_gemm_u8_12x8.hpp"
+#include "../arm_gemm/kernels/a64_gemm_u8_4x4.hpp"
+#include "../arm_gemm/kernels/a64_hgemm_24x8.hpp"
+#include "../arm_gemm/kernels/a64_sgemm_12x8.hpp"
+
+namespace arm_compute
+{
+namespace
+{
+template <typename To, bool use_dot = false>
+struct Kernel
+{
+};
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+struct Kernel<float16_t, false>
+{
+ using strategy = arm_gemm::hgemm_24x8;
+};
+#endif /*__ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+#ifdef __aarch64__
+template <>
+struct Kernel<float, false>
+{
+ using strategy = arm_gemm::sgemm_12x8;
+};
+template <>
+struct Kernel<int8_t, false>
+{
+ using strategy = arm_gemm::gemm_s8_4x4;
+};
+template <>
+struct Kernel<uint8_t, false>
+{
+ using strategy = arm_gemm::gemm_u8_4x4;
+};
+
+//Use different strategies for 8bit dot product:
+template <>
+struct Kernel<int8_t, true>
+{
+ using strategy = arm_gemm::gemm_s8_12x8;
+};
+template <>
+struct Kernel<uint8_t, true>
+{
+ using strategy = arm_gemm::gemm_u8_12x8;
+};
+#else
+template <>
+struct Kernel<float, false>
+{
+ using strategy = arm_gemm::sgemm_8x6;
+};
+#endif /* __aarch64__ */
+
+} // namespace
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDSTRATEGIES_H__ */
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.cpp b/src/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.cpp
new file mode 100644
index 0000000..1780a18
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.cpp
@@ -0,0 +1,117 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
+
+#include "NEGEMMInterleavedStrategies.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/WindowIterator.h"
+
+#include "utils/TypePrinter.h"
+
+namespace arm_compute
+{
+template <typename To, bool use_dot>
+void NEGEMMInterleavedTransformAWrapperTemplate<To, use_dot>::configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker,
+ const INEGEMMWrapperKernel::Params ¶ms)
+{
+ _a = a;
+ _transformed_a = transformed_a;
+ _transpose_a = transpose_a;
+ _Ksize = params.K;
+ _Msize = params.M;
+ _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
+}
+
+template <typename To, bool use_dot>
+void NEGEMMInterleavedTransformAWrapperTemplate<To, use_dot>::transform(const TransformAWorkload &wl, const ThreadInfo &info, const Window &batch_window, const Coordinates &start_offset,
+ const Coordinates &end_offset)
+{
+ using strategy = typename Kernel<To, use_dot>::strategy;
+
+ strategy strat(info.cpu_info);
+ TensorAccessor<To> a(*_a);
+ TensorAccessor<To> transformed_a(*_transformed_a);
+
+ if(_a->info()->data_layout() == DataLayout::NHWC)
+ {
+ // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
+ // the relevant multiple of the row stride.
+ const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize;
+ a.set_stride(2, nhwc_batch_stride);
+ }
+
+ unsigned int last_m = 0;
+ int last_y = -1;
+ auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id)
+ {
+ if(id.y() != last_y)
+ {
+ last_y = id.y();
+ unsigned int batch = id.y();
+ unsigned int first_m = id.x();
+
+ if(first_m >= last_m)
+ return;
+
+ strat.transforms.PrepareA(transformed_a(0, first_m, batch),
+ a(0, 0, batch, wl._multi),
+ a.stride(1), first_m, last_m, wl._k0, wl._kmax, _transpose_a);
+ }
+ });
+ auto on_new_row_size = [&](unsigned int start, unsigned int end)
+ {
+ last_m = std::min(end, _Msize);
+ };
+ window_iterator.iterate_2D(on_new_row_size);
+}
+
+template <typename To, bool use_dot>
+void NEGEMMInterleavedTransformAWrapperTemplate<To, use_dot>::create_workloads(std::vector<TransformAWorkload> &workloads)
+{
+ execute_window_loop(_k_multi_window, [&](const Coordinates & id)
+ {
+ const unsigned int k0 = id.x();
+ const unsigned int multi = id.y();
+ const unsigned int kmax = std::min(k0 + _k_multi_window.x().step(), _Ksize);
+
+ workloads.push_back(TransformAWorkload(k0, kmax, multi));
+ });
+}
+
+template class NEGEMMInterleavedTransformAWrapperTemplate<float>;
+#ifdef __aarch64__
+template class NEGEMMInterleavedTransformAWrapperTemplate<uint8_t>;
+template class NEGEMMInterleavedTransformAWrapperTemplate<int8_t>;
+template class NEGEMMInterleavedTransformAWrapperTemplate<uint8_t, true>;
+template class NEGEMMInterleavedTransformAWrapperTemplate<int8_t, true>;
+#endif /* __aarch64__ */
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class NEGEMMInterleavedTransformAWrapperTemplate<float16_t>;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+} // namespace arm_compute
diff --git a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
new file mode 100644
index 0000000..fb217f0
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
@@ -0,0 +1,123 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/WindowIterator.h"
+
+#include "../arm_gemm/utils.hpp"
+#include "arm_gemm.hpp"
+
+#include "../arm_gemm/mergeresults.hpp"
+#include "../arm_gemm/transform.hpp"
+
+#include "../arm_gemm/kernels/a32_sgemm_8x6.hpp"
+#include "../arm_gemm/kernels/a64_sgemm_12x8.hpp"
+#include "../arm_gemm/kernels/a64_sgemm_native_16x4.hpp"
+#include "../arm_gemm/kernels/a64_sgemv_pretransposed.hpp"
+#include "../arm_gemm/kernels/a64_sgemv_trans.hpp"
+
+namespace arm_compute
+{
+namespace
+{
+template <typename To, typename Tr>
+struct Kernel
+{
+};
+
+#ifdef __aarch64__
+template <>
+struct Kernel<float, float>
+{
+ using strategy = arm_gemm::sgemm_native_16x4;
+};
+#endif /* __aarch64__ */
+
+} // namespace
+
+template <typename To, typename Tr>
+Window NEGEMMNativeWrapperKernel<To, Tr>::configure_internal(float alpha, float beta)
+{
+ using strategy = typename Kernel<To, Tr>::strategy;
+
+ _beta = beta;
+
+ //Note: The window is shifted down by 1 dimension compare to the tensors
+ Window window;
+ window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_params.M, strategy::out_height()), strategy::out_height()));
+ window.set(Window::DimY, Window::Dimension(0, _params.batches));
+ window.set(Window::DimZ, Window::Dimension(0, _params.multis));
+
+ return window;
+}
+
+template <typename To, typename Tr>
+void NEGEMMNativeWrapperKernel<To, Tr>::run_internal(const Window &window, const Coordinates &start_offset, const Coordinates &end_offset, const ThreadInfo &info)
+{
+ using strategy = typename Kernel<To, Tr>::strategy;
+
+ TensorAccessor<To> a(*_a);
+ TensorAccessor<To> b(*_b);
+ TensorAccessor<Tr> c(*_c);
+
+ if(_a->info()->data_layout() == DataLayout::NHWC)
+ {
+ // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
+ // the relevant multiple of the row stride.
+ const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _c->info()->dimension(1);
+ a.set_stride(2, nhwc_batch_stride);
+ }
+
+ unsigned int m_end = 0;
+
+ strategy strat(info.cpu_info);
+ auto window_iterator = arm_compute::create_window_iterator(window, start_offset, end_offset, [&](const Coordinates & id)
+ {
+ const unsigned int y0 = id.x();
+ const unsigned int batch = id.y();
+ const unsigned int multi = id.z();
+ const unsigned int ymax = std::min(y0 + strategy::out_height(), m_end);
+
+ strat.kernel(a(0, y0, batch, multi), a.stride(Window::DimY),
+ b(0, 0, multi), b.stride(Window::DimY),
+ c(0, y0, batch, multi), c.stride(Window::DimY),
+ _beta, (ymax - y0), _params.N, _params.K);
+ });
+
+ auto on_new_row_size = [&](unsigned int start, unsigned int end)
+ {
+ m_end = std::min(end, _params.M);
+ };
+
+ window_iterator.iterate_3D(on_new_row_size);
+}
+
+#ifdef __aarch64__
+template class NEGEMMNativeWrapperKernel<float, float>;
+#endif /* __aarch64__ */
+
+} // namespace arm_compute
diff --git a/src/core/TensorInfo.cpp b/src/core/TensorInfo.cpp
index 676938a..b77a47e 100644
--- a/src/core/TensorInfo.cpp
+++ b/src/core/TensorInfo.cpp
@@ -33,8 +33,8 @@
using namespace arm_compute;
TensorInfo::TensorInfo()
- : _total_size(0), _fixed_point_position(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN),
- _is_resizable{ true }, _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW)
+ : _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true },
+ _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW)
{
}
@@ -42,7 +42,6 @@
: TensorInfo()
{
_total_size = info.total_size();
- _fixed_point_position = info.fixed_point_position();
_offset_first_element_in_bytes = info.offset_first_element_in_bytes();
_strides_in_bytes = info.strides_in_bytes();
_num_channels = info.num_channels();
@@ -72,22 +71,22 @@
init(tensor_shape, format);
}
-TensorInfo::TensorInfo(size_t num_channels, DataType data_type, size_t fixed_point_position)
+TensorInfo::TensorInfo(size_t num_channels, DataType data_type)
: TensorInfo()
{
- init(TensorShape(), num_channels, data_type, fixed_point_position);
+ init(TensorShape(), num_channels, data_type);
}
-TensorInfo::TensorInfo(const TensorShape &tensor_shape, size_t num_channels, DataType data_type, int fixed_point_position)
+TensorInfo::TensorInfo(const TensorShape &tensor_shape, size_t num_channels, DataType data_type)
: TensorInfo()
{
- init(tensor_shape, num_channels, data_type, fixed_point_position);
+ init(tensor_shape, num_channels, data_type);
}
TensorInfo::TensorInfo(const TensorShape &tensor_shape, size_t num_channels, DataType data_type, QuantizationInfo quantization_info)
: TensorInfo()
{
- init(tensor_shape, num_channels, data_type, 0);
+ init(tensor_shape, num_channels, data_type);
_quantization_info = quantization_info;
}
@@ -124,34 +123,28 @@
_format = format;
}
-void TensorInfo::init(size_t num_channels, DataType data_type, size_t fixed_point_position)
+void TensorInfo::init(size_t num_channels, DataType data_type)
{
- init(TensorShape(), num_channels, data_type, fixed_point_position);
+ init(TensorShape(), num_channels, data_type);
}
-void TensorInfo::init(const TensorShape &tensor_shape, size_t num_channels, DataType data_type, int fixed_point_position)
+void TensorInfo::init(const TensorShape &tensor_shape, size_t num_channels, DataType data_type)
{
ARM_COMPUTE_ERROR_ON(num_channels == 0);
- ARM_COMPUTE_ERROR_ON(data_type == DataType::QS8 && (fixed_point_position < 1 || fixed_point_position > 6));
- ARM_COMPUTE_ERROR_ON(data_type == DataType::QS16 && (fixed_point_position < 1 || fixed_point_position > 14));
- _fixed_point_position = fixed_point_position;
- _data_type = data_type;
- _num_channels = num_channels;
- _format = Format::UNKNOWN;
+ _data_type = data_type;
+ _num_channels = num_channels;
+ _format = Format::UNKNOWN;
set_tensor_shape(tensor_shape);
}
void TensorInfo::init(const TensorShape &tensor_shape, size_t num_channels, DataType data_type,
const Strides &strides_in_bytes, size_t offset_first_element_in_bytes,
- size_t total_size_in_bytes, int fixed_point_position)
+ size_t total_size_in_bytes)
{
ARM_COMPUTE_ERROR_ON(num_channels == 0);
- ARM_COMPUTE_ERROR_ON(data_type == DataType::QS8 && (fixed_point_position < 1 || fixed_point_position > 6));
- ARM_COMPUTE_ERROR_ON(data_type == DataType::QS16 && (fixed_point_position < 1 || fixed_point_position > 14));
- _fixed_point_position = fixed_point_position;
_data_type = data_type;
_num_channels = num_channels;
_format = Format::UNKNOWN;
@@ -188,17 +181,14 @@
return total_size;
}
-size_t TensorInfo::init_auto_padding(const TensorShape &tensor_shape, size_t num_channels, DataType data_type, int fixed_point_position)
+size_t TensorInfo::init_auto_padding(const TensorShape &tensor_shape, size_t num_channels, DataType data_type)
{
ARM_COMPUTE_ERROR_ON(num_channels == 0);
- ARM_COMPUTE_ERROR_ON(data_type == DataType::QS8 && (fixed_point_position < 1 || fixed_point_position > 6));
- ARM_COMPUTE_ERROR_ON(data_type == DataType::QS16 && (fixed_point_position < 1 || fixed_point_position > 14));
- _fixed_point_position = fixed_point_position;
- _data_type = data_type;
- _num_channels = num_channels;
- _format = Format::UNKNOWN;
- _tensor_shape = tensor_shape;
+ _data_type = data_type;
+ _num_channels = num_channels;
+ _format = Format::UNKNOWN;
+ _tensor_shape = tensor_shape;
_valid_region = ValidRegion{ Coordinates(), _tensor_shape };
@@ -371,14 +361,6 @@
return *this;
}
-ITensorInfo &TensorInfo::set_fixed_point_position(int fixed_point_position)
-{
- ARM_COMPUTE_ERROR_ON(_data_type == DataType::QS8 && (fixed_point_position < 1 || fixed_point_position > 6));
- ARM_COMPUTE_ERROR_ON(_data_type == DataType::QS16 && (fixed_point_position < 1 || fixed_point_position > 14));
- _fixed_point_position = fixed_point_position;
- return *this;
-}
-
ITensorInfo &TensorInfo::set_quantization_info(const QuantizationInfo &quantization_info)
{
_quantization_info = quantization_info;
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index b1c5992..11bdbda 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -24,8 +24,6 @@
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/FixedPoint.h"
-
#include "support/ToolchainSupport.h"
#include <algorithm>
@@ -145,10 +143,8 @@
{ DataType::UNKNOWN, "UNKNOWN" },
{ DataType::S8, "S8" },
{ DataType::U8, "U8" },
- { DataType::QS8, "QS8" },
{ DataType::S16, "S16" },
{ DataType::U16, "U16" },
- { DataType::QS16, "QS16" },
{ DataType::S32, "S32" },
{ DataType::U32, "U32" },
{ DataType::S64, "S64" },
@@ -353,14 +349,12 @@
case DataType::U8:
print_consecutive_elements_impl<uint8_t>(s, ptr, n, stream_width, element_delim);
break;
- case DataType::QS8:
case DataType::S8:
print_consecutive_elements_impl<int8_t>(s, reinterpret_cast<const int8_t *>(ptr), n, stream_width, element_delim);
break;
case DataType::U16:
print_consecutive_elements_impl<uint16_t>(s, reinterpret_cast<const uint16_t *>(ptr), n, stream_width, element_delim);
break;
- case DataType::QS16:
case DataType::S16:
print_consecutive_elements_impl<int16_t>(s, reinterpret_cast<const int16_t *>(ptr), n, stream_width, element_delim);
break;
@@ -388,12 +382,10 @@
case DataType::QASYMM8:
case DataType::U8:
return max_consecutive_elements_display_width_impl<uint8_t>(s, ptr, n);
- case DataType::QS8:
case DataType::S8:
return max_consecutive_elements_display_width_impl<int8_t>(s, reinterpret_cast<const int8_t *>(ptr), n);
case DataType::U16:
return max_consecutive_elements_display_width_impl<uint16_t>(s, reinterpret_cast<const uint16_t *>(ptr), n);
- case DataType::QS16:
case DataType::S16:
return max_consecutive_elements_display_width_impl<int16_t>(s, reinterpret_cast<const int16_t *>(ptr), n);
case DataType::U32:
diff --git a/src/core/Validate.cpp b/src/core/Validate.cpp
index d4fabd4..5587dad 100644
--- a/src/core/Validate.cpp
+++ b/src/core/Validate.cpp
@@ -100,6 +100,16 @@
return arm_compute::Status{};
}
+arm_compute::Status arm_compute::error_on_tensor_not_2d(const char *function, const char *file, const int line,
+ const arm_compute::ITensorInfo *tensor)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor == nullptr, function, file, line);
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(tensor->num_dimensions() != 2,
+ function, file, line,
+ "Only 2D Tensors are supported by this kernel (%d passed)", tensor->num_dimensions());
+ return arm_compute::Status{};
+}
+
arm_compute::Status arm_compute::error_on_channel_not_in_known_format(const char *function, const char *file, const int line,
arm_compute::Format fmt, arm_compute::Channel cn)
{
@@ -169,7 +179,7 @@
// Subtensor should not index in x, y dimensions.
ARM_COMPUTE_RETURN_ERROR_ON_LOC(((coords.x() != 0) || (coords.y() != 0)), function, file, line);
// Subtensor shape should match parent tensor in x, y dimensions.
- ARM_COMPUTE_RETURN_ERROR_ON_LOC(((parent_shape.x() != shape.x()) || (parent_shape.y() != parent_shape.y())), function, file, line);
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC(((parent_shape.x() != shape.x()) || (parent_shape.y() != shape.y())), function, file, line);
// Check dimensions
for(unsigned int i = 0; i < TensorShape::num_max_dimensions; ++i)
diff --git a/src/graph/Graph.cpp b/src/graph/Graph.cpp
index e1ffeed..88e2682 100644
--- a/src/graph/Graph.cpp
+++ b/src/graph/Graph.cpp
@@ -41,17 +41,24 @@
std::unique_ptr<INode> &node = _nodes[nid];
- // Remove node connections
if(node)
{
+ // Remove input connections
for(auto &input_eid : node->_input_edges)
{
remove_connection(input_eid);
}
- for(auto &outpud_eid : node->_output_edges)
+
+ // Remove output connections
+ std::set<EdgeID> output_edges_copy = node->output_edges();
+ for(auto &outpud_eid : output_edges_copy)
{
remove_connection(outpud_eid);
}
+
+ // Remove nid from tagged nodes
+ std::vector<NodeID> &tnodes = _tagged_nodes.at(node->type());
+ tnodes.erase(std::remove(tnodes.begin(), tnodes.end(), nid), tnodes.end());
}
node = nullptr;
@@ -164,9 +171,9 @@
return _id;
}
-const std::vector<NodeID> &Graph::inputs()
+const std::vector<NodeID> &Graph::nodes(NodeType type)
{
- return _tagged_nodes[NodeType::Input];
+ return _tagged_nodes[type];
}
std::vector<std::unique_ptr<INode>> &Graph::nodes()
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index 4c5d30a..81a18c4 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -25,9 +25,11 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Utils.h"
-#include "arm_compute/graph/algorithms/BFS.h"
+#include "arm_compute/graph/algorithms/TopologicalSort.h"
#include "arm_compute/graph/nodes/Nodes.h"
+#include "support/ToolchainSupport.h"
+
#define CHECK_NODEIDX_PAIR(pair, g) \
ARM_COMPUTE_ERROR_ON(((pair).node_id >= (g).nodes().size()) || ((g).node((pair).node_id) == nullptr) || ((pair).index >= (g).node((pair).node_id)->num_outputs()));
@@ -79,43 +81,6 @@
return nid;
}
-
-NodeID create_grouped_convolution(Graph &g, NodeParams ¶ms, NodeIdxPair input, NodeID weights, NodeID bias,
- PadStrideInfo conv_info, ConvolutionMethod method, FastMathHint fast_math_hint, unsigned int num_groups)
-{
- bool has_bias = (bias != EmptyNodeID);
-
- // Split input
- NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, 2);
-
- // Split weights
- NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, 3);
-
- // Split bias
- NodeID bias_split = EmptyNodeID;
- if(has_bias)
- {
- // Split bias
- bias_split = GraphBuilder::add_split_node(g, params, { bias, 0 }, num_groups, 0);
- }
-
- std::vector<NodeIdxPair> convolution_outputs;
- for(unsigned int i = 0; i < num_groups; ++i)
- {
- NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method, fast_math_hint);
- g.add_connection(input_split, i, conv_nid, 0);
- g.add_connection(weights_split, i, conv_nid, 1);
- if(has_bias)
- {
- g.add_connection(bias_split, i, conv_nid, 2);
- }
- set_node_params(g, conv_nid, params);
- convolution_outputs.push_back({ conv_nid, 0 });
- }
-
- // Depth concatenate output
- return GraphBuilder::add_depth_concatenate_node(g, params, convolution_outputs);
-}
} // namespace
NodeID GraphBuilder::add_const_node(Graph &g, NodeParams params, TensorDescriptor desc, ITensorAccessorUPtr accessor)
@@ -203,6 +168,11 @@
return batch_norm_nid;
}
+NodeID GraphBuilder::add_channel_shuffle_node(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_groups)
+{
+ return create_simple_single_input_output_node<ChannelShuffleLayerNode>(g, params, input, num_groups);
+}
+
NodeID GraphBuilder::add_convolution_node(Graph &g, NodeParams params, NodeIdxPair input,
Size2D kernel_spatial_extend, unsigned int depth, PadStrideInfo conv_info,
unsigned int num_groups, ConvolutionMethod method, FastMathHint fast_math_hint,
@@ -239,34 +209,81 @@
{
TensorDescriptor b_desc = input_tensor_desc;
b_desc.shape = TensorShape(depth);
- b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor));
- }
-
- if(num_groups == 1)
- {
- // Create convolution node and connect
- NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, method, fast_math_hint, out_quant_info);
- g.add_connection(input.node_id, input.index, conv_nid, 0);
- g.add_connection(w_nid, 0, conv_nid, 1);
- if(has_bias)
+ if(is_data_type_quantized_asymmetric(input_tensor_desc.data_type))
{
- g.add_connection(b_nid, 0, conv_nid, 2);
+ b_desc.data_type = DataType::S32;
}
- set_node_params(g, conv_nid, params);
+ b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor));
+ }
- return conv_nid;
- }
- else
+ // Create convolution node and connect
+ NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, num_groups, method, fast_math_hint, out_quant_info);
+ g.add_connection(input.node_id, input.index, conv_nid, 0);
+ g.add_connection(w_nid, 0, conv_nid, 1);
+ if(has_bias)
{
- return create_grouped_convolution(g, params, input, w_nid, b_nid, conv_info, method, fast_math_hint, num_groups);
+ g.add_connection(b_nid, 0, conv_nid, 2);
}
+ set_node_params(g, conv_nid, params);
+
+ return conv_nid;
}
-NodeID GraphBuilder::add_depth_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs)
+NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdxPair input,
+ Size2D kernel_spatial_extend, unsigned int depth, PadStrideInfo deconv_info,
+ Size2D inner_border, ITensorAccessorUPtr weights_accessor,
+ ITensorAccessorUPtr bias_accessor)
+{
+ CHECK_NODEIDX_PAIR(input, g);
+ ARM_COMPUTE_ERROR_ON(depth == 0);
+ ARM_COMPUTE_ERROR_ON((kernel_spatial_extend.width == 0) || (kernel_spatial_extend.height == 0));
+
+ bool has_bias = (bias_accessor != nullptr);
+
+ // Get input tensor descriptor
+ const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
+
+ // Create weights node
+ TensorDescriptor w_desc = input_tensor_desc;
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL),
+ get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL));
+ w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::BATCHES), depth);
+
+ NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
+
+ // Create bias nodes
+ NodeID b_nid = EmptyNodeID;
+ if(has_bias)
+ {
+ TensorDescriptor b_desc = input_tensor_desc;
+ b_desc.shape = TensorShape(depth);
+ if(is_data_type_quantized_asymmetric(input_tensor_desc.data_type))
+ {
+ b_desc.data_type = DataType::S32;
+ }
+ b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor));
+ }
+
+ // Create convolution node and connect
+ NodeID deconv_nid = g.add_node<DeconvolutionLayerNode>(deconv_info, inner_border);
+ g.add_connection(input.node_id, input.index, deconv_nid, 0);
+ g.add_connection(w_nid, 0, deconv_nid, 1);
+ if(has_bias)
+ {
+ g.add_connection(b_nid, 0, deconv_nid, 2);
+ }
+ set_node_params(g, deconv_nid, params);
+
+ return deconv_nid;
+}
+
+NodeID GraphBuilder::add_concatenate_node(Graph &g, NodeParams params, std::vector<NodeIdxPair> inputs, DataLayoutDimension axis)
{
ARM_COMPUTE_ERROR_ON(inputs.size() == 0);
- NodeID nid = g.add_node<DepthConcatenateLayerNode>(inputs.size());
+ NodeID nid = g.add_node<ConcatenateLayerNode>(inputs.size(), axis);
unsigned int i = 0;
for(const auto &input : inputs)
@@ -309,7 +326,7 @@
if(has_bias)
{
TensorDescriptor b_desc = input_tensor_desc;
- b_desc.shape = TensorShape(b_desc.shape.z());
+ b_desc.shape = TensorShape(get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL));
b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor));
}
@@ -326,6 +343,11 @@
return conv_nid;
}
+NodeID GraphBuilder::add_dummy_node(Graph &g, NodeParams params, NodeIdxPair input, TensorShape shape)
+{
+ return create_simple_single_input_output_node<DummyNode>(g, params, input, shape);
+}
+
NodeID GraphBuilder::add_elementwise_node(Graph &g, NodeParams params, NodeIdxPair input0, NodeIdxPair input1, EltwiseOperation operation)
{
CHECK_NODEIDX_PAIR(input0, g);
@@ -347,7 +369,9 @@
}
NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_outputs,
- ITensorAccessorUPtr weights_accessor, ITensorAccessorUPtr bias_accessor)
+ ITensorAccessorUPtr weights_accessor, ITensorAccessorUPtr bias_accessor,
+ const FullyConnectedLayerInfo fc_info,
+ const QuantizationInfo weights_quant_info, const QuantizationInfo out_quant_info)
{
CHECK_NODEIDX_PAIR(input, g);
ARM_COMPUTE_ERROR_ON(num_outputs == 0);
@@ -358,7 +382,7 @@
const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
// Create weights node
- TensorDescriptor w_desc = FullyConnectedLayerNode::compute_weights_descriptor(input_tensor_desc, num_outputs);
+ TensorDescriptor w_desc = FullyConnectedLayerNode::compute_weights_descriptor(input_tensor_desc, num_outputs, fc_info, weights_quant_info);
NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
// Create bias nodes
@@ -367,11 +391,15 @@
{
TensorDescriptor b_desc = input_tensor_desc;
b_desc.shape = TensorShape(num_outputs);
- b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor));
+ if(is_data_type_quantized_asymmetric(input_tensor_desc.data_type))
+ {
+ b_desc.data_type = DataType::S32;
+ }
+ b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor));
}
- // Create convolution node and connect
- NodeID fc_nid = g.add_node<FullyConnectedLayerNode>(num_outputs);
+ // Create fully connected node and connect
+ NodeID fc_nid = g.add_node<FullyConnectedLayerNode>(num_outputs, out_quant_info, fc_info);
g.add_connection(input.node_id, input.index, fc_nid, 0);
g.add_connection(w_nid, 0, fc_nid, 1);
if(has_bias)
@@ -389,6 +417,11 @@
return create_simple_single_input_output_node<NormalizationLayerNode>(g, params, input, norm_info);
}
+NodeID GraphBuilder::add_permute_node(Graph &g, NodeParams params, NodeIdxPair input, PermutationVector perm, DataLayout layout)
+{
+ return create_simple_single_input_output_node<PermuteLayerNode>(g, params, input, perm, layout);
+}
+
NodeID GraphBuilder::add_pooling_node(Graph &g, NodeParams params, NodeIdxPair input, PoolingLayerInfo pool_info)
{
return create_simple_single_input_output_node<PoolingLayerNode>(g, params, input, pool_info);
@@ -399,6 +432,12 @@
return create_simple_single_input_output_node<ReshapeLayerNode>(g, params, input, shape);
}
+NodeID GraphBuilder::add_resize_node(Graph &g, NodeParams params, NodeIdxPair input, InterpolationPolicy policy,
+ float width_scale, float height_scale)
+{
+ return create_simple_single_input_output_node<ResizeLayerNode>(g, params, input, policy, width_scale, height_scale);
+}
+
NodeID GraphBuilder::add_scale_layer(Graph &g, const NodeParams ¶ms, NodeIdxPair input, ITensorAccessorUPtr mul_accessor, ITensorAccessorUPtr add_accessor)
{
CHECK_NODEIDX_PAIR(input, g);
@@ -421,9 +460,9 @@
NodeIdxPair add_const_nidxp = { add_const_nid, 0 };
// Create node and connect
- NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::MUL);
+ NodeID mul_node = GraphBuilder::add_elementwise_node(g, params, input, mul_const_nidxp, EltwiseOperation::Mul);
NodeIdxPair mulnode_nidxp = { mul_node, 0 };
- NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::ADD);
+ NodeID add_node = GraphBuilder::add_elementwise_node(g, params, mulnode_nidxp, add_const_nidxp, EltwiseOperation::Add);
return add_node;
}
@@ -438,4 +477,4 @@
return create_simple_single_input_output_node<SplitLayerNode>(g, params, input, num_splits, axis);
}
} // namespace graph
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/graph/GraphContext.cpp b/src/graph/GraphContext.cpp
index 3f31114..5f33ed3 100644
--- a/src/graph/GraphContext.cpp
+++ b/src/graph/GraphContext.cpp
@@ -22,7 +22,9 @@
* SOFTWARE.
*/
#include "arm_compute/graph/GraphContext.h"
-#include <arm_compute/graph.h>
+
+#include "arm_compute/graph.h"
+#include "arm_compute/graph/Utils.h"
namespace arm_compute
{
@@ -33,6 +35,12 @@
{
}
+GraphContext::~GraphContext()
+{
+ _memory_managers.clear();
+ release_default_graph_context(*this);
+}
+
const GraphConfig &GraphContext::config() const
{
return _config;
@@ -82,4 +90,4 @@
}
}
} // namespace graph
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/graph/GraphManager.cpp b/src/graph/GraphManager.cpp
index ad45845..f9d13ac 100644
--- a/src/graph/GraphManager.cpp
+++ b/src/graph/GraphManager.cpp
@@ -27,10 +27,13 @@
#include "arm_compute/graph/GraphContext.h"
#include "arm_compute/graph/Logger.h"
#include "arm_compute/graph/PassManager.h"
+#include "arm_compute/graph/TypePrinter.h"
#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/detail/CrossLayerMemoryManagerHelpers.h"
#include "arm_compute/graph/detail/ExecutionHelpers.h"
+#include "arm_compute/graph/algorithms/TopologicalSort.h"
+
namespace arm_compute
{
namespace graph
@@ -38,7 +41,6 @@
GraphManager::GraphManager()
: _workloads()
{
- detail::default_initialize_backends();
}
void GraphManager::finalize_graph(Graph &graph, GraphContext &ctx, PassManager &pm, Target target)
@@ -53,7 +55,12 @@
}
// Force target to all graph construct
- Target forced_target = is_target_supported(target) ? target : get_default_target();
+ Target forced_target = target;
+ if(!is_target_supported(target))
+ {
+ forced_target = get_default_target();
+ ARM_COMPUTE_LOG_GRAPH_INFO("Switching target from " << target << " to " << forced_target << std::endl);
+ }
force_target_to_graph(graph, forced_target);
// Configure all tensors
@@ -62,22 +69,22 @@
// Apply all mutating passes
pm.run_all(graph);
+ // Perform topological sort
+ std::vector<NodeID> topological_sorted_nodes = dfs(graph);
+
// Validate all nodes
detail::validate_all_nodes(graph);
// Configure all nodes
- auto workload = detail::configure_all_nodes(graph, ctx);
+ auto workload = detail::configure_all_nodes(graph, ctx, topological_sorted_nodes);
ARM_COMPUTE_ERROR_ON_MSG(workload.tasks.empty(), "Could not configure all nodes!");
// Allocate const tensors and call accessors
detail::allocate_const_tensors(graph);
detail::call_all_const_node_accessors(graph);
- if(forced_target == Target::CL)
- {
- // Prepare graph
- detail::prepare_all_tasks(workload);
- }
+ // Prepare graph
+ detail::prepare_all_tasks(workload);
// Setup tensor memory (Allocate all tensors or setup transition manager)
if(ctx.config().use_transition_memory_manager)
@@ -95,15 +102,6 @@
// Register graph
_workloads.insert(std::make_pair(graph.id(), std::move(workload)));
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Created workload for graph with ID : " << graph.id().get() << std::endl);
-
- if(forced_target != Target::CL)
- {
- // Make first run
- execute_graph(graph);
-
- // Release all unused const tensors
- detail::release_unused_tensors(graph);
- }
}
void GraphManager::execute_graph(Graph &graph)
@@ -112,14 +110,23 @@
auto it = _workloads.find(graph.id());
ARM_COMPUTE_ERROR_ON_MSG(it == std::end(_workloads), "Graph is not registered!");
- // Call input accessors
- detail::call_all_input_node_accessors(it->second);
+ while(true)
+ {
+ // Call input accessors
+ if(!detail::call_all_input_node_accessors(it->second))
+ {
+ return;
+ }
- // Run graph
- detail::call_all_tasks(it->second);
+ // Run graph
+ detail::call_all_tasks(it->second);
- // Call output accessors
- detail::call_all_output_node_accessors(it->second);
+ // Call output accessors
+ if(!detail::call_all_output_node_accessors(it->second))
+ {
+ return;
+ }
+ }
}
void GraphManager::invalidate_graph(Graph &graph)
diff --git a/src/graph/INode.cpp b/src/graph/INode.cpp
index cd9a46a..b0c3137 100644
--- a/src/graph/INode.cpp
+++ b/src/graph/INode.cpp
@@ -185,6 +185,11 @@
return _outputs.size();
}
+NodeParams INode::common_node_params() const
+{
+ return _common_params;
+}
+
Target INode::requested_target() const
{
return _common_params.target;
diff --git a/src/graph/Tensor.cpp b/src/graph/Tensor.cpp
index 287e783..9850128 100644
--- a/src/graph/Tensor.cpp
+++ b/src/graph/Tensor.cpp
@@ -67,6 +67,11 @@
return _accessor.get();
}
+std::unique_ptr<ITensorAccessor> Tensor::extract_accessor()
+{
+ return std::move(_accessor);
+}
+
bool Tensor::call_accessor()
{
// Early exit guard
@@ -85,12 +90,12 @@
}
// Call accessor
- _accessor->access_tensor(_handle->tensor());
+ bool retval = _accessor->access_tensor(_handle->tensor());
// Unmap tensor
_handle->unmap();
- return true;
+ return retval;
}
void Tensor::bind_edge(EdgeID eid)
diff --git a/src/graph/TypeLoader.cpp b/src/graph/TypeLoader.cpp
new file mode 100644
index 0000000..30a3546
--- /dev/null
+++ b/src/graph/TypeLoader.cpp
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWNISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/TypeLoader.h"
+
+#include "arm_compute/core/utils/misc/Utility.h"
+
+#include <map>
+
+namespace arm_compute
+{
+arm_compute::DataType data_type_from_name(const std::string &name)
+{
+ static const std::map<std::string, arm_compute::DataType> data_types =
+ {
+ { "f16", DataType::F16 },
+ { "f32", DataType::F32 },
+ { "qasymm8", DataType::QASYMM8 },
+ };
+
+ try
+ {
+ return data_types.at(arm_compute::utility::tolower(name));
+ }
+ catch(const std::out_of_range &)
+ {
+ throw std::invalid_argument(name);
+ }
+}
+
+arm_compute::DataLayout data_layout_from_name(const std::string &name)
+{
+ static const std::map<std::string, arm_compute::DataLayout> data_layouts =
+ {
+ { "nhwc", DataLayout::NHWC },
+ { "nchw", DataLayout::NCHW },
+ };
+
+ try
+ {
+ return data_layouts.at(arm_compute::utility::tolower(name));
+ }
+ catch(const std::out_of_range &)
+ {
+ throw std::invalid_argument(name);
+ }
+}
+namespace graph
+{
+Target target_from_name(const std::string &name)
+{
+ static const std::map<std::string, Target> targets =
+ {
+ { "neon", Target::NEON },
+ { "cl", Target::CL },
+ { "gles", Target::GC },
+ };
+
+ try
+ {
+ return targets.at(arm_compute::utility::tolower(name));
+ }
+ catch(const std::out_of_range &)
+ {
+ throw std::invalid_argument(name);
+ }
+}
+} // namespace graph
+} // namespace arm_compute
diff --git a/src/graph/Utils.cpp b/src/graph/Utils.cpp
index 030fa2d..0a85a7f 100644
--- a/src/graph/Utils.cpp
+++ b/src/graph/Utils.cpp
@@ -78,22 +78,44 @@
{
PassManager pm;
+ // Passes that mutate graph IR
+ pm.append(support::cpp14::make_unique<GroupedConvolutionMutator>());
if(target != Target::GC)
{
- pm.append(support::cpp14::make_unique<InPlaceOperationMutator>());
pm.append(support::cpp14::make_unique<NodeFusionMutator>());
- pm.append(support::cpp14::make_unique<SplitLayerSubTensorMutator>());
- pm.append(support::cpp14::make_unique<DepthConcatSubTensorMutator>());
+ pm.append(support::cpp14::make_unique<InPlaceOperationMutator>());
}
+ // Passes that mutate backend information
+ if(target != Target::GC)
+ {
+ pm.append(support::cpp14::make_unique<DepthConcatSubTensorMutator>());
+ pm.append(support::cpp14::make_unique<SplitLayerSubTensorMutator>());
+ }
+ pm.append(support::cpp14::make_unique<NodeExecutionMethodMutator>());
+
return pm;
}
+void release_default_graph_context(GraphContext &ctx)
+{
+ for(const auto &backend : backends::BackendRegistry::get().backends())
+ {
+ if(backend.second->is_backend_supported())
+ {
+ backend.second->release_backend_context(ctx);
+ }
+ }
+}
+
void setup_default_graph_context(GraphContext &ctx)
{
for(const auto &backend : backends::BackendRegistry::get().backends())
{
- backend.second->setup_backend_context(ctx);
+ if(backend.second->is_backend_supported())
+ {
+ backend.second->setup_backend_context(ctx);
+ }
}
}
@@ -131,5 +153,37 @@
break;
}
}
+
+std::vector<NodeIdxPair> get_driving_nodes(const INode &node)
+{
+ std::vector<NodeIdxPair> driving_nodes;
+
+ const Graph *g = node.graph();
+ ARM_COMPUTE_ERROR_ON(g == nullptr);
+
+ for(auto &output_edge_id : node.output_edges())
+ {
+ auto output_edge = g->edge(output_edge_id);
+ if(output_edge != nullptr)
+ {
+ ARM_COMPUTE_ERROR_ON(output_edge->consumer() == nullptr);
+ driving_nodes.push_back({ output_edge->consumer_id(), output_edge->consumer_idx() });
+ }
+ }
+
+ return driving_nodes;
+}
+
+void configure_tensor(Tensor *tensor)
+{
+ if(tensor != nullptr && tensor->handle() == nullptr)
+ {
+ Target target = tensor->desc().target;
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(target);
+ std::unique_ptr<ITensorHandle> handle = backend.create_tensor(*tensor);
+ ARM_COMPUTE_ERROR_ON_MSG(!handle, "Couldn't create backend handle!");
+ tensor->set_handle(std::move(handle));
+ }
+}
} // namespace graph
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/graph/algorithms/TopologicalSort.cpp b/src/graph/algorithms/TopologicalSort.cpp
new file mode 100644
index 0000000..0fbf6e3
--- /dev/null
+++ b/src/graph/algorithms/TopologicalSort.cpp
@@ -0,0 +1,188 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/algorithms/TopologicalSort.h"
+
+#include "arm_compute/graph/Graph.h"
+
+#include "arm_compute/core/utils/misc/Iterable.h"
+
+#include <list>
+#include <stack>
+
+namespace arm_compute
+{
+namespace graph
+{
+namespace detail
+{
+/** Checks if all the input dependencies of a node have been visited
+ *
+ * @param[in] node Node to check
+ * @param[in] visited Vector that contains the visited information
+ *
+ * @return True if all inputs dependencies have been visited else false
+ */
+inline bool all_inputs_are_visited(const INode *node, const std::vector<bool> &visited)
+{
+ ARM_COMPUTE_ERROR_ON(node == nullptr);
+ const Graph *graph = node->graph();
+ ARM_COMPUTE_ERROR_ON(graph == nullptr);
+
+ bool are_all_visited = true;
+ for(const auto &input_edge_id : node->input_edges())
+ {
+ if(input_edge_id != EmptyNodeID)
+ {
+ const Edge *input_edge = graph->edge(input_edge_id);
+ ARM_COMPUTE_ERROR_ON(input_edge == nullptr);
+ ARM_COMPUTE_ERROR_ON(input_edge->producer() == nullptr);
+ if(!visited[input_edge->producer_id()])
+ {
+ are_all_visited = false;
+ break;
+ }
+ }
+ }
+
+ return are_all_visited;
+}
+} // namespace detail
+
+std::vector<NodeID> bfs(Graph &g)
+{
+ std::vector<NodeID> bfs_order_vector;
+
+ // Created visited vector
+ std::vector<bool> visited(g.nodes().size(), false);
+
+ // Create BFS queue
+ std::list<NodeID> queue;
+
+ // Push inputs and mark as visited
+ for(auto &input : g.nodes(NodeType::Input))
+ {
+ if(input != EmptyNodeID)
+ {
+ visited[input] = true;
+ queue.push_back(input);
+ }
+ }
+
+ // Push const nodes and mark as visited
+ for(auto &const_node : g.nodes(NodeType::Const))
+ {
+ if(const_node != EmptyNodeID)
+ {
+ visited[const_node] = true;
+ queue.push_back(const_node);
+ }
+ }
+
+ // Iterate over vector and edges
+ while(!queue.empty())
+ {
+ // Dequeue a node from queue and process
+ NodeID n = queue.front();
+ bfs_order_vector.push_back(n);
+ queue.pop_front();
+
+ const INode *node = g.node(n);
+ ARM_COMPUTE_ERROR_ON(node == nullptr);
+ for(const auto &eid : node->output_edges())
+ {
+ const Edge *e = g.edge(eid);
+ ARM_COMPUTE_ERROR_ON(e == nullptr);
+ if(!visited[e->consumer_id()] && detail::all_inputs_are_visited(e->consumer(), visited))
+ {
+ visited[e->consumer_id()] = true;
+ queue.push_back(e->consumer_id());
+ }
+ }
+ }
+
+ return bfs_order_vector;
+}
+
+std::vector<NodeID> dfs(Graph &g)
+{
+ std::vector<NodeID> dfs_order_vector;
+
+ // Created visited vector
+ std::vector<bool> visited(g.nodes().size(), false);
+
+ // Create DFS stack
+ std::stack<NodeID> stack;
+
+ // Push inputs and mark as visited
+ for(auto &input : g.nodes(NodeType::Input))
+ {
+ if(input != EmptyNodeID)
+ {
+ visited[input] = true;
+ stack.push(input);
+ }
+ }
+
+ // Push const nodes and mark as visited
+ for(auto &const_node : g.nodes(NodeType::Const))
+ {
+ if(const_node != EmptyNodeID)
+ {
+ visited[const_node] = true;
+ stack.push(const_node);
+ }
+ }
+
+ // Iterate over vector and edges
+ while(!stack.empty())
+ {
+ // Pop a node from stack and process
+ NodeID n = stack.top();
+ dfs_order_vector.push_back(n);
+ stack.pop();
+
+ // Mark node as visited
+ if(!visited[n])
+ {
+ visited[n] = true;
+ }
+
+ const INode *node = g.node(n);
+ ARM_COMPUTE_ERROR_ON(node == nullptr);
+ // Reverse iterate to push branches from right to left and pop on the opposite order
+ for(const auto &eid : arm_compute::utils::iterable::reverse_iterate(node->output_edges()))
+ {
+ const Edge *e = g.edge(eid);
+ ARM_COMPUTE_ERROR_ON(e == nullptr);
+ if(!visited[e->consumer_id()] && detail::all_inputs_are_visited(e->consumer(), visited))
+ {
+ stack.push(e->consumer_id());
+ }
+ }
+ }
+
+ return dfs_order_vector;
+}
+} // namespace graph
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/graph/backends/BackendRegistry.cpp b/src/graph/backends/BackendRegistry.cpp
index 2803322..dccfefc 100644
--- a/src/graph/backends/BackendRegistry.cpp
+++ b/src/graph/backends/BackendRegistry.cpp
@@ -48,6 +48,14 @@
return _registered_backends[target].get();
}
+IDeviceBackend &BackendRegistry::get_backend(Target target)
+{
+ IDeviceBackend *backend = find_backend(target);
+ ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
+ ARM_COMPUTE_ERROR_ON_MSG(!backend->is_backend_supported(), "Requested backend isn't supported");
+ return *backend;
+}
+
bool BackendRegistry::contains(Target target) const
{
auto it = _registered_backends.find(target);
diff --git a/src/graph/backends/CL/CLDeviceBackend.cpp b/src/graph/backends/CL/CLDeviceBackend.cpp
index bf17f80..1dbeae9 100644
--- a/src/graph/backends/CL/CLDeviceBackend.cpp
+++ b/src/graph/backends/CL/CLDeviceBackend.cpp
@@ -62,19 +62,16 @@
/** Register CL backend */
static detail::BackendRegistrar<CLDeviceBackend> CLDeviceBackend_registrar(Target::CL);
-/** Tuner export file */
-static const std::string tuner_data_filename = "acl_tuner.csv";
-
CLDeviceBackend::CLDeviceBackend()
- : _tuner(), _allocator(cl::Context::getDefault())
+ : _context_count(0), _tuner(), _allocator(nullptr), _tuner_file()
{
}
CLDeviceBackend::~CLDeviceBackend()
{
- if(_tuner.tune_new_kernels() && !_tuner.lws_table().empty())
+ if(_tuner.tune_new_kernels() && !_tuner.lws_table().empty() && !_tuner_file.empty())
{
- _tuner.save_to_file(tuner_data_filename);
+ _tuner.save_to_file(_tuner_file);
}
}
@@ -85,22 +82,40 @@
void CLDeviceBackend::initialize_backend()
{
- // Load tuner data if available
- if(_tuner.lws_table().empty() && file_exists(tuner_data_filename))
- {
- _tuner.load_from_file(tuner_data_filename);
- }
-
// Setup Scheduler
CLScheduler::get().default_init(&_tuner);
// Create allocator with new context
- _allocator = CLBufferAllocator();
+ _allocator = support::cpp14::make_unique<CLBufferAllocator>();
+}
+
+void CLDeviceBackend::release_backend_context(GraphContext &ctx)
+{
+ ARM_COMPUTE_UNUSED(ctx);
+ _context_count--;
+ if(_context_count == 0) // No more context using the backend: free resources
+ {
+ _allocator = nullptr;
+ }
}
void CLDeviceBackend::setup_backend_context(GraphContext &ctx)
{
+ // Force backend initialization
+ _context_count++;
+ if(_context_count == 1)
+ {
+ initialize_backend();
+ }
+
// Setup tuner
+ _tuner_file = ctx.config().tuner_file;
+ // Load tuner data if available
+ if(file_exists(_tuner_file))
+ {
+ _tuner.load_from_file(_tuner_file);
+ }
+
set_kernel_tuning(ctx.config().use_tuner);
// Setup a management backend
@@ -123,7 +138,7 @@
IAllocator *CLDeviceBackend::backend_allocator()
{
- return &_allocator;
+ return _allocator.get();
}
std::unique_ptr<ITensorHandle> CLDeviceBackend::create_tensor(const Tensor &tensor)
@@ -179,7 +194,7 @@
auto pool_mgr = std::make_shared<PoolManager>();
auto mm = std::make_shared<MemoryManagerOnDemand>(lifetime_mgr, pool_mgr);
- mm->set_allocator(&_allocator);
+ mm->set_allocator(_allocator.get());
return mm;
}
diff --git a/src/graph/backends/CL/CLFunctionsFactory.cpp b/src/graph/backends/CL/CLFunctionsFactory.cpp
index db8a7a0..bf3dcba 100644
--- a/src/graph/backends/CL/CLFunctionsFactory.cpp
+++ b/src/graph/backends/CL/CLFunctionsFactory.cpp
@@ -25,16 +25,9 @@
#include "arm_compute/core/utils/misc/Cast.h"
#include "arm_compute/graph/Graph.h"
-#include "arm_compute/graph/GraphContext.h"
-#include "arm_compute/graph/Logger.h"
-#include "arm_compute/graph/TypePrinter.h"
-#include "arm_compute/graph/Types.h"
-#include "arm_compute/graph/backends/Utils.h"
-#include "arm_compute/graph/nodes/Nodes.h"
+#include "arm_compute/graph/backends/FunctionHelpers.h"
#include "arm_compute/runtime/CL/CLFunctions.h"
-#include "support/ToolchainSupport.h"
-
using namespace arm_compute::utils::cast;
namespace arm_compute
@@ -43,526 +36,38 @@
{
namespace backends
{
-namespace
+/** Target specific information structure used to pass information to the layer templates */
+struct CLTargetInfo
{
-/** Returns backing tensor of a given tensor
- *
- * @param[in] tensor Tensor to extract the backing tensor from
- *
- * @return Backing tensor if present else nullptr
- */
-arm_compute::ICLTensor *get_backing_tensor(arm_compute::graph::Tensor *tensor)
+ using TensorType = arm_compute::ICLTensor;
+ static Target TargetType;
+};
+
+Target CLTargetInfo::TargetType = Target::CL;
+
+/** Collection of CL convolution functions */
+struct CLConvolutionLayerFunctions
{
- arm_compute::ICLTensor *backing_tensor = nullptr;
- if(tensor != nullptr)
- {
- ARM_COMPUTE_ERROR_ON(tensor->desc().target != arm_compute::graph::Target::CL);
- // Get backing tensor handle
- ITensorHandle *tensor_handle = tensor->handle();
- // Get backing tensor
- backing_tensor = (tensor_handle != nullptr) ? polymorphic_cast<ICLTensor *>(&tensor_handle->tensor()) : nullptr;
- }
+ using GenericConvolutionLayer = CLConvolutionLayer;
+ using GEMMConvolutionLayer = CLGEMMConvolutionLayer;
+ using DirectConvolutionLayer = CLDirectConvolutionLayer;
+ using WinogradConvolutionLayer = CLWinogradConvolutionLayer;
+};
- return backing_tensor;
-}
-
-/** Create a backend activation layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend activation layer function
- */
-std::unique_ptr<IFunction> create_activation_layer(ActivationLayerNode &node)
+/** Collection of CL depthwise convolution functions */
+struct CLDepthwiseConvolutionLayerFunctions
{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating CL ActivationLayerNode node with ID : " << node.id() << " and Name: " << node.name()
- << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
+ using GenericDepthwiseConvolutionLayer = CLDepthwiseConvolutionLayer;
+ using DepthwiseConvolutionLayer3x3 = CLDepthwiseConvolutionLayer3x3;
+};
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *output = get_backing_tensor(node.output(0));
- const ActivationLayerInfo act_info = node.activation_info();
-
- // Create function
- auto func = support::cpp14::make_unique<CLActivationLayer>();
- func->configure(input, output, act_info);
-
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLActivationLayer"
- << " Data Type: " << input->info()->data_type()
- << " Shape: " << input->info()->tensor_shape()
- << " Activation function: " << act_info.activation()
- << " a: " << act_info.a()
- << " b: " << act_info.b()
- << " InPlace : " << is_in_place_operation(input, output)
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend batch normalization layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend batch normalization layer function
- */
-std::unique_ptr<IFunction> create_batch_normalization_layer(BatchNormalizationLayerNode &node)
+/** Collection of CL element-wise functions */
+struct CLEltwiseFunctions
{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating CL BatchNormalization node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
-
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 5);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *mean = get_backing_tensor(node.input(1));
- ICLTensor *var = get_backing_tensor(node.input(2));
- ICLTensor *beta = get_backing_tensor(node.input(3));
- ICLTensor *gamma = get_backing_tensor(node.input(4));
- ICLTensor *output = get_backing_tensor(node.output(0));
- const float epsilon = node.epsilon();
- const ActivationLayerInfo fused_act = node.fused_activation();
-
- // Create and configure function
- auto func = support::cpp14::make_unique<CLBatchNormalizationLayer>();
- func->configure(input, output, mean, var, beta, gamma, epsilon, fused_act);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLBatchNormalizationLayer"
- << " Data Type: " << input->info()->data_type()
- << " Shape: " << input->info()->tensor_shape()
- << " Epsilon: " << epsilon << " "
- << (fused_act.enabled() ? to_string(fused_act.activation()) : "")
- << " InPlace : " << is_in_place_operation(input, output)
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend convolution layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend convolution layer function
- */
-std::unique_ptr<IFunction> create_convolution_layer(ConvolutionLayerNode &node, GraphContext &ctx)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating CL ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *weights = get_backing_tensor(node.input(1));
- ICLTensor *biases = get_backing_tensor(node.input(2));
- ICLTensor *output = get_backing_tensor(node.output(0));
-
- if(is_data_type_quantized_asymmetric(input->info()->data_type()))
- {
- biases->info()->set_data_type(DataType::S32);
- }
-
- const PadStrideInfo conv_info = node.convolution_info();
- const ConvolutionMethod conv_algorithm = node.convolution_method();
- const bool fast_math = node.fast_math_hint() == FastMathHint::ENABLED;
-
- // Create and configure function (we assume that functions have been validated before creation)
- std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, Target::CL);
- std::unique_ptr<IFunction> func;
- std::string func_name;
-
- if(conv_algorithm == ConvolutionMethod::WINOGRAD)
- {
- std::tie(func, func_name) = create_named_memory_managed_function<CLWinogradConvolutionLayer>(
- std::string("CLWinogradConvolutionLayer"), mm, input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math);
- }
- else if(conv_algorithm == ConvolutionMethod::DIRECT)
- {
- std::tie(func, func_name) = create_named_function<CLDirectConvolutionLayer>(
- std::string("CLDirectConvolutionLayer"), input, weights, biases, output, conv_info);
- }
- else if(conv_algorithm == ConvolutionMethod::GEMM)
- {
- std::tie(func, func_name) = create_named_memory_managed_function<CLGEMMConvolutionLayer>(std::string("CLGEMMConvolutionLayer"), mm,
- input, weights, biases, output, conv_info);
- }
- else
- {
- std::tie(func, func_name) = create_named_memory_managed_function<CLConvolutionLayer>(std::string("CLConvolutionLayer"), mm,
- input, weights, biases, output, conv_info, WeightsInfo(), Size2D(1U, 1U), ActivationLayerInfo(), fast_math);
- }
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name
- << " Data Type: " << input->info()->data_type()
- << " Input QuantInfo: " << input->info()->quantization_info()
- << " Weights QuantInfo: " << weights->info()->quantization_info()
- << " Input shape: " << input->info()->tensor_shape()
- << " Weights shape: " << weights->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
- return func;
-}
-
-/** Create a backend layer depth concatenate function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend depth concatenate layer function
- */
-std::unique_ptr<arm_compute::IFunction> create_depth_concatenate_layer(DepthConcatenateLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating CL DepthConcatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Return nullptr if depth concatenate is switched off
- if(!node.is_enabled())
- {
- return nullptr;
- }
-
- // Extract IO and info
- std::vector<arm_compute::ICLTensor *> inputs;
- for(unsigned int i = 0; i < node.num_inputs(); ++i)
- {
- inputs.push_back(get_backing_tensor(node.input(i)));
- }
- ICLTensor *output = get_backing_tensor(node.output(0));
-
- // Create and configure function
- auto func = support::cpp14::make_unique<CLDepthConcatenateLayer>();
- func->configure(inputs, output);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLDepthConcatenateLayer"
- << " Data Type: " << output->info()->data_type()
- << " Shape: " << output->info()->tensor_shape()
- << " Num Inputs: " << inputs.size()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend layer depth-wise convolution function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend depth-wise convolution layer function
- */
-std::unique_ptr<IFunction> create_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating CL DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name()
- << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *weights = get_backing_tensor(node.input(1));
- ICLTensor *biases = get_backing_tensor(node.input(2));
- ICLTensor *output = get_backing_tensor(node.output(0));
-
- if(is_data_type_quantized_asymmetric(input->info()->data_type()))
- {
- biases->info()->set_data_type(DataType::S32);
- }
-
- const PadStrideInfo conv_info = node.convolution_info();
- const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method();
-
- // Create and configure function (we assume that functions have been validated before creation)
- std::unique_ptr<IFunction> func;
- std::string func_name;
- if(dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3)
- {
- std::tie(func, func_name) = create_named_function<CLDepthwiseConvolutionLayer3x3>(
- std::string("CLDepthwiseConvolutionLayer3x3"), input, weights, biases, output, conv_info);
- }
- else
- {
- std::tie(func, func_name) = create_named_function<CLDepthwiseConvolutionLayer>(
- std::string("CLDepthwiseConvolutionLayer"), input, weights, biases, output, conv_info);
- }
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name
- << " Data Type: " << input->info()->data_type()
- << " Input QuantInfo: " << input->info()->quantization_info()
- << " Weights QuantInfo: " << weights->info()->quantization_info()
- << " Input shape: " << input->info()->tensor_shape()
- << " Weights shape: " << weights->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
- return func;
-}
-
-/** Create a backend element-wise operation layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend element-wise operation layer function
- */
-std::unique_ptr<IFunction> create_eltwise_layer(EltwiseLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating CL EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 2);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input1 = get_backing_tensor(node.input(0));
- ICLTensor *input2 = get_backing_tensor(node.input(1));
- ICLTensor *output = get_backing_tensor(node.output(0));
- const EltwiseOperation eltwise_op = node.eltwise_operation();
- const ConvertPolicy convert_policy = node.convert_policy();
- ARM_COMPUTE_ERROR_ON(input1 == nullptr);
- ARM_COMPUTE_ERROR_ON(input2 == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- std::unique_ptr<IFunction> func = nullptr;
- std::string func_name;
- if(eltwise_op == EltwiseOperation::ADD)
- {
- std::tie(func, func_name) = create_named_function<CLArithmeticAddition>(std::string("CLArithmeticAddition"),
- input1, input2, output,
- convert_policy);
- }
- else if(eltwise_op == EltwiseOperation::SUB)
- {
- std::tie(func, func_name) = create_named_function<CLArithmeticSubtraction>(
- std::string("CLArithmeticSubtraction"), input1, input2, output, convert_policy);
- }
- else if(eltwise_op == EltwiseOperation::MUL)
- {
- std::tie(func, func_name) = create_named_function<CLPixelWiseMultiplication>(
- std::string("CLPixelWiseMultiplication"), input1, input2, output, 1.f, convert_policy,
- node.rounding_policy());
- }
- else
- {
- ARM_COMPUTE_ERROR("Unsupported element-wise operation!");
- }
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name
- << " Data Type: " << input1->info()->data_type()
- << " Shape : " << input1->info()->tensor_shape()
- << std::endl);
-
- return func;
-}
-
-/** Create a backend flatten layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend flatten layer function
- */
-std::unique_ptr<IFunction> create_flatten_layer(FlattenLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating CL FlattenLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *output = get_backing_tensor(node.output(0));
-
- // Create and configure function
- auto func = support::cpp14::make_unique<CLFlattenLayer>();
- func->configure(input, output);
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLFlattenLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend fully connected layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend fully connected layer function
- */
-std::unique_ptr<IFunction> create_fully_connected_layer(FullyConnectedLayerNode &node, GraphContext &ctx)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating CL FullyConnectedLayer node with ID : " << node.id() << " and Name: " << node.name()
- << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *weights = get_backing_tensor(node.input(1));
- ICLTensor *biases = get_backing_tensor(node.input(2));
- ICLTensor *output = get_backing_tensor(node.output(0));
-
- // Create and configure function
- auto func = support::cpp14::make_unique<CLFullyConnectedLayer>(get_memory_manager(ctx, Target::CL));
- func->configure(input, weights, biases, output);
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(weights == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLFullyConnectedLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Weights shape: " << weights->info()->tensor_shape()
- << " Biases Shape: " << biases->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend normalization layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend normalization layer function
- */
-std::unique_ptr<IFunction> create_normalization_layer(NormalizationLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating CL NormalizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *output = get_backing_tensor(node.output(0));
- const NormalizationLayerInfo norm_info = node.normalization_info();
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<CLNormalizationLayer>();
- func->configure(input, output, norm_info);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLNormalizationLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << " Normalization info: " << norm_info.type()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend pooling layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend pooling layer function
- */
-std::unique_ptr<IFunction> create_pooling_layer(PoolingLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating CL PoolingLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *output = get_backing_tensor(node.output(0));
- const PoolingLayerInfo pool_info = node.pooling_info();
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<CLPoolingLayer>();
- func->configure(input, output, pool_info);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLPoolingLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << " Pooling info: " << pool_info.pool_type()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend reshape layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend reshape layer function
- */
-std::unique_ptr<IFunction> create_reshape_layer(ReshapeLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating CL ReshapeLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *output = get_backing_tensor(node.output(0));
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<CLReshapeLayer>();
- func->configure(input, output);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLReshapeLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend softmax layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend softmax layer function
- */
-std::unique_ptr<IFunction> create_softmax_layer(SoftmaxLayerNode &node, GraphContext &ctx)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating CL SoftmaxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ICLTensor *input = get_backing_tensor(node.input(0));
- ICLTensor *output = get_backing_tensor(node.output(0));
- const float beta = node.beta();
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<CLSoftmaxLayer>(get_memory_manager(ctx, Target::CL));
- func->configure(input, output, beta);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLSoftmaxLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-} // namespace
+ using Addition = CLArithmeticAddition;
+ using Subtraction = CLArithmeticSubtraction;
+ using Multiplication = CLPixelWiseMultiplication;
+};
std::unique_ptr<IFunction> CLFunctionFactory::create(INode *node, GraphContext &ctx)
{
@@ -575,33 +80,41 @@
switch(type)
{
case NodeType::ActivationLayer:
- return create_activation_layer(*polymorphic_downcast<ActivationLayerNode *>(node));
+ return detail::create_activation_layer<CLActivationLayer, CLTargetInfo>(*polymorphic_downcast<ActivationLayerNode *>(node));
case NodeType::BatchNormalizationLayer:
- return create_batch_normalization_layer(*polymorphic_downcast<BatchNormalizationLayerNode *>(node));
+ return detail::create_batch_normalization_layer<CLBatchNormalizationLayer, CLTargetInfo>(*polymorphic_downcast<BatchNormalizationLayerNode *>(node));
+ case NodeType::ChannelShuffleLayer:
+ return detail::create_channel_shuffle_layer<CLChannelShuffleLayer, CLTargetInfo>(*polymorphic_downcast<ChannelShuffleLayerNode *>(node));
case NodeType::ConvolutionLayer:
- return create_convolution_layer(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
- case NodeType::DepthConcatenateLayer:
- return create_depth_concatenate_layer(*polymorphic_downcast<DepthConcatenateLayerNode *>(node));
+ return detail::create_convolution_layer<CLConvolutionLayerFunctions, CLTargetInfo>(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
+ case NodeType::DeconvolutionLayer:
+ return detail::create_deconvolution_layer<CLDeconvolutionLayer, CLTargetInfo>(*polymorphic_downcast<DeconvolutionLayerNode *>(node), ctx);
+ case NodeType::ConcatenateLayer:
+ return detail::create_concatenate_layer<CLConcatenateLayer, CLTargetInfo>(*polymorphic_downcast<ConcatenateLayerNode *>(node));
case NodeType::DepthwiseConvolutionLayer:
- return create_depthwise_convolution_layer(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
+ return detail::create_depthwise_convolution_layer<CLDepthwiseConvolutionLayerFunctions, CLTargetInfo>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
case NodeType::EltwiseLayer:
- return create_eltwise_layer(*polymorphic_downcast<EltwiseLayerNode *>(node));
+ return detail::create_eltwise_layer<CLEltwiseFunctions, CLTargetInfo>(*polymorphic_downcast<EltwiseLayerNode *>(node));
case NodeType::FlattenLayer:
- return create_flatten_layer(*polymorphic_downcast<FlattenLayerNode *>(node));
+ return detail::create_flatten_layer<CLFlattenLayer, CLTargetInfo>(*polymorphic_downcast<FlattenLayerNode *>(node));
case NodeType::FullyConnectedLayer:
- return create_fully_connected_layer(*polymorphic_downcast<FullyConnectedLayerNode *>(node), ctx);
+ return detail::create_fully_connected_layer<CLFullyConnectedLayer, CLTargetInfo>(*polymorphic_downcast<FullyConnectedLayerNode *>(node), ctx);
case NodeType::NormalizationLayer:
- return create_normalization_layer(*polymorphic_downcast<NormalizationLayerNode *>(node));
+ return detail::create_normalization_layer<CLNormalizationLayer, CLTargetInfo>(*polymorphic_downcast<NormalizationLayerNode *>(node), ctx);
+ case NodeType::PermuteLayer:
+ return detail::create_permute_layer<CLPermute, CLTargetInfo>(*polymorphic_downcast<PermuteLayerNode *>(node));
case NodeType::PoolingLayer:
- return create_pooling_layer(*polymorphic_downcast<PoolingLayerNode *>(node));
+ return detail::create_pooling_layer<CLPoolingLayer, CLTargetInfo>(*polymorphic_downcast<PoolingLayerNode *>(node));
case NodeType::ReshapeLayer:
- return create_reshape_layer(*polymorphic_downcast<ReshapeLayerNode *>(node));
+ return detail::create_reshape_layer<CLReshapeLayer, CLTargetInfo>(*polymorphic_downcast<ReshapeLayerNode *>(node));
+ case NodeType::ResizeLayer:
+ return detail::create_resize_layer<CLScale, CLTargetInfo>(*polymorphic_downcast<ResizeLayerNode *>(node));
case NodeType::SoftmaxLayer:
- return create_softmax_layer(*polymorphic_downcast<SoftmaxLayerNode *>(node), ctx);
+ return detail::create_softmax_layer<CLSoftmaxLayer, CLTargetInfo>(*polymorphic_downcast<SoftmaxLayerNode *>(node), ctx);
default:
return nullptr;
}
}
} // namespace backends
} // namespace graph
-} // namespace arm_compute
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/graph/backends/CL/CLNodeValidator.cpp b/src/graph/backends/CL/CLNodeValidator.cpp
index c16b2e6..ba5b59d 100644
--- a/src/graph/backends/CL/CLNodeValidator.cpp
+++ b/src/graph/backends/CL/CLNodeValidator.cpp
@@ -47,6 +47,8 @@
NodeType type = node->type();
switch(type)
{
+ case NodeType::ChannelShuffleLayer:
+ return detail::validate_channel_shuffle_layer<CLChannelShuffleLayer>(*polymorphic_downcast<ChannelShuffleLayerNode *>(node));
case NodeType::ConvolutionLayer:
return detail::validate_convolution_layer<CLConvolutionLayer,
CLDirectConvolutionLayer,
@@ -55,6 +57,8 @@
case NodeType::DepthwiseConvolutionLayer:
return detail::validate_depthwise_convolution_layer<CLDepthwiseConvolutionLayer,
CLDepthwiseConvolutionLayer3x3>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
+ case NodeType::PermuteLayer:
+ return detail::validate_permute_layer<CLPermute>(*polymorphic_downcast<PermuteLayerNode *>(node));
default:
return Status{};
}
diff --git a/src/graph/backends/GLES/GCDeviceBackend.cpp b/src/graph/backends/GLES/GCDeviceBackend.cpp
index 770cca5..ec3cf4f 100644
--- a/src/graph/backends/GLES/GCDeviceBackend.cpp
+++ b/src/graph/backends/GLES/GCDeviceBackend.cpp
@@ -53,7 +53,7 @@
static detail::BackendRegistrar<GCDeviceBackend> GCDeviceBackend_registrar(Target::GC);
GCDeviceBackend::GCDeviceBackend()
- : _allocator()
+ : _initialized(false), _allocator()
{
}
@@ -63,8 +63,21 @@
GCScheduler::get().default_init();
}
+void GCDeviceBackend::release_backend_context(GraphContext &ctx)
+{
+ //Nothing to do
+ ARM_COMPUTE_UNUSED(ctx);
+}
+
void GCDeviceBackend::setup_backend_context(GraphContext &ctx)
{
+ // Force backend initialization
+ if(!_initialized)
+ {
+ initialize_backend();
+ _initialized = true;
+ }
+
// Setup a management backend
if(ctx.memory_management_ctx(Target::GC) == nullptr)
{
@@ -144,4 +157,4 @@
}
} // namespace backends
} // namespace graph
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/graph/backends/GLES/GCFunctionsFactory.cpp b/src/graph/backends/GLES/GCFunctionsFactory.cpp
index e61e840..f72513c 100644
--- a/src/graph/backends/GLES/GCFunctionsFactory.cpp
+++ b/src/graph/backends/GLES/GCFunctionsFactory.cpp
@@ -25,16 +25,9 @@
#include "arm_compute/core/utils/misc/Cast.h"
#include "arm_compute/graph/Graph.h"
-#include "arm_compute/graph/GraphContext.h"
-#include "arm_compute/graph/Logger.h"
-#include "arm_compute/graph/TypePrinter.h"
-#include "arm_compute/graph/Types.h"
-#include "arm_compute/graph/backends/Utils.h"
-#include "arm_compute/graph/nodes/Nodes.h"
+#include "arm_compute/graph/backends/FunctionHelpers.h"
#include "arm_compute/runtime/GLES_COMPUTE/GCFunctions.h"
-#include "support/ToolchainSupport.h"
-
using namespace arm_compute::utils::cast;
namespace arm_compute
@@ -43,120 +36,84 @@
{
namespace backends
{
-namespace
+/** Target specific information structure used to pass information to the layer templates */
+struct GCTargetInfo
{
-/** Returns backing tensor of a given tensor
- *
- * @param[in] tensor Tensor to extract the backing tensor from
- *
- * @return Backing tensor if present else nullptr
- */
-arm_compute::IGCTensor *get_backing_tensor(arm_compute::graph::Tensor *tensor)
+ using TensorType = arm_compute::IGCTensor;
+ static Target TargetType;
+};
+
+Target GCTargetInfo::TargetType = Target::GC;
+
+/** Collection of GC convolution functions */
+struct GCConvolutionLayerFunctions
{
- arm_compute::IGCTensor *backing_tensor = nullptr;
- if(tensor != nullptr)
+ using GenericConvolutionLayer = GCConvolutionLayer;
+ using GEMMConvolutionLayer = GCConvolutionLayer;
+ using DirectConvolutionLayer = GCDirectConvolutionLayer;
+};
+
+/** Collection of GC depthwise convolution functions */
+struct GCDepthwiseConvolutionLayerFunctions
+{
+ using DepthwiseConvolutionLayer3x3 = GCDepthwiseConvolutionLayer3x3;
+};
+
+/** Collection of GC element-wise functions */
+struct GCEltwiseFunctions
+{
+ using Addition = GCArithmeticAddition;
+ using Multiplication = GCPixelWiseMultiplication;
+};
+
+namespace detail
+{
+// Specialize functions
+template <>
+std::unique_ptr<IFunction> create_concatenate_layer<GCDepthConcatenateLayer, GCTargetInfo>(ConcatenateLayerNode &node)
+{
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating Concatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
+ ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
+
+ // Return nullptr if depth concatenate is switched off
+ if(!node.is_enabled())
{
- ARM_COMPUTE_ERROR_ON(tensor->desc().target != arm_compute::graph::Target::GC);
- // Get backing tensor handle
- ITensorHandle *tensor_handle = tensor->handle();
- // Get backing tensor
- backing_tensor = (tensor_handle != nullptr) ? polymorphic_cast<IGCTensor *>(&tensor_handle->tensor()) : nullptr;
+ return nullptr;
}
- return backing_tensor;
-}
-
-/** Create a backend activation layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend activation layer function
- */
-std::unique_ptr<IFunction> create_activation_layer(ActivationLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating GC ActivationLayerNode node with ID : " << node.id() << " and Name: " << node.name()
- << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
// Extract IO and info
- IGCTensor *input = get_backing_tensor(node.input(0));
- IGCTensor *output = get_backing_tensor(node.output(0));
- const ActivationLayerInfo act_info = node.activation_info();
-
- // Create function
- auto func = support::cpp14::make_unique<GCActivationLayer>();
- func->configure(input, output, act_info);
-
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCActivationLayer"
- << " Data Type: " << input->info()->data_type()
- << " Shape: " << input->info()->tensor_shape()
- << " Activation function: " << act_info.activation()
- << " a: " << act_info.a()
- << " b: " << act_info.b()
- << " InPlace : " << is_in_place_operation(input, output)
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend batch normalization layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend batch normalization layer function
- */
-std::unique_ptr<IFunction> create_batch_normalization_layer(BatchNormalizationLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating GC BatchNormalization node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
-
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 5);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- IGCTensor *input = get_backing_tensor(node.input(0));
- IGCTensor *mean = get_backing_tensor(node.input(1));
- IGCTensor *var = get_backing_tensor(node.input(2));
- IGCTensor *beta = get_backing_tensor(node.input(3));
- IGCTensor *gamma = get_backing_tensor(node.input(4));
- IGCTensor *output = get_backing_tensor(node.output(0));
- const float epsilon = node.epsilon();
- const ActivationLayerInfo fused_act = node.fused_activation();
+ std::vector<GCTargetInfo::TensorType *> inputs;
+ for(unsigned int i = 0; i < node.num_inputs(); ++i)
+ {
+ inputs.push_back(get_backing_tensor<GCTargetInfo>(node.input(i)));
+ }
+ typename GCTargetInfo::TensorType *output = get_backing_tensor<GCTargetInfo>(node.output(0));
// Create and configure function
- auto func = support::cpp14::make_unique<GCBatchNormalizationLayer>();
- func->configure(input, output, mean, var, beta, gamma, epsilon, fused_act);
+ auto func = support::cpp14::make_unique<GCDepthConcatenateLayer>();
+ func->configure(inputs, output);
// Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCBatchNormalizationLayer"
- << " Data Type: " << input->info()->data_type()
- << " Shape: " << input->info()->tensor_shape()
- << " Epsilon: " << epsilon << " "
- << (fused_act.enabled() ? to_string(fused_act.activation()) : "")
- << " InPlace : " << is_in_place_operation(input, output)
+ ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type()
+ << " Target " << GCTargetInfo::TargetType
+ << " Data Type: " << output->info()->data_type()
+ << " Shape: " << output->info()->tensor_shape()
+ << " Num Inputs: " << inputs.size()
<< std::endl);
return std::move(func);
}
-/** Create a backend convolution layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend convolution layer function
- */
-std::unique_ptr<IFunction> create_convolution_layer(ConvolutionLayerNode &node, GraphContext &ctx)
+template <>
+std::unique_ptr<IFunction> create_convolution_layer<GCConvolutionLayerFunctions, GCTargetInfo>(ConvolutionLayerNode &node, GraphContext &ctx)
{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating GC ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
+ validate_node<GCTargetInfo>(node, 3 /* expected inputs */, 1 /* expected outputs */);
// Extract IO and info
- IGCTensor *input = get_backing_tensor(node.input(0));
- IGCTensor *weights = get_backing_tensor(node.input(1));
- IGCTensor *biases = get_backing_tensor(node.input(2));
- IGCTensor *output = get_backing_tensor(node.output(0));
+ GCTargetInfo::TensorType *input = get_backing_tensor<GCTargetInfo>(node.input(0));
+ GCTargetInfo::TensorType *weights = get_backing_tensor<GCTargetInfo>(node.input(1));
+ GCTargetInfo::TensorType *biases = get_backing_tensor<GCTargetInfo>(node.input(2));
+ GCTargetInfo::TensorType *output = get_backing_tensor<GCTargetInfo>(node.output(0));
if(is_data_type_quantized_asymmetric(input->info()->data_type()))
{
@@ -167,19 +124,21 @@
const ConvolutionMethod conv_algorithm = node.convolution_method();
// Create and configure function (we assume that functions have been validated before creation)
- std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, Target::GC);
+ std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, GCTargetInfo::TargetType);
std::unique_ptr<IFunction> func;
std::string func_name;
- if(conv_algorithm == ConvolutionMethod::DIRECT)
+ if(conv_algorithm == ConvolutionMethod::Direct)
{
- std::tie(func, func_name) = create_named_function<GCDirectConvolutionLayer>(
- std::string("GCDirectConvolutionLayer"), input, weights, biases, output, conv_info);
+ std::tie(func, func_name) = create_named_function<GCConvolutionLayerFunctions::DirectConvolutionLayer>(
+ std::string("DirectConvolutionLayer"),
+ input, weights, biases, output, conv_info);
}
else
{
- std::tie(func, func_name) = create_named_memory_managed_function<GCConvolutionLayer>(std::string("GCConvolutionLayer"), mm,
- input, weights, biases, output, conv_info);
+ std::tie(func, func_name) = create_named_memory_managed_function<GCConvolutionLayerFunctions::GenericConvolutionLayer>(
+ std::string("ConvolutionLayer"), mm,
+ input, weights, biases, output, conv_info);
}
// Log info
@@ -194,64 +153,16 @@
return func;
}
-/** Create a backend layer depth concatenate function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend depth concatenate layer function
- */
-std::unique_ptr<arm_compute::IFunction> create_depth_concatenate_layer(DepthConcatenateLayerNode &node)
+template <>
+std::unique_ptr<IFunction> create_depthwise_convolution_layer<GCDepthwiseConvolutionLayerFunctions, GCTargetInfo>(DepthwiseConvolutionLayerNode &node)
{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating GC DepthConcatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Return nullptr if depth concatenate is switched off
- if(!node.is_enabled())
- {
- return nullptr;
- }
+ validate_node<GCTargetInfo>(node, 3 /* expected inputs */, 1 /* expected outputs */);
// Extract IO and info
- std::vector<arm_compute::IGCTensor *> inputs;
- for(unsigned int i = 0; i < node.num_inputs(); ++i)
- {
- inputs.push_back(get_backing_tensor(node.input(i)));
- }
- IGCTensor *output = get_backing_tensor(node.output(0));
-
- // Create and configure function
- auto func = support::cpp14::make_unique<GCDepthConcatenateLayer>();
- func->configure(inputs, output);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCDepthConcatenateLayer"
- << " Data Type: " << output->info()->data_type()
- << " Shape: " << output->info()->tensor_shape()
- << " Num Inputs: " << inputs.size()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend layer depth-wise convolution function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend depth-wise convolution layer function
- */
-std::unique_ptr<IFunction> create_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating GC DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name()
- << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- IGCTensor *input = get_backing_tensor(node.input(0));
- IGCTensor *weights = get_backing_tensor(node.input(1));
- IGCTensor *biases = get_backing_tensor(node.input(2));
- IGCTensor *output = get_backing_tensor(node.output(0));
+ GCTargetInfo::TensorType *input = get_backing_tensor<GCTargetInfo>(node.input(0));
+ GCTargetInfo::TensorType *weights = get_backing_tensor<GCTargetInfo>(node.input(1));
+ GCTargetInfo::TensorType *biases = get_backing_tensor<GCTargetInfo>(node.input(2));
+ GCTargetInfo::TensorType *output = get_backing_tensor<GCTargetInfo>(node.output(0));
if(is_data_type_quantized_asymmetric(input->info()->data_type()))
{
@@ -264,10 +175,11 @@
// Create and configure function (we assume that functions have been validated before creation)
std::unique_ptr<IFunction> func;
std::string func_name;
- if(dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3)
+ if(dwc_algorithm == DepthwiseConvolutionMethod::Optimized3x3)
{
- std::tie(func, func_name) = create_named_function<GCDepthwiseConvolutionLayer3x3>(
- std::string("GCDepthwiseConvolutionLayer3x3"), input, weights, biases, output, conv_info);
+ std::tie(func, func_name) = create_named_function<GCDepthwiseConvolutionLayerFunctions::DepthwiseConvolutionLayer3x3>(
+ std::string("DepthwiseConvolutionLayer3x3"),
+ input, weights, biases, output, conv_info);
}
else
{
@@ -276,6 +188,7 @@
// Log info
ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name
+ << " Target " << GCTargetInfo::TargetType
<< " Data Type: " << input->info()->data_type()
<< " Input QuantInfo: " << input->info()->quantization_info()
<< " Weights QuantInfo: " << weights->info()->quantization_info()
@@ -286,13 +199,8 @@
return func;
}
-/** Create a backend element-wise operation layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend element-wise operation layer function
- */
-std::unique_ptr<IFunction> create_eltwise_layer(EltwiseLayerNode &node)
+template <>
+std::unique_ptr<IFunction> create_eltwise_layer<GCEltwiseFunctions, GCTargetInfo>(EltwiseLayerNode &node)
{
ARM_COMPUTE_LOG_GRAPH_VERBOSE(
"Creating GC EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
@@ -300,31 +208,32 @@
ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
// Extract IO and info
- IGCTensor *input1 = get_backing_tensor(node.input(0));
- IGCTensor *input2 = get_backing_tensor(node.input(1));
- IGCTensor *output = get_backing_tensor(node.output(0));
- const EltwiseOperation eltwise_op = node.eltwise_operation();
- const ConvertPolicy convert_policy = node.convert_policy();
+ GCTargetInfo::TensorType *input1 = get_backing_tensor<GCTargetInfo>(node.input(0));
+ GCTargetInfo::TensorType *input2 = get_backing_tensor<GCTargetInfo>(node.input(1));
+ GCTargetInfo::TensorType *output = get_backing_tensor<GCTargetInfo>(node.output(0));
+ const EltwiseOperation eltwise_op = node.eltwise_operation();
+ const ConvertPolicy convert_policy = node.convert_policy();
ARM_COMPUTE_ERROR_ON(input1 == nullptr);
ARM_COMPUTE_ERROR_ON(input2 == nullptr);
ARM_COMPUTE_ERROR_ON(output == nullptr);
std::unique_ptr<IFunction> func = nullptr;
std::string func_name;
- if(eltwise_op == EltwiseOperation::ADD)
+ if(eltwise_op == EltwiseOperation::Add)
{
- std::tie(func, func_name) = create_named_function<GCArithmeticAddition>(std::string("GCArithmeticAddition"),
- input1, input2, output,
- convert_policy);
+ std::tie(func, func_name) = create_named_function<GCEltwiseFunctions::Addition>(
+ std::string("GCArithmeticAddition"),
+ input1, input2, output, convert_policy);
}
- else if(eltwise_op == EltwiseOperation::SUB)
+ else if(eltwise_op == EltwiseOperation::Sub)
{
ARM_COMPUTE_ERROR("Arithmetic subtraction is not supported in GLES backend");
}
- else if(eltwise_op == EltwiseOperation::MUL)
+ else if(eltwise_op == EltwiseOperation::Mul)
{
- std::tie(func, func_name) = create_named_function<GCPixelWiseMultiplication>(
- std::string("GCPixelWiseMultiplication"), input1, input2, output, 1.f);
+ std::tie(func, func_name) = create_named_function<GCEltwiseFunctions::Multiplication>(
+ std::string("PixelWiseMultiplication"),
+ input1, input2, output, 1.f);
}
else
{
@@ -332,157 +241,16 @@
}
// Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name
+ ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type()
+ << " Target " << GCTargetInfo::TargetType
+ << " Operation " << func_name
<< " Data Type: " << input1->info()->data_type()
<< " Shape : " << input1->info()->tensor_shape()
<< std::endl);
return func;
}
-
-/** Create a backend fully connected layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend fully connected layer function
- */
-std::unique_ptr<IFunction> create_fully_connected_layer(FullyConnectedLayerNode &node, GraphContext &ctx)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating GC FullyConnectedLayer node with ID : " << node.id() << " and Name: " << node.name()
- << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- IGCTensor *input = get_backing_tensor(node.input(0));
- IGCTensor *weights = get_backing_tensor(node.input(1));
- IGCTensor *biases = get_backing_tensor(node.input(2));
- IGCTensor *output = get_backing_tensor(node.output(0));
-
- // Create and configure function
- auto func = support::cpp14::make_unique<GCFullyConnectedLayer>(get_memory_manager(ctx, Target::GC));
- func->configure(input, weights, biases, output);
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(weights == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCFullyConnectedLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Weights shape: " << weights->info()->tensor_shape()
- << " Biases Shape: " << biases->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend normalization layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend normalization layer function
- */
-std::unique_ptr<IFunction> create_normalization_layer(NormalizationLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating GC NormalizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- IGCTensor *input = get_backing_tensor(node.input(0));
- IGCTensor *output = get_backing_tensor(node.output(0));
- const NormalizationLayerInfo norm_info = node.normalization_info();
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<GCNormalizationLayer>();
- func->configure(input, output, norm_info);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCNormalizationLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << " Normalization info: " << norm_info.type()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend pooling layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend pooling layer function
- */
-std::unique_ptr<IFunction> create_pooling_layer(PoolingLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating GC PoolingLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- IGCTensor *input = get_backing_tensor(node.input(0));
- IGCTensor *output = get_backing_tensor(node.output(0));
- const PoolingLayerInfo pool_info = node.pooling_info();
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<GCPoolingLayer>();
- func->configure(input, output, pool_info);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCPoolingLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << " Pooling info: " << pool_info.pool_type()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend softmax layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend softmax layer function
- */
-std::unique_ptr<IFunction> create_softmax_layer(SoftmaxLayerNode &node, GraphContext &ctx)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE(
- "Creating GC SoftmaxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- IGCTensor *input = get_backing_tensor(node.input(0));
- IGCTensor *output = get_backing_tensor(node.output(0));
- const float beta = node.beta();
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<GCSoftmaxLayer>(get_memory_manager(ctx, Target::CL));
- func->configure(input, output, beta);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCSoftmaxLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-} // namespace
+} //namespace detail
std::unique_ptr<IFunction> GCFunctionFactory::create(INode *node, GraphContext &ctx)
{
@@ -495,29 +263,31 @@
switch(type)
{
case NodeType::ActivationLayer:
- return create_activation_layer(*polymorphic_downcast<ActivationLayerNode *>(node));
+ return detail::create_activation_layer<GCActivationLayer, GCTargetInfo>(*polymorphic_downcast<ActivationLayerNode *>(node));
case NodeType::BatchNormalizationLayer:
- return create_batch_normalization_layer(*polymorphic_downcast<BatchNormalizationLayerNode *>(node));
+ return detail::create_batch_normalization_layer<GCBatchNormalizationLayer, GCTargetInfo>(*polymorphic_downcast<BatchNormalizationLayerNode *>(node));
case NodeType::ConvolutionLayer:
- return create_convolution_layer(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
- case NodeType::DepthConcatenateLayer:
- return create_depth_concatenate_layer(*polymorphic_downcast<DepthConcatenateLayerNode *>(node));
+ return detail::create_convolution_layer<GCConvolutionLayerFunctions, GCTargetInfo>(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
+ case NodeType::ConcatenateLayer:
+ return detail::create_concatenate_layer<GCDepthConcatenateLayer, GCTargetInfo>(*polymorphic_downcast<ConcatenateLayerNode *>(node));
case NodeType::DepthwiseConvolutionLayer:
- return create_depthwise_convolution_layer(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
+ return detail::create_depthwise_convolution_layer<GCDepthwiseConvolutionLayerFunctions, GCTargetInfo>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
case NodeType::EltwiseLayer:
- return create_eltwise_layer(*polymorphic_downcast<EltwiseLayerNode *>(node));
+ return detail::create_eltwise_layer<GCEltwiseFunctions, GCTargetInfo>(*polymorphic_downcast<EltwiseLayerNode *>(node));
case NodeType::FullyConnectedLayer:
- return create_fully_connected_layer(*polymorphic_downcast<FullyConnectedLayerNode *>(node), ctx);
+ return detail::create_fully_connected_layer<GCFullyConnectedLayer, GCTargetInfo>(*polymorphic_downcast<FullyConnectedLayerNode *>(node), ctx);
case NodeType::NormalizationLayer:
- return create_normalization_layer(*polymorphic_downcast<NormalizationLayerNode *>(node));
+ return detail::create_normalization_layer<GCNormalizationLayer, GCTargetInfo>(*polymorphic_downcast<NormalizationLayerNode *>(node), ctx);
case NodeType::PoolingLayer:
- return create_pooling_layer(*polymorphic_downcast<PoolingLayerNode *>(node));
+ return detail::create_pooling_layer<GCPoolingLayer, GCTargetInfo>(*polymorphic_downcast<PoolingLayerNode *>(node));
+ case NodeType::ResizeLayer:
+ return detail::create_resize_layer<GCScale, GCTargetInfo>(*polymorphic_downcast<ResizeLayerNode *>(node));
case NodeType::SoftmaxLayer:
- return create_softmax_layer(*polymorphic_downcast<SoftmaxLayerNode *>(node), ctx);
+ return detail::create_softmax_layer<GCSoftmaxLayer, GCTargetInfo>(*polymorphic_downcast<SoftmaxLayerNode *>(node), ctx);
default:
return nullptr;
}
}
} // namespace backends
} // namespace graph
-} // namespace arm_compute
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/graph/backends/GLES/GCNodeValidator.cpp b/src/graph/backends/GLES/GCNodeValidator.cpp
index c7f7d81..53049c7 100644
--- a/src/graph/backends/GLES/GCNodeValidator.cpp
+++ b/src/graph/backends/GLES/GCNodeValidator.cpp
@@ -57,7 +57,7 @@
// Validate function
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->tensor_shape().x() != 3 && weights->tensor_shape().y() != 3, "Unsupported depthwise convolution");
- node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::OPTIMIZED_3x3);
+ node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::Optimized3x3);
return Status{};
}
@@ -79,15 +79,13 @@
const ConvolutionMethod conv_algorithm = node.convolution_method();
// Validate function
- if(conv_algorithm == ConvolutionMethod::DIRECT)
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(node.num_groups() != 1, "Grouping is not supported by ConvolutionLayer!");
+ if(conv_algorithm == ConvolutionMethod::Direct)
{
bool is_square = weights->tensor_shape().x() == weights->tensor_shape().y();
bool is_direct = (weights->tensor_shape().x() == 1) || (weights->tensor_shape().x() == 3) || (weights->tensor_shape().x() == 5);
bool is_correct_stride = (conv_info.stride().first) <= 2 && (conv_info.stride().second <= 2);
- if(!(is_square && is_direct && is_correct_stride))
- {
- node.set_convolution_method(ConvolutionMethod::DEFAULT);
- }
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(is_square && is_direct && is_correct_stride), "Direct convolution is not supported for given configuration");
}
return Status{};
@@ -104,14 +102,18 @@
NodeType type = node->type();
switch(type)
{
+ case NodeType::ChannelShuffleLayer:
+ return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation : ChannelShuffleLayer");
case NodeType::ConvolutionLayer:
return validate_convolution_layer(*polymorphic_downcast<ConvolutionLayerNode *>(node));
case NodeType::DepthwiseConvolutionLayer:
return validate_depthwise_convolution_layer(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
case NodeType::FlattenLayer:
- return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation");
+ return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation : FlattenLayer");
+ case NodeType::PermuteLayer:
+ return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation : PermuteLayer");
case NodeType::ReshapeLayer:
- return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation");
+ return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation : ReshapeLayer");
default:
return Status{};
}
diff --git a/src/graph/backends/NEON/NEDeviceBackend.cpp b/src/graph/backends/NEON/NEDeviceBackend.cpp
index 7c2db40..5fc44d0 100644
--- a/src/graph/backends/NEON/NEDeviceBackend.cpp
+++ b/src/graph/backends/NEON/NEDeviceBackend.cpp
@@ -61,6 +61,13 @@
void NEDeviceBackend::initialize_backend()
{
+ //Nothing to do
+}
+
+void NEDeviceBackend::release_backend_context(GraphContext &ctx)
+{
+ //Nothing to do
+ ARM_COMPUTE_UNUSED(ctx);
}
void NEDeviceBackend::setup_backend_context(GraphContext &ctx)
@@ -155,4 +162,4 @@
}
} // namespace backends
} // namespace graph
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/graph/backends/NEON/NEFunctionFactory.cpp b/src/graph/backends/NEON/NEFunctionFactory.cpp
index 7b1c50f..36a25ad 100644
--- a/src/graph/backends/NEON/NEFunctionFactory.cpp
+++ b/src/graph/backends/NEON/NEFunctionFactory.cpp
@@ -28,6 +28,7 @@
#include "arm_compute/graph/GraphContext.h"
#include "arm_compute/graph/Logger.h"
#include "arm_compute/graph/TypePrinter.h"
+#include "arm_compute/graph/backends/FunctionHelpers.h"
#include "arm_compute/graph/backends/Utils.h"
#include "arm_compute/graph/nodes/Nodes.h"
#include "arm_compute/runtime/NEON/NEFunctions.h"
@@ -41,108 +42,53 @@
{
namespace backends
{
-namespace
+/** Target specific information structure used to pass information to the layer templates */
+struct NETargetInfo
{
-/** Returns backing tensor of a given tensor
- *
- * @param[in] tensor Tensor to extract the backing tensor from
- *
- * @return Backing tensor if present else nullptr
- */
-arm_compute::ITensor *get_backing_tensor(arm_compute::graph::Tensor *tensor)
-{
- return ((tensor == nullptr) || (tensor->handle() == nullptr)) ? nullptr : &tensor->handle()->tensor();
-}
+ using TensorType = arm_compute::ITensor;
+ static Target TargetType;
+};
-/** Create a backend activation layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend activation layer function
- */
-std::unique_ptr<IFunction> create_activation_layer(ActivationLayerNode &node)
+Target NETargetInfo::TargetType = Target::NEON;
+
+/** Collection of CL convolution functions */
+struct NEConvolutionLayerFunctions
{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON ActivationLayerNode node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
+ using GenericConvolutionLayer = NEConvolutionLayer;
+ using GEMMConvolutionLayer = NEGEMMConvolutionLayer;
+ using DirectConvolutionLayer = NEDirectConvolutionLayer;
+ using WinogradConvolutionLayer = NEWinogradConvolutionLayer;
+};
+
+/** Collection of CL depthwise convolution functions */
+struct NEDepthwiseConvolutionLayerFunctions
+{
+ using GenericDepthwiseConvolutionLayer = NEDepthwiseConvolutionLayer;
+ using DepthwiseConvolutionLayer3x3 = NEDepthwiseConvolutionLayer3x3;
+};
+
+/** Collection of CL element-wise functions */
+struct NEEltwiseFunctions
+{
+ using Addition = NEArithmeticAddition;
+ using Subtraction = NEArithmeticSubtraction;
+ using Multiplication = NEPixelWiseMultiplication;
+};
+
+namespace detail
+{
+// Specialize functions
+template <>
+std::unique_ptr<IFunction> create_convolution_layer<NEConvolutionLayerFunctions, NETargetInfo>(ConvolutionLayerNode &node,
+ GraphContext &ctx)
+{
+ validate_node<NETargetInfo>(node, 3 /* expected inputs */, 1 /* expected outputs */);
// Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *output = get_backing_tensor(node.output(0));
- const ActivationLayerInfo act_info = node.activation_info();
-
- // Create function
- auto func = support::cpp14::make_unique<NEActivationLayer>();
- func->configure(input, output, act_info);
-
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEActivationLayer"
- << " Data Type: " << input->info()->data_type()
- << " Shape: " << input->info()->tensor_shape()
- << " Activation function: " << act_info.activation()
- << " a: " << act_info.a()
- << " b: " << act_info.b()
- << " InPlace : " << is_in_place_operation(input, output)
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend batch normalization layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend batch normalization layer function
- */
-std::unique_ptr<IFunction> create_batch_normalization_layer(BatchNormalizationLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON BatchNormalization node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
-
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 5);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *mean = get_backing_tensor(node.input(1));
- ITensor *var = get_backing_tensor(node.input(2));
- ITensor *beta = get_backing_tensor(node.input(3));
- ITensor *gamma = get_backing_tensor(node.input(4));
- ITensor *output = get_backing_tensor(node.output(0));
- const float epsilon = node.epsilon();
- const ActivationLayerInfo fused_act = node.fused_activation();
-
- // Create and configure function
- auto func = support::cpp14::make_unique<NEBatchNormalizationLayer>();
- func->configure(input, output, mean, var, beta, gamma, epsilon, fused_act);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEBatchNormalizationLayer"
- << " Data Type: " << input->info()->data_type()
- << " Shape: " << input->info()->tensor_shape()
- << " Epsilon: " << epsilon << " "
- << (fused_act.enabled() ? to_string(fused_act.activation()) : "")
- << " InPlace : " << is_in_place_operation(input, output)
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend convolution layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend convolution layer function
- */
-std::unique_ptr<IFunction> create_convolution_layer(ConvolutionLayerNode &node, GraphContext &ctx)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *weights = get_backing_tensor(node.input(1));
- ITensor *biases = get_backing_tensor(node.input(2));
- ITensor *output = get_backing_tensor(node.output(0));
+ NETargetInfo::TensorType *input = get_backing_tensor<NETargetInfo>(node.input(0));
+ NETargetInfo::TensorType *weights = get_backing_tensor<NETargetInfo>(node.input(1));
+ NETargetInfo::TensorType *biases = get_backing_tensor<NETargetInfo>(node.input(2));
+ NETargetInfo::TensorType *output = get_backing_tensor<NETargetInfo>(node.output(0));
if(is_data_type_quantized_asymmetric(input->info()->data_type()))
{
@@ -156,29 +102,30 @@
std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, Target::NEON);
std::unique_ptr<IFunction> func;
std::string func_name;
- if(conv_algorithm == ConvolutionMethod::DIRECT)
+ if(conv_algorithm == ConvolutionMethod::Direct)
{
- std::tie(func, func_name) = create_named_memory_managed_function<NEDirectConvolutionLayer>(std::string("NEDirectConvolutionLayer"), mm,
- input, weights, biases, output, conv_info);
+ std::tie(func, func_name) = create_named_memory_managed_function<NEDirectConvolutionLayer>(
+ std::string("DirectConvolutionLayer"), mm, input, weights, biases, output, conv_info);
}
else if(conv_algorithm == ConvolutionMethod::GEMM)
{
- std::tie(func, func_name) = create_named_memory_managed_function<NEGEMMConvolutionLayer>(std::string("NEGEMMConvolutionLayer"), mm,
- input, weights, biases, output, conv_info);
+ std::tie(func, func_name) = create_named_memory_managed_function<NEGEMMConvolutionLayer>(
+ std::string("GEMMConvolutionLayer"), mm, input, weights, biases, output, conv_info);
}
- else if(conv_algorithm == ConvolutionMethod::WINOGRAD)
+ else if(conv_algorithm == ConvolutionMethod::Winograd)
{
- std::tie(func, func_name) = create_named_memory_managed_function<NEWinogradConvolutionLayer>(std::string("NEWinogradConvolutionLayer"), mm,
- input, weights, biases, output, conv_info);
+ std::tie(func, func_name) = create_named_memory_managed_function<NEWinogradConvolutionLayer>(
+ std::string("WinogradConvolutionLayer"), mm, input, weights, biases, output, conv_info);
}
else
{
- std::tie(func, func_name) = create_named_memory_managed_function<NEConvolutionLayer>(std::string("NEConvolutionLayer"), mm,
- input, weights, biases, output, conv_info);
+ std::tie(func, func_name) = create_named_memory_managed_function<NEConvolutionLayer>(
+ std::string("ConvolutionLayer"), mm, input, weights, biases, output, conv_info);
}
// Log info
ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name
+ << " Target " << NETargetInfo::TargetType
<< " Data Type: " << input->info()->data_type()
<< " Input QuantInfo: " << input->info()->quantization_info()
<< " Weights QuantInfo: " << weights->info()->quantization_info()
@@ -189,244 +136,25 @@
return func;
}
-/** Create a backend layer depth concatenate function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend depth concatenate layer function
- */
-std::unique_ptr<arm_compute::IFunction> create_depth_concatenate_layer(DepthConcatenateLayerNode &node)
+template <>
+std::unique_ptr<IFunction> create_normalization_layer<NENormalizationLayer, NETargetInfo>(NormalizationLayerNode &node, GraphContext &ctx)
{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON DepthConcatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Return nullptr if depth concatenate is switched off
- if(!node.is_enabled())
- {
- return nullptr;
- }
+ validate_node<NETargetInfo>(node, 1 /* expected inputs */, 1 /* expected outputs */);
// Extract IO and info
- std::vector<arm_compute::ITensor *> inputs;
- for(unsigned int i = 0; i < node.num_inputs(); ++i)
- {
- inputs.push_back(get_backing_tensor(node.input(i)));
- }
- ITensor *output = get_backing_tensor(node.output(0));
-
- // Create and configure function
- auto func = support::cpp14::make_unique<NEDepthConcatenateLayer>();
- func->configure(inputs, output);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEDepthConcatenateLayer"
- << " Data Type: " << output->info()->data_type()
- << " Shape: " << output->info()->tensor_shape()
- << " Num Inputs: " << inputs.size()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend layer depth-wise convolution function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend depth-wise convolution layer function
- */
-std::unique_ptr<IFunction> create_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *weights = get_backing_tensor(node.input(1));
- ITensor *biases = get_backing_tensor(node.input(2));
- ITensor *output = get_backing_tensor(node.output(0));
-
- if(is_data_type_quantized_asymmetric(input->info()->data_type()))
- {
- biases->info()->set_data_type(DataType::S32);
- }
-
- const PadStrideInfo conv_info = node.convolution_info();
- const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method();
-
- // Create and configure function (we assume that functions have been validated before creation)
- std::unique_ptr<IFunction> func;
- std::string func_name;
- if(dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3)
- {
- std::tie(func, func_name) = create_named_function<NEDepthwiseConvolutionLayer3x3>(std::string("NEDepthwiseConvolutionLayer3x3"),
- input, weights, biases, output, conv_info);
- }
- else
- {
- std::tie(func, func_name) = create_named_function<NEDepthwiseConvolutionLayer>(std::string("NEDepthwiseConvolutionLayer"),
- input, weights, biases, output, conv_info);
- }
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name
- << " Data Type: " << input->info()->data_type()
- << " Input QuantInfo: " << input->info()->quantization_info()
- << " Weights QuantInfo: " << weights->info()->quantization_info()
- << " Input shape: " << input->info()->tensor_shape()
- << " Weights shape: " << weights->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
- return func;
-}
-
-/** Create a backend element-wise operation layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend element-wise operation layer function
- */
-std::unique_ptr<IFunction> create_eltwise_layer(EltwiseLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 2);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input1 = get_backing_tensor(node.input(0));
- ITensor *input2 = get_backing_tensor(node.input(1));
- ITensor *output = get_backing_tensor(node.output(0));
- const EltwiseOperation eltwise_op = node.eltwise_operation();
- const ConvertPolicy convert_policy = node.convert_policy();
- ARM_COMPUTE_ERROR_ON(input1 == nullptr);
- ARM_COMPUTE_ERROR_ON(input2 == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- std::unique_ptr<IFunction> func = nullptr;
- std::string func_name;
- if(eltwise_op == EltwiseOperation::ADD)
- {
- std::tie(func, func_name) = create_named_function<NEArithmeticAddition>(std::string("NEArithmeticAddition"),
- input1, input2, output, convert_policy);
- }
- else if(eltwise_op == EltwiseOperation::SUB)
- {
- std::tie(func, func_name) = create_named_function<NEArithmeticSubtraction>(std::string("NEArithmeticSubtraction"),
- input1, input2, output, convert_policy);
- }
- else if(eltwise_op == EltwiseOperation::MUL)
- {
- std::tie(func, func_name) = create_named_function<NEPixelWiseMultiplication>(std::string("NEPixelWiseMultiplication"),
- input1, input2, output, 1.f,
- convert_policy, node.rounding_policy());
- }
- else
- {
- ARM_COMPUTE_ERROR("Unsupported element-wise operation!");
- }
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name
- << " Data Type: " << input1->info()->data_type()
- << " Shape : " << input1->info()->tensor_shape()
- << std::endl);
-
- return func;
-}
-
-/** Create a backend flatten layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend flatten layer function
- */
-std::unique_ptr<IFunction> create_flatten_layer(FlattenLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON FlattenLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *output = get_backing_tensor(node.output(0));
-
- // Create and configure function
- auto func = support::cpp14::make_unique<NEFlattenLayer>();
- func->configure(input, output);
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEFlattenLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend fully connected layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend fully connected layer function
- */
-std::unique_ptr<IFunction> create_fully_connected_layer(FullyConnectedLayerNode &node, GraphContext &ctx)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON FullyConnectedLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *weights = get_backing_tensor(node.input(1));
- ITensor *biases = get_backing_tensor(node.input(2));
- ITensor *output = get_backing_tensor(node.output(0));
-
- // Create and configure function
- auto func = support::cpp14::make_unique<NEFullyConnectedLayer>(get_memory_manager(ctx, Target::NEON));
- func->configure(input, weights, biases, output);
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(weights == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEFullyConnectedLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Weights shape: " << weights->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend normalization layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend normalization layer function
- */
-std::unique_ptr<IFunction> create_normalization_layer(NormalizationLayerNode &node, GraphContext &ctx)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON NormalizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *output = get_backing_tensor(node.output(0));
+ NETargetInfo::TensorType *input = get_backing_tensor<NETargetInfo>(node.input(0));
+ NETargetInfo::TensorType *output = get_backing_tensor<NETargetInfo>(node.output(0));
const NormalizationLayerInfo norm_info = node.normalization_info();
ARM_COMPUTE_ERROR_ON(input == nullptr);
ARM_COMPUTE_ERROR_ON(output == nullptr);
// Create and configure function
- auto func = support::cpp14::make_unique<NENormalizationLayer>(get_memory_manager(ctx, Target::NEON));
+ auto func = support::cpp14::make_unique<NENormalizationLayer>(get_memory_manager(ctx, NETargetInfo::TargetType));
func->configure(input, output, norm_info);
// Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NENormalizationLayer"
+ ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type()
+ << " Target " << NETargetInfo::TargetType
<< " Data Type: " << input->info()->data_type()
<< " Input shape: " << input->info()->tensor_shape()
<< " Output shape: " << output->info()->tensor_shape()
@@ -435,106 +163,7 @@
return std::move(func);
}
-
-/** Create a backend pooling layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend pooling layer function
- */
-std::unique_ptr<IFunction> create_pooling_layer(PoolingLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON PoolingLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *output = get_backing_tensor(node.output(0));
- const PoolingLayerInfo pool_info = node.pooling_info();
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<NEPoolingLayer>();
- func->configure(input, output, pool_info);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEPoolingLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << " Pooling info: " << pool_info.pool_type()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend reshape layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend reshape layer function
- */
-std::unique_ptr<IFunction> create_reshape_layer(ReshapeLayerNode &node)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON ReshapeLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *output = get_backing_tensor(node.output(0));
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<NEReshapeLayer>();
- func->configure(input, output);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEReshapeLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-
-/** Create a backend softmax layer function
- *
- * @param[in] node Node to create the backend function for
- *
- * @return Backend softmax layer function
- */
-std::unique_ptr<IFunction> create_softmax_layer(SoftmaxLayerNode &node, GraphContext &ctx)
-{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON SoftmaxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
- ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1);
- ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1);
-
- // Extract IO and info
- ITensor *input = get_backing_tensor(node.input(0));
- ITensor *output = get_backing_tensor(node.output(0));
- const float beta = node.beta();
- ARM_COMPUTE_ERROR_ON(input == nullptr);
- ARM_COMPUTE_ERROR_ON(output == nullptr);
-
- // Create and configure function
- auto func = support::cpp14::make_unique<NESoftmaxLayer>(get_memory_manager(ctx, Target::NEON));
- func->configure(input, output, beta);
-
- // Log info
- ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NESoftmaxLayer"
- << " Data Type: " << input->info()->data_type()
- << " Input shape: " << input->info()->tensor_shape()
- << " Output shape: " << output->info()->tensor_shape()
- << std::endl);
-
- return std::move(func);
-}
-} // namespace
+} // namespace detail
std::unique_ptr<IFunction> NEFunctionFactory::create(INode *node, GraphContext &ctx)
{
@@ -547,33 +176,39 @@
switch(type)
{
case NodeType::ActivationLayer:
- return create_activation_layer(*polymorphic_downcast<ActivationLayerNode *>(node));
+ return detail::create_activation_layer<NEActivationLayer, NETargetInfo>(*polymorphic_downcast<ActivationLayerNode *>(node));
case NodeType::BatchNormalizationLayer:
- return create_batch_normalization_layer(*polymorphic_downcast<BatchNormalizationLayerNode *>(node));
+ return detail::create_batch_normalization_layer<NEBatchNormalizationLayer, NETargetInfo>(*polymorphic_downcast<BatchNormalizationLayerNode *>(node));
case NodeType::ConvolutionLayer:
- return create_convolution_layer(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
- case NodeType::DepthConcatenateLayer:
- return create_depth_concatenate_layer(*polymorphic_downcast<DepthConcatenateLayerNode *>(node));
+ return detail::create_convolution_layer<NEConvolutionLayerFunctions, NETargetInfo>(*polymorphic_downcast<ConvolutionLayerNode *>(node), ctx);
+ case NodeType::DeconvolutionLayer:
+ return detail::create_deconvolution_layer<NEDeconvolutionLayer, NETargetInfo>(*polymorphic_downcast<DeconvolutionLayerNode *>(node), ctx);
+ case NodeType::ConcatenateLayer:
+ return detail::create_concatenate_layer<NEConcatenateLayer, NETargetInfo>(*polymorphic_downcast<ConcatenateLayerNode *>(node));
case NodeType::DepthwiseConvolutionLayer:
- return create_depthwise_convolution_layer(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
+ return detail::create_depthwise_convolution_layer<NEDepthwiseConvolutionLayerFunctions, NETargetInfo>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
case NodeType::EltwiseLayer:
- return create_eltwise_layer(*polymorphic_downcast<EltwiseLayerNode *>(node));
+ return detail::create_eltwise_layer<NEEltwiseFunctions, NETargetInfo>(*polymorphic_downcast<EltwiseLayerNode *>(node));
case NodeType::FlattenLayer:
- return create_flatten_layer(*polymorphic_downcast<FlattenLayerNode *>(node));
+ return detail::create_flatten_layer<NEFlattenLayer, NETargetInfo>(*polymorphic_downcast<FlattenLayerNode *>(node));
case NodeType::FullyConnectedLayer:
- return create_fully_connected_layer(*polymorphic_downcast<FullyConnectedLayerNode *>(node), ctx);
+ return detail::create_fully_connected_layer<NEFullyConnectedLayer, NETargetInfo>(*polymorphic_downcast<FullyConnectedLayerNode *>(node), ctx);
case NodeType::NormalizationLayer:
- return create_normalization_layer(*polymorphic_downcast<NormalizationLayerNode *>(node), ctx);
+ return detail::create_normalization_layer<NENormalizationLayer, NETargetInfo>(*polymorphic_downcast<NormalizationLayerNode *>(node), ctx);
+ case NodeType::PermuteLayer:
+ return detail::create_permute_layer<NEPermute, NETargetInfo>(*polymorphic_downcast<PermuteLayerNode *>(node));
case NodeType::PoolingLayer:
- return create_pooling_layer(*polymorphic_downcast<PoolingLayerNode *>(node));
+ return detail::create_pooling_layer<NEPoolingLayer, NETargetInfo>(*polymorphic_downcast<PoolingLayerNode *>(node));
case NodeType::ReshapeLayer:
- return create_reshape_layer(*polymorphic_downcast<ReshapeLayerNode *>(node));
+ return detail::create_reshape_layer<NEReshapeLayer, NETargetInfo>(*polymorphic_downcast<ReshapeLayerNode *>(node));
+ case NodeType::ResizeLayer:
+ return detail::create_resize_layer<NEScale, NETargetInfo>(*polymorphic_downcast<ResizeLayerNode *>(node));
case NodeType::SoftmaxLayer:
- return create_softmax_layer(*polymorphic_downcast<SoftmaxLayerNode *>(node), ctx);
+ return detail::create_softmax_layer<NESoftmaxLayer, NETargetInfo>(*polymorphic_downcast<SoftmaxLayerNode *>(node), ctx);
default:
return nullptr;
}
}
} // namespace backends
} // namespace graph
-} // namespace arm_compute
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/graph/backends/NEON/NENodeValidator.cpp b/src/graph/backends/NEON/NENodeValidator.cpp
index e438e79..58ffaf0 100644
--- a/src/graph/backends/NEON/NENodeValidator.cpp
+++ b/src/graph/backends/NEON/NENodeValidator.cpp
@@ -47,6 +47,8 @@
NodeType type = node->type();
switch(type)
{
+ case NodeType::ChannelShuffleLayer:
+ return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation : ChannelShuffleLayer");
case NodeType::ConvolutionLayer:
return detail::validate_convolution_layer<NEConvolutionLayer,
NEDirectConvolutionLayer,
@@ -55,7 +57,8 @@
case NodeType::DepthwiseConvolutionLayer:
return detail::validate_depthwise_convolution_layer<NEDepthwiseConvolutionLayer,
NEDepthwiseConvolutionLayer3x3>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
-
+ case NodeType::PermuteLayer:
+ return detail::validate_permute_layer<NEPermute>(*polymorphic_downcast<PermuteLayerNode *>(node));
default:
return Status{};
}
diff --git a/src/graph/detail/ExecutionHelpers.cpp b/src/graph/detail/ExecutionHelpers.cpp
index c370fdf..f479963 100644
--- a/src/graph/detail/ExecutionHelpers.cpp
+++ b/src/graph/detail/ExecutionHelpers.cpp
@@ -35,14 +35,6 @@
{
namespace detail
{
-void default_initialize_backends()
-{
- for(const auto &backend : backends::BackendRegistry::get().backends())
- {
- backend.second->initialize_backend();
- }
-}
-
void validate_all_nodes(Graph &g)
{
auto &nodes = g.nodes();
@@ -52,10 +44,9 @@
{
if(node != nullptr)
{
- Target assigned_target = node->assigned_target();
- auto backend = backends::BackendRegistry::get().find_backend(assigned_target);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
- Status status = backend->validate_node(*node);
+ Target assigned_target = node->assigned_target();
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(assigned_target);
+ Status status = backend.validate_node(*node);
ARM_COMPUTE_ERROR_ON_MSG(!bool(status), status.error_description().c_str());
}
}
@@ -67,13 +58,12 @@
for(auto &tensor : tensors)
{
- if(tensor)
+ if(tensor && tensor->handle() == nullptr)
{
- Target target = tensor->desc().target;
- auto backend = backends::BackendRegistry::get().find_backend(target);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
- auto handle = backend->create_tensor(*tensor);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Couldn't create backend handle!");
+ Target target = tensor->desc().target;
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(target);
+ std::unique_ptr<ITensorHandle> handle = backend.create_tensor(*tensor);
+ ARM_COMPUTE_ERROR_ON_MSG(!handle, "Couldn't create backend handle!");
tensor->set_handle(std::move(handle));
}
}
@@ -139,35 +129,33 @@
}
}
-ExecutionWorkload configure_all_nodes(Graph &g, GraphContext &ctx)
+ExecutionWorkload configure_all_nodes(Graph &g, GraphContext &ctx, const std::vector<NodeID> &node_order)
{
ExecutionWorkload workload;
workload.graph = &g;
workload.ctx = &ctx;
- auto &nodes = g.nodes();
-
// Create tasks
- for(auto &node : nodes)
+ for(auto &node_id : node_order)
{
+ auto node = g.node(node_id);
if(node != nullptr)
{
- Target assigned_target = node->assigned_target();
- auto backend = backends::BackendRegistry::get().find_backend(assigned_target);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
- auto func = backend->configure_node(*node, ctx);
+ Target assigned_target = node->assigned_target();
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(assigned_target);
+ std::unique_ptr<IFunction> func = backend.configure_node(*node, ctx);
if(func != nullptr)
{
ExecutionTask task;
task.task = std::move(func);
- task.node = node.get();
+ task.node = node;
workload.tasks.push_back(std::move(task));
}
}
}
// Add inputs and outputs
- for(auto &node : nodes)
+ for(auto &node : g.nodes())
{
if(node != nullptr && node->type() == NodeType::Input)
{
@@ -214,15 +202,12 @@
}
}
-void call_all_input_node_accessors(ExecutionWorkload &workload)
+bool call_all_input_node_accessors(ExecutionWorkload &workload)
{
- for(auto &input : workload.inputs)
+ return !std::any_of(std::begin(workload.inputs), std::end(workload.inputs), [](Tensor * input_tensor)
{
- if(input != nullptr)
- {
- input->call_accessor();
- }
- }
+ return (input_tensor == nullptr) || !input_tensor->call_accessor();
+ });
}
void prepare_all_tasks(ExecutionWorkload &workload)
@@ -264,16 +249,16 @@
}
}
-void call_all_output_node_accessors(ExecutionWorkload &workload)
+bool call_all_output_node_accessors(ExecutionWorkload &workload)
{
- for(auto &output : workload.outputs)
+ bool is_valid = true;
+ std::for_each(std::begin(workload.outputs), std::end(workload.outputs), [&](Tensor * output_tensor)
{
- if(output != nullptr)
- {
- output->call_accessor();
- }
- }
+ is_valid = is_valid && (output_tensor != nullptr) && output_tensor->call_accessor();
+ });
+
+ return is_valid;
}
} // namespace detail
} // namespace graph
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/graph/frontend/Stream.cpp b/src/graph/frontend/Stream.cpp
index 96a166c..878d688 100644
--- a/src/graph/frontend/Stream.cpp
+++ b/src/graph/frontend/Stream.cpp
@@ -33,7 +33,7 @@
namespace frontend
{
Stream::Stream(size_t id, std::string name)
- : _manager(), _ctx(), _g(id, std::move(name))
+ : _ctx(), _manager(), _g(id, std::move(name))
{
}
@@ -66,4 +66,4 @@
}
} // namespace frontend
} // namespace graph
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
index c56f4c5..a170c4d 100644
--- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp
+++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
@@ -25,8 +25,10 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/Utils.h"
+#include "arm_compute/graph/algorithms/TopologicalSort.h"
#include "arm_compute/graph/backends/BackendRegistry.h"
-#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h"
+#include "arm_compute/graph/nodes/ConcatenateLayerNode.h"
#include "arm_compute/core/utils/misc/Cast.h"
#include "arm_compute/core/utils/misc/Iterable.h"
@@ -42,14 +44,31 @@
void DepthConcatSubTensorMutator::mutate(Graph &g)
{
- // Should be in reverse order of execution
- for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes()))
+ // Early exit if no Concatenation layers exist in graph
+ if(g.nodes(NodeType::ConcatenateLayer).empty())
{
- if(node && node->type() == NodeType::DepthConcatenateLayer && node->output(0) != nullptr)
+ return;
+ }
+
+ // Perform topological sort
+ std::vector<NodeID> topological_sorted_node_ids = dfs(g);
+
+ // Should be in reverse order of execution
+ for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids))
+ {
+ INode *node = g.node(node_id);
+ if(node != nullptr && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr)
{
// Get output tensor
auto output_tensor = node->output(0);
+ // Check concatenation axis (Sub-tensor optimization is support for concatenation axis >=2)
+ auto *concat_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node);
+ if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc(), concat_node->concatenation_axis()) < 2)
+ {
+ continue;
+ }
+
// Check that all tensor have the same target and valid inputs
bool is_valid = std::all_of(node->input_edges().cbegin(), node->input_edges().cend(),
[&](const EdgeID & eid)
@@ -58,7 +77,7 @@
});
// Create subtensors
- if(is_valid && backends::BackendRegistry::get().find_backend(output_tensor->desc().target) != nullptr)
+ if(is_valid && is_target_supported(output_tensor->desc().target))
{
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
<< node->id() << " and name : " << node->name() << std::endl);
@@ -69,14 +88,14 @@
auto input_tensor = node->input(i);
const auto input_shape = input_tensor->desc().shape;
- auto backend = backends::BackendRegistry::get().find_backend(input_tensor->desc().target);
- auto handle = backend->create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth), false);
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(input_tensor->desc().target);
+ std::unique_ptr<ITensorHandle> handle = backend.create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth), false);
input_tensor->set_handle(std::move(handle));
depth += input_shape.z();
}
- auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<DepthConcatenateLayerNode *>(node.get());
+ auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node);
dc_node->set_enabled(false);
}
}
diff --git a/src/graph/mutators/GroupedConvolutionMutator.cpp b/src/graph/mutators/GroupedConvolutionMutator.cpp
new file mode 100644
index 0000000..0d65d6a
--- /dev/null
+++ b/src/graph/mutators/GroupedConvolutionMutator.cpp
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/mutators/GroupedConvolutionMutator.h"
+
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/GraphBuilder.h"
+#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/Utils.h"
+#include "arm_compute/graph/backends/BackendRegistry.h"
+#include "arm_compute/graph/nodes/Nodes.h"
+
+#include "arm_compute/core/utils/misc/Cast.h"
+
+#include <set>
+
+namespace arm_compute
+{
+namespace graph
+{
+namespace
+{
+NodeID create_grouped_convolution(Graph &g, const NodeParams ¶ms, NodeIdxPair input, NodeID weights, NodeID bias,
+ PadStrideInfo conv_info, ConvolutionMethod method, FastMathHint fast_math_hint, unsigned int num_groups)
+{
+ bool has_bias = (bias != EmptyNodeID);
+
+ // Split input
+ const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
+ const unsigned int input_idx = get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL);
+ NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, input_idx);
+
+ // Split weights
+ const TensorDescriptor weights_tensor_desc = get_tensor_descriptor(g, g.node(weights)->outputs()[0]);
+ const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc, DataLayoutDimension::BATCHES);
+ NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, batch_idx);
+
+ // Split bias
+ NodeID bias_split = EmptyNodeID;
+ if(has_bias)
+ {
+ // Split bias
+ bias_split = GraphBuilder::add_split_node(g, params, { bias, 0 }, num_groups, 0);
+ }
+
+ std::vector<NodeIdxPair> convolution_outputs;
+ for(unsigned int i = 0; i < num_groups; ++i)
+ {
+ NodeParams group_params = params;
+ NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, 1, method, fast_math_hint);
+ g.add_connection(input_split, i, conv_nid, 0);
+ g.add_connection(weights_split, i, conv_nid, 1);
+ if(has_bias)
+ {
+ g.add_connection(bias_split, i, conv_nid, 2);
+ }
+
+ // Add group name
+ if(!group_params.name.empty())
+ {
+ group_params.name.append("_g" + arm_compute::support::cpp11::to_string(i));
+ }
+
+ // Set node parameters
+ INode *node = g.node(conv_nid);
+ ARM_COMPUTE_ERROR_ON(node == nullptr);
+ node->set_common_node_parameters(group_params);
+
+ convolution_outputs.push_back({ conv_nid, 0 });
+ }
+
+ // Depth concatenate output
+ return GraphBuilder::add_concatenate_node(g, params, convolution_outputs, DataLayoutDimension::CHANNEL);
+}
+} // namespace
+
+const char *GroupedConvolutionMutator::name()
+{
+ return "GroupedConvolutionMutator";
+}
+
+void GroupedConvolutionMutator::mutate(Graph &g)
+{
+ // Early exit if no Convolution layers exist in graph
+ if(g.nodes(NodeType::ConvolutionLayer).empty())
+ {
+ return;
+ }
+
+ // Total nodes
+ size_t total_nodes = g.nodes().size();
+
+ // Iterate over convolution nodes
+ for(unsigned int i = 0; i < total_nodes; ++i)
+ {
+ INode *node = g.node(i);
+ if(node != nullptr && node->type() == NodeType::ConvolutionLayer && arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node)->num_groups() != 1)
+ {
+ // Validate node
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target());
+ Status status = backend.validate_node(*node);
+
+ // If grouped convolution is not supported
+ if(!bool(status))
+ {
+ // Down-cast node
+ auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node);
+
+ // Get internal convolution info
+ const PadStrideInfo conv_info = conv_node->convolution_info();
+ const ConvolutionMethod conv_method = conv_node->convolution_method();
+ const FastMathHint fast_math_hint = conv_node->fast_math_hint();
+ const unsigned int num_groups = conv_node->num_groups();
+ const NodeParams params = conv_node->common_node_params();
+ const Target assigned_target = conv_node->assigned_target();
+
+ // Extract node ids
+ const NodeID input_id = conv_node->input_id(0);
+ const NodeID weights_id = conv_node->input_id(1);
+ const NodeID bias_id = conv_node->input_id(2);
+
+ // Get driving nodes
+ std::vector<NodeIdxPair> driving_nodes = get_driving_nodes(*node);
+
+ // Extract activation node accessor if any
+ auto node_accessor = conv_node->output(0)->extract_accessor();
+
+ // Current max tensor and node id
+ TensorID latest_tid = g.tensors().size();
+ NodeID latest_nid = g.nodes().size();
+
+ // Create grouped convolution node
+ NodeID grouped_conv_id = create_grouped_convolution(g, params, { input_id, 0 }, weights_id, bias_id,
+ conv_info, conv_method, fast_math_hint, num_groups);
+
+ // Remove convolution node
+ g.remove_node(node->id());
+
+ // Update batch normalization node outputs
+ for(auto &driving_node : driving_nodes)
+ {
+ g.add_connection(grouped_conv_id, 0, driving_node.node_id, driving_node.index);
+ }
+
+ // Update accessor to batch normalization node
+ g.node(grouped_conv_id)->output(0)->set_accessor(std::move(node_accessor));
+
+ // Configure new tensors and nodes
+ std::for_each(g.tensors().begin() + latest_tid, g.tensors().end(), [](std::unique_ptr<Tensor> &t)
+ {
+ configure_tensor(t.get());
+ });
+ std::for_each(g.nodes().begin() + latest_nid, g.nodes().end(), [&assigned_target](std::unique_ptr<INode> &n)
+ {
+ if(n != nullptr)
+ {
+ n->set_assigned_target(assigned_target);
+ }
+ });
+ }
+ }
+ }
+}
+} // namespace graph
+} // namespace arm_compute
diff --git a/src/graph/mutators/InPlaceOperationMutator.cpp b/src/graph/mutators/InPlaceOperationMutator.cpp
index bd3f098..31921b3 100644
--- a/src/graph/mutators/InPlaceOperationMutator.cpp
+++ b/src/graph/mutators/InPlaceOperationMutator.cpp
@@ -50,11 +50,26 @@
// Check if parent has a single output if yes then force in place calculation else not
if((input_edge != nullptr) && (input_edge->producer() != nullptr) && (input_edge->producer()->output_edges().size() == 1))
{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Switching to in-place computation for the node with ID : "
- << node->id() << " and name : " << node->name() << std::endl);
- // Update output
- auto tensor = input_edge->tensor();
- node->set_output_tensor(tensor->id(), 0);
+ // Get current and new output tensors
+ auto current_output_tensor = node->output(0);
+ auto new_output_tensor = input_edge->tensor();
+
+ ARM_COMPUTE_ERROR_ON(current_output_tensor == nullptr || new_output_tensor == nullptr);
+
+ // Prevent in-place operation if there is an accessor bound to the in-place tensor
+ if(new_output_tensor->accessor() == nullptr)
+ {
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Switching to in-place computation for the node with ID : "
+ << node->id() << " and name : " << node->name() << std::endl);
+ // Update accessor
+ new_output_tensor->set_accessor(current_output_tensor->extract_accessor());
+ // Update output
+ node->set_output_tensor(new_output_tensor->id(), 0);
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented in-place operation as there is an accessor bound to the input tensor\n");
+ }
}
}
}
diff --git a/src/graph/mutators/NodeExecutionMethodMutator.cpp b/src/graph/mutators/NodeExecutionMethodMutator.cpp
new file mode 100644
index 0000000..b420121
--- /dev/null
+++ b/src/graph/mutators/NodeExecutionMethodMutator.cpp
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/mutators/NodeExecutionMethodMutator.h"
+
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/Utils.h"
+#include "arm_compute/graph/backends/BackendRegistry.h"
+#include "arm_compute/graph/nodes/Nodes.h"
+
+#include "arm_compute/core/utils/misc/Cast.h"
+
+namespace arm_compute
+{
+namespace graph
+{
+namespace
+{
+/** Runs a default setter function on a given types of nodes
+ *
+ * @tparam Setter Setter function to run
+ *
+ * @param[in, out] g Graph to extract the nodes from
+ * @param[in] node_type Node type
+ * @param[in] setter Setter function
+ */
+template <typename Setter>
+void set_default_on_invalid_method(Graph &g, NodeType node_type, Setter &&setter)
+{
+ const std::vector<NodeID> &node_ids = g.nodes(node_type);
+ for(auto &node_id : node_ids)
+ {
+ INode *node = g.node(node_id);
+ if(node != nullptr)
+ {
+ // Validate node
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target());
+ Status status = backend.validate_node(*node);
+
+ // Set default execution method in case of failure
+ if(!bool(status))
+ {
+ setter(node);
+ }
+ }
+ }
+}
+} // namespace
+
+const char *NodeExecutionMethodMutator::name()
+{
+ return "NodeExecutionMethodMutator";
+}
+
+void NodeExecutionMethodMutator::mutate(Graph &g)
+{
+ // Convolution Layer
+ set_default_on_invalid_method(g, NodeType::ConvolutionLayer, [](INode * n)
+ {
+ ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : "
+ << n->id() << " and Name: " << n->name() << std::endl);
+ auto *casted_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(n);
+ casted_node->set_convolution_method(ConvolutionMethod::Default);
+ });
+
+ // Depthwise Convolution Layer
+ set_default_on_invalid_method(g, NodeType::DepthwiseConvolutionLayer, [](INode * n)
+ {
+ ARM_COMPUTE_LOG_GRAPH_INFO("Switched Depthwise ConvolutionLayer method of node with ID : "
+ << n->id() << " and Name: " << n->name() << std::endl);
+ auto *casted_node = arm_compute::utils::cast::polymorphic_downcast<DepthwiseConvolutionLayerNode *>(n);
+ casted_node->set_depthwise_convolution_method(DepthwiseConvolutionMethod::Default);
+ });
+}
+} // namespace graph
+} // namespace arm_compute
diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp
index 2e893c2..82bfe25 100644
--- a/src/graph/mutators/NodeFusionMutator.cpp
+++ b/src/graph/mutators/NodeFusionMutator.cpp
@@ -25,10 +25,13 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/nodes/Nodes.h"
#include "arm_compute/core/utils/misc/Cast.h"
+#include <set>
+
namespace arm_compute
{
namespace graph
@@ -37,6 +40,9 @@
{
void fuse_batch_norm_with_activation(Graph &g)
{
+ // Supported activations when fusing
+ const std::set<Activation> supported_fused_activations = { Activation::RELU, Activation::BOUNDED_RELU, Activation::LU_BOUNDED_RELU };
+
// Not interested in the order of nodes
for(auto &node : g.nodes())
{
@@ -48,34 +54,47 @@
// Check if following node is an activation layer node
if((output_edge != nullptr) && (output_edge->consumer() != nullptr) && (output_edge->consumer()->type() == NodeType::ActivationLayer))
{
- ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing Batch Normalization node with ID : " << output_edge->producer_id()
- << " with Activation Layer node with ID : " << output_edge->consumer_id() << std::endl);
-
auto *bn_node = arm_compute::utils::cast::polymorphic_downcast<BatchNormalizationLayerNode *>(output_edge->producer());
auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(output_edge->consumer());
- // Get driving nodes of activation node
- std::vector<NodeIdxPair> act_driving_nodes;
- for(auto &act_output_edge_id : act_node->output_edges())
+ ARM_COMPUTE_ERROR_ON(act_node->output(0) == nullptr || bn_node->output(0) == nullptr);
+
+ // Check if activation is supported for fusion
+ if(supported_fused_activations.count(act_node->activation_info().activation()) == 0)
{
- auto act_output_edge = g.edge(act_output_edge_id);
- if(act_output_edge != nullptr)
- {
- ARM_COMPUTE_ERROR_ON(act_output_edge->consumer() == nullptr);
- act_driving_nodes.push_back({ act_output_edge->consumer_id(), act_output_edge->consumer_idx() });
- }
+ continue;
}
- // Set activation info to batch normalization
- bn_node->set_fused_activation(act_node->activation_info());
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing Batch Normalization node with ID : " << output_edge->producer_id()
+ << " with Activation Layer node with ID : " << output_edge->consumer_id() << std::endl);
- // Remove activation node
- g.remove_node(act_node->id());
-
- // Update batch normalization node outputs
- for(auto &driving_node : act_driving_nodes)
+ // Prevent fusion if batch normalization node has an output accessor
+ if(bn_node->output(0)->accessor() == nullptr)
{
- g.add_connection(bn_node->id(), 0, driving_node.node_id, driving_node.index);
+ // Get driving nodes of activation node
+ std::vector<NodeIdxPair> act_driving_nodes = get_driving_nodes(*act_node);
+
+ // Set activation info to batch normalization
+ bn_node->set_fused_activation(act_node->activation_info());
+
+ // Extract activation node accessor if any
+ auto act_node_accessor = act_node->output(0)->extract_accessor();
+
+ // Remove activation node
+ g.remove_node(act_node->id());
+
+ // Update batch normalization node outputs
+ for(auto &driving_node : act_driving_nodes)
+ {
+ g.add_connection(bn_node->id(), 0, driving_node.node_id, driving_node.index);
+ }
+
+ // Update accessor to batch normalization node
+ bn_node->output(0)->set_accessor(std::move(act_node_accessor));
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion as batch normalization node has an output accessor\n");
}
}
}
diff --git a/src/graph/mutators/SplitLayerSubTensorMutator.cpp b/src/graph/mutators/SplitLayerSubTensorMutator.cpp
index 2a8c029..e21252a 100644
--- a/src/graph/mutators/SplitLayerSubTensorMutator.cpp
+++ b/src/graph/mutators/SplitLayerSubTensorMutator.cpp
@@ -25,6 +25,8 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/Utils.h"
+#include "arm_compute/graph/algorithms/TopologicalSort.h"
#include "arm_compute/graph/backends/BackendRegistry.h"
#include "arm_compute/graph/nodes/SplitLayerNode.h"
@@ -42,10 +44,20 @@
void SplitLayerSubTensorMutator::mutate(Graph &g)
{
- // Should be in reverse order of execution
- for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes()))
+ // Early exit if no Split layers exist in graph
+ if(g.nodes(NodeType::SplitLayer).empty())
{
- if(node && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
+ return;
+ }
+
+ // Perform topological sort
+ std::vector<NodeID> topological_sorted_node_ids = dfs(g);
+
+ // Should be in reverse order of execution
+ for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids))
+ {
+ INode *node = g.node(node_id);
+ if(node != nullptr && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
{
// Get output tensor
Tensor *input_tensor = node->input(0);
@@ -58,12 +70,12 @@
});
// Create subtensors
- if(is_valid && backends::BackendRegistry::get().find_backend(input_tensor->desc().target) != nullptr)
+ if(is_valid && is_target_supported(input_tensor->desc().target))
{
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
<< node->id() << " and name : " << node->name() << std::endl);
- auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node.get());
+ auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node);
const unsigned int axis = split_node->axis();
const unsigned int num_splits = split_node->num_splits();
@@ -77,8 +89,8 @@
Coordinates coords;
std::tie(std::ignore, coords) = SplitLayerNode::compute_output_descriptor(input_tensor->desc(), num_splits, axis, i);
- backends::IDeviceBackend *backend = backends::BackendRegistry::get().find_backend(output_tensor->desc().target);
- std::unique_ptr<ITensorHandle> handle = backend->create_subtensor(input_tensor->handle(), output_shape, coords, extend_parent);
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(output_tensor->desc().target);
+ std::unique_ptr<ITensorHandle> handle = backend.create_subtensor(input_tensor->handle(), output_shape, coords, extend_parent);
output_tensor->set_handle(std::move(handle));
}
}
diff --git a/src/graph/nodes/ChannelShuffleLayerNode.cpp b/src/graph/nodes/ChannelShuffleLayerNode.cpp
new file mode 100644
index 0000000..08fcce1
--- /dev/null
+++ b/src/graph/nodes/ChannelShuffleLayerNode.cpp
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/nodes/ChannelShuffleLayerNode.h"
+
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/INodeVisitor.h"
+
+namespace arm_compute
+{
+namespace graph
+{
+ChannelShuffleLayerNode::ChannelShuffleLayerNode(unsigned int num_groups)
+ : _num_groups(num_groups)
+{
+ _input_edges.resize(1, EmptyEdgeID);
+ _outputs.resize(1, NullTensorID);
+}
+
+unsigned int ChannelShuffleLayerNode::num_groups() const
+{
+ return _num_groups;
+}
+
+bool ChannelShuffleLayerNode::forward_descriptors()
+{
+ if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID))
+ {
+ Tensor *dst = output(0);
+ ARM_COMPUTE_ERROR_ON(dst == nullptr);
+ dst->desc() = configure_output(0);
+ return true;
+ }
+ return false;
+}
+
+TensorDescriptor ChannelShuffleLayerNode::configure_output(size_t idx) const
+{
+ ARM_COMPUTE_UNUSED(idx);
+ ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
+
+ const Tensor *src = input(0);
+ ARM_COMPUTE_ERROR_ON(src == nullptr);
+
+ return src->desc();
+}
+
+NodeType ChannelShuffleLayerNode::type() const
+{
+ return NodeType::ChannelShuffleLayer;
+}
+
+void ChannelShuffleLayerNode::accept(INodeVisitor &v)
+{
+ v.visit(*this);
+}
+} // namespace graph
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/graph/nodes/DepthConcatenateLayerNode.cpp b/src/graph/nodes/ConcatenateLayerNode.cpp
similarity index 61%
rename from src/graph/nodes/DepthConcatenateLayerNode.cpp
rename to src/graph/nodes/ConcatenateLayerNode.cpp
index 08cccc1..ade3f6e 100644
--- a/src/graph/nodes/DepthConcatenateLayerNode.cpp
+++ b/src/graph/nodes/ConcatenateLayerNode.cpp
@@ -21,58 +21,74 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h"
+#include "arm_compute/graph/nodes/ConcatenateLayerNode.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/INodeVisitor.h"
+#include "arm_compute/graph/Utils.h"
+
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
namespace arm_compute
{
namespace graph
{
-DepthConcatenateLayerNode::DepthConcatenateLayerNode(unsigned int total_nodes)
- : _total_nodes(total_nodes), _is_enabled(true)
+ConcatenateLayerNode::ConcatenateLayerNode(unsigned int total_nodes, DataLayoutDimension axis)
+ : _total_nodes(total_nodes), _axis(axis), _is_enabled(true)
{
_input_edges.resize(_total_nodes, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
}
-void DepthConcatenateLayerNode::set_enabled(bool is_enabled)
+void ConcatenateLayerNode::set_enabled(bool is_enabled)
{
_is_enabled = is_enabled;
}
-bool DepthConcatenateLayerNode::is_enabled() const
+bool ConcatenateLayerNode::is_enabled() const
{
return _is_enabled;
}
-TensorDescriptor DepthConcatenateLayerNode::compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors)
+DataLayoutDimension ConcatenateLayerNode::concatenation_axis() const
+{
+ return _axis;
+}
+
+TensorDescriptor ConcatenateLayerNode::compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors,
+ DataLayoutDimension axis)
{
ARM_COMPUTE_ERROR_ON(input_descriptors.size() == 0);
TensorDescriptor output_descriptor = input_descriptors[0];
+ const int axis_idx = get_dimension_idx(output_descriptor, axis);
- size_t max_x = 0;
- size_t max_y = 0;
- size_t depth = 0;
-
- for(const auto &input_descriptor : input_descriptors)
+ // Extract shapes
+ std::vector<const TensorShape *> shapes;
+ for(auto &input_descriptor : input_descriptors)
{
- max_x = std::max(input_descriptor.shape.x(), max_x);
- max_y = std::max(input_descriptor.shape.y(), max_y);
- depth += input_descriptor.shape.z();
+ shapes.emplace_back(&input_descriptor.shape);
}
- output_descriptor.shape.set(0, max_x);
- output_descriptor.shape.set(1, max_y);
- output_descriptor.shape.set(2, depth);
+ // Calculate output shape
+ if(axis_idx == 0)
+ {
+ output_descriptor.shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(shapes);
+ }
+ else if(axis_idx == 2)
+ {
+ output_descriptor.shape = arm_compute::misc::shape_calculator::calculate_depth_concatenate_shape(shapes);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Unsupported concatenation axis!");
+ }
return output_descriptor;
}
-bool DepthConcatenateLayerNode::forward_descriptors()
+bool ConcatenateLayerNode::forward_descriptors()
{
if(_outputs[0] != NullTensorID)
{
@@ -84,7 +100,7 @@
return false;
}
-TensorDescriptor DepthConcatenateLayerNode::configure_output(size_t idx) const
+TensorDescriptor ConcatenateLayerNode::configure_output(size_t idx) const
{
ARM_COMPUTE_UNUSED(idx);
ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
@@ -106,18 +122,18 @@
ARM_COMPUTE_ERROR_ON(t == nullptr);
inputs_descriptors.push_back(t->desc());
}
- output_info = compute_output_descriptor(inputs_descriptors);
+ output_info = compute_output_descriptor(inputs_descriptors, _axis);
}
return output_info;
}
-NodeType DepthConcatenateLayerNode::type() const
+NodeType ConcatenateLayerNode::type() const
{
- return NodeType::DepthConcatenateLayer;
+ return NodeType::ConcatenateLayer;
}
-void DepthConcatenateLayerNode::accept(INodeVisitor &v)
+void ConcatenateLayerNode::accept(INodeVisitor &v)
{
v.visit(*this);
}
diff --git a/src/graph/nodes/ConvolutionLayerNode.cpp b/src/graph/nodes/ConvolutionLayerNode.cpp
index 6c31a6b..e9cb039 100644
--- a/src/graph/nodes/ConvolutionLayerNode.cpp
+++ b/src/graph/nodes/ConvolutionLayerNode.cpp
@@ -32,8 +32,12 @@
{
namespace graph
{
-ConvolutionLayerNode::ConvolutionLayerNode(PadStrideInfo info, ConvolutionMethod method, FastMathHint fast_math_hint, QuantizationInfo out_quant_info)
- : _info(std::move(info)), _method(method), _fast_math_hint(fast_math_hint), _out_quant_info(out_quant_info)
+ConvolutionLayerNode::ConvolutionLayerNode(PadStrideInfo info,
+ unsigned int num_groups,
+ ConvolutionMethod method,
+ FastMathHint fast_math_hint,
+ QuantizationInfo out_quant_info)
+ : _info(std::move(info)), _num_groups(num_groups), _method(method), _fast_math_hint(fast_math_hint), _out_quant_info(out_quant_info)
{
_input_edges.resize(3, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
@@ -64,6 +68,11 @@
return _info;
}
+unsigned int ConvolutionLayerNode::num_groups() const
+{
+ return _num_groups;
+}
+
TensorDescriptor ConvolutionLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
const TensorDescriptor &weights_descriptor,
const PadStrideInfo &info)
@@ -125,4 +134,4 @@
v.visit(*this);
}
} // namespace graph
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/graph/nodes/DeconvolutionLayerNode.cpp b/src/graph/nodes/DeconvolutionLayerNode.cpp
new file mode 100644
index 0000000..9329ae3
--- /dev/null
+++ b/src/graph/nodes/DeconvolutionLayerNode.cpp
@@ -0,0 +1,113 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/nodes/DeconvolutionLayerNode.h"
+
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/INodeVisitor.h"
+#include "arm_compute/graph/Utils.h"
+
+namespace arm_compute
+{
+namespace graph
+{
+DeconvolutionLayerNode::DeconvolutionLayerNode(PadStrideInfo info, Size2D inner_border)
+ : _info(std::move(info)), _inner_border(inner_border)
+{
+ _input_edges.resize(3, EmptyEdgeID);
+ _outputs.resize(1, NullTensorID);
+}
+
+PadStrideInfo DeconvolutionLayerNode::deconvolution_info() const
+{
+ return _info;
+}
+
+Size2D DeconvolutionLayerNode::inner_border() const
+{
+ return _inner_border;
+}
+
+TensorDescriptor DeconvolutionLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
+ const TensorDescriptor &weights_descriptor,
+ const PadStrideInfo &info,
+ const Size2D &inner_border)
+{
+ unsigned int output_width = 0;
+ unsigned int output_height = 0;
+
+ const unsigned int input_width = get_dimension_size(input_descriptor, DataLayoutDimension::WIDTH);
+ const unsigned int input_height = get_dimension_size(input_descriptor, DataLayoutDimension::HEIGHT);
+ const unsigned int kernel_width = get_dimension_size(weights_descriptor, DataLayoutDimension::WIDTH);
+ const unsigned int kernel_height = get_dimension_size(weights_descriptor, DataLayoutDimension::HEIGHT);
+
+ std::tie(output_width, output_height) = deconvolution_output_dimensions(input_width, input_height,
+ kernel_width, kernel_height,
+ info.pad().first, info.pad().second,
+ inner_border.x(), inner_border.y(),
+ info.stride().first, info.stride().second);
+
+ TensorDescriptor output_descriptor = input_descriptor;
+ output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), output_width);
+ output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), output_height);
+ output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::CHANNEL), weights_descriptor.shape[3]);
+
+ return output_descriptor;
+}
+
+bool DeconvolutionLayerNode::forward_descriptors()
+{
+ if((input_id(0) != NullTensorID) && (input_id(1) != NullTensorID) && (output_id(0) != NullTensorID))
+ {
+ Tensor *dst = output(0);
+ ARM_COMPUTE_ERROR_ON(dst == nullptr);
+ dst->desc() = configure_output(0);
+ return true;
+ }
+ return false;
+}
+
+TensorDescriptor DeconvolutionLayerNode::configure_output(size_t idx) const
+{
+ ARM_COMPUTE_UNUSED(idx);
+ const Tensor *src = input(0);
+ const Tensor *weights = input(1);
+
+ ARM_COMPUTE_ERROR_ON(src == nullptr || weights == nullptr);
+
+ TensorDescriptor output_info = compute_output_descriptor(src->desc(), weights->desc(), _info, _inner_border);
+ return output_info;
+}
+
+NodeType DeconvolutionLayerNode::type() const
+{
+ return NodeType::DeconvolutionLayer;
+}
+
+void DeconvolutionLayerNode::accept(INodeVisitor &v)
+{
+ v.visit(*this);
+}
+} // namespace graph
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/graph/nodes/DummyNode.cpp b/src/graph/nodes/DummyNode.cpp
new file mode 100644
index 0000000..e641181
--- /dev/null
+++ b/src/graph/nodes/DummyNode.cpp
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/nodes/DummyNode.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/INodeVisitor.h"
+#include "arm_compute/graph/Tensor.h"
+
+namespace arm_compute
+{
+namespace graph
+{
+DummyNode::DummyNode(TensorShape shape)
+ : _shape(shape)
+{
+ _input_edges.resize(1, EmptyEdgeID);
+ _outputs.resize(1, NullTensorID);
+}
+
+bool DummyNode::forward_descriptors()
+{
+ if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID))
+ {
+ Tensor *dst = output(0);
+ ARM_COMPUTE_ERROR_ON(dst == nullptr);
+ dst->desc() = configure_output(0);
+ return true;
+ }
+ return false;
+}
+
+TensorDescriptor DummyNode::configure_output(size_t idx) const
+{
+ ARM_COMPUTE_UNUSED(idx);
+ ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
+
+ const Tensor *src = input(0);
+ ARM_COMPUTE_ERROR_ON(src == nullptr);
+
+ TensorDescriptor output_desc = src->desc();
+ output_desc.shape = _shape;
+
+ return output_desc;
+}
+
+NodeType DummyNode::type() const
+{
+ return NodeType::Dummy;
+}
+
+void DummyNode::accept(INodeVisitor &v)
+{
+ v.visit(*this);
+}
+} // namespace graph
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/graph/nodes/FullyConnectedLayer.cpp b/src/graph/nodes/FullyConnectedLayer.cpp
index d94a785..6ea0292 100644
--- a/src/graph/nodes/FullyConnectedLayer.cpp
+++ b/src/graph/nodes/FullyConnectedLayer.cpp
@@ -31,15 +31,17 @@
{
namespace graph
{
-FullyConnectedLayerNode::FullyConnectedLayerNode(unsigned int num_outputs)
- : _num_outputs(num_outputs)
+FullyConnectedLayerNode::FullyConnectedLayerNode(unsigned int num_outputs, QuantizationInfo out_quant_info, FullyConnectedLayerInfo fc_info)
+ : _num_outputs(num_outputs), _out_quant_info(out_quant_info), _info(fc_info)
{
_input_edges.resize(3, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
}
TensorDescriptor FullyConnectedLayerNode::compute_weights_descriptor(const TensorDescriptor &input_descriptor,
- unsigned int num_outputs)
+ unsigned int num_outputs,
+ FullyConnectedLayerInfo fc_info,
+ QuantizationInfo weights_quant_info)
{
unsigned int num_weights = 1;
unsigned int num_dimensions = input_descriptor.shape.num_dimensions();
@@ -56,11 +58,24 @@
TensorDescriptor weights_descriptor = input_descriptor;
weights_descriptor.shape = TensorShape(num_weights, num_outputs);
+ // If weights are tranposed, use tranposed shape
+ if(!fc_info.transpose_weights)
+ {
+ weights_descriptor.shape = TensorShape(num_outputs, num_weights);
+ }
+
+ // Set quantization info if present
+ if(!weights_quant_info.empty())
+ {
+ weights_descriptor.quant_info = weights_quant_info;
+ }
+
return weights_descriptor;
}
TensorDescriptor FullyConnectedLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
- unsigned int num_outputs)
+ unsigned int num_outputs,
+ QuantizationInfo out_quant_info)
{
// Note: Only 1D batch space is supported at the moment
unsigned int batches = input_descriptor.shape[1];
@@ -69,12 +84,24 @@
batches = input_descriptor.shape[3];
}
+ // Set descriptor shape
TensorDescriptor output_descriptor = input_descriptor;
output_descriptor.shape = TensorShape(num_outputs, batches);
+ // Set quantization info if present
+ if(!out_quant_info.empty())
+ {
+ output_descriptor.quant_info = out_quant_info;
+ }
+
return output_descriptor;
}
+FullyConnectedLayerInfo FullyConnectedLayerNode::info() const
+{
+ return _info;
+}
+
bool FullyConnectedLayerNode::forward_descriptors()
{
if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID))
@@ -93,7 +120,7 @@
const Tensor *src = input(0);
ARM_COMPUTE_ERROR_ON(src == nullptr);
- return compute_output_descriptor(src->desc(), _num_outputs);
+ return compute_output_descriptor(src->desc(), _num_outputs, _out_quant_info);
}
NodeType FullyConnectedLayerNode::type() const
diff --git a/src/graph/nodes/PermuteLayerNode.cpp b/src/graph/nodes/PermuteLayerNode.cpp
new file mode 100644
index 0000000..042ec09
--- /dev/null
+++ b/src/graph/nodes/PermuteLayerNode.cpp
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/nodes/PermuteLayerNode.h"
+
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/INodeVisitor.h"
+
+#include "arm_compute/core/Helpers.h"
+
+namespace arm_compute
+{
+namespace graph
+{
+PermuteLayerNode::PermuteLayerNode(PermutationVector perm, DataLayout layout)
+ : _perm(perm), _layout(layout)
+{
+ _input_edges.resize(1, EmptyEdgeID);
+ _outputs.resize(1, NullTensorID);
+}
+
+const PermutationVector &PermuteLayerNode::permutation_vector() const
+{
+ return _perm;
+}
+
+bool PermuteLayerNode::forward_descriptors()
+{
+ if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID))
+ {
+ Tensor *dst = output(0);
+ ARM_COMPUTE_ERROR_ON(dst == nullptr);
+ dst->desc() = configure_output(0);
+ return true;
+ }
+ return false;
+}
+
+TensorDescriptor PermuteLayerNode::configure_output(size_t idx) const
+{
+ ARM_COMPUTE_UNUSED(idx);
+ ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
+
+ const Tensor *src = input(0);
+ ARM_COMPUTE_ERROR_ON(src == nullptr);
+
+ TensorDescriptor output_desc = src->desc();
+ permute(output_desc.shape, _perm);
+ if(_layout != DataLayout::UNKNOWN)
+ {
+ output_desc.layout = _layout;
+ }
+
+ return output_desc;
+}
+
+NodeType PermuteLayerNode::type() const
+{
+ return NodeType::PermuteLayer;
+}
+
+void PermuteLayerNode::accept(INodeVisitor &v)
+{
+ v.visit(*this);
+}
+} // namespace graph
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/graph/nodes/ResizeLayerNode.cpp b/src/graph/nodes/ResizeLayerNode.cpp
new file mode 100644
index 0000000..a6aa7bf
--- /dev/null
+++ b/src/graph/nodes/ResizeLayerNode.cpp
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/nodes/ResizeLayerNode.h"
+
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/INodeVisitor.h"
+#include "arm_compute/graph/Utils.h"
+
+namespace arm_compute
+{
+namespace graph
+{
+ResizeLayerNode::ResizeLayerNode(InterpolationPolicy policy, float scale_width, float scale_height)
+ : _policy(policy), _scale_width(scale_width), _scale_height(scale_height)
+{
+ _input_edges.resize(1, EmptyEdgeID);
+ _outputs.resize(1, NullTensorID);
+}
+
+InterpolationPolicy ResizeLayerNode::policy() const
+{
+ return _policy;
+}
+
+std::pair<float, float> ResizeLayerNode::scaling_factor() const
+{
+ return std::make_pair(_scale_width, _scale_height);
+}
+
+bool ResizeLayerNode::forward_descriptors()
+{
+ if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID))
+ {
+ Tensor *dst = output(0);
+ ARM_COMPUTE_ERROR_ON(dst == nullptr);
+ dst->desc() = configure_output(0);
+ return true;
+ }
+ return false;
+}
+
+TensorDescriptor ResizeLayerNode::configure_output(size_t idx) const
+{
+ ARM_COMPUTE_UNUSED(idx);
+ ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
+
+ const Tensor *src = input(0);
+ ARM_COMPUTE_ERROR_ON(src == nullptr);
+
+ TensorDescriptor output_desc = src->desc();
+ size_t width_idx = get_dimension_idx(output_desc, DataLayoutDimension::WIDTH);
+ size_t height_idx = get_dimension_idx(output_desc, DataLayoutDimension::HEIGHT);
+ output_desc.shape.set(width_idx, static_cast<int>(output_desc.shape[width_idx] * _scale_width));
+ output_desc.shape.set(height_idx, static_cast<int>(output_desc.shape[height_idx] * _scale_height));
+
+ return output_desc;
+}
+
+NodeType ResizeLayerNode::type() const
+{
+ return NodeType::ResizeLayer;
+}
+
+void ResizeLayerNode::accept(INodeVisitor &v)
+{
+ v.visit(*this);
+}
+} // namespace graph
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/graph/printers/DotGraphPrinter.cpp b/src/graph/printers/DotGraphPrinter.cpp
index 61cf423..ef156ea 100644
--- a/src/graph/printers/DotGraphPrinter.cpp
+++ b/src/graph/printers/DotGraphPrinter.cpp
@@ -47,6 +47,15 @@
_info = ss.str();
}
+void DotGraphVisitor::visit(ConcatenateLayerNode &n)
+{
+ std::stringstream ss;
+ ss << "Enabled: " << n.is_enabled();
+ ss << R"( \n )";
+ ss << "Axis: " << n.concatenation_axis();
+ _info = ss.str();
+}
+
void DotGraphVisitor::visit(ConvolutionLayerNode &n)
{
std::stringstream ss;
@@ -54,13 +63,6 @@
_info = ss.str();
}
-void DotGraphVisitor::visit(DepthConcatenateLayerNode &n)
-{
- std::stringstream ss;
- ss << "Enabled: " << n.is_enabled();
- _info = ss.str();
-}
-
void DotGraphVisitor::visit(DepthwiseConvolutionLayerNode &n)
{
std::stringstream ss;
diff --git a/src/runtime/CL/CLMemory.cpp b/src/runtime/CL/CLMemory.cpp
index 534c4f9..bbc513d 100644
--- a/src/runtime/CL/CLMemory.cpp
+++ b/src/runtime/CL/CLMemory.cpp
@@ -61,7 +61,7 @@
void CLMemory::create_empty_region()
{
- _region_owned = std::make_shared<CLBufferMemoryRegion>(cl::Context::getDefault(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, 0);
+ _region_owned = std::make_shared<CLBufferMemoryRegion>(cl::Context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, 0);
_region = _region_owned.get();
}
} // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/CL/CLScheduler.cpp b/src/runtime/CL/CLScheduler.cpp
index fdae615..a311c6f 100644
--- a/src/runtime/CL/CLScheduler.cpp
+++ b/src/runtime/CL/CLScheduler.cpp
@@ -25,13 +25,24 @@
#include "arm_compute/core/CL/ICLKernel.h"
#include "arm_compute/runtime/CL/CLTuner.h"
+#include "arm_compute/runtime/CL/tuners/Tuners.h"
using namespace arm_compute;
+namespace
+{
+#if defined(ARM_COMPUTE_DEBUG_ENABLED)
+void printf_callback(const char *buffer, unsigned int len, size_t complete, void *user_data)
+{
+ printf("%.*s", len, buffer);
+}
+#endif /* defined(ARM_COMPUTE_DEBUG_ENABLED) */
+} // namespace
+
std::once_flag CLScheduler::_initialize_symbols;
CLScheduler::CLScheduler()
- : _queue(), _target(GPUTarget::MIDGARD), _is_initialised(false), _cl_tuner()
+ : _context(), _queue(), _target(GPUTarget::MIDGARD), _is_initialised(false), _cl_tuner(nullptr), _cl_default_static_tuner(nullptr)
{
}
@@ -42,6 +53,61 @@
return scheduler;
}
+void CLScheduler::default_init(ICLTuner *cl_tuner)
+{
+ if(!_is_initialised)
+ {
+ std::vector<cl::Platform> platforms;
+ cl::Platform::get(&platforms);
+ ARM_COMPUTE_ERROR_ON_MSG(platforms.size() == 0, "Couldn't find any OpenCL platform");
+ cl::Platform p = platforms[0];
+ cl::Context ctx;
+ cl::Device device;
+ std::vector<cl::Device> platform_devices;
+ p.getDevices(CL_DEVICE_TYPE_DEFAULT, &platform_devices);
+ ARM_COMPUTE_ERROR_ON_MSG(platform_devices.size() == 0, "Couldn't find any OpenCL device");
+ device = platform_devices[0];
+#if defined(ARM_COMPUTE_DEBUG_ENABLED)
+
+ // Query devices in the context for cl_arm_printf support
+ if(device_supports_extension(device, "cl_arm_printf"))
+ {
+ // Create a cl_context with a printf_callback and user specified buffer size.
+ cl_context_properties properties[] =
+ {
+ CL_CONTEXT_PLATFORM, reinterpret_cast<cl_context_properties>(p()),
+ // Enable a printf callback function for this context.
+ CL_PRINTF_CALLBACK_ARM, reinterpret_cast<cl_context_properties>(printf_callback),
+ // Request a minimum printf buffer size of 4MB for devices in the
+ // context that support this extension.
+ CL_PRINTF_BUFFERSIZE_ARM, 0x1000,
+ 0
+ };
+ ctx = cl::Context(device, properties);
+ }
+ else
+#endif // defined(ARM_COMPUTE_DEBUG_ENABLED)
+ {
+ cl_context_properties properties[] =
+ {
+ CL_CONTEXT_PLATFORM, reinterpret_cast<cl_context_properties>(p()),
+ 0
+ };
+ ctx = cl::Context(device, properties);
+ };
+
+ cl::CommandQueue queue = cl::CommandQueue(ctx, device);
+ CLKernelLibrary::get().init("./cl_kernels/", ctx, device);
+ init(ctx, queue, device, cl_tuner);
+
+ // Create a default static tuner and set if none was provided
+ _cl_default_static_tuner = tuners::TunerFactory::create_tuner(_target);
+ }
+
+ // Set CL tuner
+ _cl_tuner = (cl_tuner == nullptr) ? _cl_default_static_tuner.get() : cl_tuner;
+}
+
void CLScheduler::enqueue(ICLKernel &kernel, bool flush)
{
ARM_COMPUTE_ERROR_ON_MSG(!_is_initialised,
diff --git a/src/runtime/CL/CLTensorAllocator.cpp b/src/runtime/CL/CLTensorAllocator.cpp
index 54e7c5b..dd716f7 100644
--- a/src/runtime/CL/CLTensorAllocator.cpp
+++ b/src/runtime/CL/CLTensorAllocator.cpp
@@ -74,8 +74,16 @@
if(_associated_memory_group == nullptr)
{
- ARM_COMPUTE_ERROR_ON(_memory.region()->cl_data().get() != nullptr);
- _memory = CLMemory(allocate_region(CLScheduler::get().context(), info().total_size(), 0));
+ if(_memory.region()->cl_data().get() != nullptr)
+ {
+ // Memory is already allocated. Reuse it if big enough, otherwise fire an assertion
+ ARM_COMPUTE_ERROR_ON_MSG(info().total_size() > _memory.region()->size(), "Reallocation of a bigger memory region is not allowed!");
+ }
+ else
+ {
+ // Perform memory allocation
+ _memory = CLMemory(allocate_region(CLScheduler::get().context(), info().total_size(), 0));
+ }
}
else
{
diff --git a/src/runtime/CL/functions/CLArithmeticDivision.cpp b/src/runtime/CL/functions/CLArithmeticDivision.cpp
new file mode 100644
index 0000000..1c2849c
--- /dev/null
+++ b/src/runtime/CL/functions/CLArithmeticDivision.cpp
@@ -0,0 +1,54 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/CL/functions/CLArithmeticDivision.h"
+
+#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/CL/kernels/CLArithmeticDivisionKernel.h"
+#include "support/ToolchainSupport.h"
+
+#include <utility>
+
+using namespace arm_compute;
+
+void CLArithmeticDivision::configure(ICLTensor *input1, ICLTensor *input2, ICLTensor *output)
+{
+ auto k = arm_compute::support::cpp14::make_unique<CLArithmeticDivisionKernel>();
+ k->configure(input1, input2, output);
+ _kernel = std::move(k);
+
+ if(output->info()->dimension(0) > 1)
+ {
+ ICLTensor *broadcasted_info = (input1->info()->dimension(0) == 1) ? input1 : input2;
+
+ if(broadcasted_info->info()->dimension(0) == 1)
+ {
+ _border_handler.configure(broadcasted_info, _kernel->border_size(), BorderMode::REPLICATE);
+ }
+ }
+}
+
+Status CLArithmeticDivision::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+{
+ return CLArithmeticDivisionKernel::validate(input1, input2, output);
+}
diff --git a/src/runtime/CL/functions/CLCannyEdge.cpp b/src/runtime/CL/functions/CLCannyEdge.cpp
index 5acb8e7..84e8709 100644
--- a/src/runtime/CL/functions/CLCannyEdge.cpp
+++ b/src/runtime/CL/functions/CLCannyEdge.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,16 +50,22 @@
_visited(),
_recorded(),
_l1_list_counter(),
- _l1_stack()
+ _l1_stack(),
+ _output(nullptr)
{
}
-void CLCannyEdge::configure(ICLTensor *input, ICLTensor *output, int32_t upper_thr, int32_t lower_thr, int32_t gradient_size, int32_t norm_type, BorderMode border_mode, uint8_t constant_border_value)
+void CLCannyEdge::configure(ICLTensor *input, ICLTensor *output, int32_t upper_thr, int32_t lower_thr, int32_t gradient_size, int32_t norm_type, BorderMode border_mode,
+ uint8_t constant_border_value)
{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
ARM_COMPUTE_ERROR_ON((1 != norm_type) && (2 != norm_type));
- ARM_COMPUTE_ERROR_ON(lower_thr > upper_thr);
+ ARM_COMPUTE_ERROR_ON((gradient_size != 3) && (gradient_size != 5) && (gradient_size != 7));
+ ARM_COMPUTE_ERROR_ON((lower_thr < 0) || (lower_thr >= upper_thr));
+
+ _output = output;
const unsigned int L1_hysteresis_stack_size = 8;
const TensorShape shape = input->info()->tensor_shape();
@@ -122,7 +128,7 @@
}
else
{
- ARM_COMPUTE_ERROR("Gradient %d size not supported", gradient_size);
+ ARM_COMPUTE_ERROR("Gradient size %d not supported", gradient_size);
}
// Manage intermediate buffers
@@ -187,6 +193,7 @@
CLScheduler::get().enqueue(_non_max_suppr, false);
// Clear temporary structures and run edge trace
+ _output->clear(CLScheduler::get().queue());
_visited.clear(CLScheduler::get().queue());
_recorded.clear(CLScheduler::get().queue());
_l1_list_counter.clear(CLScheduler::get().queue());
diff --git a/src/runtime/CL/functions/CLConcatenateLayer.cpp b/src/runtime/CL/functions/CLConcatenateLayer.cpp
new file mode 100644
index 0000000..018c674
--- /dev/null
+++ b/src/runtime/CL/functions/CLConcatenateLayer.cpp
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/CL/functions/CLConcatenateLayer.h"
+
+#include "arm_compute/runtime/CL/functions/CLDepthConcatenateLayer.h"
+#include "arm_compute/runtime/CL/functions/CLWidthConcatenateLayer.h"
+
+#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "support/ToolchainSupport.h"
+
+namespace arm_compute
+{
+CLConcatenateLayer::CLConcatenateLayer()
+ : _concat_function(nullptr)
+{
+}
+
+void CLConcatenateLayer::configure(const std::vector<ICLTensor *> &inputs_vector, ICLTensor *output, DataLayoutDimension axis)
+{
+ ARM_COMPUTE_ERROR_ON(output == nullptr);
+
+ switch(get_data_layout_dimension_index(output->info()->data_layout(), axis))
+ {
+ case 0:
+ {
+ auto func = support::cpp14::make_unique<CLWidthConcatenateLayer>();
+ func->configure(inputs_vector, output);
+ _concat_function = std::move(func);
+ break;
+ }
+ case 2:
+ {
+ auto func = support::cpp14::make_unique<CLDepthConcatenateLayer>();
+ func->configure(inputs_vector, output);
+ _concat_function = std::move(func);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Concatenation is supported across width and depth only!");
+ }
+}
+
+Status CLConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, DataLayoutDimension axis)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON(output == nullptr);
+
+ switch(get_data_layout_dimension_index(output->data_layout(), axis))
+ {
+ case 0:
+ ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenateLayer::validate(inputs_vector, output));
+ break;
+ case 2:
+ ARM_COMPUTE_RETURN_ON_ERROR(CLDepthConcatenateLayer::validate(inputs_vector, output));
+ break;
+ default:
+ ARM_COMPUTE_RETURN_ERROR_MSG("Concatenation is supported across width and depth only!");
+ }
+ return Status{};
+}
+
+void CLConcatenateLayer::run()
+{
+ ARM_COMPUTE_ERROR_ON(_concat_function == nullptr);
+ _concat_function->run();
+}
+} // namespace arm_compute
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp
index 47a8d5f..0014e71 100644
--- a/src/runtime/CL/functions/CLConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp
@@ -43,17 +43,18 @@
}
void CLConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
- const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math)
+ const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_ERROR_THROW_ON(CLConvolutionLayer::validate(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info,
- enable_fast_math));
+ enable_fast_math, num_groups));
switch(CLConvolutionLayer::get_convolution_method(input->info(), weights->info(), output->info(), conv_info,
weights_info, act_info, CLScheduler::get().target(), dilation, enable_fast_math))
{
case ConvolutionMethod::WINOGRAD:
{
+ ARM_COMPUTE_ERROR_ON(num_groups != 1);
auto f = arm_compute::support::cpp14::make_unique<CLWinogradConvolutionLayer>(_memory_manager);
f->configure(input, weights, biases, output, conv_info, act_info, enable_fast_math);
_function = std::move(f);
@@ -61,6 +62,7 @@
}
case ConvolutionMethod::DIRECT:
{
+ ARM_COMPUTE_ERROR_ON(num_groups != 1);
auto f = arm_compute::support::cpp14::make_unique<CLDirectConvolutionLayer>();
f->configure(input, weights, biases, output, conv_info, act_info);
_function = std::move(f);
@@ -69,7 +71,7 @@
case ConvolutionMethod::GEMM:
{
auto f = arm_compute::support::cpp14::make_unique<CLGEMMConvolutionLayer>(_memory_manager);
- f->configure(input, weights, biases, output, conv_info, weights_info, dilation, act_info);
+ f->configure(input, weights, biases, output, conv_info, weights_info, dilation, act_info, num_groups);
_function = std::move(f);
break;
}
@@ -80,9 +82,10 @@
}
Status CLConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
- const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math)
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((num_groups != 1) && (input->data_layout() != DataLayout::NCHW), "Grouping (num_groups != 1) with NHWC data layout is not supported");
const GPUTarget gpu_target = CLScheduler::get().target();
@@ -91,19 +94,21 @@
case ConvolutionMethod::WINOGRAD:
{
//Validate Winograd
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "Grouping (num_groups != 1) with CLWinogradConvolutionLayer is not supported");
ARM_COMPUTE_RETURN_ON_ERROR(CLWinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info, enable_fast_math));
break;
}
case ConvolutionMethod::DIRECT:
{
// Validate direct convolution layer
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "Grouping (num_groups != 1) with CLDirectConvolutionLayer is not supported");
ARM_COMPUTE_RETURN_ON_ERROR(CLDirectConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info));
break;
}
case ConvolutionMethod::GEMM:
{
// Validate gemm-based convolution layer
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info, num_groups));
break;
}
default:
@@ -123,8 +128,47 @@
ARM_COMPUTE_UNUSED(weights_info);
ARM_COMPUTE_UNUSED(gpu_target);
+ const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
+ /* Input spatial dims, kernel size, IFM/OFM, conv info*/
+ using ConvolutionConfiguration = std::tuple<Size2D, Size2D, Size2D, PadStrideInfo, DataLayout>;
+ using ConfigurationMethod = std::pair<ConvolutionConfiguration, ConvolutionMethod>;
+
+ const std::vector<ConfigurationMethod> known_configs =
+ {
+ // Alexnet
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(27U, 27U), Size2D(5U, 5U), Size2D(48U, 128U), PadStrideInfo(1U, 1U, 2U, 2U), DataLayout::NCHW), ConvolutionMethod::DIRECT),
+ // VGG16 / VGG19
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(224U, 224U), Size2D(3U, 3U), Size2D(3U, 64U), PadStrideInfo(1U, 1U, 1U, 1U), DataLayout::NCHW), ConvolutionMethod::DIRECT),
+ // Mobilenet 224
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(224U, 224U), Size2D(3U, 3U), Size2D(3U, 32U), PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR), DataLayout::NCHW), ConvolutionMethod::GEMM),
+ // Mobilenet 160
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(160U, 160U), Size2D(3U, 3U), Size2D(3U, 24U), PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR), DataLayout::NCHW), ConvolutionMethod::GEMM),
+ // Mobilenet 224
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(224U, 224U), Size2D(3U, 3U), Size2D(3U, 32U), PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR), DataLayout::NHWC), ConvolutionMethod::GEMM),
+ // Mobilenet 160
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(160U, 160U), Size2D(3U, 3U), Size2D(3U, 24U), PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR), DataLayout::NHWC), ConvolutionMethod::GEMM),
+ };
+
+ const auto find_config = [&](ConfigurationMethod c)
+ {
+ const ConvolutionConfiguration config = c.first;
+ const PadStrideInfo info = std::get<3>(config);
+ const DataLayout data_layout = std::get<4>(config);
+
+ return std::get<0>(config) == Size2D(input->dimension(idx_w), input->dimension(idx_h)) && std::get<1>(config) == Size2D(weights->dimension(idx_w), weights->dimension(idx_h))
+ && std::get<2>(config) == Size2D(weights->dimension(idx_c), weights->dimension(3)) && info.pad_top() == conv_info.pad_top() && info.pad_right() == conv_info.pad_right()
+ && info.pad_bottom() == conv_info.pad_bottom() && info.pad_left() == conv_info.pad_left() && info.stride() == conv_info.stride() && (data_layout == input->data_layout());
+ };
+
+ std::vector<ConfigurationMethod>::const_iterator found;
+ if((found = std::find_if(known_configs.begin(), known_configs.end(), find_config)) != known_configs.end())
+ {
+ return (*found).second;
+ }
+
if(dilation != Size2D(1U, 1U) || (input->dimension(idx_c) < 16))
{
return ConvolutionMethod::GEMM;
diff --git a/src/runtime/CL/functions/CLCopy.cpp b/src/runtime/CL/functions/CLCopy.cpp
index 3442e37..d1b7926 100644
--- a/src/runtime/CL/functions/CLCopy.cpp
+++ b/src/runtime/CL/functions/CLCopy.cpp
@@ -41,3 +41,8 @@
k->configure(input, output);
_kernel = std::move(k);
}
+
+Status CLCopy::validate(const arm_compute::ITensorInfo *input, const arm_compute::ITensorInfo *output)
+{
+ return CLCopyKernel::validate(input, output);
+}
diff --git a/src/runtime/CL/functions/CLDeconvolutionLayer.cpp b/src/runtime/CL/functions/CLDeconvolutionLayer.cpp
index cb8dc02..40562b5 100644
--- a/src/runtime/CL/functions/CLDeconvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLDeconvolutionLayer.cpp
@@ -38,15 +38,16 @@
: _memory_group(std::move(memory_manager)),
_scale_f(),
_conv_f(),
- _scaled_output()
+ _scaled_output(),
+ _is_prepared(false)
{
}
Status CLDeconvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *output, const PadStrideInfo &info,
- unsigned int inner_border_right, unsigned int inner_border_top)
+ unsigned int inner_border_right, unsigned int inner_border_top, const WeightsInfo &weights_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(0) != weights->dimension(1));
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(0) < 1);
ARM_COMPUTE_RETURN_ERROR_ON(!info.padding_is_symmetric());
@@ -63,12 +64,10 @@
const TensorShape output_shape = deconvolution_output_shape(out_dims, input->tensor_shape(), weights->tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, weights);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output, weights);
if(bias != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, bias);
}
ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->dimension(Window::DimX) != output_shape.x(), "Output's width is invalid.");
@@ -80,13 +79,13 @@
const PadStrideInfo conv_info(1, 1, 0, 0, 0, 0, DimensionRoundingType::CEIL);
ARM_COMPUTE_RETURN_ON_ERROR(CLDeconvolutionLayerUpsample::validate(input, &scale_out_info, BorderSize(inner_border_right, inner_border_top), info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayer::validate(&scale_out_info, weights, bias, output, conv_info, WeightsInfo()));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayer::validate(&scale_out_info, weights, bias, output, conv_info, weights_info));
return Status{};
}
void CLDeconvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const PadStrideInfo &info,
- unsigned int inner_border_right, unsigned int inner_border_top)
+ unsigned int inner_border_right, unsigned int inner_border_top, const WeightsInfo &weights_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
@@ -94,16 +93,18 @@
const unsigned int stride_y = info.stride().second;
auto out_dims = deconvolution_output_dimensions(input->info()->dimension(0), input->info()->dimension(1), weights->info()->dimension(0), weights->info()->dimension(1),
- info.pad().first, info.pad().second, inner_border_top, inner_border_right, stride_x, stride_y);
+ info.pad().first, info.pad().second, inner_border_right, inner_border_top, stride_x, stride_y);
const TensorShape output_shape = deconvolution_output_shape(out_dims, input->info()->tensor_shape(), weights->info()->tensor_shape());
// Output auto initialization if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
// Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(CLDeconvolutionLayer::validate(input->info(), weights->info(), bias == nullptr ? nullptr : bias->info(), output->info(), info, inner_border_right, inner_border_top));
+ _is_prepared = false;
+
_memory_group.manage(&_scaled_output);
// configure scale function
@@ -113,21 +114,34 @@
const unsigned int out_y = input->info()->dimension(1) + (input->info()->dimension(1) - 1) * (stride_y - 1) + inner_border_top + 2 * info.pad().second;
scale_out_shape.set(0, out_x);
scale_out_shape.set(1, out_y);
- TensorInfo scale_out_info(scale_out_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ TensorInfo scale_out_info(scale_out_shape, 1, input->info()->data_type());
_scaled_output.allocator()->init(scale_out_info);
_scale_f.configure(input, &_scaled_output, BorderSize(inner_border_top, inner_border_right), info);
// setup the function to convolve the upscaled output
const PadStrideInfo conv_info(1, 1, 0, 0, 0, 0, DimensionRoundingType::CEIL);
- _conv_f.configure(&_scaled_output, weights, bias, output, conv_info);
+ _conv_f.configure(&_scaled_output, weights, bias, output, conv_info, weights_info);
_scaled_output.allocator()->allocate();
}
void CLDeconvolutionLayer::run()
{
+ prepare();
+
_memory_group.acquire();
+
_scale_f.run();
_conv_f.run();
+
_memory_group.release();
}
+
+void CLDeconvolutionLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ _conv_f.prepare();
+ _is_prepared = true;
+ }
+}
diff --git a/src/runtime/CL/functions/CLDepthConcatenateLayer.cpp b/src/runtime/CL/functions/CLDepthConcatenateLayer.cpp
index 26d46a4..b5e8fd9 100644
--- a/src/runtime/CL/functions/CLDepthConcatenateLayer.cpp
+++ b/src/runtime/CL/functions/CLDepthConcatenateLayer.cpp
@@ -27,7 +27,9 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/PixelValue.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "support/ToolchainSupport.h"
@@ -41,22 +43,26 @@
{
}
-void CLDepthConcatenateLayer::configure(std::vector<ICLTensor *> inputs_vector, ICLTensor *output) // NOLINT
+void CLDepthConcatenateLayer::configure(const std::vector<ICLTensor *> &inputs_vector, ICLTensor *output) // NOLINT
{
- ARM_COMPUTE_ERROR_ON(inputs_vector.size() < 2);
-
_num_inputs = inputs_vector.size();
- unsigned int depth_offset = 0;
+ std::vector<ITensorInfo *> inputs_vector_info;
+ for(unsigned int i = 0; i < _num_inputs; i++)
+ {
+ inputs_vector_info.emplace_back(inputs_vector.at(i)->info());
+ }
_concat_kernels_vector = arm_compute::support::cpp14::make_unique<CLDepthConcatenateLayerKernel[]>(_num_inputs);
_border_handlers_vector = arm_compute::support::cpp14::make_unique<CLFillBorderKernel[]>(_num_inputs);
- TensorShape output_shape = calculate_depth_concatenate_shape(inputs_vector);
+ TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_depth_concatenate_shape(inputs_vector_info);
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type(), inputs_vector[0]->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type());
+ ARM_COMPUTE_ERROR_THROW_ON(CLDepthConcatenateLayer::validate(inputs_vector_info, output->info()));
+ unsigned int depth_offset = 0;
for(unsigned int i = 0; i < _num_inputs; i++)
{
_concat_kernels_vector[i].configure(inputs_vector.at(i), depth_offset, output);
@@ -69,6 +75,27 @@
output->info()->set_valid_region(ValidRegion(Coordinates(), output_shape));
}
+Status CLDepthConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON(inputs_vector.size() < 2);
+
+ // Output auto inizialitation if not yet initialized
+ TensorInfo tmp_output_info = *output->clone();
+ TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_depth_concatenate_shape(inputs_vector);
+ auto_init_if_empty(tmp_output_info, output_shape, 1, inputs_vector[0]->data_type());
+
+ unsigned int depth_offset = 0;
+ for(const auto &input : inputs_vector)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_RETURN_ON_ERROR(CLDepthConcatenateLayerKernel::validate(input, depth_offset, &tmp_output_info));
+ depth_offset += input->dimension(2);
+ }
+
+ return Status{};
+}
+
void CLDepthConcatenateLayer::run()
{
cl::CommandQueue q = CLScheduler::get().queue();
diff --git a/src/runtime/CL/functions/CLDepthConvertLayer.cpp b/src/runtime/CL/functions/CLDepthConvertLayer.cpp
index b448465..2e52e8a 100644
--- a/src/runtime/CL/functions/CLDepthConvertLayer.cpp
+++ b/src/runtime/CL/functions/CLDepthConvertLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,3 +36,8 @@
k->configure(input, output, policy, shift);
_kernel = std::move(k);
}
+
+Status CLDepthConvertLayer::validate(const ITensorInfo *input, const ITensorInfo *output, ConvertPolicy policy, uint32_t shift)
+{
+ return CLDepthConvertLayerKernel::validate(input, output, policy, shift);
+}
diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
index 676a121..76451af 100644
--- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
@@ -73,7 +73,7 @@
ActivationLayerInfo act_info, GPUTarget gpu_target)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() != DataLayout::NCHW && input->data_layout() != DataLayout::NHWC);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
if(input->data_layout() == DataLayout::NCHW)
{
@@ -91,7 +91,7 @@
CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayer()
: _im2col_kernel(), _weights_reshape_kernel(), _v2mm_kernel(), _vector_to_tensor_kernel(), _output_stage_kernel(), _v2mm_input_fill_border(), _v2mm_weights_fill_border(), _input_reshaped(),
- _weights_reshaped(), _v2mm_output(), _output_reshaped(), _is_first_run(true), _is_quantized(false), _original_weights(nullptr)
+ _weights_reshaped(), _v2mm_output(), _output_reshaped(), _is_prepared(false), _is_quantized(false), _original_weights(nullptr)
{
}
@@ -99,12 +99,17 @@
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
- const size_t weights_w = weights->info()->dimension(0);
- const size_t weights_h = weights->info()->dimension(1);
- const size_t weights_z = weights->info()->dimension(2);
+ const size_t idx_w = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT);
+ const size_t idx_c = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::CHANNEL);
- _is_first_run = true;
+ const size_t weights_w = weights->info()->dimension(idx_w);
+ const size_t weights_h = weights->info()->dimension(idx_h);
+ const size_t weights_z = weights->info()->dimension(idx_c);
+
+ _is_prepared = false;
_original_weights = weights;
_is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
@@ -119,8 +124,8 @@
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
// Output width and height
- const unsigned int conv_w = output_shape.x();
- const unsigned int conv_h = output_shape.y();
+ const unsigned int conv_w = output_shape[idx_w];
+ const unsigned int conv_h = output_shape[idx_h];
// Set up intermediate tensors
const size_t patch_size = weights_w * weights_h + ((append_bias) ? 1 : 0);
@@ -134,6 +139,7 @@
_input_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
_im2col_kernel.set_target(gpu_target);
_im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias, depth_multiplier);
+ CLScheduler::get().tune_kernel_static(_im2col_kernel);
// Weights reshape configuration
const TensorShape shape_weights_reshape(patch_size, weights_z);
@@ -149,6 +155,7 @@
_v2mm_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(v2mm_dt).set_tensor_shape(shape_v2mm_out));
_v2mm_kernel.set_target(gpu_target);
_v2mm_kernel.configure(&_input_reshaped, &_weights_reshaped, &_v2mm_output);
+ CLScheduler::get().tune_kernel_static(_v2mm_kernel);
_output_reshaped.allocator()->init(_v2mm_output.info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape));
_vector_to_tensor_kernel.configure(&_v2mm_output, (_is_quantized) ? &_output_reshaped : output, conv_w, conv_h);
@@ -180,24 +187,27 @@
// Allocate intermediate tensors
_input_reshaped.allocator()->allocate();
- _weights_reshaped.allocator()->allocate();
_v2mm_output.allocator()->allocate();
}
Status CLDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
unsigned int depth_multiplier)
{
+ const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+ const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
+
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(2) * depth_multiplier) != weights->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON((input->dimension(idx_c) * depth_multiplier) != weights->dimension(idx_c));
const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
const bool append_bias = (biases != nullptr) && !is_quantized;
const TensorShape output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier);
- const size_t weights_w = weights->dimension(0);
- const size_t weights_h = weights->dimension(1);
- const size_t weights_z = weights->dimension(2);
- const unsigned int conv_w = output_shape.x();
- const unsigned int conv_h = output_shape.y();
+ const size_t weights_w = weights->dimension(idx_w);
+ const size_t weights_h = weights->dimension(idx_h);
+ const size_t weights_z = weights->dimension(idx_c);
+ const unsigned int conv_w = output_shape[idx_w];
+ const unsigned int conv_h = output_shape[idx_h];
const size_t patch_size = weights_w * weights_h + ((append_bias) ? 1 : 0);
const size_t conv_size = conv_w * conv_h;
@@ -233,18 +243,7 @@
void CLDepthwiseConvolutionLayer::run()
{
- // Run weights reshaping (Runs once for every configure)
- if(_is_first_run)
- {
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
- CLScheduler::get().enqueue(_weights_reshape_kernel);
- CLScheduler::get().enqueue(_v2mm_weights_fill_border);
- _is_first_run = false;
-
- // Mark original weights tensor as unused
- _original_weights->mark_as_unused();
- }
+ prepare();
CLScheduler::get().enqueue(_im2col_kernel);
CLScheduler::get().enqueue(_v2mm_input_fill_border);
@@ -255,3 +254,20 @@
CLScheduler::get().enqueue(_output_stage_kernel);
}
}
+
+void CLDepthwiseConvolutionLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ // Run weights reshaping and mark original weights tensor as unused
+ _weights_reshaped.allocator()->allocate();
+ CLScheduler::get().enqueue(_weights_reshape_kernel);
+ CLScheduler::get().enqueue(_v2mm_weights_fill_border);
+ _original_weights->mark_as_unused();
+
+ CLScheduler::get().queue().finish();
+ _is_prepared = true;
+ }
+}
diff --git a/src/runtime/CL/functions/CLDepthwiseSeparableConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseSeparableConvolutionLayer.cpp
index af2c6f0..fa2c3af 100644
--- a/src/runtime/CL/functions/CLDepthwiseSeparableConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLDepthwiseSeparableConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,6 +45,14 @@
void CLDepthwiseSeparableConvolutionLayer::run()
{
+ prepare();
+
_depthwise_conv.run();
_pointwise_conv.run();
+}
+
+void CLDepthwiseSeparableConvolutionLayer::prepare()
+{
+ _depthwise_conv.prepare();
+ _pointwise_conv.prepare();
}
\ No newline at end of file
diff --git a/src/runtime/CL/functions/CLFlattenLayer.cpp b/src/runtime/CL/functions/CLFlattenLayer.cpp
index 9f571b2..b372c35 100644
--- a/src/runtime/CL/functions/CLFlattenLayer.cpp
+++ b/src/runtime/CL/functions/CLFlattenLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,15 +23,21 @@
*/
#include "arm_compute/runtime/CL/functions/CLFlattenLayer.h"
-#include "arm_compute/core/CL/kernels/CLIm2ColKernel.h"
-#include "arm_compute/core/Size2D.h"
+#include "arm_compute/core/CL/kernels/CLFlattenLayerKernel.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
#include "support/ToolchainSupport.h"
using namespace arm_compute;
void CLFlattenLayer::configure(const ICLTensor *input, ICLTensor *output)
{
- auto k = arm_compute::support::cpp14::make_unique<CLIm2ColKernel>();
- k->configure(input, output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
+ auto k = arm_compute::support::cpp14::make_unique<CLFlattenLayerKernel>();
+ k->configure(input, output);
_kernel = std::move(k);
+ CLScheduler::get().tune_kernel_static(*_kernel);
}
+
+Status CLFlattenLayer::validate(const ITensorInfo *input, const ITensorInfo *output)
+{
+ return CLFlattenLayerKernel::validate(input, output);
+}
\ No newline at end of file
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index 151fa1b..010985d 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -73,12 +73,12 @@
}
CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(memory_manager), _im2col_kernel(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(),
- _im2col_output(), _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _original_weights(nullptr)
+ : _memory_group(memory_manager), _convert_weights(), _flatten_layer(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
+ _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true),
+ _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), _original_weights(nullptr)
{
}
-
-void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output)
+void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights)
{
if(_is_quantized)
{
@@ -100,40 +100,41 @@
else
{
// Configure matrix multiply kernel
- _mm_gemm.configure(input, weights, nullptr, output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */));
+ _mm_gemm.configure(input, weights, nullptr, output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */, 1, false, retain_internal_weights));
}
}
-void CLFullyConnectedLayer::configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output)
+void CLFullyConnectedLayer::configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights)
{
ARM_COMPUTE_ERROR_ON((weights->info()->dimension(1) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2))));
// If the fully connected layer is called after a convolution layer, the input tensor must be linearized
- // Initialize output tensor for im2col
- TensorShape shape_im2col = compute_im2col_fc_shape(input->info());
- _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
+ // Initialize output tensor for flatten
+ TensorShape shape_flatten = compute_flatten_shape(input->info());
+ _flatten_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_flatten).set_data_layout(DataLayout::NCHW));
- // Configure im2col kernel
- _memory_group.manage(&_im2col_output);
- _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
+ // Configure flatten kernel
+ _memory_group.manage(&_flatten_output);
+ _flatten_layer.configure(input, &_flatten_output);
// Configure matrix multiply kernel
- configure_mm(&_im2col_output, weights, output);
+ configure_mm(&_flatten_output, weights, output, retain_internal_weights);
- // Allocate the output tensor for im2col once all the configure methods have been called
- _im2col_output.allocator()->allocate();
+ // Allocate the output tensor for flatten once all the configure methods have been called
+ _flatten_output.allocator()->allocate();
}
-void CLFullyConnectedLayer::configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output)
+void CLFullyConnectedLayer::configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights)
{
ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1));
// Configure matrix multiply kernel
- configure_mm(input, weights, output);
+ configure_mm(input, weights, output, retain_internal_weights);
}
-void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights, bool are_weights_reshaped)
+void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output,
+ FullyConnectedLayerInfo fc_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
@@ -142,14 +143,15 @@
weights->info(),
biases != nullptr ? biases->info() : nullptr,
output->info(),
- transpose_weights,
- are_weights_reshaped));
+ fc_info));
- _are_weights_reshaped = transpose_weights ? are_weights_reshaped : true;
- _is_fc_after_conv = true;
- _accumulate_biases = false;
- _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
- _original_weights = weights;
+ _are_weights_converted = true;
+ _are_weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
+ _is_fc_after_conv = true;
+ _accumulate_biases = false;
+ _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
+ _is_prepared = fc_info.retain_internal_weights;
+ _original_weights = weights;
// Configure gemmlowp output
if(_is_quantized)
@@ -169,25 +171,16 @@
_accumulate_biases_kernel.configure(output, biases);
}
+ const ICLTensor *weights_to_use = weights;
+
// With the Fully Connected layer we can have 4 different cases:
// 1) Convolution layer -> Fully Connected layer without batches
// 2) Fully Connected layer -> Fully Connected layer without batches
// 3) Convolution layer -> Fully Connected layer with batches
// 4) Fully Connected layer -> Fully Connected layer with batches
- const ICLTensor *weights_to_use = weights;
-
- if(!_are_weights_reshaped)
- {
- weights_to_use = &_reshape_weights_output;
-
- // Reshape the weights
- _reshape_weights_kernel.configure(weights, &_reshape_weights_output);
- }
-
// Check if we have a fully connected layer with batches
const bool is_batched_fc_layer = output->info()->dimension(1) > 1;
-
if(is_batched_fc_layer)
{
_is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->info()->tensor_shape().cbegin() + 3,
@@ -199,16 +192,38 @@
_is_fc_after_conv = input->info()->num_dimensions() > 1;
}
+ // Reshape weights if needed
+ if(!_are_weights_reshaped)
+ {
+ // Reshape the weights
+ _reshape_weights_kernel.configure(weights, &_reshape_weights_output);
+ weights_to_use = &_reshape_weights_output;
+ }
+
+ // Convert weights if needed
+ if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout))
+ {
+ // Convert weights
+ _convert_weights.configure(weights_to_use,
+ &_converted_weights_output,
+ input->info()->tensor_shape(),
+ fc_info.weights_trained_layout);
+
+ weights_to_use = &_converted_weights_output;
+ _are_weights_converted = false;
+ }
+
+ // Configure fc core
ICLTensor *tmp_output = (_is_quantized) ? &_gemmlowp_output : output;
if(_is_fc_after_conv)
{
// Fully Connected layer after a Convolution Layer without batches
- configure_conv_fc(input, weights_to_use, tmp_output);
+ configure_conv_fc(input, weights_to_use, tmp_output, fc_info.retain_internal_weights);
}
else
{
// Fully Connected layer after a Fully Connected Layer without batches
- configure_fc_fc(input, weights_to_use, tmp_output);
+ configure_fc_fc(input, weights_to_use, tmp_output, fc_info.retain_internal_weights);
}
// Configure output stage for asymmetric quantized types
@@ -222,21 +237,23 @@
}
}
-Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights, bool are_weights_reshaped)
+Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+ FullyConnectedLayerInfo fc_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
- bool weights_reshaped = transpose_weights ? are_weights_reshaped : true;
+ bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
bool is_fc_after_conv = true;
bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
const GPUTarget gpu_target = CLScheduler::get().target();
- const ITensorInfo &im2col_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input)));
- const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
- const ITensorInfo &gemmlowp_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
+ const ITensorInfo &flatten_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(input)).set_data_layout(DataLayout::NCHW));
+ const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
+ const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone());
+ const ITensorInfo &gemmlowp_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
// Configure accumulate biases kernel for non quantized asymmetric types
if(biases != nullptr && !is_quantized)
@@ -255,16 +272,8 @@
const ITensorInfo *weights_to_use = weights;
const ITensorInfo *tmp_output = (is_quantized) ? &gemmlowp_output : output;
- if(!weights_reshaped)
- {
- // Validate reshape weights kernel
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights));
- weights_to_use = &reshaped_weights;
- }
-
// Check if we have a fully connected layer with batches
const bool is_batched_fc_layer = output->dimension(1) > 1;
-
if(is_batched_fc_layer)
{
is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->tensor_shape().cbegin() + 3,
@@ -276,14 +285,31 @@
is_fc_after_conv = input->num_dimensions() > 1;
}
+ if(!weights_reshaped)
+ {
+ // Validate reshape weights kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights));
+ weights_to_use = &reshaped_weights;
+ }
+
+ if(is_fc_after_conv && (input->data_layout() != fc_info.weights_trained_layout))
+ {
+ // Validate convert weights kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(CLConvertFullyConnectedWeights::validate(weights_to_use,
+ &converted_weights,
+ input->tensor_shape(),
+ fc_info.weights_trained_layout));
+ weights_to_use = &converted_weights;
+ }
+
if(is_fc_after_conv)
{
// Fully Connected layer after a Convolution Layer without batches
ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (input->dimension(0) * input->dimension(1) * input->dimension(2))));
- // Validate im2col kernel
- ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_input, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false));
- input_to_use = &im2col_input;
+ // Validate flatten kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFlattenLayer::validate(input, &flatten_input));
+ input_to_use = &flatten_input;
}
else
{
@@ -311,7 +337,7 @@
// Linearize input if it comes from a convolutional layer
if(_is_fc_after_conv)
{
- CLScheduler::get().enqueue(_im2col_kernel, false);
+ _flatten_layer.run();
}
// Run matrix multiply
@@ -342,27 +368,57 @@
void CLFullyConnectedLayer::prepare()
{
- // Reshape of the weights (happens only once)
- if(!_are_weights_reshaped)
+ if(!_is_prepared)
{
ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
- // Run reshape weights kernel and mark weights as unused
- _reshape_weights_output.allocator()->allocate();
- _reshape_weights_kernel.run();
- _original_weights->mark_as_unused();
+ auto release_unused = [](CLTensor * w)
+ {
+ if(!w->is_used())
+ {
+ CLScheduler::get().queue().finish();
+ w->allocator()->free();
+ }
+ };
+
+ // Pointer to current weights
+ const ICLTensor *cur_weights = _original_weights;
+
+ // Reshape of the weights if needed (happens only once)
+ if(!_are_weights_reshaped)
+ {
+ // Run reshape weights kernel and mark weights as unused
+ _reshape_weights_output.allocator()->allocate();
+ _reshape_weights_kernel.run();
+
+ cur_weights->mark_as_unused();
+ cur_weights = &_reshape_weights_output;
+ _are_weights_reshaped = true;
+ }
+
+ // Convert weights if needed (happens only once)
+ if(!_are_weights_converted)
+ {
+ _converted_weights_output.allocator()->allocate();
+ _convert_weights.run();
+
+ cur_weights->mark_as_unused();
+ _are_weights_converted = true;
+ }
+
+ // Release reshaped weights if unused
+ release_unused(&_reshape_weights_output);
// Prepare GEMM prepare and release unused weights
if(!_is_quantized)
{
_mm_gemm.prepare();
- if(!_reshape_weights_output.is_used())
- {
- _reshape_weights_output.allocator()->free();
- }
}
- CLScheduler::get().queue().finish();
- _are_weights_reshaped = true;
+ // Release converted weights if unused
+ release_unused(&_reshape_weights_output);
+ release_unused(&_converted_weights_output);
+
+ _is_prepared = true;
}
}
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index f81da6c..f16d1c0 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -24,10 +24,6 @@
#include "arm_compute/runtime/CL/functions/CLGEMM.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMTranspose1xWKernel.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/GPUTarget.h"
#include "arm_compute/core/Helpers.h"
@@ -48,13 +44,16 @@
{
bool flag = true;
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
+ if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
{
- // COMPMID-852
if(k > 256 && m > 4 && is_data_type_float(data_type) && reshape_b_only_on_first_run)
{
- const float scale = k < 1024 ? 2.0f : 2.5f;
- flag = (scale * n) > ((1.66f * n) + 38.4f);
+ constexpr float alpha = 3.2f;
+ constexpr float fact0 = 1.51f;
+ constexpr float fact1 = 1.66f;
+ constexpr float ops = 12.0f;
+ const float scale = k > 1024 ? 1.07f : 1.0f;
+ flag = alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops);
}
else
{
@@ -84,12 +83,10 @@
// Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
- // Store original b matrix
- _original_b = b;
-
// Check if we need to reshape the matrix B only on the first run
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
- _is_prepared = false;
+ _is_prepared = gemm_info.retain_internal_weights();
+ _original_b = b;
const ICLTensor *matrix_a = a;
const ICLTensor *matrix_b = b;
@@ -104,9 +101,11 @@
// Arguments used by GEMMReshapeInfo
// If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
// in order to know how the matrices have been reshaped
- const int m = a->info()->dimension(1);
+ bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ const int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
const int n = b->info()->dimension(0);
const int k = a->info()->dimension(0);
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
int mult_transpose1xW_width = 1;
int mult_interleave4x4_height = 1;
@@ -119,6 +118,12 @@
// Check if we need to reshape the matrix A and matrix B
_is_interleaved_transposed = is_interleaved_transposed(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
+ // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+ if(_is_interleaved_transposed)
+ {
+ reinterpret_input_as_3d = false;
+ }
+
if(_is_interleaved_transposed)
{
matrix_a = &_tmp_a;
@@ -133,13 +138,16 @@
// _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
// Configure interleave kernel
- _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height);
+ _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d());
// Configure transpose kernel
_transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
}
- _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height));
+ // Configure and tune matrix multiply kernel
+ _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d,
+ reinterpret_input_as_3d));
+ CLScheduler::get().tune_kernel_static(_mm_kernel);
if(_is_interleaved_transposed)
{
@@ -162,6 +170,7 @@
Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
+ ARM_COMPUTE_UNUSED(output);
// Check if we need to reshape the matrix B only on the first run
const bool reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
@@ -171,7 +180,6 @@
TensorInfo tmp_a_info{};
TensorInfo tmp_b_info{};
- TensorInfo tmp_output_info = *output->clone();
// Get the GPU target
const GPUTarget gpu_target = CLScheduler::get().target();
@@ -179,11 +187,13 @@
// Arguments used by GEMMReshapeInfo
// If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
// in order to know how the matrices have been reshaped
- const int m = a->dimension(1);
+ bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ const int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
const int n = b->dimension(0);
const int k = a->dimension(0);
int mult_transpose1xW_width = 1;
int mult_interleave4x4_height = 1;
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
{
@@ -191,19 +201,25 @@
mult_interleave4x4_height = 2;
}
- const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height);
-
// Check if we need to reshape the matrix A and matrix B
const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);
+ // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+ if(run_interleave_transpose)
+ {
+ reinterpret_input_as_3d = false;
+ }
+
+ const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
+
if(run_interleave_transpose)
{
matrix_a_info = &tmp_a_info;
matrix_b_info = &tmp_b_info;
// Validate interleave kernel
- auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height)));
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height));
+ auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
// Validate transpose kernel
auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
@@ -211,13 +227,12 @@
}
// Validate matrix multiply
- auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info, gpu_target));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, alpha, run_interleave_transpose, reshape_info, gpu_target));
if(beta != 0 && c != nullptr)
{
// Validate matrix addition kernel
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, &tmp_output_info, beta));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
}
return Status{};
@@ -259,7 +274,7 @@
{
if(_is_interleaved_transposed && _reshape_b_only_on_first_run)
{
- // Run transpose kernel
+ // Run transpose kernel and mark original weights tensor as unused
_tmp_b.allocator()->allocate();
CLScheduler::get().enqueue(_transpose_kernel, false);
_original_b->mark_as_unused();
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index 79495e4..92d04d6 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -43,42 +43,43 @@
{
}
-void CLConvolutionLayerReshapeWeights::configure(const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output)
+void CLConvolutionLayerReshapeWeights::configure(const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, unsigned int num_groups)
{
// Perform validation step
ARM_COMPUTE_ERROR_ON_NULLPTR(weights, output);
ARM_COMPUTE_ERROR_THROW_ON(CLConvolutionLayerReshapeWeights::validate(weights->info(),
(biases != nullptr) ? biases->info() : nullptr,
- output->info()));
+ output->info(),
+ num_groups));
const bool append_biases = (biases != nullptr) && !is_data_type_quantized_asymmetric(weights->info()->data_type());
const ICLTensor *biases_to_use = (append_biases) ? biases : nullptr;
- _weights_reshape_kernel.configure(weights, biases_to_use, output);
+ _weights_reshape_kernel.configure(weights, biases_to_use, output, num_groups);
output->info()->set_quantization_info(weights->info()->quantization_info());
}
-Status CLConvolutionLayerReshapeWeights::validate(const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output)
+Status CLConvolutionLayerReshapeWeights::validate(const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, unsigned int num_groups)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(weights);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
if(biases != nullptr)
{
+ const int idx_kernels = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::BATCHES);
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(weights->data_type()));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
- ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3));
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels));
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
if((output != nullptr) && (output->total_size() != 0))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(weights, output);
- CLWeightsReshapeKernel::validate(weights, biases, output);
+ CLWeightsReshapeKernel::validate(weights, biases, output, num_groups);
}
return Status{};
@@ -91,14 +92,15 @@
CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _col2im_kernel(), _activationlayer_function(),
- _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
+ _add_bias_kernel(), _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false),
+ _skip_im2col(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
{
}
-void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output)
+void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
- ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info()));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info(), gemm_3d_depth, _skip_im2col));
if(_is_quantized)
{
@@ -119,15 +121,16 @@
else
{
// Configure matrix multiply function
- _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/));
+ _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth,
+ _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */));
}
}
-Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output)
+Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth, bool skip_im2col)
{
const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
- const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */);
+ const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
if(is_quantized)
{
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -141,18 +144,17 @@
weights_qa->set_quantization_info(QuantizationInfo(weights_quantization_info.scale, -weights_quantization_info.offset));
// Perform validation step on GEMMLowp
- CLGEMMLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), output, gemm_info);
+ return CLGEMMLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), output, gemm_info);
}
else
{
// Perform validation step on Matrix multiply function
- CLGEMM::validate(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
+ return CLGEMM::validate(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
}
- return Status{};
}
void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
- const Size2D &dilation, const ActivationLayerInfo &act_info)
+ const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
@@ -163,22 +165,35 @@
conv_info,
weights_info,
dilation,
- act_info));
+ act_info,
+ num_groups));
- _is_prepared = false;
+ const DataType data_type = input->info()->data_type();
+ const DataLayout data_layout = input->info()->data_layout();
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
+
+ const unsigned int kernel_width = weights->info()->dimension(idx_width);
+ const unsigned int kernel_height = weights->info()->dimension(idx_height);
+
+ _is_prepared = weights_info.retain_internal_weights();
_original_weights = weights;
_is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
-
- const DataType dt = input->info()->data_type();
+ _data_layout = data_layout;
+ _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !_is_quantized;
+ _append_bias = (biases != nullptr) && (!_is_quantized);
// Set the GPU target for im2col and col2im
_im2col_kernel.set_target(CLScheduler::get().target());
_col2im_kernel.set_target(CLScheduler::get().target());
- const bool append_bias = (biases != nullptr) && (!_is_quantized);
+ bool is_nhwc = _data_layout == DataLayout::NHWC;
+ const ICLTensor *gemm_input_to_use = input;
+ ICLTensor *gemm_output_to_use = output;
+ ICLTensor *gemm_output_staged_to_use = output;
- const unsigned bias_element = (append_bias) ? 1 : 0;
- const ICLTensor *biases_to_use = (append_bias) ? biases : nullptr;
+ const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
// Get parameters from conv_info
unsigned int stride_x = 0;
@@ -188,51 +203,66 @@
// Get convolved dimensions
unsigned int conv_w = 0;
unsigned int conv_h = 0;
+ std::tie(conv_w, conv_h) = scaled_dimensions(input->info()->dimension(idx_width),
+ input->info()->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
- const unsigned int kernel_width = weights->info()->dimension(0);
- const unsigned int kernel_height = weights->info()->dimension(1);
- std::tie(conv_w, conv_h) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_width, kernel_height,
- conv_info, dilation);
-
- unsigned int mat_weights_cols = weights->info()->dimension(3);
- unsigned int mat_weights_rows = weights->info()->dimension(0) * weights->info()->dimension(1) * weights->info()->dimension(2) + bias_element;
+ unsigned int mat_weights_cols = weights->info()->dimension(idx_kernels) / num_groups;
// _weights_reshaped will be auto configured in the kernel.
// Just append biases and do not transpose 1xW as it will be reshaped in CLGEMM
- _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped);
-
- weights = &_weights_reshaped;
+ _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped, num_groups);
// Create tensor to store im2col reshaped inputs
- const unsigned int mat_input_cols = mat_weights_rows;
- const unsigned int mat_input_rows = conv_w * conv_h;
- TensorShape shape_im2col = input->info()->tensor_shape();
- shape_im2col.set(0, mat_input_cols);
- shape_im2col.set(1, mat_input_rows);
- shape_im2col.set(2, 1);
- TensorInfo im2col_reshaped_info(shape_im2col, 1, dt, input->info()->fixed_point_position());
- im2col_reshaped_info.set_quantization_info(input->info()->quantization_info());
- _im2col_output.allocator()->init(im2col_reshaped_info);
- _memory_group.manage(&_im2col_output);
+ if(!_skip_im2col)
+ {
+ _memory_group.manage(&_im2col_output);
+
+ // Configure and tune im2col. im2col output shape is auto-initialized
+ _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation, num_groups);
+
+ // Set quantization info
+ _im2col_output.info()->set_quantization_info(input->info()->quantization_info());
+ CLScheduler::get().tune_kernel_static(_im2col_kernel);
+
+ // Update GEMM input
+ gemm_input_to_use = &_im2col_output;
+ }
+ else if(_append_bias)
+ {
+ // Configure add bias kernel
+ _add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE);
+ }
// Create GEMM output tensor
- TensorShape shape_gemm = _im2col_output.info()->tensor_shape();
- shape_gemm.set(0, mat_weights_cols);
- shape_gemm.set(1, mat_input_rows);
- const DataType gemm_data_type = _is_quantized ? DataType::S32 : dt;
- // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
- TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->info()->fixed_point_position());
- info_gemm.set_quantization_info(output->info()->quantization_info());
- _gemm_output.allocator()->init(info_gemm);
- _memory_group.manage(&_gemm_output);
+ if(!is_nhwc || _is_quantized)
+ {
+ // Calculate GEMM output shape
+ TensorShape shape_gemm = _im2col_output.info()->tensor_shape();
+ shape_gemm.set(0, mat_weights_cols);
+ shape_gemm.set(1, conv_w * conv_h);
- // Configure im2col
- _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation);
+ // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
+ const DataType gemm_data_type = _is_quantized ? DataType::S32 : data_type;
+ TensorInfo info_gemm(shape_gemm, 1, gemm_data_type);
+ info_gemm.set_quantization_info(output->info()->quantization_info());
+ _gemm_output.allocator()->init(info_gemm);
+ _memory_group.manage(&_gemm_output);
- // Configure GEMM
- configure_mm(&_im2col_output, weights, &_gemm_output);
+ // Update GEMM output
+ gemm_output_to_use = &_gemm_output;
+ }
- _im2col_output.allocator()->allocate();
+ // Configure and tune GEMM
+ configure_mm(gemm_input_to_use, &_weights_reshaped, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1);
+
+ if(!_skip_im2col)
+ {
+ _im2col_output.allocator()->allocate();
+ }
// Configure output stage for quantized case
if(_is_quantized)
@@ -242,19 +272,36 @@
float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
int output_multiplier, output_shift;
quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+
_memory_group.manage(&_tmp_output);
- _gemmlowp_output_stage.configure(&_gemm_output, biases, &_tmp_output, output_multiplier, output_shift, output_quant_info.offset);
+ gemm_output_staged_to_use = &_tmp_output;
+
+ _gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset);
}
- // Configure Col2Im
- _col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, std::make_pair(conv_w, conv_h));
- if(_is_quantized)
+ if(!is_nhwc || _is_quantized)
+ {
+ if(input->info()->data_layout() == DataLayout::NCHW)
+ {
+ // Configure and tune Col2Im
+ _col2im_kernel.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output, std::make_pair(conv_w, conv_h), num_groups);
+ CLScheduler::get().tune_kernel_static(_col2im_kernel);
+ }
+ else
+ {
+ // Configure reshape layer
+ _reshape_layer.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output);
+ }
+ }
+
+ if(!is_nhwc || _is_quantized)
{
_tmp_output.allocator()->allocate();
+ _gemm_output.allocator()->allocate();
}
- _gemm_output.allocator()->allocate();
- ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(0) != conv_w) || (output->info()->dimension(1) != conv_h), "Output shape does not match the expected one");
+ ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h),
+ "Output shape does not match the expected one");
//Configure Activation Layer
_is_activationlayer_enabled = act_info.enabled();
@@ -268,76 +315,42 @@
}
Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
- const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info)
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, weights);
- ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(2) != input->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((num_groups != 1) && (input->data_layout() != DataLayout::NCHW), "Grouping (num_groups != 1) with NHWC data layout is not supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((num_groups != 1) && (input->data_type() == DataType::QASYMM8), "Grouping (num_groups != 1) is not supported with QASYMM8");
+ ARM_COMPUTE_RETURN_ERROR_ON(((input->dimension(2) / weights->dimension(2)) != num_groups) && (input->data_layout() == DataLayout::NCHW));
+
+ const DataLayout data_layout = input->data_layout();
+ const DataType data_type = input->data_type();
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+ const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
+
+ const unsigned int kernel_width = weights->dimension(idx_width);
+ const unsigned int kernel_height = weights->dimension(idx_height);
+
+ TensorInfo im2col_reshaped_info, info_gemm, tmp_info, weights_reshaped_info;
+ const ITensorInfo *gemm_input_to_use = input;
+ const ITensorInfo *gemm_output_to_use = output;
+ const ITensorInfo *gemm_output_staged_to_use = output;
+ const ITensorInfo *weights_to_use = weights;
+
+ const bool is_nhwc = data_layout == DataLayout::NHWC;
+ const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
+ const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !is_quantized;
+ const bool append_bias = (biases != nullptr) && (!is_quantized);
+
+ ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(idx_channel) * num_groups) != input->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
- if(act_info.enabled())
- {
- ARM_COMPUTE_ERROR_ON(act_info.b() > act_info.a());
- }
-
- const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
- const bool append_bias = (biases != nullptr) && (!is_quantized);
- const unsigned bias_element = (append_bias) ? 1 : 0;
- const DataType dt = input->data_type();
-
- // Get convolved dimensions
- unsigned int conv_w = 0;
- unsigned int conv_h = 0;
-
- const unsigned int kernel_width = weights->dimension(0);
- const unsigned int kernel_height = weights->dimension(1);
-
- std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_width, kernel_height, conv_info, dilation);
-
- unsigned int mat_weights_cols = weights->dimension(3);
- unsigned int mat_weights_rows = weights->dimension(0) * weights->dimension(1) * weights->dimension(2) + bias_element;
-
- ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, is_quantized ? nullptr : biases, nullptr));
-
- // Create tensor info for im2col reshaped inputs
- const unsigned int mat_input_cols = mat_weights_rows;
- const unsigned int mat_input_rows = conv_w * conv_h;
- TensorShape shape_im2col = input->tensor_shape();
- shape_im2col.set(0, mat_input_cols);
- shape_im2col.set(1, mat_input_rows);
- shape_im2col.set(2, 1);
- TensorInfo im2col_reshaped_info(shape_im2col, 1, dt, input->fixed_point_position());
- im2col_reshaped_info.set_quantization_info(input->quantization_info());
- ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
-
- // Create GEMM output tensor
- TensorShape shape_gemm = im2col_reshaped_info.tensor_shape();
- shape_gemm.set(0, mat_weights_cols);
- shape_gemm.set(1, mat_input_rows);
- const DataType gemm_data_type = is_quantized ? DataType::S32 : dt;
- // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
- TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->fixed_point_position());
- info_gemm.set_quantization_info(output->quantization_info());
-
- ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(&im2col_reshaped_info, weights, &info_gemm));
- TensorInfo tmp_info(shape_gemm, 1, DataType::QASYMM8, input->fixed_point_position());
- tmp_info.set_quantization_info(output->quantization_info());
-
- if(is_quantized)
- {
- float multiplier = input->quantization_info().scale * weights->quantization_info().scale / output->quantization_info().scale;
- int output_multiplier, output_shift;
- quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- // Validate output stage for quantized case
- CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(&info_gemm, biases, &tmp_info, output->quantization_info().offset);
- }
-
- // Validate Col2Im
- ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? &tmp_info : &info_gemm, output, std::make_pair(conv_w, conv_h)));
-
+ // Validate biases
if(biases != nullptr)
{
if(is_quantized)
@@ -348,11 +361,91 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
}
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases);
- ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3));
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels));
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
+ if(act_info.enabled())
+ {
+ ARM_COMPUTE_ERROR_ON(act_info.b() > act_info.a());
+ }
+
+ // Get convolved dimensions
+ unsigned int conv_w = 0;
+ unsigned int conv_h = 0;
+
+ std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(idx_width),
+ input->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
+
+ unsigned int mat_weights_cols = weights->dimension(idx_kernels) / num_groups;
+
+ // Output tensor auto inizialitation if not yet initialized
+ ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, is_quantized ? nullptr : biases, nullptr, num_groups));
+ weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, (append_bias && !skip_im2col), num_groups), 1, data_type);
+ weights_to_use = &weights_reshaped_info;
+
+ if(!skip_im2col)
+ {
+ const Size2D kernel_dims(kernel_width, kernel_height);
+
+ // Output tensor auto initialization if not yet initialized
+ TensorShape expected_output_shape = compute_im2col_conv_shape(input, kernel_dims, conv_info, append_bias, dilation, num_groups == 1, num_groups);
+
+ auto_init_if_empty(im2col_reshaped_info, input->clone()->set_tensor_shape(expected_output_shape));
+
+ ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, kernel_dims, conv_info, append_bias, dilation, num_groups));
+ gemm_input_to_use = &im2col_reshaped_info;
+ }
+ else if(append_bias)
+ {
+ // Validate add bias kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE));
+ }
+
+ // Create GEMM output tensor
+ if(!is_nhwc || is_quantized)
+ {
+ TensorShape shape_gemm = gemm_input_to_use->tensor_shape();
+ shape_gemm.set(0, mat_weights_cols);
+ shape_gemm.set(1, conv_w * conv_h);
+ const DataType gemm_data_type = is_quantized ? DataType::S32 : data_type;
+ // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
+ info_gemm = TensorInfo(shape_gemm, 1, gemm_data_type);
+ info_gemm.set_quantization_info(output->quantization_info());
+ gemm_output_to_use = &info_gemm;
+ }
+
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1, skip_im2col));
+
+ if(is_quantized)
+ {
+ float multiplier = input->quantization_info().scale * weights_to_use->quantization_info().scale / output->quantization_info().scale;
+ int output_multiplier, output_shift;
+ quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+
+ tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
+ tmp_info.set_quantization_info(output->quantization_info());
+ gemm_output_staged_to_use = &tmp_info;
+
+ // Validate output stage for quantized case
+ CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, output->quantization_info().offset);
+ }
+
+ // Validate Col2Im
+ if(!is_nhwc || is_quantized)
+ {
+ if(input->data_layout() == DataLayout::NCHW)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? gemm_output_staged_to_use : gemm_output_to_use,
+ output,
+ std::make_pair(conv_w, conv_h), num_groups));
+ }
+ }
+
//Validate Activation Layer
if(act_info.enabled())
{
@@ -369,7 +462,10 @@
_memory_group.acquire();
// Run im2col
- CLScheduler::get().enqueue(_im2col_kernel);
+ if(!_skip_im2col)
+ {
+ CLScheduler::get().enqueue(_im2col_kernel);
+ }
// Runs CLGEMM or CLGEMMLowpMatrixMultiplyCore functions
if(_is_quantized)
@@ -386,8 +482,23 @@
_mm_gemm.run();
}
+ if(_skip_im2col && _append_bias)
+ {
+ CLScheduler::get().enqueue(_add_bias_kernel);
+ }
+
// Reshape output matrix
- CLScheduler::get().enqueue(_col2im_kernel, false);
+ if(_data_layout == DataLayout::NCHW || _is_quantized)
+ {
+ if(_data_layout == DataLayout::NCHW)
+ {
+ CLScheduler::get().enqueue(_col2im_kernel, false);
+ }
+ else
+ {
+ _reshape_layer.run();
+ }
+ }
//Run Activation Layer if enabled
if(_is_activationlayer_enabled)
@@ -402,20 +513,18 @@
{
if(!_is_prepared)
{
- // Run weights reshaping and mark as unused
ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ // Run weights reshaping and mark original weights tensor as unused
_weights_reshaped.allocator()->allocate();
_reshape_weights.run();
_original_weights->mark_as_unused();
- // Run GEMM prepare
- if(!_is_quantized)
+ // Prepare GEMM
+ _is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare();
+ if(!_weights_reshaped.is_used())
{
- _mm_gemm.prepare();
- if(!_weights_reshaped.is_used())
- {
- _weights_reshaped.allocator()->free();
- }
+ _weights_reshaped.allocator()->free();
}
CLScheduler::get().queue().finish();
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 711b006..0ce07c3 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -41,9 +41,8 @@
{
bool flag = true;
- if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
+ if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::G76))
{
- // COMPMID-852
if(k > 256 && m > 4 && reshape_b_only_on_first_run)
{
flag = ((0.72f + n * 0.10766f) < (n * 0.1284f));
@@ -59,8 +58,23 @@
} // namespace
CLGEMMLowpMatrixMultiplyCore::CLGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _mm_kernel(), _mtx_a_reshape_kernel(), _mtx_b_reshape_kernel(), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), _offset_contribution_kernel(),
- _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _a_offset(0), _b_offset(0), _is_interleaved_transposed(true), _is_first_run(true), _reshape_b_only_on_first_run(false)
+ : _memory_group(std::move(memory_manager)),
+ _mm_kernel(),
+ _mtx_a_reshape_kernel(),
+ _mtx_b_reshape_kernel(),
+ _mtx_a_reduction_kernel(),
+ _mtx_b_reduction_kernel(),
+ _offset_contribution_kernel(),
+ _vector_sum_col(),
+ _vector_sum_row(),
+ _tmp_a(),
+ _tmp_b(),
+ _original_b(nullptr),
+ _a_offset(0),
+ _b_offset(0),
+ _is_interleaved_transposed(true),
+ _reshape_b_only_on_first_run(false),
+ _is_prepared(false)
{
}
@@ -70,6 +84,8 @@
ARM_COMPUTE_UNUSED(gemm_info);
ARM_COMPUTE_ERROR_THROW_ON(CLGEMMLowpMatrixMultiplyCore::validate(a->info(), b->info(), output->info(), gemm_info));
+ _is_prepared = false;
+ _original_b = b;
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
_a_offset = a->info()->quantization_info().offset;
_b_offset = b->info()->quantization_info().offset;
@@ -149,10 +165,13 @@
if(_is_interleaved_transposed)
{
_tmp_a.allocator()->allocate();
- _tmp_b.allocator()->allocate();
+ if(!_reshape_b_only_on_first_run)
+ {
+ _tmp_b.allocator()->allocate();
+ }
}
- if(_a_offset != 0)
+ if(_a_offset != 0 && !_reshape_b_only_on_first_run)
{
_vector_sum_col.allocator()->allocate();
}
@@ -185,16 +204,17 @@
const int k = a->dimension(0);
constexpr int mult_transpose1xW_width = 1;
constexpr int mult_interleave4x4_height = 1;
- const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height);
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
+ const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d);
bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
if(reshape_matrices)
{
- TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height), 1, a->data_type());
+ TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()), 1, a->data_type());
TensorInfo info_b(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width), 1, b->data_type());
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMTranspose1xWKernel::validate(b, &info_b, mult_transpose1xW_width));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output, reshape_matrices, reshape_info));
}
@@ -234,6 +254,8 @@
void CLGEMMLowpMatrixMultiplyCore::run()
{
+ prepare();
+
_memory_group.acquire();
if(_is_interleaved_transposed)
@@ -241,21 +263,17 @@
// Run reshape matrix A
CLScheduler::get().enqueue(_mtx_a_reshape_kernel, false);
- if(_is_first_run || !_reshape_b_only_on_first_run)
+ if(!_reshape_b_only_on_first_run)
{
// Run reshape matrix B
CLScheduler::get().enqueue(_mtx_b_reshape_kernel, false);
}
}
- // Note: if _reshape_b_only_on_first_run = true, the reduction kernel can be executed only once
- if(_is_first_run || !_reshape_b_only_on_first_run)
+ // Run matrix B reduction kernel only if _a_offset is not equal to 0
+ if(_a_offset != 0 && !_reshape_b_only_on_first_run)
{
- // Run matrix B reduction kernel only if _a_offset is not equal to 0
- if(_a_offset != 0)
- {
- CLScheduler::get().enqueue(_mtx_b_reduction_kernel, false);
- }
+ CLScheduler::get().enqueue(_mtx_b_reduction_kernel, false);
}
// Run matrix multiply
@@ -271,6 +289,30 @@
CLScheduler::get().enqueue(_offset_contribution_kernel, true);
_memory_group.release();
+}
- _is_first_run = false;
+void CLGEMMLowpMatrixMultiplyCore::prepare()
+{
+ if(!_is_prepared)
+ {
+ if(_is_interleaved_transposed && _reshape_b_only_on_first_run)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+
+ // Run reshape kernel and mark original weights tensor as unused
+ _tmp_b.allocator()->allocate();
+ CLScheduler::get().enqueue(_mtx_b_reshape_kernel, false);
+ _original_b->mark_as_unused();
+ }
+
+ // Run matrix B reduction kernel only if _a_offset is not equal to 0
+ if(_a_offset != 0 && _reshape_b_only_on_first_run)
+ {
+ _vector_sum_col.allocator()->allocate();
+ CLScheduler::get().enqueue(_mtx_b_reduction_kernel, false);
+ }
+
+ CLScheduler::get().queue().finish();
+ _is_prepared = true;
+ }
}
diff --git a/src/runtime/CL/functions/CLGaussianPyramid.cpp b/src/runtime/CL/functions/CLGaussianPyramid.cpp
index ddce5fb..fd82769 100644
--- a/src/runtime/CL/functions/CLGaussianPyramid.cpp
+++ b/src/runtime/CL/functions/CLGaussianPyramid.cpp
@@ -166,7 +166,7 @@
_gauss5x5[i].configure(_pyramid->get_pyramid_level(i), _tmp.get_pyramid_level(i), border_mode, constant_border_value);
/* Configure scale image kernel */
- _scale_nearest[i].configure(_tmp.get_pyramid_level(i), _pyramid->get_pyramid_level(i + 1), InterpolationPolicy::NEAREST_NEIGHBOR, border_mode == BorderMode::UNDEFINED, SamplingPolicy::CENTER);
+ _scale_nearest[i].configure(_tmp.get_pyramid_level(i), _pyramid->get_pyramid_level(i + 1), InterpolationPolicy::NEAREST_NEIGHBOR, border_mode, SamplingPolicy::CENTER);
}
_tmp.allocate();
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index 930d311..3458135 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -38,85 +38,91 @@
using namespace arm_compute::misc::shape_calculator;
CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _gemm_input_gate1(), _gemm_input_gate2(), _transpose_input_gate1(), _transpose_input_gate2(), _accum_input_gate1(),
- _accum_input_gate2(), _subtract_input_gate(), _activation_input_gate(), _fully_connected_forget_gate(), _gemm_forget_gate1(), _gemm_forget_gate2(), _transpose_forget_gate1(),
- _transpose_forget_gate2(), _accum_forget_gate1(), _accum_forget_gate2(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(), _gemm_cell_state2(), _transpose_cell_state1(),
- _accum_cell_state1(), _accum_cell_state2(), _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(), _pixelwise_mul_cell_state2(), _fully_connected_output(), _gemm_output1(),
- _gemm_output2(), _transpose_output1(), _transpose_output2(), _accum_output1(), _accum_output2(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state(),
- _fully_connected_output_state(), _gemm_output_state(), _accum_output_state(), _projection_clip(), _copy_cell_state(), _copy_output(), _concat_scratch_buffer(), _input_gate_out1(), _input_gate_out2(),
- _input_gate_out3(), _input_gate_out4(), _input_gate_out5(), _input_gate_out6(), _forget_gate_out1(), _forget_gate_out2(), _forget_gate_out3(), _forget_gate_out4(), _forget_gate_out5(),
- _forget_gate_out6(), _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(), _cell_state_out5(), _output1(), _output2(), _output3(), _output4(), _output5(), _output6(),
- _cell_state_activation(), _output_projection1(), _ones(), _run_peephole_opt(false), _run_cifg_opt(false), _perform_cell_clipping(false), _has_projection_weights(false),
- _perform_projection_clipping(false)
+ : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _gemm_input_gate(), _transpose_input_gate(), _accum_input_gate1(), _accum_input_gate2(), _subtract_input_gate(),
+ _pixelwise_mul_input_gate(), _activation_input_gate(), _fully_connected_forget_gate(), _gemm_forget_gate(), _transpose_forget_gate(), _accum_forget_gate1(), _accum_forget_gate2(),
+ _pixelwise_mul_forget_gate(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(), _gemm_cell_state2(), _transpose_cell_state(), _accum_cell_state1(), _accum_cell_state2(),
+ _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(), _pixelwise_mul_cell_state2(), _fully_connected_output(), _gemm_output(), _pixelwise_mul_output_state1(), _transpose_output(),
+ _accum_output1(), _accum_output2(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state2(), _fully_connected_output_state(), _gemm_output_state(), _accum_output_state(),
+ _projection_clip(), _copy_cell_state(), _copy_output(), _concat_scratch_buffer(), _input_gate_out1(), _input_gate_out2(), _input_gate_out3(), _input_gate_out4(), _input_gate_out5(),
+ _forget_gate_out1(), _forget_gate_out2(), _forget_gate_out3(), _forget_gate_out4(), _forget_gate_out5(), _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(),
+ _cell_state_out5(), _output1(), _output2(), _output3(), _output4(), _output5(), _cell_state_activation(), _output_state1(), _ones(), _run_peephole_opt(false), _run_cifg_opt(false),
+ _perform_cell_clipping(false), _has_projection_weights(false), _perform_projection_clipping(false)
{
}
-void CLLSTMLayer::configure(const ICLTensor *input, const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
+void CLLSTMLayer::configure(const ICLTensor *input,
+ const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
- ICLTensor *output_state, ICLTensor *cell_state, ICLTensor *scratch_buffer, ICLTensor *output, const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info,
- float cell_threshold, float projection_threshold)
+ const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
+ ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
+ const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
- forget_gate_bias, cell_bias, output_gate_bias, output_state, cell_state);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input,
+ input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
+ recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ output_state_in, cell_state_in,
+ scratch_buffer, output_state_out, cell_state_out, output);
+
+ // Set lstm parameters
LSTMParams<ITensorInfo> lstm_params_info;
if(lstm_params.has_peephole_opt())
{
- lstm_params_info.set_peephole_params(lstm_params.cell_to_input_weights()->info(), lstm_params.cell_to_forget_weights()->info(), lstm_params.cell_to_output_weights()->info());
+ lstm_params_info.set_peephole_params(lstm_params.cell_to_forget_weights()->info(), lstm_params.cell_to_output_weights()->info());
}
if(lstm_params.has_projection())
{
- lstm_params_info.set_projection_params(lstm_params.projection_weights()->info(), lstm_params.projection_bias()->info());
+ lstm_params_info.set_projection_params(lstm_params.projection_weights()->info(),
+ lstm_params.projection_bias() != nullptr ? lstm_params.projection_bias()->info() : nullptr);
}
if(!lstm_params.has_cifg_opt())
{
+ const ITensorInfo *cell_to_input_weights_info = (lstm_params.has_peephole_opt()) ? lstm_params.cell_to_input_weights()->info() : nullptr;
lstm_params_info.set_cifg_params(lstm_params.input_to_input_weights()->info(), lstm_params.recurrent_to_input_weights()->info(),
- lstm_params.cell_to_input_weights()->info(), lstm_params.input_gate_bias()->info());
+ cell_to_input_weights_info, lstm_params.input_gate_bias()->info());
}
+
+ // Validate
ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(),
input_to_cell_weights->info(), input_to_output_weights->info(),
recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
- output_state->info(), cell_state->info(), scratch_buffer->info(), output->info(), lstm_params_info,
- activation_info, cell_threshold, projection_threshold));
+ output_state_in->info(), cell_state_in->info(),
+ scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
+ lstm_params_info, activation_info, cell_threshold, projection_threshold));
- const TensorShape cell_state_shape = cell_state->info()->tensor_shape();
+ const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
+ // Configure block that calculates the forget gate
+ // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
TensorShape forget_gate1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
- TensorShape forget_gate2_shape = compute_transposed_shape(*forget_gate_bias->info());
- TensorShape forget_gate3_shape{ 1, output_state->info()->dimension(1) };
_forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
_forget_gate_out2.allocator()->init(TensorInfo(forget_gate1_shape, 1, input->info()->data_type()));
_forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
- _forget_gate_out6.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
+ _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
- // Configure block that calculates the forget gate
- // forget_gate = Activation(input * input_to_forget_weights + output_state * recurrent_to_forget_weights + cell_state * cell_to_forget_weights + forget_gate_bias)
_memory_group.manage(&_forget_gate_out1);
- _fully_connected_forget_gate.configure(input, input_to_forget_weights, forget_gate_bias, &_forget_gate_out1, true, false);
+ _fully_connected_forget_gate.configure(input, input_to_forget_weights, forget_gate_bias, &_forget_gate_out1);
_memory_group.manage(&_forget_gate_out2);
- _transpose_forget_gate1.configure(recurrent_to_forget_weights, &_forget_gate_out2);
+ _transpose_forget_gate.configure(recurrent_to_forget_weights, &_forget_gate_out2);
_memory_group.manage(&_forget_gate_out3);
- _gemm_forget_gate1.configure(output_state, &_forget_gate_out2, nullptr, &_forget_gate_out3, 1.f, 0.f);
+ _gemm_forget_gate.configure(output_state_in, &_forget_gate_out2, nullptr, &_forget_gate_out3, 1.f, 0.f);
_forget_gate_out2.allocator()->allocate();
- _memory_group.manage(&_forget_gate_out6);
- _accum_forget_gate1.configure(&_forget_gate_out1, &_forget_gate_out3, &_forget_gate_out6, ConvertPolicy::SATURATE);
- CLTensor *forget_gate_out = &_forget_gate_out6;
+ _memory_group.manage(&_forget_gate_out5);
+ _accum_forget_gate1.configure(&_forget_gate_out1, &_forget_gate_out3, &_forget_gate_out5, ConvertPolicy::SATURATE);
+ CLTensor *forget_gate_out = &_forget_gate_out5;
if(lstm_params.has_peephole_opt())
{
- _forget_gate_out4.allocator()->init(TensorInfo(forget_gate2_shape, 1, input->info()->data_type()));
- _forget_gate_out5.allocator()->init(TensorInfo(forget_gate3_shape, 1, input->info()->data_type()));
+ _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
_run_peephole_opt = true;
_memory_group.manage(&_forget_gate_out4);
- _transpose_forget_gate2.configure(lstm_params.cell_to_forget_weights(), &_forget_gate_out4);
- _memory_group.manage(&_forget_gate_out5);
- _gemm_forget_gate2.configure(cell_state, &_forget_gate_out4, nullptr, &_forget_gate_out5, 1.f, 0.f);
+ _pixelwise_mul_forget_gate.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ _accum_forget_gate2.configure(&_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3, ConvertPolicy::SATURATE);
_forget_gate_out4.allocator()->allocate();
- _accum_forget_gate2.configure(&_forget_gate_out6, &_forget_gate_out5, &_forget_gate_out3, ConvertPolicy::SATURATE);
_forget_gate_out5.allocator()->allocate();
- _forget_gate_out6.allocator()->allocate();
forget_gate_out = &_forget_gate_out3;
}
else
@@ -126,13 +132,10 @@
_activation_forget_gate.configure(forget_gate_out, &_forget_gate_out1, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
forget_gate_out->allocator()->allocate();
- TensorShape input_gate3_shape{ 1, output_state->info()->dimension(1) };
- _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
- _input_gate_out5.allocator()->init(TensorInfo(input_gate3_shape, 1, input->info()->data_type()));
-
// Configure block that calculates the input gate
- // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + cell_state * cell_to_input_weights + input_gate_bias), without CIFG
+ // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
// input_gate = 1 - forget_gate, with CIFG
+ _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
if(lstm_params.has_cifg_opt())
{
_memory_group.manage(&_input_gate_out1);
@@ -143,35 +146,36 @@
}
else
{
- TensorShape input_gate1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
- TensorShape input_gate2_shape = compute_transposed_shape(*lstm_params.cell_to_input_weights()->info());
+ TensorShape input_gate_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
- _input_gate_out2.allocator()->init(TensorInfo(input_gate1_shape, 1, input->info()->data_type()));
+ _input_gate_out2.allocator()->init(TensorInfo(input_gate_shape, 1, input->info()->data_type()));
_input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
- _input_gate_out4.allocator()->init(TensorInfo(input_gate2_shape, 1, input->info()->data_type()));
- _input_gate_out6.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
+ _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
+ _input_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
_memory_group.manage(&_input_gate_out1);
- _fully_connected_input_gate.configure(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), &_input_gate_out1, true, false);
+ _fully_connected_input_gate.configure(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), &_input_gate_out1);
_memory_group.manage(&_input_gate_out2);
- _transpose_input_gate1.configure(lstm_params.recurrent_to_input_weights(), &_input_gate_out2);
+ _transpose_input_gate.configure(lstm_params.recurrent_to_input_weights(), &_input_gate_out2);
_memory_group.manage(&_input_gate_out3);
- _gemm_input_gate1.configure(output_state, &_input_gate_out2, nullptr, &_input_gate_out3, 1.f, 0.f);
+ _gemm_input_gate.configure(output_state_in, &_input_gate_out2, nullptr, &_input_gate_out3, 1.f, 0.f);
_input_gate_out2.allocator()->allocate();
_memory_group.manage(&_input_gate_out4);
- _transpose_input_gate2.configure(lstm_params.cell_to_input_weights(), &_input_gate_out4);
- _memory_group.manage(&_input_gate_out5);
- _gemm_input_gate2.configure(cell_state, &_input_gate_out4, nullptr, &_input_gate_out5, 1.f, 0.f);
- _input_gate_out4.allocator()->allocate();
- _memory_group.manage(&_input_gate_out6);
- _accum_input_gate1.configure(&_input_gate_out1, &_input_gate_out3, &_input_gate_out6, ConvertPolicy::SATURATE);
+ _accum_input_gate1.configure(&_input_gate_out1, &_input_gate_out3, &_input_gate_out4, ConvertPolicy::SATURATE);
+ if(_run_peephole_opt)
+ {
+ _memory_group.manage(&_input_gate_out5);
+ _pixelwise_mul_input_gate.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ _accum_input_gate2.configure(&_input_gate_out4, &_input_gate_out5, &_input_gate_out1, ConvertPolicy::SATURATE);
+ _input_gate_out5.allocator()->allocate();
+ }
_input_gate_out3.allocator()->allocate();
- _accum_input_gate2.configure(&_input_gate_out6, &_input_gate_out5, &_input_gate_out1, ConvertPolicy::SATURATE);
- _input_gate_out5.allocator()->allocate();
- _input_gate_out6.allocator()->allocate();
+ _input_gate_out4.allocator()->allocate();
_activation_input_gate.configure(&_input_gate_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
}
+ // Configure block that calculates the cell state
+ // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
_cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
_cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
@@ -179,14 +183,12 @@
_cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
_cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
- // Configure block that calculates the cell state
- // cell_state = Clip((RixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
_memory_group.manage(&_cell_state_out1);
- _fully_connected_cell_state.configure(input, input_to_cell_weights, cell_bias, &_cell_state_out1, true, false);
+ _fully_connected_cell_state.configure(input, input_to_cell_weights, cell_bias, &_cell_state_out1);
_memory_group.manage(&_cell_state_out2);
- _transpose_cell_state1.configure(recurrent_to_cell_weights, &_cell_state_out2);
+ _transpose_cell_state.configure(recurrent_to_cell_weights, &_cell_state_out2);
_memory_group.manage(&_cell_state_out3);
- _gemm_cell_state1.configure(output_state, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
+ _gemm_cell_state1.configure(output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
_cell_state_out2.allocator()->allocate();
_memory_group.manage(&_cell_state_out4);
_accum_cell_state1.configure(&_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
@@ -195,12 +197,11 @@
_pixelwise_mul_cell_state1.configure(&_cell_state_out4, &_input_gate_out1, &_cell_state_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
_input_gate_out1.allocator()->allocate();
_cell_state_out4.allocator()->allocate();
- _pixelwise_mul_cell_state2.configure(&_forget_gate_out1, cell_state, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ _pixelwise_mul_cell_state2.configure(&_forget_gate_out1, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
_forget_gate_out1.allocator()->allocate();
_accum_cell_state2.configure(&_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
_cell_state_out3.allocator()->allocate();
_cell_state_out5.allocator()->allocate();
-
// Perform clipping
if(cell_threshold != 0.f)
{
@@ -208,53 +209,45 @@
_cell_clip.configure(&_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
}
+ // Configure block that calculates the output
+ // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
TensorShape output1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
- TensorShape output2_shape = compute_transposed_shape(*cell_bias->info());
- TensorShape output3_shape{ 1, output_state->info()->dimension(1) };
_output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
_output2.allocator()->init(TensorInfo(output1_shape, 1, input->info()->data_type()));
_output3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
- _output6.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
+ _output5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
- // Configure block that calculates the output
- // output_gate = Activation(input * input_to_output_weights + output_state * recurrent_to_output_weights + cell_state * cell_to_output_weights + output_gate_bias)
_memory_group.manage(&_output1);
- _fully_connected_output.configure(input, input_to_output_weights, output_gate_bias, &_output1, true, false);
+ _fully_connected_output.configure(input, input_to_output_weights, output_gate_bias, &_output1);
_memory_group.manage(&_output2);
- _transpose_output1.configure(recurrent_to_output_weights, &_output2);
+ _transpose_output.configure(recurrent_to_output_weights, &_output2);
_memory_group.manage(&_output3);
- _gemm_output1.configure(output_state, &_output2, nullptr, &_output3, 1.f, 0.f);
+ _gemm_output.configure(output_state_in, &_output2, nullptr, &_output3, 1.f, 0.f);
_output2.allocator()->allocate();
- _memory_group.manage(&_output6);
- _accum_output1.configure(&_output1, &_output3, &_output6, ConvertPolicy::SATURATE);
+ _memory_group.manage(&_output5);
+ _accum_output1.configure(&_output1, &_output3, &_output5, ConvertPolicy::SATURATE);
_output3.allocator()->allocate();
- CLTensor *output_gate_out = &_output6;
+ CLTensor *output_gate_out = &_output5;
if(lstm_params.has_peephole_opt())
{
- _output4.allocator()->init(TensorInfo(output2_shape, 1, input->info()->data_type()));
- _output5.allocator()->init(TensorInfo(output3_shape, 1, input->info()->data_type()));
+ _output4.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
_memory_group.manage(&_output4);
- _transpose_output2.configure(lstm_params.cell_to_output_weights(), &_output4);
- _memory_group.manage(&_output5);
- _gemm_output2.configure(&_cell_state_out1, &_output4, nullptr, &_output5, 1.f, 0.f);
- _accum_output2.configure(&_output6, &_output5, &_output1, ConvertPolicy::SATURATE);
- _output6.allocator()->allocate();
+ _pixelwise_mul_output_state1.configure(&_cell_state_out1, lstm_params.cell_to_output_weights(), &_output4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ _accum_output2.configure(&_output5, &_output4, &_output1, ConvertPolicy::SATURATE);
+ _output5.allocator()->allocate();
output_gate_out = &_output1;
// Allocate intermediate buffers
_output4.allocator()->allocate();
- _output5.allocator()->allocate();
}
else
{
_output1.allocator()->allocate();
}
- _activation_output.configure(output_gate_out, output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ _activation_output.configure(output_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
output_gate_out->allocator()->allocate();
- _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
-
// Configure block that calculates the output state
/** lstm_res = PixelwiseMul(output, Activation(cell_state))
*
@@ -264,32 +257,32 @@
* \
* -- lstm_res , otherwise
*/
+ ICLTensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
+ _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
+ _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
+
_memory_group.manage(&_cell_state_activation);
_activation_output_state.configure(&_cell_state_out1, &_cell_state_activation, activation_info);
- _pixelwise_mul_output_state.configure(&_cell_state_activation, output, output_state, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ _pixelwise_mul_output_state2.configure(&_cell_state_activation, output_gate_out, output_state_out_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
_cell_state_activation.allocator()->allocate();
if(lstm_params.has_projection())
{
_has_projection_weights = true;
- _output_projection1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
- _memory_group.manage(&_output_projection1);
- _fully_connected_output_state.configure(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), &_output_projection1, true, false);
+ _fully_connected_output_state.configure(output_state_out_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out);
+ _output_state1.allocator()->allocate();
// Perform clipping
if(projection_threshold != 0.f)
{
_perform_projection_clipping = true;
- _projection_clip.configure(&_output_projection1, output_state, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
+ _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
}
-
- // Allocate intermediate buffer
- _output_projection1.allocator()->allocate();
}
// Copy cell state and output
- _copy_cell_state.configure(&_cell_state_out1, cell_state);
+ _copy_cell_state.configure(&_cell_state_out1, cell_state_out);
_cell_state_out1.allocator()->allocate();
- _copy_output.configure(output_state, output);
+ _copy_output.configure(output_state_out, output);
// Vector for holding the tensors to store in scratch buffer
std::vector<ICLTensor *> scratch_inputs;
@@ -303,121 +296,161 @@
_concat_scratch_buffer.configure(scratch_inputs, scratch_buffer);
}
-Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
+Status CLLSTMLayer::validate(const ITensorInfo *input,
+ const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
- const ITensorInfo *output_state, const ITensorInfo *cell_state, const ITensorInfo *scratch_buffer, const ITensorInfo *output,
+ const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
+ const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
- forget_gate_bias, cell_bias, output_gate_bias, output_state, cell_state);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
- recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state, cell_state);
- ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0) && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input,
+ input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
+ recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ output_state_in, cell_state_in,
+ scratch_buffer, output_state_out, cell_state_out, output);
+ // Check data types
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input,
+ input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
+ recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ output_state_in, cell_state_in,
+ scratch_buffer, output_state_out, cell_state_out, output);
+
+ // Check dimensions
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0)
+ && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
+
+ const unsigned int num_batches = input->dimension(1);
+ const unsigned int num_cells = input_to_output_weights->dimension(1);
+
+ // Check peephole optimization
if(lstm_params.has_peephole_opt())
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights(), lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
}
TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
- TensorShape gemmv_shape{ 1, output_state->dimension(1) };
TensorShape num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
const TensorInfo units_out_transposed_info = TensorInfo(units_out_transposed_shape, 1, input->data_type());
- const TensorInfo gemmv_shape_info = TensorInfo(gemmv_shape, 1, input->data_type());
const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type());
+ TensorInfo input_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
+ TensorInfo forget_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
+ TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
+ TensorInfo cell_state_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
+
// Validate forget gate
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, forget_gate_bias, cell_state, true, false));
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state, &units_out_transposed_info, nullptr, cell_state, 1.f, 0.f, GEMMInfo()));
- ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(cell_state, cell_state, cell_state, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, forget_gate_bias, &forget_gate));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &forget_gate, 1.f, 0.f, GEMMInfo()));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(cell_state, &num_units_transposed_info, nullptr, &gemmv_shape_info, 1.f, 0.f, GEMMInfo()));
- ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, &gemmv_shape_info, cell_state, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
}
- ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, cell_state, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
// Validate input gate
if(!lstm_params.has_cifg_opt())
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.cell_to_input_weights(), lstm_params.input_gate_bias());
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state, true, false));
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(cell_state, &num_units_transposed_info, nullptr, &gemmv_shape_info, 1.f, 0.f, GEMMInfo()));
- ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, &gemmv_shape_info, cell_state, ConvertPolicy::SATURATE));
- ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
+ lstm_params.recurrent_to_input_weights(),
+ lstm_params.input_gate_bias());
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
+
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), &input_gate));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &input_gate, 1.f, 0.f, GEMMInfo()));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
+ if(lstm_params.has_peephole_opt())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
+ }
+ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
}
else
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtractionKernel::validate(cell_state, cell_state, cell_state, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtractionKernel::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
}
// Validate cell state
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, cell_bias, cell_state, true, false));
- ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, nullptr, activation_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state, cell_state, cell_state, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
-
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, cell_bias, &cell_state_tmp));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, activation_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
if(cell_threshold != 0.f)
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold)));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
+ cell_threshold)));
}
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, output_gate_bias, cell_state, true, false));
+ // Validate output gate tmp
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, output_gate_bias, &output_gate_tmp));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &output_gate_tmp, 1.f, 0.f, GEMMInfo()));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
if(lstm_params.has_peephole_opt())
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, cell_state, cell_state, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
+ RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
}
- ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
// Validate output state
- ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, cell_state, activation_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state, output, output_state, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
if(lstm_params.has_projection())
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(output_state, lstm_params.projection_weights(), lstm_params.projection_bias(), cell_state, true, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
if(projection_threshold != 0.f)
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(cell_state, output_state, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold,
- projection_threshold)));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(output_state_out, output_state_out,
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
}
}
- std::vector<TensorInfo> inputs_vector_info;
+ // Validate copy kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(&cell_state_tmp, cell_state_out));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
+
+ // Validate scratch concatenation
+ std::vector<ITensorInfo *> inputs_vector_info_raw;
if(lstm_params.has_cifg_opt())
{
- inputs_vector_info.emplace_back(*cell_state);
+ inputs_vector_info_raw.push_back(&input_gate);
}
- inputs_vector_info.emplace_back(*cell_state);
- inputs_vector_info.emplace_back(*cell_state);
- inputs_vector_info.emplace_back(*cell_state);
-
- std::vector<ITensorInfo *> inputs_vector_info_raw;
- for(auto &input : inputs_vector_info)
- {
- inputs_vector_info_raw.emplace_back(&input);
- }
+ inputs_vector_info_raw.push_back(&cell_state_tmp);
+ inputs_vector_info_raw.push_back(&forget_gate);
+ inputs_vector_info_raw.push_back(&output_gate_tmp);
ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer));
return Status{};
@@ -428,14 +461,13 @@
_memory_group.acquire();
_fully_connected_forget_gate.run();
- CLScheduler::get().enqueue(_transpose_forget_gate1);
- _gemm_forget_gate1.run();
+ CLScheduler::get().enqueue(_transpose_forget_gate);
+ _gemm_forget_gate.run();
CLScheduler::get().enqueue(_accum_forget_gate1);
if(_run_peephole_opt)
{
- CLScheduler::get().enqueue(_transpose_forget_gate2);
- _gemm_forget_gate2.run();
+ CLScheduler::get().enqueue(_pixelwise_mul_forget_gate);
_accum_forget_gate2.run();
}
CLScheduler::get().enqueue(_activation_forget_gate);
@@ -443,24 +475,33 @@
if(_run_cifg_opt)
{
_ones.map(true);
- std::fill_n(_ones.buffer(), _ones.info()->total_size(), 1);
+ if(_ones.info()->data_type() == DataType::F16)
+ {
+ std::fill_n(reinterpret_cast<half *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 1);
+ }
+ else
+ {
+ std::fill_n(reinterpret_cast<float *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 1);
+ }
_ones.unmap();
CLScheduler::get().enqueue(_subtract_input_gate);
}
else
{
_fully_connected_input_gate.run();
- CLScheduler::get().enqueue(_transpose_input_gate1);
- _gemm_input_gate1.run();
- CLScheduler::get().enqueue(_transpose_input_gate2);
- _gemm_input_gate2.run();
+ CLScheduler::get().enqueue(_transpose_input_gate);
+ _gemm_input_gate.run();
CLScheduler::get().enqueue(_accum_input_gate1);
- _accum_input_gate2.run();
+ if(_run_peephole_opt)
+ {
+ CLScheduler::get().enqueue(_pixelwise_mul_input_gate);
+ _accum_input_gate2.run();
+ }
CLScheduler::get().enqueue(_activation_input_gate);
}
_fully_connected_cell_state.run();
- CLScheduler::get().enqueue(_transpose_cell_state1);
+ CLScheduler::get().enqueue(_transpose_cell_state);
_gemm_cell_state1.run();
CLScheduler::get().enqueue(_accum_cell_state1);
CLScheduler::get().enqueue(_activation_cell_state);
@@ -474,21 +515,19 @@
}
_fully_connected_output.run();
- CLScheduler::get().enqueue(_transpose_output1);
- _gemm_output1.run();
+ CLScheduler::get().enqueue(_transpose_output);
+ _gemm_output.run();
CLScheduler::get().enqueue(_accum_output1);
- CLScheduler::get().enqueue(_pixelwise_mul_output_state);
if(_run_peephole_opt)
{
- CLScheduler::get().enqueue(_transpose_output2);
- _gemm_output2.run();
+ CLScheduler::get().enqueue(_pixelwise_mul_output_state1);
_accum_output2.run();
}
CLScheduler::get().enqueue(_activation_output);
CLScheduler::get().enqueue(_activation_output_state);
- CLScheduler::get().enqueue(_pixelwise_mul_output_state);
+ CLScheduler::get().enqueue(_pixelwise_mul_output_state2);
if(_has_projection_weights)
{
diff --git a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp
index 986fe00..40bf032 100644
--- a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp
@@ -48,7 +48,10 @@
// Get convolved dimensions
unsigned int conv_w = 0;
unsigned int conv_h = 0;
- std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_width, kernel_height,
+ std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0),
+ input->dimension(1),
+ kernel_width,
+ kernel_height,
conv_info);
const size_t mat_weights_cols = weights->dimension(3);
@@ -61,9 +64,12 @@
const size_t mat_input_rows = conv_w * conv_h;
shape_im2col = input->tensor_shape();
+ if(shape_im2col.num_dimensions() >= 3)
+ {
+ shape_im2col.remove_dimension(2);
+ }
shape_im2col.set(0, mat_input_cols);
shape_im2col.set(1, mat_input_rows);
- shape_im2col.set(2, 1);
shape_gemm = shape_im2col;
shape_gemm.set(0, mat_weights_cols);
@@ -73,7 +79,7 @@
CLLocallyConnectedLayer::CLLocallyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(std::move(memory_manager)), _input_im2col_kernel(), _weights_reshape_kernel(), _mm_kernel(), _output_col2im_kernel(), _input_im2col_reshaped(), _weights_reshaped(), _gemm_output(),
- _is_first_run(false), _original_weights(nullptr)
+ _is_prepared(false), _original_weights(nullptr)
{
}
@@ -128,7 +134,7 @@
bool _has_bias = (biases != nullptr);
_original_weights = weights;
- _is_first_run = true;
+ _is_prepared = false;
const unsigned int kernel_width = weights->info()->dimension(0);
const unsigned int kernel_height = weights->info()->dimension(1);
@@ -160,24 +166,15 @@
_output_col2im_kernel.configure(&_gemm_output, output, std::make_pair(conv_w, conv_h));
// Allocate intermediate tensors
- _weights_reshaped.allocator()->allocate();
_input_im2col_reshaped.allocator()->allocate();
_gemm_output.allocator()->allocate();
+
+ CLScheduler::get().tune_kernel_static(_input_im2col_kernel);
}
void CLLocallyConnectedLayer::run()
{
- // Run weights reshaping (Runs once for every configure)
- if(_is_first_run)
- {
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
- _is_first_run = false;
- CLScheduler::get().enqueue(_weights_reshape_kernel);
-
- // Mark original weights tensor as unused
- _original_weights->mark_as_unused();
- }
+ prepare();
_memory_group.acquire();
@@ -192,3 +189,19 @@
_memory_group.release();
}
+
+void CLLocallyConnectedLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ // Run weights reshaping and mark original weights tensor as unused
+ _weights_reshaped.allocator()->allocate();
+ CLScheduler::get().enqueue(_weights_reshape_kernel);
+ _original_weights->mark_as_unused();
+
+ CLScheduler::get().queue().finish();
+ _is_prepared = true;
+ }
+}
diff --git a/src/runtime/CL/functions/CLMagnitude.cpp b/src/runtime/CL/functions/CLMagnitude.cpp
index b1284db..e2dfe3a 100644
--- a/src/runtime/CL/functions/CLMagnitude.cpp
+++ b/src/runtime/CL/functions/CLMagnitude.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,10 +30,8 @@
using namespace arm_compute;
-void CLMagnitude::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, MagnitudeType mag_type, bool use_fp16)
+void CLMagnitude::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, MagnitudeType mag_type)
{
- ARM_COMPUTE_UNUSED(use_fp16);
-
auto k = arm_compute::support::cpp14::make_unique<CLMagnitudePhaseKernel>();
k->configure(input1, input2, output, nullptr, mag_type);
_kernel = std::move(k);
diff --git a/src/runtime/CL/functions/CLMeanStdDev.cpp b/src/runtime/CL/functions/CLMeanStdDev.cpp
index 838f7e7..157f306 100644
--- a/src/runtime/CL/functions/CLMeanStdDev.cpp
+++ b/src/runtime/CL/functions/CLMeanStdDev.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,35 +21,149 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#include "arm_compute/runtime/CL/functions/CLMeanStdDev.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "arm_compute/runtime/CL/functions/CLMeanStdDev.h"
using namespace arm_compute;
-CLMeanStdDev::CLMeanStdDev()
- : _mean_stddev_kernel(),
+CLMeanStdDev::CLMeanStdDev(std::shared_ptr<IMemoryManager> memory_manager) // NOLINT
+ : _memory_group(std::move(memory_manager)),
+ _data_type(),
+ _num_pixels(),
+ _run_stddev(),
+ _reduction_operation_mean(),
+ _reduction_operation_stddev(),
+ _reduction_output_mean(),
+ _reduction_output_stddev(),
+ _mean(nullptr),
+ _stddev(nullptr),
+ _mean_stddev_kernel(),
_fill_border_kernel(),
_global_sum(),
_global_sum_squared()
{
}
+Status CLMeanStdDev::validate(ITensorInfo *input, float *mean, float *stddev)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_TENSOR_NOT_2D(input);
+ if(is_data_type_float(input->data_type()))
+ {
+ ARM_COMPUTE_UNUSED(mean);
+ ARM_COMPUTE_UNUSED(stddev);
+
+ TensorShape output_shape = TensorShape{ 1, input->dimension(1) };
+ TensorInfo output_shape_info = TensorInfo(output_shape, 1, DataType::U8);
+ return CLReductionOperation::validate(input, &output_shape_info, 0, ReductionOperation::SUM);
+ }
+ else
+ {
+ return CLMeanStdDevKernel::validate(input, mean, nullptr, stddev, nullptr);
+ }
+}
+
void CLMeanStdDev::configure(ICLImage *input, float *mean, float *stddev)
{
- _global_sum = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_ulong));
+ // In the case of F16/F32 we call reduction operation for calculating CLMeanStdDev
+ _data_type = input->info()->data_type();
- if(stddev != nullptr)
+ if(is_data_type_float(_data_type))
{
- _global_sum_squared = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_ulong));
+ _num_pixels = input->info()->dimension(0) * input->info()->dimension(1);
+
+ _memory_group.manage(&_reduction_output_mean);
+ _reduction_operation_mean.configure(input, &_reduction_output_mean, 0, ReductionOperation::SUM);
+ _reduction_output_mean.allocator()->allocate();
+ _mean = mean;
+
+ if(stddev != nullptr)
+ {
+ _memory_group.manage(&_reduction_output_stddev);
+ _reduction_operation_stddev.configure(input, &_reduction_output_stddev, 0, ReductionOperation::SUM_SQUARE);
+ _reduction_output_stddev.allocator()->allocate();
+ _stddev = stddev;
+ _run_stddev = true;
+ }
+ }
+ else
+ {
+ _global_sum = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_ulong));
+
+ if(stddev != nullptr)
+ {
+ _global_sum_squared = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_ulong));
+ }
+
+ _mean_stddev_kernel.configure(input, mean, &_global_sum, stddev, &_global_sum_squared);
+ _fill_border_kernel.configure(input, _mean_stddev_kernel.border_size(), BorderMode::CONSTANT, PixelValue(static_cast<uint8_t>(0)));
+ }
+}
+
+template <typename T>
+void CLMeanStdDev::run_float()
+{
+ _memory_group.acquire();
+
+ // Perform reduction on x-axis
+ _reduction_operation_mean.run();
+ if(_run_stddev)
+ {
+ _reduction_operation_stddev.run();
+ _reduction_output_stddev.map(true);
}
- _mean_stddev_kernel.configure(input, mean, &_global_sum, stddev, &_global_sum_squared);
- _fill_border_kernel.configure(input, _mean_stddev_kernel.border_size(), BorderMode::CONSTANT, PixelValue(static_cast<uint8_t>(0)));
+ _reduction_output_mean.map(true);
+
+ auto mean = static_cast<T>(0);
+
+ // Calculate final result for mean
+ for(unsigned int i = 0; i < _reduction_output_mean.info()->dimension(1); ++i)
+ {
+ mean += *reinterpret_cast<T *>(_reduction_output_mean.buffer() + _reduction_output_mean.info()->offset_element_in_bytes(Coordinates(0, i)));
+ }
+
+ mean /= _num_pixels;
+ *_mean = mean;
+
+ if(_run_stddev)
+ {
+ auto stddev = static_cast<T>(0);
+ // Calculate final result for stddev
+ for(unsigned int i = 0; i < _reduction_output_stddev.info()->dimension(1); ++i)
+ {
+ stddev += *reinterpret_cast<T *>(_reduction_output_stddev.buffer() + _reduction_output_stddev.info()->offset_element_in_bytes(Coordinates(0, i)));
+ }
+ *_stddev = std::sqrt((stddev / _num_pixels) - (mean * mean));
+
+ _reduction_output_stddev.unmap();
+ }
+ _reduction_output_mean.unmap();
+
+ _memory_group.release();
+}
+
+void CLMeanStdDev::run_int()
+{
+ CLScheduler::get().enqueue(_fill_border_kernel);
+ CLScheduler::get().enqueue(_mean_stddev_kernel);
}
void CLMeanStdDev::run()
{
- CLScheduler::get().enqueue(_fill_border_kernel);
- CLScheduler::get().enqueue(_mean_stddev_kernel);
+ switch(_data_type)
+ {
+ case DataType::F16:
+ run_float<half>();
+ break;
+ case DataType::F32:
+ run_float<float>();
+ break;
+ case DataType::U8:
+ run_int();
+ break;
+ default:
+ ARM_COMPUTE_ERROR_ON("Not supported");
+ }
}
diff --git a/src/runtime/CL/functions/CLPoolingLayer.cpp b/src/runtime/CL/functions/CLPoolingLayer.cpp
index 17875a3..cbe1ce3 100644
--- a/src/runtime/CL/functions/CLPoolingLayer.cpp
+++ b/src/runtime/CL/functions/CLPoolingLayer.cpp
@@ -63,6 +63,9 @@
ARM_COMPUTE_ERROR("Data layout not supported");
}
_border_handler.configure(input, _kernel->border_size(), border_mode, pixel_value);
+
+ // Tune kernels
+ CLScheduler::get().tune_kernel_static(*_kernel);
}
Status CLPoolingLayer::validate(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info)
diff --git a/src/runtime/CL/functions/CLRNNLayer.cpp b/src/runtime/CL/functions/CLRNNLayer.cpp
index 4843ba6..1809e6e 100644
--- a/src/runtime/CL/functions/CLRNNLayer.cpp
+++ b/src/runtime/CL/functions/CLRNNLayer.cpp
@@ -36,7 +36,8 @@
using namespace arm_compute::misc::shape_calculator;
CLRNNLayer::CLRNNLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected_kernel(), _copy_kernel(), _fully_connected_out(), _gemm_output(), _add_output()
+ : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected_kernel(), _copy_kernel(), _fully_connected_out(), _gemm_output(), _add_output(),
+ _is_prepared(false)
{
}
@@ -57,7 +58,7 @@
auto shape_info = TensorInfo(compute_rnn_shape(recurrent_weights, hidden_state->dimension(idx_height)), 1, input->data_type());
- ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, weights, bias, &shape_info, true, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, weights, bias, &shape_info));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(hidden_state, recurrent_weights, nullptr, &shape_info, 1.f, 0.f));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(&shape_info, &shape_info, &shape_info, ConvertPolicy::SATURATE));
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&shape_info, &shape_info, info));
@@ -74,12 +75,14 @@
const int idx_height = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT);
TensorShape shape = compute_rnn_shape(recurrent_weights->info(), hidden_state->info()->dimension(idx_height));
+ _is_prepared = false;
+
_fully_connected_out.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
_gemm_output.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
// Manage intermediate buffers and configure
_memory_group.manage(&_fully_connected_out);
- _fully_connected_kernel.configure(input, weights, bias, &_fully_connected_out, true, false);
+ _fully_connected_kernel.configure(input, weights, bias, &_fully_connected_out);
_memory_group.manage(&_gemm_output);
_gemm_state_f.configure(hidden_state, recurrent_weights, nullptr, &_gemm_output, 1.f, 0.f);
@@ -100,7 +103,10 @@
void CLRNNLayer::run()
{
+ prepare();
+
_memory_group.acquire();
+
_fully_connected_kernel.run();
_gemm_state_f.run();
CLScheduler::get().enqueue(_add_kernel);
@@ -108,5 +114,17 @@
// copy hidden out to output
CLScheduler::get().enqueue(_copy_kernel);
+
_memory_group.release();
+}
+
+void CLRNNLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ _fully_connected_kernel.prepare();
+ _gemm_state_f.prepare();
+
+ _is_prepared = true;
+ }
}
\ No newline at end of file
diff --git a/src/runtime/CL/functions/CLReductionOperation.cpp b/src/runtime/CL/functions/CLReductionOperation.cpp
index 3a5133d..2a171c3 100644
--- a/src/runtime/CL/functions/CLReductionOperation.cpp
+++ b/src/runtime/CL/functions/CLReductionOperation.cpp
@@ -71,7 +71,6 @@
sums_vector[i].set_data_type(input->data_type());
sums_vector[i].set_tensor_shape(shape);
sums_vector[i].set_num_channels(input->num_channels());
- sums_vector[i].set_fixed_point_position(input->fixed_point_position());
}
// Validate ReductionOperation only on first kernel
@@ -105,7 +104,7 @@
for(unsigned int i = 0; i < _num_of_stages - 1; i++)
{
shape.set(0, ceil(shape.x() / 128.f));
- _sums_vector[i].allocator()->init(TensorInfo(shape, input->info()->num_channels(), input->info()->data_type(), input->info()->fixed_point_position()));
+ _sums_vector[i].allocator()->init(TensorInfo(shape, input->info()->num_channels(), input->info()->data_type()));
}
// Apply ReductionOperation only on first kernel
diff --git a/src/runtime/CL/functions/CLScale.cpp b/src/runtime/CL/functions/CLScale.cpp
index cb68481..4ff9763 100644
--- a/src/runtime/CL/functions/CLScale.cpp
+++ b/src/runtime/CL/functions/CLScale.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/core/CL/kernels/CLScaleKernel.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
#include "support/ToolchainSupport.h"
using namespace arm_compute;
@@ -34,7 +35,18 @@
void CLScale::configure(ICLTensor *input, ICLTensor *output, InterpolationPolicy policy, BorderMode border_mode, PixelValue constant_border_value, SamplingPolicy sampling_policy)
{
auto k = arm_compute::support::cpp14::make_unique<CLScaleKernel>();
- k->configure(input, output, policy, border_mode == BorderMode::UNDEFINED, sampling_policy);
+ k->set_target(CLScheduler::get().target());
+ k->configure(input, output, policy, border_mode, sampling_policy);
_kernel = std::move(k);
+
+ // Tune kernels
+ CLScheduler::get().tune_kernel_static(*_kernel);
+
+ // In the case of NHWC we can't have undefined border mode as this would require to access elements outside z dimension,
+ // so we treat it like border constant.
+ if(border_mode == BorderMode::UNDEFINED && input->info()->data_layout() == DataLayout::NHWC)
+ {
+ border_mode = BorderMode::CONSTANT;
+ }
_border_handler.configure(input, _kernel->border_size(), border_mode, constant_border_value);
}
diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
index a92fbce..7a20d9f 100644
--- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp
+++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
@@ -77,15 +77,16 @@
Status CLSoftmaxLayer::validate(const ITensorInfo *input, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() > 2, "Only 2D inputs are supported");
// Create intermediate tensor info
DataType tmp_data_type = is_data_type_quantized_asymmetric(input->data_type()) ? DataType::S32 : input->data_type();
- TensorInfo tensor_info_tmp(input->clone()->set_data_type(tmp_data_type));
+ TensorInfo tensor_info_tmp(input->clone()->set_data_type(tmp_data_type).set_is_resizable(true));
TensorShape max_sum_shape = input->tensor_shape();
max_sum_shape.set(0, 1);
- TensorInfo tensor_info_max(input->clone()->set_tensor_shape(max_sum_shape));
- TensorInfo tensor_info_sum(input->clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type).set_quantization_info(QuantizationInfo()));
+ TensorInfo tensor_info_max(input->clone()->set_tensor_shape(max_sum_shape).set_is_resizable(true));
+ TensorInfo tensor_info_sum(input->clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type).set_quantization_info(QuantizationInfo()).set_is_resizable(true));
ARM_COMPUTE_RETURN_ON_ERROR(CLLogits1DMaxShiftExpSumKernel::validate(input, &tensor_info_max, &tensor_info_tmp, &tensor_info_sum));
ARM_COMPUTE_RETURN_ON_ERROR(CLLogits1DNormKernel::validate(&tensor_info_tmp, &tensor_info_sum, output));
diff --git a/src/runtime/CL/functions/CLWarpAffine.cpp b/src/runtime/CL/functions/CLWarpAffine.cpp
index f785c75..4286cf6 100644
--- a/src/runtime/CL/functions/CLWarpAffine.cpp
+++ b/src/runtime/CL/functions/CLWarpAffine.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,7 +31,7 @@
using namespace arm_compute;
-void CLWarpAffine::configure(ICLTensor *input, ICLTensor *output, const float *matrix, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
+void CLWarpAffine::configure(ICLTensor *input, ICLTensor *output, const std::array<float, 9> &matrix, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
{
auto k = arm_compute::support::cpp14::make_unique<CLWarpAffineKernel>();
k->configure(input, output, matrix, policy);
diff --git a/src/runtime/CL/functions/CLWarpPerspective.cpp b/src/runtime/CL/functions/CLWarpPerspective.cpp
index b445b3b..4603ee0 100644
--- a/src/runtime/CL/functions/CLWarpPerspective.cpp
+++ b/src/runtime/CL/functions/CLWarpPerspective.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,7 +31,7 @@
using namespace arm_compute;
-void CLWarpPerspective::configure(ICLTensor *input, ICLTensor *output, const float *matrix, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
+void CLWarpPerspective::configure(ICLTensor *input, ICLTensor *output, const std::array<float, 9> &matrix, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
{
auto k = arm_compute::support::cpp14::make_unique<CLWarpPerspectiveKernel>();
k->configure(input, output, matrix, policy);
diff --git a/src/runtime/CL/functions/CLWidthConcatenateLayer.cpp b/src/runtime/CL/functions/CLWidthConcatenateLayer.cpp
index d542781..5233ff4 100644
--- a/src/runtime/CL/functions/CLWidthConcatenateLayer.cpp
+++ b/src/runtime/CL/functions/CLWidthConcatenateLayer.cpp
@@ -48,7 +48,7 @@
// Output auto inizialitation if not yet initialized
TensorInfo tmp_output_info = *output->clone();
TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(inputs_vector);
- auto_init_if_empty(tmp_output_info, output_shape, 1, inputs_vector[0]->data_type(), inputs_vector[0]->fixed_point_position());
+ auto_init_if_empty(tmp_output_info, output_shape, 1, inputs_vector[0]->data_type());
unsigned int width_offset = 0;
for(const auto &input : inputs_vector)
@@ -73,7 +73,7 @@
TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(inputs_vector);
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type(), inputs_vector[0]->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type());
ARM_COMPUTE_ERROR_THROW_ON(CLWidthConcatenateLayer::validate(inputs_vector_info, output->info()));
unsigned int width_offset = 0;
diff --git a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
index 49753ad..a70389a 100644
--- a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
@@ -33,17 +33,34 @@
namespace
{
-Size2D winograd_output_tile(const Size2D &input_dims, const Size2D &kernel_dims)
+Size2D winograd_output_tile(const Size2D &input_dims, const Size2D &kernel_dims, DataLayout data_layout)
{
Size2D output_tile = Size2D{};
- if(kernel_dims == Size2D(3U, 3U))
+ const unsigned int kernel_max_dim = std::max(kernel_dims.width, kernel_dims.height);
+
+ // Check if the input spatial dimensions are smaller than 4
+ const bool is_input_lt4_nchw = (input_dims.width <= 4 && input_dims.height <= 4) && (data_layout == DataLayout::NCHW);
+
+ if(kernel_max_dim == 3U)
{
- output_tile = (input_dims.width <= 4 && input_dims.height <= 4) ? Size2D(2U, 2U) : Size2D(4U, 4U);
+ if(kernel_dims == Size2D(3U, 3U))
+ {
+ output_tile = is_input_lt4_nchw ? Size2D(2U, 2U) : Size2D(4U, 4U);
+ }
+ else if(kernel_dims == Size2D(3U, 1U))
+ {
+ output_tile = is_input_lt4_nchw ? Size2D(2U, 1U) : Size2D(4U, 1U);
+ }
+ else
+ {
+ output_tile = is_input_lt4_nchw ? Size2D(1U, 2U) : Size2D(1U, 4U);
+ }
}
- else if(kernel_dims == Size2D(5U, 5U))
+ else if(kernel_max_dim == 5U)
{
- output_tile = Size2D(4U, 4U);
+ output_tile = Size2D(kernel_dims.width == 1 ? 1U : 4U,
+ kernel_dims.height == 1 ? 1U : 4U);
}
return output_tile;
@@ -82,7 +99,7 @@
// Input shape, kernel size and output tile
const Size2D input_dims = Size2D(input->info()->tensor_shape()[idx_width], input->info()->tensor_shape()[idx_height]);
const Size2D kernel_size = Size2D(weights->info()->tensor_shape()[idx_width], weights->info()->tensor_shape()[idx_height]);
- const Size2D output_tile = winograd_output_tile(input_dims, kernel_size);
+ const Size2D output_tile = winograd_output_tile(input_dims, kernel_size, input->info()->data_layout());
// Check if the Winograd configuration requires fast math
if(!enable_fast_math)
@@ -139,7 +156,7 @@
// Input shape, kernel size and output tile
const Size2D input_dims = Size2D(input->tensor_shape()[idx_width], input->tensor_shape()[idx_height]);
const Size2D kernel_size = Size2D(weights->tensor_shape()[idx_width], weights->tensor_shape()[idx_height]);
- const Size2D output_tile = winograd_output_tile(input_dims, kernel_size);
+ const Size2D output_tile = winograd_output_tile(input_dims, kernel_size, input->data_layout());
// Check if the Winograd configuration requires fast math
if(!enable_fast_math)
diff --git a/src/runtime/CL/tuners/BifrostTuner.cpp b/src/runtime/CL/tuners/BifrostTuner.cpp
index c0ebd24..2d52f33 100644
--- a/src/runtime/CL/tuners/BifrostTuner.cpp
+++ b/src/runtime/CL/tuners/BifrostTuner.cpp
@@ -124,15 +124,195 @@
k.set_lws_hint(lws_hint);
}
}
+
+void tune_col2im_kernel(CLCol2ImKernel &k)
+{
+ cl::NDRange lws_hint = k.lws_hint();
+ const GPUTarget gpu_target = k.get_target();
+
+ // Configure the local work size for Bifrost with a value obtained
+ // via exhaustive autotuning over 30 representative tensor shapes.
+ if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::G76))
+ {
+ if((k._convolved_dims.first == 7) || (k._convolved_dims.first == 14))
+ {
+ lws_hint = cl::NDRange(1, 7, 1);
+ }
+ else
+ {
+ lws_hint = cl::NDRange(1, 8, 1);
+ }
+ }
+
+ k.set_lws_hint(lws_hint);
+}
+
+void tune_im2col_kernel(CLIm2ColKernel &k)
+{
+ cl::NDRange lws_hint = k.lws_hint();
+ const GPUTarget gpu_target = k.get_target();
+
+ // Local work size optimized for the 11x11 AlexNet convolution on Bifrost.
+ if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::G76) && k._kernel_dims.width == 11)
+ {
+ const bool is_square_kernel = (k._kernel_dims.width == k._kernel_dims.height);
+ if(!is_square_kernel && k._kernel_dims.width > 1 && !k._conv_info.has_padding())
+ {
+ lws_hint = cl::NDRange(1, 1, 1);
+ }
+ }
+ k.set_lws_hint(lws_hint);
+}
+
+void tune_depthwise_im2col_kernel(CLDepthwiseIm2ColKernel &k)
+{
+ cl::NDRange lws_hint = k.lws_hint();
+ const GPUTarget gpu_target = k.get_target();
+
+ // Configure the local work size for Bifrost with a value obtained
+ // via exhaustive autotuning for the MobileNets tensor shapes.
+ if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::G76))
+ {
+ lws_hint = cl::NDRange(1, 2, 1);
+ }
+
+ k.set_lws_hint(lws_hint);
+}
+
+void tune_gemv_kernel(CLGEMMMatrixVectorMultiplyKernel &k)
+{
+ cl::NDRange lws_hint = k.lws_hint();
+ const GPUTarget gpu_target = k.get_target();
+
+ // Configure the local work size for Bifrost with a value obtained
+ // via exhaustive autotuning for the MobileNets tensor shapes.
+ if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::G76))
+ {
+ lws_hint = cl::NDRange(1, 1, 1);
+ }
+
+ k.set_lws_hint(lws_hint);
+}
+
+void tune_gemm_kernel(CLGEMMMatrixMultiplyKernel &k)
+{
+ cl::NDRange lws_hint = k.lws_hint();
+ const GPUTarget gpu_target = k.get_target();
+
+ // Configure LWS hint
+ switch(gpu_target)
+ {
+ case GPUTarget::G71:
+ case GPUTarget::G72:
+ case GPUTarget::G51:
+ case GPUTarget::G51BIG:
+ case GPUTarget::G51LIT:
+ case GPUTarget::G76:
+ if(k._input1->info()->dimension(1) == 24)
+ {
+ // LWS optimized for the 11x11 AlexNet convolution on Bifrost.
+ lws_hint = cl::NDRange(2, 2);
+ }
+ else if(k._output->info()->dimension(1) == 196)
+ {
+ lws_hint = cl::NDRange(1, 7);
+ }
+ else
+ {
+ lws_hint = cl::NDRange(8, 8);
+ }
+ break;
+ default:
+ lws_hint = cl::NullRange;
+ }
+
+ k.set_lws_hint(lws_hint);
+}
+
+void tune_pooling_kernel(CLPoolingLayerKernel &k)
+{
+ cl::NDRange lws_hint = k.lws_hint();
+ const GPUTarget gpu_target = k.get_target();
+
+ // Configure the local work size (hint) from the first two dimensions of the global work size.
+ // On Bifrost, this works for up to 35x35xC filters, for which the pooling_layer_3_optimized
+ // kernel is launched with gws=(9, 33, C). In any case, the hint will be ignored if it is
+ // invalid (e.g. exceeds the maximum workgroup size that the kernel can be launched with).
+ if(k._input->info()->data_layout() == DataLayout::NCHW)
+ {
+ if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::G76))
+ {
+ cl::NDRange gws = ICLKernel::gws_from_window(k.window());
+ lws_hint = cl::NDRange(gws[0], gws[1], 1);
+ }
+ }
+
+ k.set_lws_hint(lws_hint);
+}
+
+void tune_scale_kernel(CLScaleKernel &k)
+{
+ cl::NDRange lws_hint = k.lws_hint();
+ const GPUTarget gpu_target = k.get_target();
+ const DataType dt = k.input()->info()->data_type();
+ const InterpolationPolicy interpolation = k._interpolationPolicy;
+
+ // Configure the local work size for Bifrost, interpolation (bilinear) and datatype F32.
+ // The value are obtained via exhaustive autotuning.
+ if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72) && (dt == DataType::F32) && (interpolation == InterpolationPolicy::BILINEAR))
+ {
+ auto dim_0 = k.output()->info()->dimension(0);
+ if(dim_0 == 480)
+ {
+ lws_hint = cl::NDRange(2, 1);
+ }
+ else if(dim_0 == 3120)
+ {
+ lws_hint = cl::NDRange(2, 8);
+ }
+ else if(dim_0 == 4160)
+ {
+ lws_hint = cl::NDRange(4, 8);
+ }
+ k.set_lws_hint(lws_hint);
+ }
+}
} // namespace
void BifrostTuner::tune_kernel_static(ICLKernel &kernel)
{
- // Continue on tuning if dynamic tuning
if(dynamic_cast<CLDirectConvolutionLayerKernel *>(&kernel) != nullptr)
{
tune_direct_convolution_kernel(*utils::cast::polymorphic_downcast<CLDirectConvolutionLayerKernel *>(&kernel));
}
+ else if(dynamic_cast<CLCol2ImKernel *>(&kernel) != nullptr)
+ {
+ tune_col2im_kernel(*utils::cast::polymorphic_downcast<CLCol2ImKernel *>(&kernel));
+ }
+ else if(dynamic_cast<CLIm2ColKernel *>(&kernel) != nullptr)
+ {
+ tune_im2col_kernel(*utils::cast::polymorphic_downcast<CLIm2ColKernel *>(&kernel));
+ }
+ else if(dynamic_cast<CLDepthwiseIm2ColKernel *>(&kernel) != nullptr)
+ {
+ tune_depthwise_im2col_kernel(*utils::cast::polymorphic_downcast<CLDepthwiseIm2ColKernel *>(&kernel));
+ }
+ else if(dynamic_cast<CLGEMMMatrixVectorMultiplyKernel *>(&kernel) != nullptr)
+ {
+ tune_gemv_kernel(*utils::cast::polymorphic_downcast<CLGEMMMatrixVectorMultiplyKernel *>(&kernel));
+ }
+ else if(dynamic_cast<CLGEMMMatrixMultiplyKernel *>(&kernel) != nullptr)
+ {
+ tune_gemm_kernel(*utils::cast::polymorphic_downcast<CLGEMMMatrixMultiplyKernel *>(&kernel));
+ }
+ else if(dynamic_cast<CLPoolingLayerKernel *>(&kernel) != nullptr)
+ {
+ tune_pooling_kernel(*utils::cast::polymorphic_downcast<CLPoolingLayerKernel *>(&kernel));
+ }
+ else if(dynamic_cast<CLScaleKernel *>(&kernel) != nullptr)
+ {
+ tune_scale_kernel(*utils::cast::polymorphic_downcast<CLScaleKernel *>(&kernel));
+ }
}
void BifrostTuner::tune_kernel_dynamic(ICLKernel &kernel)
diff --git a/src/runtime/CL/tuners/MidgardTuner.cpp b/src/runtime/CL/tuners/MidgardTuner.cpp
new file mode 100644
index 0000000..cae3123
--- /dev/null
+++ b/src/runtime/CL/tuners/MidgardTuner.cpp
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/CL/tuners/MidgardTuner.h"
+
+#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLKernels.h"
+#include "arm_compute/core/utils/misc/Cast.h"
+
+namespace arm_compute
+{
+namespace tuners
+{
+namespace
+{
+void tune_gemm_kernel(CLGEMMMatrixMultiplyKernel &k)
+{
+ cl::NDRange lws_hint = k.lws_hint();
+ const GPUTarget gpu_target = k.get_target();
+
+ switch(gpu_target)
+ {
+ case GPUTarget::MIDGARD:
+ case GPUTarget::T600:
+ case GPUTarget::T700:
+ case GPUTarget::T800:
+ if(k._output->info()->dimension(1) == 196)
+ {
+ lws_hint = cl::NDRange(1, 7);
+ }
+ else
+ {
+ lws_hint = cl::NDRange(8, 8);
+ }
+ break;
+ default:
+ lws_hint = cl::NullRange;
+ }
+
+ k.set_lws_hint(lws_hint);
+}
+} // namespace
+
+void MidgardTuner::tune_kernel_static(ICLKernel &kernel)
+{
+ if(dynamic_cast<CLGEMMMatrixMultiplyKernel *>(&kernel) != nullptr)
+ {
+ tune_gemm_kernel(*utils::cast::polymorphic_downcast<CLGEMMMatrixMultiplyKernel *>(&kernel));
+ }
+}
+
+void MidgardTuner::tune_kernel_dynamic(ICLKernel &kernel)
+{
+ ARM_COMPUTE_UNUSED(kernel);
+}
+} // namespace tuners
+} // namespace arm_compute
diff --git a/src/runtime/CPP/CPPScheduler.cpp b/src/runtime/CPP/CPPScheduler.cpp
index 92dce34..de28b4f 100644
--- a/src/runtime/CPP/CPPScheduler.cpp
+++ b/src/runtime/CPP/CPPScheduler.cpp
@@ -29,6 +29,7 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/runtime/CPUUtils.h"
+#include <atomic>
#include <condition_variable>
#include <iostream>
#include <mutex>
@@ -37,7 +38,59 @@
namespace arm_compute
{
-class Thread
+namespace
+{
+class ThreadFeeder
+{
+public:
+ /** Constructor
+ *
+ * @param[in] start First value that will be returned by the feeder
+ * @param[in] end End condition (The last value returned by get_next() will be end - 1)
+ */
+ explicit ThreadFeeder(unsigned int start = 0, unsigned int end = 0)
+ : _atomic_counter(start), _end(end)
+ {
+ }
+ /** Return the next element in the range if there is one.
+ *
+ * @param[out] next Will contain the next element if there is one.
+ *
+ * @return False if the end of the range has been reached and next wasn't set.
+ */
+ bool get_next(unsigned int &next)
+ {
+ next = atomic_fetch_add_explicit(&_atomic_counter, 1u, std::memory_order_relaxed);
+ return next < _end;
+ }
+
+private:
+ std::atomic_uint _atomic_counter;
+ const unsigned int _end;
+};
+
+/** Execute workloads[info.thread_id] first, then call the feeder to get the index of the next workload to run.
+ *
+ * Will run workloads until the feeder reaches the end of its range.
+ *
+ * @param[in] workloads The array of workloads
+ * @param[in,out] feeder The feeder indicating which workload to execute next.
+ * @param[in] info Threading and CPU info.
+ */
+void process_workloads(std::vector<IScheduler::Workload> &workloads, ThreadFeeder &feeder, const ThreadInfo &info)
+{
+ unsigned int workload_index = info.thread_id;
+ do
+ {
+ ARM_COMPUTE_ERROR_ON(workload_index >= workloads.size());
+ workloads[workload_index](info);
+ }
+ while(feeder.get_next(workload_index));
+}
+
+} //namespace
+
+class CPPScheduler::Thread
{
public:
/** Start a new thread. */
@@ -51,11 +104,15 @@
/** Destructor. Make the thread join. */
~Thread();
- /** Request the worker thread to start executing the given kernel
- * This function will return as soon as the kernel has been sent to the worker thread.
+ /** Request the worker thread to start executing workloads.
+ *
+ * The thread will start by executing workloads[info.thread_id] and will then call the feeder to
+ * get the index of the following workload to run.
+ *
+ * @note This function will return as soon as the workloads have been sent to the worker thread.
* wait() needs to be called to ensure the execution is complete.
*/
- void start(ICPPKernel *kernel, const Window &window, const ThreadInfo &info);
+ void start(std::vector<IScheduler::Workload> *workloads, ThreadFeeder &feeder, const ThreadInfo &info);
/** Wait for the current kernel execution to complete. */
void wait();
@@ -64,39 +121,38 @@
void worker_thread();
private:
- std::thread _thread;
- ICPPKernel *_kernel{ nullptr };
- Window _window;
- ThreadInfo _info;
- std::mutex _m;
- std::condition_variable _cv;
- bool _wait_for_work{ false };
- bool _job_complete{ true };
- std::exception_ptr _current_exception;
+ std::thread _thread{};
+ ThreadInfo _info{};
+ std::vector<IScheduler::Workload> *_workloads{ nullptr };
+ ThreadFeeder *_feeder{ nullptr };
+ std::mutex _m{};
+ std::condition_variable _cv{};
+ bool _wait_for_work{ false };
+ bool _job_complete{ true };
+ std::exception_ptr _current_exception{ nullptr };
};
-Thread::Thread()
- : _thread(), _window(), _info(), _m(), _cv(), _current_exception(nullptr)
+CPPScheduler::Thread::Thread()
{
_thread = std::thread(&Thread::worker_thread, this);
}
-Thread::~Thread()
+CPPScheduler::Thread::~Thread()
{
// Make sure worker thread has ended
if(_thread.joinable())
{
- start(nullptr, Window(), ThreadInfo());
+ ThreadFeeder feeder;
+ start(nullptr, feeder, ThreadInfo());
_thread.join();
}
}
-void Thread::start(ICPPKernel *kernel, const Window &window, const ThreadInfo &info)
+void CPPScheduler::Thread::start(std::vector<IScheduler::Workload> *workloads, ThreadFeeder &feeder, const ThreadInfo &info)
{
- _kernel = kernel;
- _window = window;
- _info = info;
-
+ _workloads = workloads;
+ _feeder = &feeder;
+ _info = info;
{
std::lock_guard<std::mutex> lock(_m);
_wait_for_work = true;
@@ -105,7 +161,7 @@
_cv.notify_one();
}
-void Thread::wait()
+void CPPScheduler::Thread::wait()
{
{
std::unique_lock<std::mutex> lock(_m);
@@ -118,7 +174,7 @@
}
}
-void Thread::worker_thread()
+void CPPScheduler::Thread::worker_thread()
{
while(true)
{
@@ -129,15 +185,14 @@
_current_exception = nullptr;
// Time to exit
- if(_kernel == nullptr)
+ if(_workloads == nullptr)
{
return;
}
try
{
- _window.validate();
- _kernel->run(_window, _info);
+ process_workloads(*_workloads, *_feeder, _info);
}
catch(...)
{
@@ -174,56 +229,90 @@
return _num_threads;
}
-void CPPScheduler::schedule(ICPPKernel *kernel, unsigned int split_dimension)
+void CPPScheduler::run_workloads(std::vector<IScheduler::Workload> &workloads)
+{
+ const unsigned int num_threads = std::min(_num_threads, static_cast<unsigned int>(workloads.size()));
+ if(num_threads < 1)
+ {
+ return;
+ }
+ ThreadFeeder feeder(num_threads, workloads.size());
+ ThreadInfo info;
+ info.cpu_info = &_cpu_info;
+ info.num_threads = num_threads;
+ unsigned int t = 0;
+ auto thread_it = _threads.begin();
+ for(; t < num_threads - 1; ++t, ++thread_it)
+ {
+ info.thread_id = t;
+ thread_it->start(&workloads, feeder, info);
+ }
+
+ info.thread_id = t;
+ process_workloads(workloads, feeder, info);
+
+ try
+ {
+ for(auto &thread : _threads)
+ {
+ thread.wait();
+ }
+ }
+ catch(const std::system_error &e)
+ {
+ std::cerr << "Caught system_error with code " << e.code() << " meaning " << e.what() << '\n';
+ }
+}
+
+void CPPScheduler::schedule(ICPPKernel *kernel, const Hints &hints)
{
ARM_COMPUTE_ERROR_ON_MSG(!kernel, "The child class didn't set the kernel");
- /** [Scheduler example] */
- ThreadInfo info;
- info.cpu_info = &_cpu_info;
-
const Window &max_window = kernel->window();
- const unsigned int num_iterations = max_window.num_iterations(split_dimension);
- info.num_threads = std::min(num_iterations, _num_threads);
+ const unsigned int num_iterations = max_window.num_iterations(hints.split_dimension());
+ const unsigned int num_threads = std::min(num_iterations, _num_threads);
if(num_iterations == 0)
{
return;
}
- if(!kernel->is_parallelisable() || info.num_threads == 1)
+ if(!kernel->is_parallelisable() || num_threads == 1)
{
+ ThreadInfo info;
+ info.cpu_info = &_cpu_info;
kernel->run(max_window, info);
}
else
{
- int t = 0;
- auto thread_it = _threads.begin();
-
- for(; t < info.num_threads - 1; ++t, ++thread_it)
+ unsigned int num_windows = 0;
+ switch(hints.strategy())
{
- Window win = max_window.split_window(split_dimension, t, info.num_threads);
- info.thread_id = t;
- thread_it->start(kernel, win, info);
- }
-
- // Run last part on main thread
- Window win = max_window.split_window(split_dimension, t, info.num_threads);
- info.thread_id = t;
- kernel->run(win, info);
-
- try
- {
- for(auto &thread : _threads)
+ case StrategyHint::STATIC:
+ num_windows = num_threads;
+ break;
+ case StrategyHint::DYNAMIC:
{
- thread.wait();
+ // Make sure we don't use some windows which are too small as this might create some contention on the ThreadFeeder
+ const unsigned int max_iterations = static_cast<unsigned int>(_num_threads) * 3;
+ num_windows = num_iterations > max_iterations ? max_iterations : num_iterations;
+ break;
}
+ default:
+ ARM_COMPUTE_ERROR("Unknown strategy");
}
- catch(const std::system_error &e)
+ std::vector<IScheduler::Workload> workloads(num_windows);
+ for(unsigned int t = 0; t < num_windows; t++)
{
- std::cerr << "Caught system_error with code " << e.code() << " meaning " << e.what() << '\n';
+ //Capture 't' by copy, all the other variables by reference:
+ workloads[t] = [t, &hints, &max_window, &num_windows, &kernel](const ThreadInfo & info)
+ {
+ Window win = max_window.split_window(hints.split_dimension(), t, num_windows);
+ win.validate();
+ kernel->run(win, info);
+ };
}
+ run_workloads(workloads);
}
- /** [Scheduler example] */
}
} // namespace arm_compute
diff --git a/src/runtime/CPP/SingleThreadScheduler.cpp b/src/runtime/CPP/SingleThreadScheduler.cpp
index 2adc14c..3701159 100644
--- a/src/runtime/CPP/SingleThreadScheduler.cpp
+++ b/src/runtime/CPP/SingleThreadScheduler.cpp
@@ -41,14 +41,23 @@
ARM_COMPUTE_ERROR_ON(num_threads != 1);
}
-void SingleThreadScheduler::schedule(ICPPKernel *kernel, unsigned int split_dimension)
+void SingleThreadScheduler::schedule(ICPPKernel *kernel, const Hints &hints)
{
- ARM_COMPUTE_UNUSED(split_dimension);
+ ARM_COMPUTE_UNUSED(hints);
ThreadInfo info;
info.cpu_info = &_cpu_info;
kernel->run(kernel->window(), info);
}
+void SingleThreadScheduler::run_workloads(std::vector<Workload> &workloads)
+{
+ ThreadInfo info;
+ info.cpu_info = &_cpu_info;
+ for(auto &wl : workloads)
+ {
+ wl(info);
+ }
+}
unsigned int SingleThreadScheduler::num_threads() const
{
return 1;
diff --git a/src/runtime/CPUUtils.cpp b/src/runtime/CPUUtils.cpp
index 7e8bf2b..6c21086 100644
--- a/src/runtime/CPUUtils.cpp
+++ b/src/runtime/CPUUtils.cpp
@@ -69,49 +69,83 @@
using namespace arm_compute;
#if !defined(BARE_METAL) && (defined(__arm__) || defined(__aarch64__))
-struct PerCPUData
-{
- CPUModel model = CPUModel::GENERIC;
- unsigned int midr = 0;
- bool model_set = false;
-};
+bool model_supports_dot(CPUModel model)
+{
+ switch(model)
+ {
+ case CPUModel::GENERIC_FP16_DOT:
+ case CPUModel::A55r1:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool model_supports_fp16(CPUModel model)
+{
+ switch(model)
+ {
+ case CPUModel::GENERIC_FP16:
+ case CPUModel::GENERIC_FP16_DOT:
+ case CPUModel::A55r1:
+ return true;
+ default:
+ return false;
+ }
+}
/* Convert an MIDR register value to a CPUModel enum value. */
CPUModel midr_to_model(const unsigned int midr)
{
- CPUModel model;
+ CPUModel model = CPUModel::GENERIC;
// Unpack variant and CPU ID
- const int variant = (midr >> 20) & 0xF;
- const int cpunum = (midr >> 4) & 0xFFF;
+ const int implementer = (midr >> 24) & 0xFF;
+ const int variant = (midr >> 20) & 0xF;
+ const int cpunum = (midr >> 4) & 0xFFF;
- // Only CPUs we have code paths for are detected. All other CPUs can be safely classed as "GENERIC"
- switch(cpunum)
+ if(implementer == 0x41) // Arm CPUs
{
- case 0xd03:
- model = CPUModel::A53;
- break;
-
- case 0xd05:
- if(variant != 0)
- {
- model = CPUModel::A55r1;
- }
- else
- {
- model = CPUModel::A55r0;
- }
- break;
-
- default:
- model = CPUModel::GENERIC;
- break;
+ // Only CPUs we have code paths for are detected. All other CPUs can be safely classed as "GENERIC"
+ switch(cpunum)
+ {
+ case 0xd03: // A53
+ case 0xd04: // A35
+ model = CPUModel::A53;
+ break;
+ case 0xd05: // A55
+ if(variant != 0)
+ {
+ model = CPUModel::A55r1;
+ }
+ else
+ {
+ model = CPUModel::A55r0;
+ }
+ break;
+ case 0xd0a: // A75
+ if(variant != 0)
+ {
+ model = CPUModel::GENERIC_FP16_DOT;
+ }
+ else
+ {
+ model = CPUModel::GENERIC_FP16;
+ }
+ break;
+ case 0xd0b: // A76
+ model = CPUModel::GENERIC_FP16_DOT;
+ break;
+ default:
+ model = CPUModel::GENERIC;
+ break;
+ }
}
return model;
}
-void populate_models_cpuid(std::vector<PerCPUData> &cpusv)
+void populate_models_cpuid(std::vector<CPUModel> &cpusv)
{
// If the CPUID capability is present, MIDR information is provided in /sys. Use that to populate the CPU model table.
uint32_t i = 0;
@@ -126,16 +160,14 @@
std::string line;
if(bool(getline(file, line)))
{
- const unsigned long midr = support::cpp11::stoul(line, nullptr, 16);
- c.midr = (midr & 0xffffffff);
- c.model = midr_to_model(c.midr);
- c.model_set = true;
+ const unsigned long midr = support::cpp11::stoul(line, nullptr, support::cpp11::NumericBase::BASE_16);
+ c = midr_to_model(midr & 0xffffffff);
}
}
}
}
-void populate_models_cpuinfo(std::vector<PerCPUData> &cpusv)
+void populate_models_cpuinfo(std::vector<CPUModel> &cpusv)
{
// If "long-form" cpuinfo is present, parse that to populate models.
std::regex proc_regex("^processor.*(\\d+)$");
@@ -160,7 +192,7 @@
if(std::regex_match(line, match, proc_regex))
{
std::string id = match[1];
- int newcpu = support::cpp11::stoi(id, nullptr, 0);
+ int newcpu = support::cpp11::stoi(id, nullptr);
if(curcpu >= 0 && midr == 0)
{
@@ -170,9 +202,7 @@
if(curcpu >= 0)
{
- cpusv[curcpu].midr = midr;
- cpusv[curcpu].model = midr_to_model(midr);
- cpusv[curcpu].model_set = true;
+ cpusv[curcpu] = midr_to_model(midr);
}
midr = 0;
@@ -183,28 +213,28 @@
if(std::regex_match(line, match, imp_regex))
{
- int impv = support::cpp11::stoi(match[1], nullptr, 16);
+ int impv = support::cpp11::stoi(match[1], nullptr, support::cpp11::NumericBase::BASE_16);
midr |= (impv << 24);
continue;
}
if(std::regex_match(line, match, var_regex))
{
- int varv = support::cpp11::stoi(match[1], nullptr, 16);
- midr |= (varv << 16);
+ int varv = support::cpp11::stoi(match[1], nullptr, support::cpp11::NumericBase::BASE_16);
+ midr |= (varv << 20);
continue;
}
if(std::regex_match(line, match, part_regex))
{
- int partv = support::cpp11::stoi(match[1], nullptr, 16);
+ int partv = support::cpp11::stoi(match[1], nullptr, support::cpp11::NumericBase::BASE_16);
midr |= (partv << 4);
continue;
}
if(std::regex_match(line, match, rev_regex))
{
- int regv = support::cpp11::stoi(match[1], nullptr, 10);
+ int regv = support::cpp11::stoi(match[1], nullptr);
midr |= (regv);
midr |= (0xf << 16);
continue;
@@ -213,9 +243,7 @@
if(curcpu >= 0)
{
- cpusv[curcpu].midr = midr;
- cpusv[curcpu].model = midr_to_model(midr);
- cpusv[curcpu].model_set = true;
+ cpusv[curcpu] = midr_to_model(midr);
}
}
}
@@ -251,7 +279,7 @@
line.erase(line.begin(), startfrom);
- max_cpus = support::cpp11::stoi(line, nullptr, 0) + 1;
+ max_cpus = support::cpp11::stoi(line, nullptr) + 1;
success = true;
}
}
@@ -262,7 +290,6 @@
max_cpus = std::thread::hardware_concurrency();
}
#endif /* BARE_METAL */
-
return max_cpus;
}
#endif /* !defined(BARE_METAL) && (defined(__arm__) || defined(__aarch64__)) */
@@ -274,9 +301,9 @@
void get_cpu_configuration(CPUInfo &cpuinfo)
{
#if !defined(BARE_METAL) && (defined(__arm__) || defined(__aarch64__))
- bool cpuid = false;
- bool fp16_support = false;
- bool dot_support = false;
+ bool cpuid = false;
+ bool hwcaps_fp16_support = false;
+ bool hwcaps_dot_support = false;
const uint32_t hwcaps = getauxval(AT_HWCAP);
@@ -287,54 +314,17 @@
if((hwcaps & HWCAP_ASIMDHP) != 0)
{
- fp16_support = true;
+ hwcaps_fp16_support = true;
}
if((hwcaps & HWCAP_ASIMDDP) != 0)
{
- dot_support = true;
+ hwcaps_dot_support = true;
}
-#ifdef __aarch64__
- /* Pre-4.15 kernels don't have the ASIMDDP bit.
- *
- * Although the CPUID bit allows us to read the feature register
- * directly, the kernel quite sensibly masks this to only show
- * features known by it to be safe to show to userspace. As a
- * result, pre-4.15 kernels won't show the relevant bit in the
- * feature registers either.
- *
- * So for now, use a whitelist of CPUs known to support the feature.
- */
- if(!dot_support && cpuid)
- {
- /* List of CPUs with dot product support: A55r1 A75r1 A75r2 */
- const unsigned int dotprod_whitelist_masks[] = { 0xfff0fff0, 0xfff0fff0, 0xfff0fff0, 0 };
- const unsigned int dotprod_whitelist_values[] = { 0x4110d050, 0x4110d0a0, 0x4120d0a0, 0 };
-
- unsigned long cpuid;
-
- __asm __volatile(
- "mrs %0, midr_el1\n"
- : "=r"(cpuid)
- :
- : );
-
- for(int i = 0; dotprod_whitelist_values[i] != 0; i++)
- {
- if((cpuid & dotprod_whitelist_masks[i]) == dotprod_whitelist_values[i])
- {
- dot_support = true;
- break;
- }
- }
- }
-#endif /* __aarch64__ */
const unsigned int max_cpus = get_max_cpus();
cpuinfo.set_cpu_num(max_cpus);
- cpuinfo.set_fp16(fp16_support);
- cpuinfo.set_dotprod(dot_support);
- std::vector<PerCPUData> percpu(max_cpus);
+ std::vector<CPUModel> percpu(max_cpus, CPUModel::GENERIC);
if(cpuid)
{
populate_models_cpuid(percpu);
@@ -344,10 +334,17 @@
populate_models_cpuinfo(percpu);
}
int j(0);
+ // Update dot product and FP16 support if all CPUs support these features:
+ bool all_support_dot = true;
+ bool all_support_fp16 = true;
for(const auto &v : percpu)
{
- cpuinfo.set_cpu_model(j++, v.model);
+ all_support_dot &= model_supports_dot(v);
+ all_support_fp16 &= model_supports_fp16(v);
+ cpuinfo.set_cpu_model(j++, v);
}
+ cpuinfo.set_dotprod(all_support_dot || hwcaps_dot_support);
+ cpuinfo.set_fp16(all_support_fp16 || hwcaps_fp16_support);
#else /* !defined(BARE_METAL) && (defined(__arm__) || defined(__aarch64__)) */
ARM_COMPUTE_UNUSED(cpuinfo);
#endif /* !defined(BARE_METAL) && (defined(__arm__) || defined(__aarch64__)) */
diff --git a/src/runtime/GLES_COMPUTE/functions/GCConvolutionLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCConvolutionLayer.cpp
index 2a710f7..a7a56b6 100644
--- a/src/runtime/GLES_COMPUTE/functions/GCConvolutionLayer.cpp
+++ b/src/runtime/GLES_COMPUTE/functions/GCConvolutionLayer.cpp
@@ -37,7 +37,7 @@
using namespace arm_compute;
GCConvolutionLayerReshapeWeights::GCConvolutionLayerReshapeWeights()
- : _weights_reshape_kernel(), _weights_reshaped()
+ : _weights_reshape_kernel()
{
}
@@ -68,7 +68,7 @@
GCConvolutionLayer::GCConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(std::move(memory_manager)), _reshape_weights(), _input_im2col_kernel(), _mm_gemm(), _output_col2im_kernel(), _fill_border(), _activationlayer_function(), _original_weights(nullptr),
- _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _weights_transposed(), _gemm_output(), _tmp_output(), _is_first_run(true), _is_activationlayer_enabled(false)
+ _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _weights_transposed(), _gemm_output(), _tmp_output(), _is_activationlayer_enabled(false), _is_prepared(false)
{
}
@@ -88,7 +88,7 @@
}
void GCConvolutionLayer::configure(const IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
- const Size2D &dilation, const ActivationLayerInfo &act_info)
+ const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
@@ -96,8 +96,10 @@
ARM_COMPUTE_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2));
ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
+ ARM_COMPUTE_ERROR_ON(num_groups > 1);
+ ARM_COMPUTE_UNUSED(num_groups);
- _is_first_run = true;
+ _is_prepared = false;
_original_weights = weights;
if(biases != nullptr)
@@ -148,7 +150,7 @@
shape_im2col.set(1, mat_input_rows);
shape_im2col.set(2, 1);
- TensorInfo im2col_reshaped_info(shape_im2col, 1, dt, input->info()->fixed_point_position());
+ TensorInfo im2col_reshaped_info(shape_im2col, 1, dt);
_input_im2col_reshaped.allocator()->init(im2col_reshaped_info);
_memory_group.manage(&_input_im2col_reshaped);
@@ -158,7 +160,7 @@
shape_gemm.set(1, mat_input_rows);
const DataType gemm_data_type = dt;
- TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->info()->fixed_point_position());
+ TensorInfo info_gemm(shape_gemm, 1, gemm_data_type);
_gemm_output.allocator()->init(info_gemm);
_memory_group.manage(&_gemm_output);
@@ -182,9 +184,6 @@
ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(0) != conv_w) || (output->info()->dimension(1) != conv_h), "Output shape does not match the expected one");
- // Allocate intermediate tensor
- _weights_reshaped.allocator()->allocate();
-
//Configure Activation Layer
_is_activationlayer_enabled = act_info.enabled();
@@ -198,17 +197,7 @@
void GCConvolutionLayer::run()
{
- // Run weights reshaping (Runs once for every configure)
- if(_is_first_run)
- {
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
- _reshape_weights.run();
- _is_first_run = false;
-
- // Mark original weights tensor as unused
- _original_weights->mark_as_unused();
- }
+ prepare();
_memory_group.acquire();
@@ -219,17 +208,34 @@
// Run gemm on reshaped matrices
_mm_gemm.run();
-
GCScheduler::get().memory_barrier();
+
// Reshape output matrix
GCScheduler::get().dispatch(_output_col2im_kernel, false);
+ GCScheduler::get().memory_barrier();
_memory_group.release();
- GCScheduler::get().memory_barrier();
// Run Activation Layer
if(_is_activationlayer_enabled)
{
_activationlayer_function.run();
}
}
+
+void GCConvolutionLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ // Run weights reshaping and mark as unused
+ _weights_reshaped.allocator()->allocate();
+ _reshape_weights.run();
+
+ // Mark original weights tensor as unused
+ _original_weights->mark_as_unused();
+
+ _is_prepared = true;
+ }
+}
diff --git a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp
index a300033..6b8e341 100644
--- a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp
+++ b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp
@@ -40,7 +40,7 @@
GCFullyConnectedLayer::GCFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(std::move(memory_manager)), _im2col_kernel(), _reshape_weights_kernel(), _mm_kernel(), _accumulate_biases_kernel(), _im2col_output(), _reshape_weights_output(),
- _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false)
+ _original_weights(nullptr), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false)
{
}
@@ -80,13 +80,14 @@
}
void GCFullyConnectedLayer::configure(const IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output,
- bool transpose_weights, bool are_weights_reshaped, bool retain_internal_weights)
+ FullyConnectedLayerInfo fc_info)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::F16);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 2);
- _are_weights_reshaped = transpose_weights ? are_weights_reshaped : true;
+ _original_weights = weights;
+ _are_weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
_is_fc_after_conv = true;
_accumulate_biases = false;
@@ -141,25 +142,13 @@
configure_fc_fc(input, weights_to_use, output);
}
- // Allocate the transpose tensor if the are_weights_reshaped flag is false and once all the configure methods have been called
- if(!_are_weights_reshaped && !retain_internal_weights)
- {
- // Allocate the tensor for the weights reshaped
- _reshape_weights_output.allocator()->allocate();
- }
-
- ARM_COMPUTE_ERROR_ON(retain_internal_weights && _reshape_weights_output.gc_buffer() == 0);
- _are_weights_reshaped = _are_weights_reshaped || retain_internal_weights;
+ ARM_COMPUTE_ERROR_ON(fc_info.retain_internal_weights && _reshape_weights_output.gc_buffer() == 0);
+ _are_weights_reshaped = _are_weights_reshaped || fc_info.retain_internal_weights;
}
void GCFullyConnectedLayer::run()
{
- // Reshape of the weights (happens only once)
- if(!_are_weights_reshaped)
- {
- _are_weights_reshaped = true;
- _reshape_weights_kernel.run();
- }
+ prepare();
_memory_group.acquire();
@@ -187,3 +176,21 @@
_memory_group.release();
}
+
+void GCFullyConnectedLayer::prepare()
+{
+ // Reshape of the weights (happens only once)
+ if(!_are_weights_reshaped)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ // Run reshape weights kernel and mark weights as unused
+ _reshape_weights_output.allocator()->allocate();
+ _reshape_weights_kernel.run();
+
+ // Mark original weights tensor as unused
+ _original_weights->mark_as_unused();
+
+ _are_weights_reshaped = true;
+ }
+}
\ No newline at end of file
diff --git a/src/runtime/GLES_COMPUTE/functions/GCGEMM.cpp b/src/runtime/GLES_COMPUTE/functions/GCGEMM.cpp
index 79f8f71..8ae91ee 100644
--- a/src/runtime/GLES_COMPUTE/functions/GCGEMM.cpp
+++ b/src/runtime/GLES_COMPUTE/functions/GCGEMM.cpp
@@ -73,8 +73,8 @@
} // namespace
GCGEMM::GCGEMM(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _ma_kernel(), _tmp_a(), _tmp_b(), _is_interleaved_transposed(false), _run_addition(false),
- _is_first_run(true), _reshape_b_only_on_first_run(false)
+ : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _ma_kernel(), _tmp_a(), _tmp_b(), _original_b(nullptr), _is_interleaved_transposed(false),
+ _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
{
}
@@ -87,6 +87,8 @@
// Check if we need to reshape the matrix B only on the first run
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
+ _is_prepared = false;
+ _original_b = b;
const IGCTensor *matrix_a = a;
const IGCTensor *matrix_b = b;
@@ -136,7 +138,10 @@
{
// Allocate intermediate tensors
_tmp_a.allocator()->allocate();
- _tmp_b.allocator()->allocate();
+ if(!_reshape_b_only_on_first_run)
+ {
+ _tmp_b.allocator()->allocate();
+ }
}
// Configure matrix addition kernel
@@ -155,23 +160,21 @@
void GCGEMM::run()
{
+ prepare();
+
_memory_group.acquire();
+
if(_is_interleaved_transposed)
{
// Run interleave kernel
GCScheduler::get().dispatch(_interleave_kernel, false);
- if(_is_first_run)
- {
- // Run transpose kernel
- GCScheduler::get().dispatch(_transpose_kernel, false);
- _is_first_run = false;
- }
- else if(!_reshape_b_only_on_first_run)
+ if(!_reshape_b_only_on_first_run)
{
// Run transpose kernel
GCScheduler::get().dispatch(_transpose_kernel, false);
}
+
GCScheduler::get().memory_barrier();
}
@@ -184,5 +187,27 @@
GCScheduler::get().memory_barrier();
GCScheduler::get().dispatch(_ma_kernel);
}
+
_memory_group.release();
}
+
+void GCGEMM::prepare()
+{
+ if(!_is_prepared)
+ {
+ if(_is_interleaved_transposed && _reshape_b_only_on_first_run)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+
+ // Run transpose kernel
+ _tmp_b.allocator()->allocate();
+ GCScheduler::get().dispatch(_transpose_kernel, false);
+ GCScheduler::get().memory_barrier();
+
+ // Mark original weights tensor as unused
+ _original_b->mark_as_unused();
+ }
+
+ _is_prepared = true;
+ }
+}
diff --git a/src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp
index 1748a59..0c8769b 100644
--- a/src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp
+++ b/src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp
@@ -42,11 +42,11 @@
ARM_COMPUTE_ERROR_ON(beta != 1.0f);
// Create intermediate tensors shapes
- _tmp.allocator()->init(TensorInfo(input->info()->tensor_shape(), input->info()->num_channels(), input->info()->data_type(), input->info()->fixed_point_position()));
+ _tmp.allocator()->init(TensorInfo(input->info()->tensor_shape(), input->info()->num_channels(), input->info()->data_type()));
TensorShape shape = input->info()->tensor_shape();
shape.set(0, 1);
- TensorInfo tensor_info_max_sum(shape, input->info()->num_channels(), input->info()->data_type(), input->info()->fixed_point_position());
+ TensorInfo tensor_info_max_sum(shape, input->info()->num_channels(), input->info()->data_type());
_max.allocator()->init(tensor_info_max_sum);
_sum.allocator()->init(tensor_info_max_sum);
diff --git a/src/runtime/ITensorAllocator.cpp b/src/runtime/ITensorAllocator.cpp
index 8294201..087f324 100644
--- a/src/runtime/ITensorAllocator.cpp
+++ b/src/runtime/ITensorAllocator.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,13 +31,14 @@
using namespace arm_compute;
ITensorAllocator::ITensorAllocator()
- : _info()
+ : _info(), _alignment(0)
{
}
-void ITensorAllocator::init(const TensorInfo &input)
+void ITensorAllocator::init(const TensorInfo &input, size_t alignment)
{
- _info = input;
+ _info = input;
+ _alignment = alignment;
}
TensorInfo &ITensorAllocator::info()
@@ -49,3 +50,8 @@
{
return _info;
}
+
+size_t ITensorAllocator::alignment() const
+{
+ return _alignment;
+}
diff --git a/src/runtime/NEON/functions/NECannyEdge.cpp b/src/runtime/NEON/functions/NECannyEdge.cpp
index c27ff2f..d72c98b 100644
--- a/src/runtime/NEON/functions/NECannyEdge.cpp
+++ b/src/runtime/NEON/functions/NECannyEdge.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -61,12 +61,12 @@
void NECannyEdge::configure(ITensor *input, ITensor *output, int32_t upper_thr, int32_t lower_thr, int32_t gradient_size, int32_t norm_type, BorderMode border_mode, uint8_t constant_border_value,
bool use_fp16)
{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON(gradient_size < 3);
- ARM_COMPUTE_ERROR_ON(gradient_size > 7);
- ARM_COMPUTE_ERROR_ON(lower_thr > upper_thr);
ARM_COMPUTE_ERROR_ON((1 != norm_type) && (2 != norm_type));
+ ARM_COMPUTE_ERROR_ON((gradient_size != 3) && (gradient_size != 5) && (gradient_size != 7));
+ ARM_COMPUTE_ERROR_ON((lower_thr < 0) || (lower_thr >= upper_thr));
_output = output;
@@ -119,7 +119,7 @@
}
else
{
- ARM_COMPUTE_ERROR("Gradient size not supported\n");
+ ARM_COMPUTE_ERROR("Gradient size %d not supported\n", gradient_size);
}
// Manage intermediate buffers
@@ -171,24 +171,23 @@
void NECannyEdge::run()
{
ARM_COMPUTE_ERROR_ON_MSG(_sobel == nullptr, "Unconfigured function");
- ARM_COMPUTE_ERROR_ON(_output == nullptr);
_memory_group.acquire();
// Run sobelNxN
_sobel->run();
- // Fill border before non-maxima suppression. Nop for border mode undefined.
- NEScheduler::get().schedule(&_border_mag_gradient, Window::DimZ);
-
// Run gradient
NEScheduler::get().schedule(_gradient.get(), Window::DimY);
+ // Fill border before non-maxima suppression. Nop for border mode undefined.
+ NEScheduler::get().schedule(&_border_mag_gradient, Window::DimZ);
+
// Run non-maxima suppression
NEScheduler::get().schedule(&_non_max_suppr, Window::DimY);
ARM_COMPUTE_ERROR_ON(_output->buffer() == nullptr);
- memset(_output->buffer(), 0, _output->info()->total_size());
+ std::fill_n(_output->buffer(), _output->info()->total_size(), 0);
// Fill border before edge trace
NEScheduler::get().schedule(&_border_edge_trace, Window::DimZ);
diff --git a/src/runtime/NEON/functions/NEConcatenateLayer.cpp b/src/runtime/NEON/functions/NEConcatenateLayer.cpp
new file mode 100644
index 0000000..21ab47d
--- /dev/null
+++ b/src/runtime/NEON/functions/NEConcatenateLayer.cpp
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
+
+#include "arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEWidthConcatenateLayer.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "support/ToolchainSupport.h"
+
+namespace arm_compute
+{
+NEConcatenateLayer::NEConcatenateLayer()
+ : _concat_function(nullptr)
+{
+}
+
+void NEConcatenateLayer::configure(const std::vector<ITensor *> &inputs_vector, ITensor *output, DataLayoutDimension axis)
+{
+ ARM_COMPUTE_ERROR_ON(output == nullptr);
+
+ switch(get_data_layout_dimension_index(output->info()->data_layout(), axis))
+ {
+ case 0:
+ {
+ auto func = support::cpp14::make_unique<NEWidthConcatenateLayer>();
+ func->configure(inputs_vector, output);
+ _concat_function = std::move(func);
+ break;
+ }
+ case 2:
+ {
+ auto func = support::cpp14::make_unique<NEDepthConcatenateLayer>();
+ func->configure(inputs_vector, output);
+ _concat_function = std::move(func);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Concatenation is supported across width and depth only!");
+ }
+}
+
+Status NEConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, DataLayoutDimension axis)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON(output == nullptr);
+
+ switch(get_data_layout_dimension_index(output->data_layout(), axis))
+ {
+ case 0:
+ ARM_COMPUTE_RETURN_ON_ERROR(NEWidthConcatenateLayer::validate(inputs_vector, output));
+ break;
+ case 2:
+ ARM_COMPUTE_RETURN_ON_ERROR(NEDepthConcatenateLayer::validate(inputs_vector, output));
+ break;
+ default:
+ ARM_COMPUTE_RETURN_ERROR_MSG("Concatenation is supported across width and depth only!");
+ }
+ return Status{};
+}
+
+void NEConcatenateLayer::run()
+{
+ ARM_COMPUTE_ERROR_ON(_concat_function == nullptr);
+ _concat_function->run();
+}
+} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
index 7053c7e..931e5db 100644
--- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/PixelValue.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "support/ToolchainSupport.h"
#include <cmath>
@@ -41,10 +42,11 @@
}
void NEConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
- const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math)
+ const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
{
// Perform validate step
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_UNUSED(num_groups);
ARM_COMPUTE_ERROR_THROW_ON(NEConvolutionLayer::validate(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info,
enable_fast_math));
@@ -78,8 +80,10 @@
}
Status NEConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
- const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math)
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((num_groups != 1), "Grouping (num_groups != 1) is not supported on NEON");
+
switch(NEConvolutionLayer::get_convolution_method(input, weights, output, conv_info, weights_info, dilation, act_info))
{
case ConvolutionMethod::WINOGRAD:
@@ -108,6 +112,42 @@
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, weights);
ARM_COMPUTE_UNUSED(weights_info);
+ const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+ const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
+
+ /* Input spatial dims, kernel size, IFM/OFM, conv info*/
+ using ConvolutionConfiguration = std::tuple<Size2D, Size2D, Size2D, PadStrideInfo>;
+ using ConfigurationMethod = std::pair<ConvolutionConfiguration, ConvolutionMethod>;
+
+ const std::vector<ConfigurationMethod> known_configs =
+ {
+ // Alexnet
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(27U, 27U), Size2D(5U, 5U), Size2D(48U, 128U), PadStrideInfo(1U, 1U, 2U, 2U)), ConvolutionMethod::GEMM),
+ // VGG16 / VGG19
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(224U, 224U), Size2D(3U, 3U), Size2D(3U, 64U), PadStrideInfo(1U, 1U, 1U, 1U)), ConvolutionMethod::GEMM),
+ // Mobilenet 224
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(224U, 224U), Size2D(3U, 3U), Size2D(3U, 32U), PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR)), ConvolutionMethod::GEMM),
+ // Mobilenet 160
+ ConfigurationMethod(ConvolutionConfiguration(Size2D(160U, 160U), Size2D(3U, 3U), Size2D(3U, 24U), PadStrideInfo(2U, 2U, 0U, 1U, 0U, 1U, DimensionRoundingType::FLOOR)), ConvolutionMethod::GEMM)
+ };
+
+ const auto find_config = [&](ConfigurationMethod c)
+ {
+ const ConvolutionConfiguration config = c.first;
+ const PadStrideInfo info = std::get<3>(config);
+
+ return std::get<0>(config) == Size2D(input->dimension(idx_w), input->dimension(idx_h)) && std::get<1>(config) == Size2D(weights->dimension(idx_w), weights->dimension(idx_h))
+ && std::get<2>(config) == Size2D(weights->dimension(idx_c), weights->dimension(3)) && info.pad_top() == conv_info.pad_top() && info.pad_right() == conv_info.pad_right()
+ && info.pad_bottom() == conv_info.pad_bottom() && info.pad_left() == conv_info.pad_left() && info.stride() == conv_info.stride();
+ };
+
+ std::vector<ConfigurationMethod>::const_iterator found;
+ if((found = std::find_if(known_configs.begin(), known_configs.end(), find_config)) != known_configs.end())
+ {
+ return (*found).second;
+ }
+
if(dilation != Size2D(1U, 1U) || Scheduler::get().cpu_info().get_cpu_model() == CPUModel::A53
|| input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) <= 16)
{
@@ -119,6 +159,12 @@
void NEConvolutionLayer::run()
{
+ prepare();
_function->run();
}
+
+void NEConvolutionLayer::prepare()
+{
+ _function->prepare();
+}
} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NECopy.cpp b/src/runtime/NEON/functions/NECopy.cpp
new file mode 100644
index 0000000..efa8b89
--- /dev/null
+++ b/src/runtime/NEON/functions/NECopy.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NECopy.h"
+
+#include "arm_compute/core/NEON/kernels/NECopyKernel.h"
+#include "support/ToolchainSupport.h"
+
+#include <utility>
+
+using namespace arm_compute;
+
+void NECopy::configure(ITensor *input, ITensor *output)
+{
+ auto k = arm_compute::support::cpp14::make_unique<NECopyKernel>();
+ k->configure(input, output);
+ _kernel = std::move(k);
+}
+
+Status NECopy::validate(const arm_compute::ITensorInfo *input, const arm_compute::ITensorInfo *output)
+{
+ return NECopyKernel::validate(input, output);
+}
diff --git a/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp b/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp
index 40ada8f..fda9f57 100644
--- a/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp
@@ -38,7 +38,8 @@
_scaled_output(),
_input(nullptr),
_info(),
- _inner_border()
+ _inner_border(),
+ _is_prepared(false)
{
}
@@ -62,18 +63,15 @@
info.pad().first, info.pad().second, inner_border_right, inner_border_top, stride_x, stride_y);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, bias);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, weights, bias);
if(bias != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, bias);
}
if(output->tensor_shape().total_size() > 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
const TensorShape output_shape = deconvolution_output_shape(out_dims, input->tensor_shape(), weights->tensor_shape());
@@ -104,6 +102,7 @@
_input = input;
_info = info;
_inner_border = std::make_pair(inner_border_right, inner_border_top);
+ _is_prepared = false;
const unsigned int stride_x = info.stride().first;
const unsigned int stride_y = info.stride().second;
@@ -115,8 +114,7 @@
// configure scale function
// Init and allocate intermmidiate tensor for output, same size as input but the first two axis are the same as the output tensor
- const TensorInfo scale_out_info(compute_deconvolution_shape(*input->info(), stride_x, stride_y, inner_border_right, inner_border_top, info), 1, input->info()->data_type(),
- input->info()->fixed_point_position());
+ const TensorInfo scale_out_info(compute_deconvolution_shape(*input->info(), stride_x, stride_y, inner_border_right, inner_border_top, info), 1, input->info()->data_type());
_scaled_output.allocator()->init(scale_out_info);
// setup the function to convolve the upscaled output
@@ -132,13 +130,21 @@
void NEDeconvolutionLayer::run()
{
+ prepare();
+
_memory_group.acquire();
- // Run upsample kernel
_upsample_f.run();
-
- // Run convolution layer
_conv_f.run();
_memory_group.release();
+}
+
+void NEDeconvolutionLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ _conv_f.prepare();
+ _is_prepared = true;
+ }
}
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEDepthConcatenateLayer.cpp b/src/runtime/NEON/functions/NEDepthConcatenateLayer.cpp
index 930f8d5..49db855 100644
--- a/src/runtime/NEON/functions/NEDepthConcatenateLayer.cpp
+++ b/src/runtime/NEON/functions/NEDepthConcatenateLayer.cpp
@@ -27,7 +27,9 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/PixelValue.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "support/ToolchainSupport.h"
@@ -41,18 +43,22 @@
{
}
-void NEDepthConcatenateLayer::configure(std::vector<ITensor *> inputs_vector, ITensor *output) // NOLINT
+void NEDepthConcatenateLayer::configure(const std::vector<ITensor *> &inputs_vector, ITensor *output) // NOLINT
{
- ARM_COMPUTE_ERROR_ON(inputs_vector.size() < 2);
-
_num_inputs = inputs_vector.size();
_concat_kernels_vector = arm_compute::support::cpp14::make_unique<NEDepthConcatenateLayerKernel[]>(_num_inputs);
_border_handlers_vector = arm_compute::support::cpp14::make_unique<NEFillBorderKernel[]>(_num_inputs);
- TensorShape output_shape = calculate_depth_concatenate_shape(inputs_vector);
+ std::vector<ITensorInfo *> inputs_vector_info;
+ for(unsigned int i = 0; i < _num_inputs; i++)
+ {
+ inputs_vector_info.emplace_back(inputs_vector.at(i)->info());
+ }
+ TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_depth_concatenate_shape(inputs_vector_info);
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type(), inputs_vector[0]->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type());
+ ARM_COMPUTE_ERROR_THROW_ON(NEDepthConcatenateLayer::validate(inputs_vector_info, output->info()));
unsigned int depth_offset = 0;
for(unsigned int i = 0; i < _num_inputs; ++i)
@@ -67,6 +73,27 @@
output->info()->set_valid_region(ValidRegion(Coordinates(), output_shape));
}
+Status NEDepthConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON(inputs_vector.size() < 2);
+
+ // Output auto inizialitation if not yet initialized
+ TensorInfo tmp_output_info = *output->clone();
+ TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_depth_concatenate_shape(inputs_vector);
+ auto_init_if_empty(tmp_output_info, output_shape, 1, inputs_vector[0]->data_type());
+
+ unsigned int depth_offset = 0;
+ for(const auto &input : inputs_vector)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_RETURN_ON_ERROR(NEDepthConcatenateLayerKernel::validate(input, depth_offset, &tmp_output_info));
+ depth_offset += input->dimension(2);
+ }
+
+ return Status{};
+}
+
void NEDepthConcatenateLayer::run()
{
for(unsigned i = 0; i < _num_inputs; ++i)
diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
index 0a977ad..24b12f4 100644
--- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
@@ -36,8 +36,8 @@
using namespace arm_compute::misc::shape_calculator;
NEDepthwiseConvolutionLayer3x3::NEDepthwiseConvolutionLayer3x3()
- : _dwc_kernel(), _output_stage_kernel(), _border_handler(), _permute_input(), _permute_weights(), _permute_output(), _accumulator(), _input_nhwc(), _weights_hwio(), _output_nhwc(), _has_bias(false),
- _is_quantized(false), _is_optimized(false), _are_weights_reshaped(false), _is_nchw(true), _is_first_run(true)
+ : _dwc_kernel(), _output_stage_kernel(), _border_handler(), _permute_input(), _permute_weights(), _permute_output(), _accumulator(), _permuted_input(), _permuted_weights(), _permuted_output(),
+ _has_bias(false), _is_quantized(false), _is_optimized(false), _are_weights_reshaped(false), _is_nchw(true), _is_first_run(true), _permute(false)
{
}
@@ -57,29 +57,31 @@
input->info()->data_layout());
_are_weights_reshaped = false;
_is_nchw = input->info()->data_layout() == DataLayout::NCHW;
-
- ARM_COMPUTE_ERROR_ON(!_is_optimized && !_is_nchw);
+ _permute = _is_optimized == _is_nchw;
if(_is_optimized)
{
if(_is_nchw)
{
// Configure the function to transform the input tensor from NCHW -> NHWC
- _permute_input.configure(input, &_input_nhwc, PermutationVector(2U, 0U, 1U));
+ _permute_input.configure(input, &_permuted_input, PermutationVector(2U, 0U, 1U));
+ _permuted_input.info()->set_data_layout(DataLayout::NHWC);
// Configure the function to transform the weights tensor from IHW -> HWI
- _permute_weights.configure(weights, &_weights_hwio, PermutationVector(2U, 0U, 1U));
+ _permute_weights.configure(weights, &_permuted_weights, PermutationVector(2U, 0U, 1U));
+ _permuted_weights.info()->set_data_layout(DataLayout::NHWC);
// Configure optimized depthwise
- _dwc_kernel.configure(&_input_nhwc, &_weights_hwio, &_output_nhwc, conv_info, depth_multiplier, DataLayout::NHWC);
+ _dwc_kernel.configure(&_permuted_input, &_permuted_weights, &_permuted_output, conv_info, depth_multiplier, DataLayout::NHWC);
// Configure the function to transform the convoluted output to ACL's native ordering format NCHW
- _permute_output.configure(&_output_nhwc, output, PermutationVector(1U, 2U, 0U));
+ _permute_output.configure(&_permuted_output, output, PermutationVector(1U, 2U, 0U));
+ _permuted_output.info()->set_data_layout(DataLayout::NCHW);
// Allocate tensors
- _input_nhwc.allocator()->allocate();
- _weights_hwio.allocator()->allocate();
- _output_nhwc.allocator()->allocate();
+ _permuted_input.allocator()->allocate();
+ _permuted_weights.allocator()->allocate();
+ _permuted_output.allocator()->allocate();
}
else
{
@@ -88,39 +90,88 @@
}
else
{
- // Allocate the intermediate accumulator tensor in case of fixed point input
+ // Allocate the intermediate accumulator tensor in case of quantized input
if(_is_quantized)
{
- _accumulator.allocator()->init(TensorInfo(output->info()->tensor_shape(), 1, DataType::S32));
+ TensorShape accum_shape = output->info()->tensor_shape();
+
+ if(!_is_nchw)
+ {
+ permute(accum_shape, PermutationVector(1U, 2U, 0U));
+ }
+
+ _accumulator.allocator()->init(TensorInfo(accum_shape, 1, DataType::S32));
_accumulator.info()->set_quantization_info(input->info()->quantization_info());
zero_value = PixelValue(static_cast<uint32_t>(input->info()->quantization_info().offset));
}
- // Configure depthwise convolution kernel
- _dwc_kernel.configure(input, weights, (_is_quantized) ? &_accumulator : output, conv_info, depth_multiplier);
-
- // Configure border handler
- _border_handler.configure(input, _dwc_kernel.border_size(), BorderMode::CONSTANT, zero_value);
- }
-
- // Configure biases accumulation
- if(_has_bias || _is_quantized)
- {
- if(_is_quantized)
+ if(!_is_nchw)
{
- const QuantizationInfo output_quant_info = (output->info()->total_size() == 0) ? input->info()->quantization_info() : output->info()->quantization_info();
+ // Configure the function to transform the input tensor from NHWC -> NCHW
+ _permute_input.configure(input, &_permuted_input, PermutationVector(1U, 2U, 0U));
+ _permuted_input.info()->set_data_layout(DataLayout::NCHW);
- float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
- int output_multiplier, output_shift;
- quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- _output_stage_kernel.configure(&_accumulator, biases, output, output_multiplier, output_shift, output_quant_info.offset);
- _accumulator.allocator()->allocate();
+ // Configure the function to transform the weights tensor from HWI -> IHW
+ _permute_weights.configure(weights, &_permuted_weights, PermutationVector(1U, 2U, 0U));
+ _permuted_weights.info()->set_data_layout(DataLayout::NCHW);
+
+ // Configure optimized depthwise
+ _dwc_kernel.configure(&_permuted_input, &_permuted_weights, (_is_quantized) ? &_accumulator : &_permuted_output, conv_info, depth_multiplier);
+
+ // Configure border handler
+ _border_handler.configure(&_permuted_input, _dwc_kernel.border_size(), BorderMode::CONSTANT, zero_value);
+
+ // Allocate tensors
+ _permuted_input.allocator()->allocate();
+ _permuted_weights.allocator()->allocate();
}
else
{
- _output_stage_kernel.configure(output, biases);
+ // Configure depthwise convolution kernel
+ _dwc_kernel.configure(input, weights, (_is_quantized) ? &_accumulator : output, conv_info, depth_multiplier);
+
+ // Configure border handler
+ _border_handler.configure(input, _dwc_kernel.border_size(), BorderMode::CONSTANT, zero_value);
}
}
+
+ // Configure biases accumulation
+ if(_is_quantized)
+ {
+ const QuantizationInfo output_quant_info = (output->info()->total_size() == 0) ? input->info()->quantization_info() : output->info()->quantization_info();
+
+ float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
+ int output_multiplier, output_shift;
+ quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+ _output_stage_kernel.configure(&_accumulator, biases, _is_nchw ? output : &_permuted_output, output_multiplier, output_shift, output_quant_info.offset);
+ _accumulator.allocator()->allocate();
+ }
+ else if(_has_bias)
+ {
+ _output_stage_kernel.configure((_is_nchw || _is_optimized) ? output : &_permuted_output, biases);
+ }
+
+ if(!_is_optimized && !_is_nchw)
+ {
+ // Configure the function to transform the convoluted output to NHWC
+ _permute_output.configure(&_permuted_output, output, PermutationVector(2U, 0U, 1U));
+ _permuted_output.allocator()->allocate();
+ }
+}
+
+Status NEDepthwiseConvolutionLayer3x3::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
+ unsigned int depth_multiplier)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() != DataLayout::NCHW && input->data_layout() != DataLayout::NHWC);
+
+ if(biases != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3));
+ }
+
+ return NEDepthwiseConvolutionLayer3x3Kernel::validate(input, weights, output, conv_info, depth_multiplier);
}
void NEDepthwiseConvolutionLayer3x3::run()
@@ -132,32 +183,29 @@
_dwc_kernel.generate_convolver();
}
- // Permute weights in HWIO format if the optimized kernel will be executedd
- if(!_are_weights_reshaped && _is_optimized && _is_nchw)
+ // Permute weights
+ if(_permute)
{
- _are_weights_reshaped = true;
- _permute_weights.run();
+ if(!_are_weights_reshaped)
+ {
+ _are_weights_reshaped = true;
+ _permute_weights.run();
+ }
+
+ _permute_input.run();
}
// Handle input
- if(_is_optimized)
+ if(!_is_optimized)
{
- if(_is_nchw)
- {
- // Permute input to NHWC format execution
- _permute_input.run();
- }
- }
- else
- {
- // Fill border in NCHW format execution
+ // Fill border
NEScheduler::get().schedule(&_border_handler, Window::DimX);
}
// Execute depthwise convolution
NEScheduler::get().schedule(&_dwc_kernel, Window::DimX);
- // Permute output to ACL's native NCHW format in case of NHWC execution
+ // Permute output
if(_is_optimized && _is_nchw)
{
_permute_output.run();
@@ -168,27 +216,54 @@
{
NEScheduler::get().schedule(&_output_stage_kernel, Window::DimX);
}
+
+ // Permute output
+ if(!_is_optimized && !_is_nchw)
+ {
+ _permute_output.run();
+ }
}
NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer()
- : _im2col_kernel(), _weights_reshape_kernel(), _v2mm_kernel(), _vector_to_tensor_kernel(), _output_stage_kernel(), _v2mm_input_fill_border(), _v2mm_weights_fill_border(), _input_reshaped(),
- _weights_reshaped(), _v2mm_output(), _output_reshaped(), _is_first_run(true), _is_quantized(false), _original_weights(nullptr)
+ : _im2col_kernel(), _weights_reshape_kernel(), _v2mm_kernel(), _vector_to_tensor_kernel(), _output_stage_kernel(), _v2mm_input_fill_border(), _v2mm_weights_fill_border(), _permute_input(),
+ _permute_weights(), _permute_output(), _input_reshaped(), _weights_reshaped(), _v2mm_output(), _output_reshaped(), _permuted_input(), _permuted_weights(), _permuted_output(), _is_prepared(false),
+ _is_quantized(false), _is_nhwc(false), _original_weights(nullptr)
{
}
void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier)
{
+ const unsigned int channel_idx = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::CHANNEL);
+ ARM_COMPUTE_UNUSED(channel_idx);
+
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
- ARM_COMPUTE_ERROR_ON((input->info()->dimension(2) * depth_multiplier) != weights->info()->dimension(2));
+ ARM_COMPUTE_ERROR_ON((input->info()->dimension(channel_idx) * depth_multiplier) != weights->info()->dimension(channel_idx));
- const size_t weights_w = weights->info()->dimension(0);
- const size_t weights_h = weights->info()->dimension(1);
- const size_t weights_z = weights->info()->dimension(2);
+ _is_nhwc = input->info()->data_layout() == DataLayout::NHWC;
+
+ ITensor *input_to_use = input;
+ const ITensor *weights_to_use = weights;
+ ITensor *output_to_use = output;
+
+ if(_is_nhwc)
+ {
+ _permute_input.configure(input, &_permuted_input, PermutationVector(1U, 2U, 0U));
+ _permuted_input.info()->set_data_layout(DataLayout::NCHW);
+ input_to_use = &_permuted_input;
+
+ _permute_weights.configure(weights, &_permuted_weights, PermutationVector(1U, 2U, 0U));
+ _permuted_weights.info()->set_data_layout(DataLayout::NCHW);
+ weights_to_use = &_permuted_weights;
+ }
+
+ const size_t weights_w = weights_to_use->info()->dimension(0);
+ const size_t weights_h = weights_to_use->info()->dimension(1);
+ const size_t weights_z = weights_to_use->info()->dimension(2);
_is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
- _is_first_run = true;
- _original_weights = weights;
+ _is_prepared = false;
+ _original_weights = weights_to_use;
// Should bias be appended ?
bool append_bias = (biases != nullptr) && !_is_quantized;
@@ -200,6 +275,14 @@
auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
+ if(_is_nhwc)
+ {
+ permute(output_shape, PermutationVector(1U, 2U, 0U));
+ _permuted_output.allocator()->init(output->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape));
+ _permuted_output.info()->set_data_layout(DataLayout::NCHW);
+ output_to_use = &_permuted_output;
+ }
+
// Output width and height
const unsigned int conv_w = output_shape.x();
const unsigned int conv_h = output_shape.y();
@@ -209,41 +292,50 @@
const size_t conv_size = conv_w * conv_h;
// Im2Col configuration
- TensorShape shape_im2col = input->info()->tensor_shape();
+ TensorShape shape_im2col = input_to_use->info()->tensor_shape();
shape_im2col.set(0, patch_size);
shape_im2col.set(1, conv_size);
shape_im2col.set(2, weights_z);
- _input_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
- _im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias, depth_multiplier);
+ _input_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col).set_data_layout(DataLayout::NCHW));
+ _im2col_kernel.configure(input_to_use, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias, depth_multiplier);
// Weights reshape configuration
const TensorShape shape_weights_reshape(patch_size, weights_z);
- _weights_reshaped.allocator()->init(weights->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_weights_reshape));
- _weights_reshape_kernel.configure(weights, &_weights_reshaped, append_bias ? biases : nullptr);
+ _weights_reshaped.allocator()->init(weights->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_weights_reshape).set_data_layout(DataLayout::NCHW));
+ _weights_reshape_kernel.configure(weights_to_use, &_weights_reshaped, append_bias ? biases : nullptr);
// GEMV configuration
DataType v2mm_dt = (input->info()->data_type() == DataType::QASYMM8) ? DataType::S32 : input->info()->data_type();
- TensorShape shape_v2mm_out = input->info()->tensor_shape();
+ TensorShape shape_v2mm_out = input_to_use->info()->tensor_shape();
shape_v2mm_out.set(0, conv_size * weights_z);
shape_v2mm_out.set(1, 1);
shape_v2mm_out.set(2, 1);
- _v2mm_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(v2mm_dt).set_tensor_shape(shape_v2mm_out));
+ _v2mm_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(v2mm_dt).set_tensor_shape(shape_v2mm_out).set_data_layout(DataLayout::NCHW));
_v2mm_kernel.configure(&_input_reshaped, &_weights_reshaped, &_v2mm_output);
_output_reshaped.allocator()->init(_v2mm_output.info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape));
- _vector_to_tensor_kernel.configure(&_v2mm_output, (_is_quantized) ? &_output_reshaped : output, conv_w, conv_h);
+ _vector_to_tensor_kernel.configure(&_v2mm_output, (_is_quantized) ? &_output_reshaped : output_to_use, conv_w, conv_h);
// Output staged configuration
if(_is_quantized)
{
- const QuantizationInfo output_quant_info = (output->info()->total_size() == 0) ? input->info()->quantization_info() : output->info()->quantization_info();
+ const QuantizationInfo output_quant_info = output->info()->quantization_info();
float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
int output_multiplier, output_shift;
quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- _output_stage_kernel.configure(&_output_reshaped, biases, output, output_multiplier, output_shift, output_quant_info.offset);
+ _output_stage_kernel.configure(&_output_reshaped, biases, output_to_use, output_multiplier, output_shift, output_quant_info.offset);
_output_reshaped.allocator()->allocate();
}
+ if(_is_nhwc)
+ {
+ _permute_output.configure(&_permuted_output, output, PermutationVector(2U, 0U, 1U));
+
+ _permuted_input.allocator()->allocate();
+ _permuted_weights.allocator()->allocate();
+ _permuted_output.allocator()->allocate();
+ }
+
// Fill borders on inputs
PixelValue zero_in(static_cast<int32_t>(0));
PixelValue zero_w(static_cast<int32_t>(0));
@@ -260,23 +352,102 @@
// Allocate intermediate tensors
_input_reshaped.allocator()->allocate();
- _weights_reshaped.allocator()->allocate();
_v2mm_output.allocator()->allocate();
}
+Status NEDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
+ unsigned int depth_multiplier)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() != DataLayout::NCHW && input->data_layout() != DataLayout::NHWC);
+
+ // Clone output to use auto init
+ auto output_clone = output->clone();
+
+ const ITensorInfo *input_to_use = input;
+ const ITensorInfo *weights_to_use = weights;
+ const ITensorInfo *output_to_use = output_clone.get();
+
+ TensorShape permuted_input_shape = input->tensor_shape();
+ TensorShape permuted_weights_shape = weights->tensor_shape();
+ TensorInfo permuted_input;
+ TensorInfo permuted_weights;
+
+ if(input->data_layout() == DataLayout::NHWC)
+ {
+ permute(permuted_input_shape, PermutationVector(1U, 2U, 0U));
+ permute(permuted_weights_shape, PermutationVector(1U, 2U, 0U));
+
+ permuted_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_input_shape).set_data_layout(DataLayout::NCHW));
+ permuted_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_weights_shape).set_data_layout(DataLayout::NCHW));
+
+ input_to_use = &permuted_input;
+ weights_to_use = &permuted_weights;
+ }
+
+ const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
+ const bool append_bias = (biases != nullptr) && !is_quantized;
+ TensorShape output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier);
+ const size_t weights_w = weights_to_use->dimension(0);
+ const size_t weights_h = weights_to_use->dimension(1);
+ const size_t weights_z = weights_to_use->dimension(2);
+ const unsigned int conv_w = output_shape.x();
+ const unsigned int conv_h = output_shape.y();
+ const size_t patch_size = weights_w * weights_h + (append_bias ? 1 : 0);
+ const size_t conv_size = conv_w * conv_h;
+
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output_clone, input->clone()->set_tensor_shape(output_shape));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
+
+ TensorInfo permuted_output;
+ if(input->data_layout() == DataLayout::NHWC)
+ {
+ permute(output_shape, PermutationVector(1U, 2U, 0U));
+ permuted_output = TensorInfo(output_clone->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape).set_data_layout(DataLayout::NCHW));
+ output_to_use = &permuted_output;
+ }
+
+ // Im2Col configuration
+ TensorShape shape_im2col = input_to_use->tensor_shape();
+ shape_im2col.set(0, patch_size);
+ shape_im2col.set(1, conv_size);
+ shape_im2col.set(2, weights_z);
+ TensorInfo input_reshaped(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col).set_data_layout(DataLayout::NCHW));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEDepthwiseIm2ColKernel::validate(input_to_use, &input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias, depth_multiplier));
+
+ // Weights reshape configuration
+ const TensorShape shape_weights_reshape(patch_size, weights_z);
+ TensorInfo weights_reshaped(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_weights_reshape).set_data_layout(DataLayout::NCHW));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEDepthwiseWeightsReshapeKernel::validate(weights_to_use, &weights_reshaped, append_bias ? biases : nullptr));
+
+ // GEMV configuration
+ DataType v2mm_dt = (input->data_type() == DataType::QASYMM8) ? DataType::S32 : input->data_type();
+ TensorShape shape_v2mm_out = input_to_use->tensor_shape();
+ shape_v2mm_out.set(0, conv_size * weights_z);
+ shape_v2mm_out.set(1, 1);
+ shape_v2mm_out.set(2, 1);
+ TensorInfo v2mm_output(input->clone()->set_is_resizable(true).reset_padding().set_data_type(v2mm_dt).set_tensor_shape(shape_v2mm_out).set_data_layout(DataLayout::NCHW));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixVectorMultiplyKernel::validate(&input_reshaped, &weights_reshaped, &v2mm_output));
+
+ TensorInfo output_reshaped(v2mm_output.clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_to_use->tensor_shape()));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEDepthwiseVectorToTensorKernel::validate(&v2mm_output, (is_quantized) ? &output_reshaped : output_to_use, conv_w, conv_h));
+
+ if(is_quantized)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayerOutputStageKernel::validate(&output_reshaped, biases, output_to_use));
+ }
+
+ return Status{};
+}
+
void NEDepthwiseConvolutionLayer::run()
{
- // Run weights reshaping (Runs once for every configure)
- if(_is_first_run)
+ prepare();
+
+ if(_is_nhwc)
{
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
- NEScheduler::get().schedule(&_weights_reshape_kernel, Window::DimX);
- NEScheduler::get().schedule(&_v2mm_weights_fill_border, Window::DimX);
- _is_first_run = false;
-
- // Mark original weights tensor as unused
- _original_weights->mark_as_unused();
+ _permute_input.run();
}
NEScheduler::get().schedule(&_im2col_kernel, Window::DimX);
@@ -287,4 +458,30 @@
{
NEScheduler::get().schedule(&_output_stage_kernel, Window::DimX);
}
+
+ if(_is_nhwc)
+ {
+ _permute_output.run();
+ }
+}
+
+void NEDepthwiseConvolutionLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ if(_is_nhwc)
+ {
+ _permute_weights.run();
+ }
+
+ // Run reshape and mark original weights as unused
+ _weights_reshaped.allocator()->allocate();
+ NEScheduler::get().schedule(&_weights_reshape_kernel, Window::DimX);
+ NEScheduler::get().schedule(&_v2mm_weights_fill_border, Window::DimX);
+ _original_weights->mark_as_unused();
+
+ _is_prepared = true;
+ }
}
diff --git a/src/runtime/NEON/functions/NEDepthwiseSeparableConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseSeparableConvolutionLayer.cpp
index d70a668..da2e49c 100644
--- a/src/runtime/NEON/functions/NEDepthwiseSeparableConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDepthwiseSeparableConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,6 +45,14 @@
void NEDepthwiseSeparableConvolutionLayer::run()
{
+ prepare();
+
_depthwise_conv.run();
_pointwise_conv.run();
+}
+
+void NEDepthwiseSeparableConvolutionLayer::prepare()
+{
+ _depthwise_conv.prepare();
+ _pointwise_conv.prepare();
}
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
index 445864c..40e40c8 100644
--- a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
@@ -34,7 +34,7 @@
using namespace arm_compute;
NEDirectConvolutionLayer::NEDirectConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _output_stage_kernel(), _conv_kernel(), _input_border_handler(), _activationlayer_function(), _accumulator(), _has_bias(false), _is_fixed_point(false),
+ : _memory_group(std::move(memory_manager)), _output_stage_kernel(), _conv_kernel(), _input_border_handler(), _activationlayer_function(), _accumulator(), _has_bias(false),
_is_activationlayer_enabled(false), _dim_split(Window::DimZ)
{
}
@@ -54,26 +54,10 @@
// Check if bias should be added in the convolution result
_has_bias = (bias != nullptr);
- // Allocate the intermediate accumulator tensor in case of fixed point input
- _is_fixed_point = is_data_type_fixed_point(input->info()->data_type());
- if(_is_fixed_point)
+ _conv_kernel.configure(input, weights, output, conv_info);
+ if(_has_bias)
{
- const DataType promoted_dt = (input->info()->data_type() == DataType::QS8) ? DataType::QS16 : DataType::QS32;
- _accumulator.allocator()->init(TensorInfo(output->info()->tensor_shape(), 1, promoted_dt, output->info()->fixed_point_position()));
- _memory_group.manage(&_accumulator);
- _conv_kernel.configure(input, weights, &_accumulator, conv_info);
-
- // When no bias is provided, we need to downscale the accumulator tensor
- _output_stage_kernel.configure(&_accumulator, bias, output);
- _accumulator.allocator()->allocate();
- }
- else
- {
- _conv_kernel.configure(input, weights, output, conv_info);
- if(_has_bias)
- {
- _output_stage_kernel.configure(output, bias);
- }
+ _output_stage_kernel.configure(output, bias);
}
// Add zero padding XY
@@ -92,12 +76,7 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
- DataType data_type = output->data_type();
- if(is_data_type_fixed_point(data_type))
- {
- // Promote data type in case of fixed point
- data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
- }
+ DataType data_type = output->data_type();
TensorInfo accumulator(output->clone()->set_is_resizable(true).reset_padding().set_data_type(data_type));
// Validate Convolution kernel
@@ -129,7 +108,7 @@
_memory_group.acquire();
NEScheduler::get().schedule(&_conv_kernel, _dim_split);
- if(_has_bias || _is_fixed_point)
+ if(_has_bias)
{
NEScheduler::get().schedule(&_output_stage_kernel, Window::DimY);
}
diff --git a/src/runtime/NEON/functions/NEFlattenLayer.cpp b/src/runtime/NEON/functions/NEFlattenLayer.cpp
index 32edf93..1814d61 100644
--- a/src/runtime/NEON/functions/NEFlattenLayer.cpp
+++ b/src/runtime/NEON/functions/NEFlattenLayer.cpp
@@ -32,6 +32,6 @@
void NEFlattenLayer::configure(const ITensor *input, ITensor *output)
{
auto k = arm_compute::support::cpp14::make_unique<NEIm2ColKernel>();
- k->configure(input, output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, false, true);
+ k->configure(input, output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, Size2D(1U, 1U), 1, false, true);
_kernel = std::move(k);
}
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 958d081..f1606aa 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/core/Size2D.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include <algorithm>
@@ -35,120 +36,108 @@
using namespace arm_compute;
using namespace arm_compute::misc::shape_calculator;
-NEFullyConnectedLayerReshapeWeights::NEFullyConnectedLayerReshapeWeights(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _transpose_kernel(), _transpose1xW_kernel(), _transpose_output(), _transpose_weights(false), _is_batched_fc_layer(false)
+namespace
{
-}
-
-void NEFullyConnectedLayerReshapeWeights::configure(const ITensor *input, ITensor *output, bool transpose_weights, bool is_batched_fc_layer)
+Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
-
- // Perform validate step
- ARM_COMPUTE_ERROR_THROW_ON(NEFullyConnectedLayerReshapeWeights::validate(input->info(), output->info(), transpose_weights, is_batched_fc_layer));
-
- _transpose_weights = transpose_weights;
- _is_batched_fc_layer = is_batched_fc_layer;
-
- // Check if we need to transpose the weights
- if(_transpose_weights)
+ if(is_data_type_quantized_asymmetric(input.data_type()))
{
- if(_is_batched_fc_layer)
- {
- // Initialize the output tensor for transpose
- _transpose_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*input->info())));
- _memory_group.manage(&_transpose_output);
- _transpose_kernel.configure(input, &_transpose_output);
+ // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
+ // Extract and negate input and weights offset
+ const QuantizationInfo input_quantization_info(input.quantization_info().scale, -input.quantization_info().offset);
+ const QuantizationInfo weights_quantization_info(weights.quantization_info().scale, -weights.quantization_info().offset);
- // Configure transpose 1xW kernel
- _transpose1xW_kernel.configure(&_transpose_output, output);
-
- // Allocate temporary tensor used for transposing the weights
- _transpose_output.allocator()->allocate();
- }
- else
- {
- _transpose_kernel.configure(input, output);
- }
+ // Validate gemmlowp function
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(&input.clone()->set_quantization_info(input_quantization_info),
+ &weights.clone()->set_quantization_info(weights_quantization_info),
+ &output));
}
else
{
- if(_is_batched_fc_layer)
- {
- // Configure transpose 1xW kernel
- _transpose1xW_kernel.configure(input, output);
- }
- }
-}
-
-Status NEFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, const ITensorInfo *output, bool transpose_weights, bool is_batched_fc_layer)
-{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!transpose_weights && !is_batched_fc_layer, "Configuration transpose_weights=false & is_batched_fc_layer=false not supported");
-
- if(transpose_weights)
- {
- if(is_batched_fc_layer)
- {
- std::unique_ptr<ITensorInfo> use_output = output->clone();
- use_output->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*input));
-
- ARM_COMPUTE_RETURN_ON_ERROR(NETransposeKernel::validate(input, use_output.get()));
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(use_output.get(), output));
- }
- else
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NETransposeKernel::validate(input, output));
- }
- }
- else
- {
- if(is_batched_fc_layer)
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(input, output));
- }
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMM::validate(&input, &weights, nullptr, &output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */)));
}
return Status{};
}
+} // namespace
-void NEFullyConnectedLayerReshapeWeights::run()
+void NEFullyConnectedLayerReshapeWeights::configure(const ITensor *input, ITensor *output)
{
- _memory_group.acquire();
+ auto k = arm_compute::support::cpp14::make_unique<NETransposeKernel>();
+ k->configure(input, output);
+ _kernel = std::move(k);
+}
- if(_transpose_weights)
- {
- NEScheduler::get().schedule(&_transpose_kernel, Window::DimY);
- }
-
- if(_is_batched_fc_layer)
- {
- NEScheduler::get().schedule(&_transpose1xW_kernel, Window::DimY);
- }
-
- _memory_group.release();
+Status NEFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, const ITensorInfo *output)
+{
+ return NETransposeKernel::validate(input, output);
}
NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _im2col_kernel(), _reshape_weights_kernel(), _interleave4x4_kernel(), _mm_kernel(), _accumulate_biases_kernel(), _im2col_output(), _interleave4x4_output(),
- _reshape_weights_output(), _are_weights_reshaped(false), _is_batched_fc_layer(false), _linearize_input(false), _accumulate_biases(false), _original_weights(nullptr)
+ : _memory_group(std::move(memory_manager)), _im2col_kernel(), _convert_weights(), _reshape_weights_function(), _mm_gemm(), _mm_gemmlowp(), _gemmlowp_output_stage(), _accumulate_biases_kernel(),
+ _im2col_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _original_weights(nullptr), _are_weights_converted(true), _are_weights_reshaped(false),
+ _is_fc_after_conv(false), _accumulate_biases(false), _is_quantized(false), _is_prepared(false)
{
}
-void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, bool transpose_weights, bool are_weights_reshaped)
+void NEFullyConnectedLayer::configure_mm(const ITensor *input, const ITensor *weights, ITensor *output)
{
- // With the Fully Connected layer we can have 4 different cases:
- // 1) Convolution layer -> Fully Connected layer without batches
- // 2) Fully Connected layer -> Fully Connected layer without batches
- // 3) Convolution layer -> Fully Connected layer with batches
- // 4) Fully Connected layer -> Fully Connected layer with batches
+ if(_is_quantized)
+ {
+ // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
+ // Extract and negate input and weights offset
+ const QuantizationInfo input_quantization_info = input->info()->quantization_info();
+ const QuantizationInfo weights_quantization_info = weights->info()->quantization_info();
- // Expected shape before transpose and reshaping
- // Input: In x B (In and B can be multi-dimensional)
- // Weights: flat(In) x Out
- // Biases: Out
- // Output: Out x B (B can be multi-dimensional)
+ input->info()->set_quantization_info(QuantizationInfo(input_quantization_info.scale, -input_quantization_info.offset));
+ weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.scale, -weights_quantization_info.offset));
+
+ // Configure gemmlowp function
+ _mm_gemmlowp.configure(input, weights, output);
+
+ // Revert back QuantizatioInfo as input and weights could be used in other fully connected layers
+ input->info()->set_quantization_info(input_quantization_info);
+ weights->info()->set_quantization_info(weights_quantization_info);
+ }
+ else
+ {
+ // Configure matrix multiply kernel
+ _mm_gemm.configure(input, weights, nullptr, output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */));
+ }
+}
+
+void NEFullyConnectedLayer::configure_conv_fc(const ITensor *input, const ITensor *weights, ITensor *output)
+{
+ ARM_COMPUTE_ERROR_ON((weights->info()->dimension(1) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2))));
+
+ // If the fully connected layer is called after a convolution layer, the input tensor must be linearized
+
+ // Initialize output tensor for im2col
+ TensorShape shape_im2col = compute_flatten_shape(input->info());
+ _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
+
+ // Configure im2col kernel
+ _memory_group.manage(&_im2col_output);
+ _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, Size2D(1U, 1U), 1, true);
+
+ // Configure matrix multiply kernel
+ configure_mm(&_im2col_output, weights, output);
+
+ // Allocate the output tensor for im2col once all the configure methods have been called
+ _im2col_output.allocator()->allocate();
+}
+
+void NEFullyConnectedLayer::configure_fc_fc(const ITensor *input, const ITensor *weights, ITensor *output)
+{
+ ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1));
+
+ // Configure matrix multiply kernel
+ configure_mm(input, weights, output);
+}
+
+void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output,
+ FullyConnectedLayerInfo fc_info)
+{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
// Perform validate step
@@ -156,165 +145,184 @@
weights->info(),
biases != nullptr ? biases->info() : nullptr,
output->info(),
- transpose_weights,
- are_weights_reshaped));
+ fc_info));
- const int num_batch_dimensions = std::max(0, static_cast<int>(output->info()->tensor_shape().num_dimensions()) - 1);
- const int num_input_dimensions = input->info()->tensor_shape().num_dimensions() - num_batch_dimensions;
- const size_t linear_input_size = input->info()->tensor_shape().total_size_lower(num_input_dimensions);
+ _are_weights_converted = true;
+ _are_weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
+ _is_fc_after_conv = true;
+ _accumulate_biases = false;
+ _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
+ _original_weights = weights;
- _original_weights = weights;
- _linearize_input = (input->info()->tensor_shape().x() != linear_input_size) || (num_input_dimensions > 1 && linear_input_size == 1);
- _are_weights_reshaped = are_weights_reshaped;
- _accumulate_biases = biases != nullptr;
- _is_batched_fc_layer = num_batch_dimensions > 0;
-
- const size_t interleave_width = 16 / input->info()->element_size();
- const ITensor *weights_to_use = weights;
-
- if(!are_weights_reshaped && (transpose_weights || _is_batched_fc_layer))
+ // Configure gemmlowp output
+ if(_is_quantized)
{
- weights_to_use = &_reshape_weights_output;
-
- _reshape_weights_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_fully_connected_reshaped_weights_shape(weights->info(),
- transpose_weights,
- _is_batched_fc_layer, interleave_width)));
-
- // Reshape the weights
- _reshape_weights_kernel.configure(weights, &_reshape_weights_output, transpose_weights, _is_batched_fc_layer);
+ _gemmlowp_output.allocator()->init(output->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
}
- const ITensor *multiply_input = input;
-
- if(_linearize_input)
+ // Configure accumulate biases kernel for non quantized asymmetric types
+ if(biases != nullptr && !_is_quantized)
{
- _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input->info(), num_input_dimensions)));
+ _accumulate_biases = true;
- // Configure im2col kernel
- _memory_group.manage(&_im2col_output);
- _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, true);
-
- multiply_input = &_im2col_output;
- }
-
- int m = multiply_input->info()->dimension(1);
- int k = multiply_input->info()->dimension(0);
-
- if(_is_batched_fc_layer)
- {
- _interleave4x4_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_interleaved_shape(*multiply_input->info())));
-
- // Configure interleave4x4 kernel
- _memory_group.manage(&_interleave4x4_output);
- _interleave4x4_kernel.configure(multiply_input, &_interleave4x4_output);
-
- multiply_input = &_interleave4x4_output;
- }
-
- // Configure matrix multiply kernel
- _mm_kernel.configure(multiply_input, weights_to_use, output, 1.0f, _is_batched_fc_layer, GEMMReshapeInfo(m, 0 /* no transpose */, k));
-
- if(_accumulate_biases)
- {
// Configure accumulate biases kernel
_accumulate_biases_kernel.configure(output, biases);
}
- // Allocate the transpose tensor if the are_weights_reshaped flag is false and once all the configure methods have been called
- if(!are_weights_reshaped && (transpose_weights || _is_batched_fc_layer))
- {
- // Allocate the tensor for the weights reshaped
- _reshape_weights_output.allocator()->allocate();
- }
+ // With the Fully Connected layer we can have 4 different cases:
+ // 1) Convolution layer -> Fully Connected layer without batches
+ // 2) Fully Connected layer -> Fully Connected layer without batches
+ // 3) Convolution layer -> Fully Connected layer with batches
+ // 4) Fully Connected layer -> Fully Connected layer with batches
- if(_linearize_input)
- {
- _im2col_output.allocator()->allocate();
- }
+ const ITensor *weights_to_use = weights;
- if(_is_batched_fc_layer)
- {
- _interleave4x4_output.allocator()->allocate();
- }
-}
-
-Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights, bool are_weights_reshaped)
-{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, weights, output);
-
- const int num_batch_dimensions = std::max(0, static_cast<int>(output->tensor_shape().num_dimensions()) - 1);
- const int num_input_dimensions = input->tensor_shape().num_dimensions() - num_batch_dimensions;
- const size_t linear_input_size = input->tensor_shape().total_size_lower(num_input_dimensions);
-
- const bool linearize_input = (input->tensor_shape().x() != linear_input_size) || (num_input_dimensions > 1 && linear_input_size == 1);
- const bool accumulate_biases = biases != nullptr;
- const bool is_batched_fc_layer = num_batch_dimensions > 0;
-
- ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().total_size_upper(num_input_dimensions) != output->tensor_shape().total_size_upper(1));
- ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
-
- const size_t interleave_width = 16 / input->element_size();
- const ITensorInfo *weights_to_use = weights;
- std::unique_ptr<ITensorInfo> reshape_weights_output = input->clone();
-
- if(!are_weights_reshaped && (transpose_weights || is_batched_fc_layer))
- {
- reshape_weights_output->set_tensor_shape(compute_fully_connected_reshaped_weights_shape(weights, transpose_weights, is_batched_fc_layer, interleave_width));
-
- ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayerReshapeWeights::validate(weights, reshape_weights_output.get(), transpose_weights, is_batched_fc_layer));
-
- weights_to_use = reshape_weights_output.get();
- }
-
- // Check correct shape of weights
+ // Check if we have a fully connected layer with batches
+ const bool is_batched_fc_layer = output->info()->dimension(1) > 1;
if(is_batched_fc_layer)
{
- // Transpose + Transpose1xW
- ARM_COMPUTE_RETURN_ERROR_ON(weights_to_use->tensor_shape().x() != linear_input_size * interleave_width);
- ARM_COMPUTE_RETURN_ERROR_ON(weights_to_use->tensor_shape().y() != static_cast<unsigned int>(std::ceil(static_cast<float>(output->tensor_shape().x()) / interleave_width)));
+ _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->info()->tensor_shape().cbegin() + 3,
+ input->info()->tensor_shape().cend(),
+ output->info()->tensor_shape().cbegin() + 1));
}
else
{
- // Transpose
- ARM_COMPUTE_RETURN_ERROR_ON(weights_to_use->tensor_shape().x() != output->tensor_shape().x());
- ARM_COMPUTE_RETURN_ERROR_ON(weights_to_use->tensor_shape().y() != linear_input_size);
+ _is_fc_after_conv = input->info()->num_dimensions() > 1;
}
- const ITensorInfo *multiply_input = input;
- std::unique_ptr<ITensorInfo> im2col_output = input->clone();
- std::unique_ptr<ITensorInfo> interleave4x4_output = input->clone();
-
- if(linearize_input)
+ // Reshape weights if needed
+ if(!_are_weights_reshaped)
{
- im2col_output->set_tensor_shape(compute_im2col_fc_shape(input, num_input_dimensions));
-
- ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, im2col_output.get(), Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, true));
-
- multiply_input = im2col_output.get();
+ // Reshape the weights
+ _reshape_weights_function.configure(weights, &_reshape_weights_output);
+ weights_to_use = &_reshape_weights_output;
}
- int m = multiply_input->dimension(1);
- int k = multiply_input->dimension(0);
+ // Convert weights if needed
+ if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout))
+ {
+ // Convert weights
+ _convert_weights.configure(weights_to_use,
+ &_converted_weights_output,
+ input->info()->tensor_shape(),
+ fc_info.weights_trained_layout);
+
+ weights_to_use = &_converted_weights_output;
+ _are_weights_converted = false;
+ }
+
+ ITensor *tmp_output = (_is_quantized) ? &_gemmlowp_output : output;
+ if(_is_fc_after_conv)
+ {
+ // Fully Connected layer after a Convolution Layer without batches
+ configure_conv_fc(input, weights_to_use, tmp_output);
+ }
+ else
+ {
+ // Fully Connected layer after a Fully Connected Layer without batches
+ configure_fc_fc(input, weights_to_use, tmp_output);
+ }
+
+ // Configure output stage for asymmetric quantized types
+ if(_is_quantized)
+ {
+ float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output->info()->quantization_info().scale;
+ int output_multiplier, output_shift;
+ quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+ _gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, output_multiplier, output_shift, output->info()->quantization_info().offset);
+ _gemmlowp_output.allocator()->allocate();
+ }
+
+ _are_weights_reshaped = _are_weights_reshaped || fc_info.retain_internal_weights;
+}
+
+Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+ FullyConnectedLayerInfo fc_info)
+{
+ ARM_COMPUTE_UNUSED(fc_info.retain_internal_weights);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
+
+ bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
+ bool is_fc_after_conv = true;
+ bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
+
+ const ITensorInfo &im2col_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(input)));
+ const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
+ const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone());
+ const ITensorInfo &gemmlowp_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
+
+ // Configure accumulate biases kernel for non quantized asymmetric types
+ if(biases != nullptr && !is_quantized)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixAccumulateBiasesKernel::validate(output, biases));
+ }
+
+ // With the Fully Connected layer we can have 4 different cases:
+ // 1) Convolution layer -> Fully Connected layer without batches
+ // 2) Fully Connected layer -> Fully Connected layer without batches
+ // 3) Convolution layer -> Fully Connected layer with batches
+ // 4) Fully Connected layer -> Fully Connected layer with batches
+
+ const ITensorInfo *input_to_use = input;
+ const ITensorInfo *weights_to_use = weights;
+ const ITensorInfo *tmp_output = (is_quantized) ? &gemmlowp_output : output;
+
+ // Check if we have a fully connected layer with batches
+ const bool is_batched_fc_layer = output->dimension(1) > 1;
if(is_batched_fc_layer)
{
- interleave4x4_output->set_tensor_shape(compute_interleaved_shape(*multiply_input));
-
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(multiply_input, interleave4x4_output.get()));
-
- multiply_input = interleave4x4_output.get();
+ is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->tensor_shape().cbegin() + 3,
+ input->tensor_shape().cend(),
+ output->tensor_shape().cbegin() + 1));
+ }
+ else
+ {
+ is_fc_after_conv = input->num_dimensions() > 1;
}
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(multiply_input, weights_to_use, output, 1.0f, is_batched_fc_layer, GEMMReshapeInfo(m, 0 /* no transpose */, k)));
-
- if(accumulate_biases)
+ if(!weights_reshaped)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
- ARM_COMPUTE_RETURN_ERROR_ON(biases->tensor_shape().x() != output->tensor_shape().x());
+ // Validate reshape weights kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights));
+ weights_to_use = &reshaped_weights;
+ }
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixAccumulateBiasesKernel::validate(output, biases));
+ if(is_fc_after_conv && (input->data_layout() != fc_info.weights_trained_layout))
+ {
+ // Validate convert weights kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(NEConvertFullyConnectedWeights::validate(weights_to_use,
+ &converted_weights,
+ input->tensor_shape(),
+ fc_info.weights_trained_layout));
+ weights_to_use = &converted_weights;
+ }
+
+ if(is_fc_after_conv)
+ {
+ // Fully Connected layer after a Convolution Layer without batches
+ ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (input->dimension(0) * input->dimension(1) * input->dimension(2))));
+
+ // Validate im2col kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2col_input, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false, Size2D(1U, 1U), 1, true));
+ input_to_use = &im2col_input;
+ }
+ else
+ {
+ // Fully Connected layer after a Fully Connected Layer without batches
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != weights_to_use->dimension(1));
+ }
+ // Validate matrix multiply kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, *tmp_output));
+
+ // Validate output stage for asymmetric quantized types
+ if(is_quantized)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(&gemmlowp_output, biases, output));
}
return Status{};
@@ -322,40 +330,94 @@
void NEFullyConnectedLayer::run()
{
- // Reshape of the weights (happens only once)
- if(!_are_weights_reshaped)
- {
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
- _are_weights_reshaped = true;
- _reshape_weights_kernel.run();
-
- // Mark original weights tensor as unused
- _original_weights->mark_as_unused();
- }
+ prepare();
_memory_group.acquire();
// Linearize input if it comes from a convolutional layer
- if(_linearize_input)
+ if(_is_fc_after_conv)
{
NEScheduler::get().schedule(&_im2col_kernel, Window::DimY);
}
- // Interleave input
- if(_is_batched_fc_layer)
+ // Run matrix multiply
+ if(_is_quantized)
{
- NEScheduler::get().schedule(&_interleave4x4_kernel, Window::DimY);
+ _mm_gemmlowp.run();
+ }
+ else
+ {
+ _mm_gemm.run();
}
- // Run matrix multiply
- NEScheduler::get().schedule(&_mm_kernel, _is_batched_fc_layer ? Window::DimY : Window::DimX);
-
// Accumulate biases if provided
- if(_accumulate_biases)
+ if(_is_quantized)
{
- NEScheduler::get().schedule(&_accumulate_biases_kernel, Window::DimY);
+ _gemmlowp_output_stage.run();
+ }
+ else
+ {
+ if(_accumulate_biases)
+ {
+ NEScheduler::get().schedule(&_accumulate_biases_kernel, Window::DimY);
+ }
}
_memory_group.release();
}
+
+void NEFullyConnectedLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ auto release_unused = [](Tensor * w)
+ {
+ if(!w->is_used())
+ {
+ w->allocator()->free();
+ }
+ };
+
+ // Pointer to current weights
+ const ITensor *cur_weights = _original_weights;
+
+ // Reshape of the weights (happens only once)
+ if(!_are_weights_reshaped)
+ {
+ // Run reshape weights kernel and mark weights as unused
+ _reshape_weights_output.allocator()->allocate();
+ _reshape_weights_function.run();
+
+ cur_weights->mark_as_unused();
+ cur_weights = &_reshape_weights_output;
+ _are_weights_reshaped = true;
+ }
+
+ // Convert weights if needed (happens only once)
+ if(!_are_weights_converted)
+ {
+ _converted_weights_output.allocator()->allocate();
+ _convert_weights.run();
+
+ cur_weights->mark_as_unused();
+ _are_weights_converted = true;
+ }
+
+ // Release reshaped weights if unused
+ release_unused(&_reshape_weights_output);
+
+ // Prepare GEMM prepare and release unused weights
+ if(!_is_quantized)
+ {
+ _mm_gemm.prepare();
+ }
+
+ // Release converted weights if unused
+ release_unused(&_reshape_weights_output);
+ release_unused(&_converted_weights_output);
+
+ _is_prepared = true;
+ }
+}
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 9168ed4..de51266 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -23,72 +23,56 @@
*/
#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
-#include "arm_compute/runtime/NEON/AssemblyHelper.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "arm_compute/runtime/TensorAllocator.h"
#include "support/ToolchainSupport.h"
#include <cmath>
+using namespace arm_compute::misc::shape_calculator;
+
namespace arm_compute
{
NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(), _ma_kernel(), _tmp_a(), _tmp_b(), _workspace(), _B_pretransposed(),
- _run_vector_matrix_multiplication(false), _run_addition(false), _is_first_run(true), _reshape_b_only_on_first_run(false)
+ : _memory_group(memory_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager), _ma_kernel(), _tmp_a(), _tmp_b(), _original_b(nullptr),
+ _run_vector_matrix_multiplication(false), _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
{
}
void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16, DataType::QS8, DataType::QS16);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, d);
- ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(0) != b->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
- ARM_COMPUTE_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
- ARM_COMPUTE_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
-
- if(c != nullptr)
- {
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(c, 1, DataType::F32, DataType::F16, DataType::QS8, DataType::QS16);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, c);
- ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(1) != c->info()->dimension(1), "The C matrix must have the same number of rows as the matrix A");
- ARM_COMPUTE_ERROR_ON_MSG(b->info()->dimension(0) != c->info()->dimension(0), "The C matrix must have the same number of columns as the matrix B");
- ARM_COMPUTE_ERROR_ON_MSG(c->info()->dimension(0) != d->info()->dimension(0), "The C matrix must have the same number of rows as the output matrix");
- ARM_COMPUTE_ERROR_ON_MSG(c->info()->dimension(1) != d->info()->dimension(1), "The C matrix must have the same number of columns as the output matrix");
- }
+ ARM_COMPUTE_ERROR_THROW_ON(NEGEMM::validate(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info));
// Check if we need to reshape the matrix B only on the first run
+ _is_prepared = false;
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
_run_vector_matrix_multiplication = a->info()->dimension(1) < 2;
+ _original_b = b;
- const bool run_optimised = a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f)
- && setup_assembly_kernel(a, b, d, alpha, beta, _reshape_b_only_on_first_run, _workspace, _B_pretransposed, _memory_group, _asm_glue);
+ bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, _reshape_b_only_on_first_run));
- // Check if the first input tensor is a vector.
- // If so, all the kernels for reshaping the tensors can be skipped
- if(_run_vector_matrix_multiplication)
+ if(run_optimised)
{
- if(!run_optimised)
+ _asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run);
+ ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured());
+ }
+ else
+ {
+ if(_run_vector_matrix_multiplication)
{
// Configure the matrix multiply kernel
_mm_kernel.configure(a, b, d, alpha, false);
}
-
- // Configure matrix addition kernel
- if(beta != 0 && c != nullptr)
- {
- _ma_kernel.configure(c, d, beta);
- _run_addition = true;
- }
- }
- else
- {
- if(!run_optimised)
+ else
{
TensorShape shape_tmp_a = a->info()->tensor_shape();
TensorShape shape_tmp_b = b->info()->tensor_shape();
@@ -100,8 +84,8 @@
shape_tmp_b.set(0, b->info()->dimension(1) * transpose_w);
shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / static_cast<float>(transpose_w)));
- TensorInfo info_a(shape_tmp_a, 1, a->info()->data_type(), a->info()->fixed_point_position());
- TensorInfo info_b(shape_tmp_b, 1, b->info()->data_type(), a->info()->fixed_point_position());
+ TensorInfo info_a(shape_tmp_a, 1, a->info()->data_type());
+ TensorInfo info_b(shape_tmp_b, 1, b->info()->data_type());
_tmp_a.allocator()->init(info_a);
_tmp_b.allocator()->init(info_b);
@@ -128,42 +112,135 @@
// Allocate once the all configure methods have been called
_tmp_a.allocator()->allocate();
- _tmp_b.allocator()->allocate();
-
- // Configure matrix addition kernel
- if(beta != 0 && c != nullptr)
+ if(!_reshape_b_only_on_first_run)
{
- _ma_kernel.configure(c, d, beta);
- _run_addition = true;
+ _tmp_b.allocator()->allocate();
}
}
+
+ // Configure matrix addition kernel
+ if(beta != 0 && c != nullptr)
+ {
+ _ma_kernel.configure(c, d, beta);
+ _run_addition = true;
+ }
}
}
+Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+ ARM_COMPUTE_UNUSED(alpha);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
+
+ if(c != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->dimension(0), "The C matrix must have the same number of columns as the matrix B");
+ }
+
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != output->dimension(0));
+ if(gemm_info.depth_output_gemm3d() != 1)
+ {
+ if(gemm_info.reinterpret_input_as_3d())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(2) != output->dimension(2));
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1) * output->dimension(2));
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
+ }
+ }
+
+ // Check if we need to run the optimized assembly kernel
+ const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, true));
+
+ if(!run_optimised)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "NEGEMM cannot reinterpret the input tensor as 3D");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 1, "NEGEMM cannot reinterpret the output tensor as 3D");
+
+ // Check if the first input tensor is a vector.
+ const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
+ // Check if we need to reshape the matrix A and matrix B
+ const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
+
+ // Arguments used by GEMMReshapeInfo
+ // If we pass the matrix A and matrix B reshaped to NEGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to NEGEMMReshapeInfo
+ // in order to know how the matrices have been reshaped
+ const int m = a->dimension(1);
+ const int n = b->dimension(0);
+ const int k = a->dimension(0);
+ int mult_transpose1xW_width = 1;
+ int mult_interleave4x4_height = 1;
+
+ const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
+
+ const ITensorInfo *matrix_a_info = a;
+ const ITensorInfo *matrix_b_info = b;
+
+ TensorInfo tmp_a_info{};
+ TensorInfo tmp_b_info{};
+ TensorInfo tmp_output_info = *output->clone();
+
+ if(run_interleave_transpose)
+ {
+ matrix_a_info = &tmp_a_info;
+ matrix_b_info = &tmp_b_info;
+
+ // Validate interleave kernel
+ auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, &tmp_a_info));
+
+ // Validate transpose kernel
+ auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, &tmp_b_info));
+ }
+
+ // Validate matrix multiply
+ auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
+ }
+
+ return Status{};
+}
+
void NEGEMM::run()
{
- _memory_group.acquire();
+ prepare();
- if(_asm_glue._optimised_kernel != nullptr)
+ if(_asm_glue.is_configured())
{
+ _memory_group.acquire();
_asm_glue.run();
_memory_group.release();
}
else
{
+ _memory_group.acquire();
+
if(!_run_vector_matrix_multiplication)
{
// Run interleave kernel
NEScheduler::get().schedule(&_interleave_kernel, Window::DimY);
- if(_is_first_run)
- {
- // Run transpose kernel
- NEScheduler::get().schedule(&_transpose_kernel, Window::DimY);
-
- _is_first_run = false;
- }
- else if(!_reshape_b_only_on_first_run)
+ if(!_reshape_b_only_on_first_run)
{
// Run transpose kernel
NEScheduler::get().schedule(&_transpose_kernel, Window::DimY);
@@ -181,4 +258,27 @@
}
}
}
+
+void NEGEMM::prepare()
+{
+ if(!_is_prepared)
+ {
+ if(_asm_glue.is_configured())
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+
+ _asm_glue.prepare();
+ }
+ else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue.is_configured())
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+
+ _tmp_b.allocator()->allocate();
+ NEScheduler::get().schedule(&_transpose_kernel, Window::DimY);
+ _original_b->mark_as_unused();
+ }
+
+ _is_prepared = true;
+ }
+}
} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
new file mode 100644
index 0000000..29db654
--- /dev/null
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -0,0 +1,448 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
+
+#include "arm_compute/core/CPP/Validate.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
+#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
+
+#include <arm_neon.h>
+
+namespace arm_compute
+{
+namespace
+{
+std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+
+{
+ //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
+ switch(method)
+ {
+ case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
+ {
+ if(!pretranspose_hint)
+ {
+ return nullptr;
+ }
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ function->configure(a, b, d, alpha, beta, pretranspose_hint);
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+}
+
+template <typename TypeInput, typename TypeOutput>
+std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ ARM_COMPUTE_UNUSED(method);
+ ARM_COMPUTE_UNUSED(a);
+ ARM_COMPUTE_UNUSED(b);
+ ARM_COMPUTE_UNUSED(d);
+ ARM_COMPUTE_UNUSED(alpha);
+ ARM_COMPUTE_UNUSED(beta);
+ ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_UNUSED(memory_manager);
+ return nullptr;
+}
+
+#ifdef __aarch64__
+template <>
+std::unique_ptr<IFunction> create_function<int8_t, int32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ switch(method)
+ {
+ case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+ {
+ if(!pretranspose_hint)
+ {
+ return nullptr;
+ }
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+ return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<uint8_t, uint32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ switch(method)
+ {
+ case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+ {
+ if(!pretranspose_hint)
+ {
+ return nullptr;
+ }
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+ return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_UNUSED(memory_manager);
+ switch(method)
+ {
+ case arm_gemm::GemmMethod::GEMM_NATIVE:
+ {
+ auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
+ kernel->configure(a, b, d, alpha, beta);
+ auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
+ function->configure(std::move(kernel));
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+}
+#endif /* __aarch64__ */
+
+/** Fallback in case ACL doesn't have a function */
+template <typename TypeInput, typename TypeOutput>
+class Fallback : public NEGEMMAssemblyDispatch::IFallback
+{
+public:
+ void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group);
+ void run() override;
+ void prepare() override;
+ bool is_configured() const override;
+
+private:
+ /** Allocate a workspace tensor.
+ *
+ * @param[in] workspace_size Size to allocate.
+ * @param[in] memory_group Tensor memory group.
+ * @param[in] alignment Workspace memory alignment.
+ */
+ void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
+
+ /** Assembly Gemm kernel */
+ std::unique_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
+ /** Optimised NEON kernel */
+ std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
+ /** Input A */
+ const ITensor *_a
+ {
+ nullptr
+ };
+ /** Input B */
+ const ITensor *_b
+ {
+ nullptr
+ };
+ /** Output */
+ ITensor *_d{ nullptr };
+ /** GEMM workspace */
+ Tensor _workspace{};
+ /** Pre-transpose tensor */
+ Tensor _pretranspose{};
+ /** Prepared flag */
+ bool _is_prepared{ false };
+};
+
+template <typename TypeInput, typename TypeOutput>
+void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
+{
+ _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args, nullptr);
+ if(_gemm_kernel_asm == nullptr)
+ {
+ //configuration not supported: Leave function unconfigured:
+ return;
+ }
+
+ // arm_compute wrapper for the Gemm object (see above)
+ std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>();
+ ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
+ acl_gemm_wrapper->configure(_gemm_kernel_asm.get());
+ const size_t workspace_size = _gemm_kernel_asm->get_working_size();
+ if(workspace_size > 0)
+ {
+ // Allocate workspace
+ const unsigned int alignment = 4096;
+ allocate_workspace(workspace_size, memory_group, alignment);
+ }
+
+ //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
+ //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
+ {
+ const int window_size = _gemm_kernel_asm->get_window_size();
+ if(window_size < args._maxthreads)
+ {
+ _gemm_kernel_asm->set_nthreads(window_size);
+ }
+ }
+
+ _optimised_kernel = std::move(acl_gemm_wrapper);
+ _a = a;
+ _b = b;
+ _d = d;
+ // Check for pre-transposed support
+ if(_gemm_kernel_asm->B_pretranspose_required())
+ {
+ // Forcing 128-byte alignment (required by 32-bit kernels)
+ const unsigned int alignment = 128;
+ const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
+ _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment) }, 1, DataType::S8), alignment);
+ _pretranspose.allocator()->allocate();
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_pretranspose.buffer());
+ }
+}
+
+template <typename TypeInput, typename TypeOutput>
+void Fallback<TypeInput, TypeOutput>::prepare()
+{
+ if(!_is_prepared)
+ {
+ // Pretranspose B if required
+ if(_gemm_kernel_asm->B_pretranspose_required())
+ {
+ ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
+ const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+
+ _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
+ _b->mark_as_unused();
+ }
+
+ _is_prepared = true;
+ }
+}
+
+template <typename TypeInput, typename TypeOutput>
+void Fallback<TypeInput, TypeOutput>::allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment)
+{
+ ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
+ _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment) }, 1, DataType::S8), alignment);
+ memory_group.manage(&_workspace);
+ _workspace.allocator()->allocate();
+}
+
+template <typename TypeInput, typename TypeOutput>
+bool Fallback<TypeInput, TypeOutput>::is_configured() const
+{
+ return _optimised_kernel != nullptr;
+}
+
+template <typename TypeInput, typename TypeOutput>
+void Fallback<TypeInput, TypeOutput>::run()
+{
+ const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ int ldb = 0;
+ const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
+
+ // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
+ // the relevant multiple of the row stride.
+ const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC;
+ const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
+
+ const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
+ const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
+
+ const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+ int multi_stride_b = 0;
+ const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
+
+ const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
+ const TypeInput *in1_ptr = nullptr;
+ auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes());
+
+ // Check if B is pre-tranposed and de-reference if not
+ if(!_gemm_kernel_asm->B_is_pretransposed())
+ {
+ ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
+ }
+
+ // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
+ if(_workspace.buffer() != nullptr)
+ {
+ _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace.buffer()));
+ const unsigned int window_size = _gemm_kernel_asm->get_window_size();
+ unsigned int num_threads = NEScheduler::get().num_threads();
+ if(window_size < num_threads)
+ {
+ num_threads = window_size;
+ _gemm_kernel_asm->set_nthreads(num_threads);
+ }
+ }
+
+ // Prepare assembly kernel
+ prepare();
+
+ // Set gemm parameters
+ _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
+
+ // Schedule assembly kernel
+ NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
+}
+
+template <typename TypeInput, typename TypeOutput>
+void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
+ ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
+{
+ INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ unsigned int num_threads = NEScheduler::get().num_threads();
+
+ arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+
+ //Try to create an ACL function:
+ acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+ // If the type agnostic factory failed to create an ACL function, try the specialised one:
+ if(acl_function == nullptr)
+ {
+ acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+ }
+ //If we still don't have an ACL function:
+ if(acl_function == nullptr)
+ {
+ //Fallback onto arm_gemm function if ACL doesn't support this method.
+ auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
+ fallback->configure(a, b, d, args, memory_group);
+ arm_gemm = std::move(fallback);
+ }
+}
+
+} //namespace
+
+NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
+ : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
+{
+}
+
+Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint)
+{
+ ARM_COMPUTE_UNUSED(alpha);
+ ARM_COMPUTE_UNUSED(beta);
+ ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
+#ifndef __aarch64__
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 || a->data_type() == DataType::S8 || a->data_type() == DataType::QASYMM8, "8bit integer types only supported for aarch64");
+#endif /* __aarch64__ */
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::U8, DataType::QASYMM8, DataType::S8, DataType::F16);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::U32, "Only U32/S32 output supported for QASYMM8 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
+ return Status{};
+}
+
+void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(a);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(b);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(d);
+
+ //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
+ if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint))
+ {
+ return;
+ }
+
+ switch(a->info()->data_type())
+ {
+ case DataType::F32:
+ create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ break;
+#ifdef __aarch64__
+ case DataType::U8:
+ case DataType::QASYMM8:
+ create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ break;
+ case DataType::S8:
+ create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ break;
+#endif /* __aarch64__ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ default:
+ break;
+ }
+}
+
+void NEGEMMAssemblyDispatch::prepare()
+{
+ if(_function != nullptr)
+ {
+ _function->prepare();
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
+ _arm_gemm->prepare();
+ }
+}
+
+bool NEGEMMAssemblyDispatch::is_configured() const
+{
+ return (_arm_gemm != nullptr && _arm_gemm->is_configured()) || _function != nullptr;
+}
+
+void NEGEMMAssemblyDispatch::run()
+{
+ _memory_group.acquire();
+ if(_function != nullptr)
+ {
+ _function->run();
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
+ _arm_gemm->run();
+ }
+ _memory_group.release();
+}
+} //namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index 2888b43..92e641e 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -23,10 +23,10 @@
*/
#include "arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h"
-#include "arm_compute/core/PixelValue.h"
#include "arm_compute/core/Size2D.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "support/ToolchainSupport.h"
@@ -34,98 +34,50 @@
#include <cmath>
#include <tuple>
-namespace
-{
-arm_compute::TensorShape get_reshaped_weights_shape(const arm_compute::ITensorInfo *weights, bool append_bias)
-{
- const unsigned int mat_weights_cols = weights->dimension(3);
- const unsigned int mat_weights_rows = weights->dimension(0) * weights->dimension(1) * weights->dimension(2) + (append_bias ? 1 : 0);
- return arm_compute::TensorShape(mat_weights_cols, mat_weights_rows);
-}
-} // namespace
+using namespace arm_compute;
+using namespace arm_compute::misc::shape_calculator;
-namespace arm_compute
-{
-NEConvolutionLayerReshapeWeights::NEConvolutionLayerReshapeWeights(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _weights_reshape_kernel(), _weights_transposed_kernel(), _weights_reshaped(), _transpose1xW(false)
+NEConvolutionLayerReshapeWeights::NEConvolutionLayerReshapeWeights()
+ : _weights_reshape_kernel()
{
}
-void NEConvolutionLayerReshapeWeights::configure(const ITensor *weights, const ITensor *biases, ITensor *output, bool transpose1xW)
+void NEConvolutionLayerReshapeWeights::configure(const ITensor *weights, const ITensor *biases, ITensor *output)
{
// Perform validation step
ARM_COMPUTE_ERROR_ON_NULLPTR(weights, output);
ARM_COMPUTE_ERROR_THROW_ON(NEConvolutionLayerReshapeWeights::validate(weights->info(),
(biases != nullptr) ? biases->info() : nullptr,
- output->info(),
- transpose1xW));
+ output->info()));
- // Check if bias are present, if yes they will be embedded to the weights matrix
- const bool append_biases = (biases != nullptr) && !is_data_type_quantized_asymmetric(weights->info()->data_type());
- //const unsigned bias_element = (append_biases) ? 1 : 0;
+ const bool append_biases = (biases != nullptr) && !is_data_type_quantized_asymmetric(weights->info()->data_type());
const ITensor *biases_to_use = (append_biases) ? biases : nullptr;
- _transpose1xW = transpose1xW;
-
- if(transpose1xW)
- {
- // Create tensor to store the reshaped weights
- TensorInfo info_wr = weights->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(get_reshaped_weights_shape(weights->info(), append_biases));
-
- _weights_reshaped.allocator()->init(info_wr);
- _memory_group.manage(&_weights_reshaped);
-
- _weights_reshape_kernel.configure(weights, biases, &_weights_reshaped);
- _weights_transposed_kernel.configure(&_weights_reshaped, output);
-
- _weights_reshaped.allocator()->allocate();
- }
- else
- {
- _weights_reshape_kernel.configure(weights, biases_to_use, output);
- }
+ _weights_reshape_kernel.configure(weights, biases_to_use, output);
output->info()->set_quantization_info(weights->info()->quantization_info());
}
-Status NEConvolutionLayerReshapeWeights::validate(const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose1xW)
+Status NEConvolutionLayerReshapeWeights::validate(const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(weights);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
- if(!is_data_type_quantized_asymmetric(weights->data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(weights, output);
- }
- // Check if bias are present, if yes they will be embedded to the weights matrix
- const bool append_bias = (biases != nullptr);
- if(append_bias)
+ if(biases != nullptr)
{
+ const int idx_kernels = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::BATCHES);
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(weights->data_type()));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(weights, biases);
- ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3));
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels));
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
- // Checks performed when biases are present
- if(append_bias)
+ if((output != nullptr) && (output->total_size() != 0))
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
- ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3));
- ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
- }
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, output);
- if(transpose1xW)
- {
- TensorInfo weights_reshaped = weights->clone()->set_tensor_shape(get_reshaped_weights_shape(weights, append_bias));
- ARM_COMPUTE_RETURN_ON_ERROR(NEWeightsReshapeKernel::validate(weights, biases, &weights_reshaped));
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(&weights_reshaped, output));
- }
- else
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NEWeightsReshapeKernel::validate(weights, biases, output));
+ NEWeightsReshapeKernel::validate(weights, biases, output);
}
return Status{};
@@ -133,110 +85,21 @@
void NEConvolutionLayerReshapeWeights::run()
{
- _memory_group.acquire();
-
NEScheduler::get().schedule(&_weights_reshape_kernel, 3);
-
- if(_transpose1xW)
- {
- NEScheduler::get().schedule(&_weights_transposed_kernel, Window::DimY);
- }
-
- _memory_group.release();
}
-namespace
-{
-TensorShape get_reshaped_weights_shape_conv(const ITensorInfo *weights, bool append_bias, bool is_fully_connected_convolution)
-{
- unsigned int mat_weights_cols = weights->dimension(3);
- unsigned int mat_weights_rows = weights->dimension(0) * weights->dimension(1) * weights->dimension(2) + (append_bias ? 1 : 0);
-
- if(is_fully_connected_convolution)
- {
- // Create tensor to store the reshaped weights
- return TensorShape(mat_weights_cols, mat_weights_rows);
- }
- else
- {
- // Create tensor to store transposed weights
- const float transpose_width = 16.0f / weights->element_size();
- return TensorShape(mat_weights_rows * static_cast<unsigned int>(transpose_width), static_cast<unsigned int>(std::ceil(mat_weights_cols / transpose_width)));
- }
-}
-
-Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
- const ActivationLayerInfo &act_info, DataType &dt,
- bool &append_bias, bool &skip_im2col,
- bool &are_weights_reshaped, unsigned int &kernel_width, unsigned int &kernel_height,
- bool &is_fully_connected_convolution, bool &is_interleaved, bool &is_quantized, bool &is_activationlayer_enabled,
- unsigned int &mat_weights_cols, unsigned int &mat_weights_rows,
- unsigned int &conv_w, unsigned int &conv_h, const Size2D &dilation)
-{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, weights);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights);
-
- DataLayout data_layout = input->data_layout();
- const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
- const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
-
- ARM_COMPUTE_RETURN_ERROR_ON(!weights_info.are_reshaped() && weights->dimension(idx_channel) != input->dimension(idx_channel));
- ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
- ARM_COMPUTE_RETURN_ERROR_ON(weights_info.are_reshaped() && is_data_type_quantized_asymmetric(input->data_type()));
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32, "NHWC is only supported for FP32 data type.");
-
- dt = input->data_type();
- is_quantized = is_data_type_quantized_asymmetric(dt);
-
- if(biases != nullptr)
- {
- if(is_quantized)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
- }
- else
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
- }
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases);
- ARM_COMPUTE_RETURN_ERROR_ON(!weights_info.are_reshaped() && biases->dimension(0) != weights->dimension(3));
- ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
- }
-
- // If we have 1x1 convolution and data layout is NHWC we can disable im2col
- append_bias = (biases != nullptr) && (!is_quantized);
- are_weights_reshaped = weights_info.are_reshaped();
- kernel_width = (are_weights_reshaped) ? weights_info.kernel_size().first : weights->dimension(idx_width);
- kernel_height = (are_weights_reshaped) ? weights_info.kernel_size().second : weights->dimension(idx_height);
- mat_weights_cols = weights->dimension(3);
- mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + ((append_bias && !skip_im2col) ? 1 : 0);
- skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1);
-
- std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(idx_width), input->dimension(idx_height), kernel_width, kernel_height,
- conv_info, dilation);
-
- // Check if its a "fully connected" convolution
- is_fully_connected_convolution = ((conv_w == 1) && (conv_h == 1));
- is_interleaved = (!is_fully_connected_convolution && !is_quantized);
- is_activationlayer_enabled = act_info.enabled();
-
- return Status{};
-}
-} // namespace
-
NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager)
- : _asm_glue(), _memory_group(memory_manager), _input_im2col_kernel(), _input_interleave_kernel(), _reshape_weights(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
- _output_col2im_kernel(), _activationlayer_function(), _add_bias_kernel(), _original_weights(nullptr), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(),
- _tmp_output(), _workspace(), _B_pretransposed(), _data_layout(DataLayout::NCHW), _append_bias(false), _is_fully_connected_convolution(false), _are_weights_reshaped(false), _is_quantized(false),
- _is_interleaved(false), _is_activationlayer_enabled(false), _skip_im2col(false)
+ : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _col2im_kernel(), _activationlayer_function(),
+ _add_bias_kernel(), _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false),
+ _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
{
}
-void NEGEMMConvolutionLayer::configure_mm(const ITensor *input, const ITensor *weights, ITensor *output, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
+void NEGEMMConvolutionLayer::configure_mm(const ITensor *input, const ITensor *weights, ITensor *output, int gemm_3d_depth)
{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info(), gemm_3d_depth, _skip_im2col));
+
if(_is_quantized)
{
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -255,128 +118,145 @@
}
else
{
- _mm_kernel.configure(input, weights, output, 1.f, is_interleaved, reshape_info);
+ // Configure matrix multiply function
+ _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth,
+ _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */));
}
}
-void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
- const Size2D &dilation, const ActivationLayerInfo &act_info)
+Status NEGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth, bool skip_im2col)
{
- // Perform validate step
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
+ const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
- DataType dt{};
- unsigned int kernel_width = 0;
- unsigned int kernel_height = 0;
- unsigned int mat_weights_cols = 0;
- unsigned int mat_weights_rows = 0;
- unsigned int conv_w = 0;
- unsigned int conv_h = 0;
-
- _data_layout = input->info()->data_layout();
- const bool is_nhwc = _data_layout == DataLayout::NHWC;
- const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
- const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
- const int idx_channel = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::CHANNEL);
-
- Status status = validate_and_initialize_values(input->info(), weights->info(), (biases == nullptr) ? nullptr : biases->info(), conv_info, weights_info, act_info, dt, _append_bias, _skip_im2col,
- _are_weights_reshaped,
- kernel_width, kernel_height,
- _is_fully_connected_convolution, _is_interleaved, _is_quantized, _is_activationlayer_enabled,
- mat_weights_cols, mat_weights_rows, conv_w, conv_h, dilation);
-
- ARM_COMPUTE_ERROR_THROW_ON(status);
-
- _original_weights = weights;
- const unsigned int fixed_point_position = input->info()->fixed_point_position();
- const ITensor *biases_to_use = (_append_bias) ? biases : nullptr;
-
- bool run_optimised = dt == DataType::F32;
-
- // Reshape weights if needed
- if(run_optimised)
+ const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col);
+ if(is_quantized)
{
- TensorShape reshaped_weights_shape{ mat_weights_cols, mat_weights_rows };
+ // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
+ // Extract and negate input and weights offset
+ const QuantizationInfo input_quantization_info = input->quantization_info();
+ const QuantizationInfo weights_quantization_info = weights->quantization_info();
- // Create tensor to store the reshaped weights
- _weights_reshaped.allocator()->init(TensorInfo(reshaped_weights_shape, 1, dt, fixed_point_position));
- _reshape_weights.configure(weights, biases, &_weights_reshaped, false /* 1xW transpose */);
- weights = &_weights_reshaped;
+ std::unique_ptr<ITensorInfo> input_qa = input->clone();
+ std::unique_ptr<ITensorInfo> weights_qa = weights->clone();
+ input_qa->set_quantization_info(QuantizationInfo(input_quantization_info.scale, -input_quantization_info.offset));
+ weights_qa->set_quantization_info(QuantizationInfo(weights_quantization_info.scale, -weights_quantization_info.offset));
+
+ // Perform validation step on GEMMLowp
+ return NEGEMMLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), output, gemm_info);
}
else
{
- if(_are_weights_reshaped)
- {
- if(_is_fully_connected_convolution || _is_quantized)
- {
- mat_weights_cols = weights_info.num_kernels();
- mat_weights_rows = weights->info()->dimension(idx_height);
- }
- else
- {
- mat_weights_cols = weights_info.num_kernels();
- mat_weights_rows = weights_info.kernel_size().first * weights_info.kernel_size().second * input->info()->dimension(idx_channel) + (_append_bias ? 1 : 0);
- }
- }
- else
- {
- TensorShape reshaped_weights_shape;
+ // Perform validation step on Matrix multiply function
+ return NEGEMM::validate(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
+ }
+}
- if(_is_fully_connected_convolution || _is_quantized)
- {
- reshaped_weights_shape = TensorShape{ mat_weights_cols, mat_weights_rows };
- }
- else
- {
- // Create tensor to store transposed weights
- const float transpose_width = 16.0f / input->info()->element_size();
- reshaped_weights_shape = TensorShape{ mat_weights_rows *static_cast<unsigned int>(transpose_width),
- static_cast<unsigned int>(std::ceil(mat_weights_cols / transpose_width)) };
- }
+Status NEGEMMConvolutionLayer::validate_gemm3d(DataType data_type, int gemm_3d_depth, bool skip_im2col)
+{
+ const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
+ const DataType output_gemm_data_type = is_quantized ? DataType::S32 : data_type;
+ const unsigned int mult_y = skip_im2col ? 1U : gemm_3d_depth;
+ const unsigned int mult_z = skip_im2col ? gemm_3d_depth : 1U;
- // Create tensor to store the reshaped weights
- _weights_reshaped.allocator()->init(TensorInfo(reshaped_weights_shape, 1, dt, fixed_point_position));
- _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped, _is_interleaved /* 1xW transpose */);
- weights = &_weights_reshaped;
+ // Set dummy tensor shapes for the validation
+ const TensorInfo dummy_input_info(TensorShape(4U, 4U * mult_y, 1U * mult_z), 1, data_type);
+ const TensorInfo dummy_weights_info(TensorShape(4U, 4U), 1, data_type);
+ const TensorInfo dummy_output_info(TensorShape(4U, 4U, gemm_3d_depth), 1, output_gemm_data_type);
+
+ return validate_mm(&dummy_input_info, &dummy_weights_info, &dummy_output_info, gemm_3d_depth, skip_im2col);
+}
+
+void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
+ const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_UNUSED(num_groups);
+ ARM_COMPUTE_ERROR_THROW_ON(NEGEMMConvolutionLayer::validate(input->info(),
+ weights->info(),
+ biases != nullptr ? biases->info() : nullptr,
+ output->info(),
+ conv_info,
+ weights_info,
+ dilation,
+ act_info,
+ num_groups));
+
+ const DataType data_type = input->info()->data_type();
+ const DataLayout data_layout = input->info()->data_layout();
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+ const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
+
+ const unsigned int kernel_width = weights->info()->dimension(idx_width);
+ const unsigned int kernel_height = weights->info()->dimension(idx_height);
+
+ _is_prepared = weights_info.retain_internal_weights();
+ _original_weights = weights;
+ _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
+ _data_layout = data_layout;
+ _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+ _skip_col2im = data_layout == DataLayout::NHWC;
+ _append_bias = (biases != nullptr) && (!_is_quantized);
+
+ const ITensor *gemm_input_to_use = input;
+ ITensor *gemm_output_to_use = output;
+ ITensor *gemm_output_staged_to_use = output;
+
+ // Get convolved dimensions
+ unsigned int conv_w = 0;
+ unsigned int conv_h = 0;
+ std::tie(conv_w, conv_h) = scaled_dimensions(input->info()->dimension(idx_width),
+ input->info()->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
+
+ // Check if GEMM3D is supported
+ if(_skip_col2im)
+ {
+ // If not supported, we need to perform im2col and col2im (or reshape layer)
+ if(!bool(validate_gemm3d(input->info()->data_type(), conv_h, _skip_im2col)))
+ {
+ _skip_im2col = false;
+ _skip_col2im = false;
}
}
- // In case we skip im2col we have to add bias
+ const unsigned bias_element = (_append_bias && !_skip_im2col) ? 1 : 0;
+ const ITensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
+
+ // Get parameters from conv_info
+ unsigned int stride_x = 0;
+ unsigned int stride_y = 0;
+ std::tie(stride_x, stride_y) = conv_info.stride();
+
+ unsigned int mat_weights_cols = weights->info()->dimension(idx_kernels);
+ unsigned int mat_weights_rows = weights->info()->dimension(idx_width) * weights->info()->dimension(idx_height) * weights->info()->dimension(idx_channel) + bias_element;
+
+ // _weights_reshaped will be auto configured in the kernel.
+ // Just append biases and do not transpose 1xW as it will be reshaped in NEGEMM
+ _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped);
+
+ // Create tensor to store im2col reshaped inputs
if(!_skip_im2col)
{
- const unsigned int mat_input_cols = mat_weights_rows;
- const unsigned int mat_input_rows = conv_w * conv_h;
-
- // Create tensor to store im2col reshaped inputs
- TensorShape shape_im2col(input->info()->tensor_shape());
- shape_im2col.set(0, mat_input_cols);
- shape_im2col.set(1, mat_input_rows);
+ // Calculate im2col shape
+ // For NEON the batch size is on the fourth dimension
+ TensorShape shape_im2col = input->info()->tensor_shape();
+ shape_im2col.set(0, mat_weights_rows);
+ shape_im2col.set(1, conv_w * conv_h);
shape_im2col.set(2, 1);
- _input_im2col_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
- _memory_group.manage(&_input_im2col_reshaped);
- // Create tensor (interleave) to prepare input tensor for GEMM
- if(!_is_fully_connected_convolution && !run_optimised && _is_interleaved)
- {
- TensorShape shape_interleaved(shape_im2col);
- shape_interleaved.set(idx_width, shape_interleaved.x() * 4);
- shape_interleaved.set(idx_height, std::ceil(shape_interleaved[idx_height] / 4.f));
- _input_interleaved_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_interleaved));
- _memory_group.manage(&_input_interleaved_reshaped);
- }
+ _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
+ _memory_group.manage(&_im2col_output);
- // Create GEMM output tensor
- TensorShape shape_gemm(_input_im2col_reshaped.info()->tensor_shape());
- shape_gemm.set(0, mat_weights_cols);
- shape_gemm.set(1, mat_input_rows);
- const DataType gemm_data_type = _is_quantized ? DataType::S32 : dt;
- // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
- TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->info()->fixed_point_position());
- info_gemm.set_quantization_info(output->info()->quantization_info());
- _gemm_output.allocator()->init(info_gemm);
+ // Configure
+ _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation);
- // Configure im2col
- _input_im2col_kernel.configure(input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _append_bias, false, false, dilation);
+ // Update GEMM input
+ gemm_input_to_use = &_im2col_output;
}
else if(_append_bias)
{
@@ -384,129 +264,187 @@
_add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE);
}
- // Configure matrix multiply
- if(run_optimised)
+ // Create temporary GEMM output tensor in case we cannot skip col2im
+ if(!_skip_col2im)
{
- if(!setup_assembly_kernel(_skip_im2col ? input : &_input_im2col_reshaped, weights, is_nhwc ? output : &_gemm_output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue))
- {
- ARM_COMPUTE_ERROR("setup_assembly_kernel failed.");
- }
- }
- else
- {
- if(_is_interleaved)
- {
- // Configure GEMMInterleave4x4. _input_interleaved_reshaped will be auto configured in the kernel
- _input_interleave_kernel.configure(&_input_im2col_reshaped, &_input_interleaved_reshaped);
+ // Calculate GEMM output shape
+ TensorShape shape_gemm = _im2col_output.info()->tensor_shape();
+ shape_gemm.set(0, mat_weights_cols);
+ shape_gemm.set(1, conv_w * conv_h);
- // Configure GEMM
- configure_mm(&_input_interleaved_reshaped, weights, &_gemm_output, _is_interleaved, GEMMReshapeInfo(_input_im2col_reshaped.info()->dimension(idx_height), 0 /* no transpose */,
- _input_im2col_reshaped.info()->dimension(idx_width)));
- _input_interleaved_reshaped.allocator()->allocate();
- }
- else
- {
- configure_mm(&_input_im2col_reshaped, weights, &_gemm_output, _is_interleaved);
- }
+ // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
+ const DataType gemm_data_type = _is_quantized ? DataType::S32 : data_type;
+ TensorInfo info_gemm(shape_gemm, 1, gemm_data_type);
+ info_gemm.set_quantization_info(output->info()->quantization_info());
+ _gemm_output.allocator()->init(info_gemm);
+ _memory_group.manage(&_gemm_output);
+
+ // Update GEMM output
+ gemm_output_to_use = &_gemm_output;
}
+ // Configure GEMM
+ configure_mm(gemm_input_to_use, &_weights_reshaped, gemm_output_to_use, _skip_col2im ? conv_h : 1);
+
if(!_skip_im2col)
{
- _input_im2col_reshaped.allocator()->allocate();
+ _im2col_output.allocator()->allocate();
+ }
- // Configure output stage for quantized case
- if(_is_quantized)
+ // Configure output stage for quantized case
+ if(_is_quantized)
+ {
+ const QuantizationInfo output_quant_info = (output->info()->total_size() == 0) ? input->info()->quantization_info() : output->info()->quantization_info();
+
+ float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
+ int output_multiplier, output_shift;
+ quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+
+ _memory_group.manage(&_tmp_output);
+ gemm_output_staged_to_use = &_tmp_output;
+
+ _gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset);
+ }
+
+ if(!_skip_col2im)
+ {
+ if(_data_layout == DataLayout::NCHW)
{
- const QuantizationInfo output_quant_info = (output->info()->total_size() == 0) ? input->info()->quantization_info() : output->info()->quantization_info();
-
- float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
- int output_multiplier, output_shift;
- quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- _memory_group.manage(&_tmp_output);
- _gemmlowp_output_stage.configure(&_gemm_output, biases, &_tmp_output, output_multiplier, output_shift, output_quant_info.offset);
+ // Configure col2im
+ _col2im_kernel.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output, Size2D(conv_w, conv_h));
}
-
- // Configure Col2Im
- if(!is_nhwc)
+ else
{
- _output_col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, Size2D(conv_w, conv_h));
+ // Configure reshape layer
+ _reshape_layer.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output);
}
+ }
- if(_is_quantized)
- {
- _tmp_output.allocator()->allocate();
- }
+ if(_is_quantized)
+ {
+ _tmp_output.allocator()->allocate();
+ }
+
+ if(!_skip_col2im)
+ {
_gemm_output.allocator()->allocate();
}
- ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h), "Output shape does not match the expected one");
-
- // Allocate intermediate tensor
- if(!_are_weights_reshaped)
- {
- _weights_reshaped.allocator()->allocate();
- }
+ ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h),
+ "Output shape does not match the expected one");
//Configure Activation Layer
+ _is_activationlayer_enabled = act_info.enabled();
+
if(_is_activationlayer_enabled)
{
_activationlayer_function.configure(output, nullptr, act_info);
}
+
+ ARM_COMPUTE_UNUSED(weights_info);
}
Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
- const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info)
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
{
- ARM_COMPUTE_UNUSED(output);
-
- DataType dt{};
- bool append_bias{};
- bool skip_im2col{};
- bool are_weights_reshaped{};
- bool is_fully_connected_convolution{};
- bool is_interleaved{};
- bool is_quantized{};
- bool is_activationlayer_enabled{};
- unsigned int kernel_width = 0;
- unsigned int kernel_height = 0;
- unsigned int mat_weights_cols = 0;
- unsigned int mat_weights_rows = 0;
- unsigned int conv_w = 0;
- unsigned int conv_h = 0;
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Grouping (num_groups != 1) is not supported on NEON");
const DataLayout data_layout = input->data_layout();
- const bool is_nhwc = data_layout == DataLayout::NHWC;
+ const DataType data_type = input->data_type();
const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+ const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
- Status status = validate_and_initialize_values(input, weights, biases, conv_info, weights_info, act_info, dt, append_bias, skip_im2col, are_weights_reshaped, kernel_width, kernel_height,
- is_fully_connected_convolution, is_interleaved, is_quantized, is_activationlayer_enabled, mat_weights_cols, mat_weights_rows,
- conv_w, conv_h, dilation);
+ const unsigned int kernel_width = weights->dimension(idx_width);
+ const unsigned int kernel_height = weights->dimension(idx_height);
- const Size2D kernel_weights = Size2D(kernel_width, kernel_height);
+ TensorInfo im2col_reshaped_info, info_gemm, tmp_info, weights_reshaped_info;
+ const ITensorInfo *gemm_input_to_use = input;
+ const ITensorInfo *gemm_output_to_use = output;
+ const ITensorInfo *gemm_output_staged_to_use = output;
+ const ITensorInfo *weights_to_use = weights;
- ARM_COMPUTE_RETURN_ON_ERROR(status);
+ const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
+ const bool append_bias = (biases != nullptr) && (!is_quantized);
+ bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+ bool skip_col2im = data_layout == DataLayout::NHWC;
- std::unique_ptr<ITensorInfo> reshaped_weights = weights->clone();
- bool optimised_kernel = false;
+ // Get convolved dimensions
+ unsigned int conv_w = 0;
+ unsigned int conv_h = 0;
- if(dt == DataType::F32)
+ std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(idx_width),
+ input->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
+
+ // Check if GEMM3D is supported
+ if(skip_col2im)
{
- optimised_kernel = true;
+ // If not supported, we need to perform im2col and col2im (or reshape layer)
+ if(!bool(validate_gemm3d(input->data_type(), conv_h, skip_im2col)))
+ {
+ skip_im2col = false;
+ skip_col2im = false;
+ }
}
- const unsigned int mat_input_cols = mat_weights_rows;
- const unsigned int mat_input_rows = conv_w * conv_h;
- TensorShape shape_im2col = input->tensor_shape();
- shape_im2col.set(0, mat_input_cols);
- shape_im2col.set(1, mat_input_rows);
- shape_im2col.set(2, 1);
- TensorInfo im2_col_info = input->clone()->set_tensor_shape(shape_im2col);
+ const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0;
+ const ITensorInfo *biases_to_use = (append_bias && !skip_im2col) ? biases : nullptr;
+
+ ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel));
+ ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
+
+ // Validate biases
+ if(biases != nullptr)
+ {
+ if(is_quantized)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
+ }
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels));
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+ }
+
+ if(act_info.enabled())
+ {
+ ARM_COMPUTE_ERROR_ON(act_info.b() > act_info.a());
+ }
+
+ unsigned int mat_weights_cols = weights->dimension(idx_kernels);
+ unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + bias_element;
+
+ // Output tensor auto inizialization if not yet initialized
+ ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases_to_use, nullptr));
+ weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, (append_bias && !skip_im2col)), 1, data_type);
+ weights_to_use = &weights_reshaped_info;
if(!skip_im2col)
{
- // Validate im2col
- ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2_col_info, kernel_weights, conv_info, append_bias, false, false, dilation));
+ // Create tensor info for im2col reshaped inputs
+ // For NEON the batch size is on the fourth dimension
+ TensorShape shape_im2col = input->tensor_shape();
+ shape_im2col.set(0, mat_weights_rows);
+ shape_im2col.set(1, conv_w * conv_h);
+ shape_im2col.set(2, 1);
+
+ im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type);
+ im2col_reshaped_info.set_quantization_info(input->quantization_info());
+
+ ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
+ gemm_input_to_use = &im2col_reshaped_info;
}
else if(append_bias)
{
@@ -514,66 +452,45 @@
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE));
}
- // Create GEMM output tensor
- TensorShape shape_gemm(im2_col_info.tensor_shape());
- shape_gemm.set(0, mat_weights_cols);
- shape_gemm.set(1, mat_input_rows);
- TensorInfo gemm_output_info = input->clone()->set_tensor_shape(shape_gemm);
-
- // Reshape weights if needed
- if(optimised_kernel)
+ // Create temporary GEMM output tensor in case we cannot skip col2im
+ if(!skip_col2im)
{
- ARM_COMPUTE_RETURN_ERROR_ON(are_weights_reshaped);
+ TensorShape shape_gemm = gemm_input_to_use->tensor_shape();
+ shape_gemm.set(0, mat_weights_cols);
+ shape_gemm.set(1, conv_w * conv_h);
+ const DataType gemm_data_type = is_quantized ? DataType::S32 : data_type;
+ // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
+ info_gemm = TensorInfo(shape_gemm, 1, gemm_data_type);
+ info_gemm.set_quantization_info(output->quantization_info());
- // Create tensor to store the reshaped weights
- reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution));
- ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */));
- }
- else if(!is_quantized)
- {
- TensorShape reshaped_weights_shape;
-
- if(is_fully_connected_convolution || is_quantized)
- {
- reshaped_weights_shape = TensorShape{ mat_weights_cols, mat_weights_rows };
- }
- else
- {
- // Create tensor to store transposed weights
- const float transpose_width = 16.0f / input->element_size();
- reshaped_weights_shape = TensorShape{ mat_weights_rows *static_cast<unsigned int>(transpose_width),
- static_cast<unsigned int>(std::ceil(mat_weights_cols / transpose_width)) };
- }
-
- // Create tensor to store the reshaped weights
- reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution));
- ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */));
- weights = reshaped_weights.get();
-
- // Validate GEMM interleave and multiply
- if(is_interleaved)
- {
- TensorShape shape_interleaved = shape_im2col;
- shape_interleaved.set(idx_width, shape_interleaved.x() * 4);
- shape_interleaved.set(idx_height, std::ceil(shape_interleaved.y() / 4.f));
- TensorInfo input_interleaved_info = input->clone()->set_tensor_shape(shape_interleaved);
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(&im2_col_info, &input_interleaved_info));
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&input_interleaved_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo(shape_im2col[1], // m
- weights->tensor_shape()[0], // n
- shape_im2col[0]) /* k */));
- }
- else
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&im2_col_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo()));
- }
- }
- if(!is_nhwc)
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h)));
+ gemm_output_to_use = &info_gemm;
}
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((output->dimension(idx_width) != conv_w) || (output->dimension(idx_height) != conv_h), "Output shape does not match the expected one");
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, skip_col2im ? conv_h : 1, skip_im2col));
+ if(is_quantized)
+ {
+ float multiplier = input->quantization_info().scale * weights_to_use->quantization_info().scale / output->quantization_info().scale;
+ int output_multiplier, output_shift;
+ quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+
+ tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
+ tmp_info.set_quantization_info(output->quantization_info());
+ gemm_output_staged_to_use = &tmp_info;
+
+ // Validate output stage for quantized case
+ NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, output->quantization_info().offset);
+ }
+
+ // Validate Col2Im/ReshapeLayer
+ if(!skip_col2im && (data_layout == DataLayout::NCHW))
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(is_quantized ? gemm_output_staged_to_use : gemm_output_to_use,
+ output,
+ Size2D(conv_w, conv_h)));
+ }
+
+ //Validate Activation Layer
if(act_info.enabled())
{
ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info));
@@ -584,54 +501,30 @@
void NEGEMMConvolutionLayer::run()
{
- // Run weights reshaping (Runs once for every configure)
- if(!_are_weights_reshaped)
- {
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
- _are_weights_reshaped = true;
- _reshape_weights.run();
-
- // Mark original weights tensor as unused
- _original_weights->mark_as_unused();
- }
+ prepare();
_memory_group.acquire();
if(!_skip_im2col)
{
// Run input reshaping
- unsigned int _y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
- NEScheduler::get().schedule(&_input_im2col_kernel, _y_dim);
+ unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
+ NEScheduler::get().schedule(&_im2col_kernel, y_dim);
}
- // Runs matrix multiply on reshaped matrices
- if(_asm_glue._optimised_kernel != nullptr)
+ // Runs NEGEMM or NEGEMMLowpMatrixMultiplyCore functions
+ if(_is_quantized)
{
- _asm_glue.run();
- // Release weights in case buffer is pretransposed
- if(!_weights_reshaped.is_used())
- {
- _weights_reshaped.allocator()->free();
- }
+ // Run gemmlowp
+ _mm_gemmlowp.run();
+
+ // Run output stage
+ _gemmlowp_output_stage.run();
}
else
{
- if(_is_interleaved)
- {
- // Run interleave
- NEScheduler::get().schedule(&_input_interleave_kernel, Window::DimY);
- }
-
- // Runs matrix multiply on reshaped matrices
- if(_is_quantized)
- {
- _mm_gemmlowp.run();
- }
- else
- {
- NEScheduler::get().schedule(&_mm_kernel, Window::DimY);
- }
+ // Run gemm
+ _mm_gemm.run();
}
if(_skip_im2col && _append_bias)
@@ -639,16 +532,17 @@
NEScheduler::get().schedule(&_add_bias_kernel, Window::DimY);
}
- // Run output stage for quantized case
- if(_is_quantized)
- {
- _gemmlowp_output_stage.run();
- }
-
// Reshape output matrix
- if(_data_layout == DataLayout::NCHW)
+ if(!_skip_col2im)
{
- NEScheduler::get().schedule(&_output_col2im_kernel, Window::DimY);
+ if(_data_layout == DataLayout::NCHW)
+ {
+ NEScheduler::get().schedule(&_col2im_kernel, Window::DimY);
+ }
+ else
+ {
+ _reshape_layer.run();
+ }
}
if(_is_activationlayer_enabled)
@@ -658,4 +552,25 @@
_memory_group.release();
}
-} // namespace arm_compute
+
+void NEGEMMConvolutionLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ // Run weights reshaping and mark original weights tensor as unused
+ _weights_reshaped.allocator()->allocate();
+ _reshape_weights.run();
+ _original_weights->mark_as_unused();
+
+ // Prepare GEMM
+ _is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare();
+ if(!_weights_reshaped.is_used())
+ {
+ _weights_reshaped.allocator()->free();
+ }
+
+ _is_prepared = true;
+ }
+}
diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
index 98b4767..47c3358 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
@@ -38,8 +38,7 @@
using namespace arm_compute;
NEGEMMLowpAssemblyMatrixMultiplyCore::NEGEMMLowpAssemblyMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _asm_glue_unsigned(), _asm_glue_signed(), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b(),
- _workspace(), _B_pretransposed()
+ : _memory_group(memory_manager), _asm_glue(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b()
{
}
@@ -53,18 +52,14 @@
ARM_COMPUTE_ERROR_ON_MSG((b)->info()->dimension(0) != (output)->info()->dimension(0), "The output matrix must have the same number of columns as the matrix B");
bool run_optimised = false;
-#ifdef __aarch64__
switch(a->info()->data_type())
{
case DataType::S8:
- {
- run_optimised = setup_assembly_kernel(a, b, output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue_signed);
- break;
- }
case DataType::QASYMM8:
case DataType::U8:
{
- run_optimised = setup_assembly_kernel(a, b, output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue_unsigned);
+ _asm_glue.configure(a, b, output, 1.f, 0.f, true);
+ run_optimised = _asm_glue.is_configured();
break;
}
default:
@@ -73,7 +68,6 @@
break;
}
}
-#endif /* __aarch64__ */
if(!run_optimised)
{
// The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
@@ -133,13 +127,9 @@
NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY);
}
- if(_asm_glue_unsigned._optimised_kernel != nullptr)
+ if(_asm_glue.is_configured())
{
- _asm_glue_unsigned.run();
- }
- else if(_asm_glue_signed._optimised_kernel != nullptr)
- {
- _asm_glue_signed.run();
+ _asm_glue.run();
}
else
{
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index 2e06fa2..828011d 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -41,9 +41,9 @@
using namespace arm_compute::misc::shape_calculator;
NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _asm_glue_unsigned(), _asm_glue_signed(), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(),
- _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _workspace(), _B_pretranspose(), _a_offset(0), _b_offset(0),
- _run_vector_matrix_multiplication(false), _dot_product_path(false), _is_first_run(true), _reshape_b_only_on_first_run(false)
+ : _memory_group(memory_manager), _asm_glue(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(),
+ _offset_contribution_kernel(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _original_b(nullptr), _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false),
+ _dot_product_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
{
}
@@ -52,23 +52,27 @@
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
ARM_COMPUTE_ERROR_THROW_ON(NEGEMMLowpMatrixMultiplyCore::validate(a->info(), b->info(), output->info(), gemm_info));
+ // Clear state
+ _mtx_a_reshape_kernel = nullptr;
+ _mtx_b_reshape_kernel = nullptr;
+
+ // Set internal variables
_a_offset = a->info()->quantization_info().offset;
_b_offset = b->info()->quantization_info().offset;
_run_vector_matrix_multiplication = a->info()->dimension(1) < 2;
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
+ _is_prepared = false;
+ _original_b = b;
#ifdef __aarch64__
switch(a->info()->data_type())
{
- case DataType::S8:
- {
- _dot_product_path = setup_assembly_kernel(a, b, output, 1.f, 0.f, true, _workspace, _B_pretranspose, _memory_group, _asm_glue_signed);
- break;
- }
case DataType::QASYMM8:
case DataType::U8:
+ case DataType::S8:
{
- _dot_product_path = setup_assembly_kernel(a, b, output, 1.f, 0.f, true, _workspace, _B_pretranspose, _memory_group, _asm_glue_unsigned);
+ _asm_glue.configure(a, b, output, 1.f, 0.f, _reshape_b_only_on_first_run);
+ _dot_product_path = _asm_glue.is_configured();
break;
}
default:
@@ -160,10 +164,13 @@
if(!_dot_product_path && !_run_vector_matrix_multiplication)
{
_tmp_a.allocator()->allocate();
- _tmp_b.allocator()->allocate();
+ if(!_reshape_b_only_on_first_run)
+ {
+ _tmp_b.allocator()->allocate();
+ }
}
- if(_a_offset != 0)
+ if(_a_offset != 0 && !_reshape_b_only_on_first_run)
{
_vector_sum_col.allocator()->allocate();
}
@@ -188,6 +195,8 @@
ARM_COMPUTE_UNUSED(gemm_info);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "NEGEMMLowpMatrixMultiplyCore cannot reinterpret the input tensor as 3D");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 1, "NEGEMMLowpMatrixMultiplyCore cannot reinterpret the output tensor as 3D");
int32_t a_offset = a->quantization_info().offset;
int32_t b_offset = b->quantization_info().offset;
@@ -248,29 +257,24 @@
void NEGEMMLowpMatrixMultiplyCore::run()
{
+ prepare();
+
_memory_group.acquire();
- // Do not reshape if we run the vector-by-matrix case and we do not have the optimized gemm with dot product instruction
- if(!_run_vector_matrix_multiplication && !_dot_product_path)
+ // Reshape inputs
+ if(_mtx_a_reshape_kernel)
{
- if(_mtx_a_reshape_kernel)
- {
- NEScheduler::get().schedule(_mtx_a_reshape_kernel.get(), Window::DimY);
- }
-
- if(_mtx_b_reshape_kernel && (_is_first_run || !_reshape_b_only_on_first_run))
- {
- NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY);
- }
+ NEScheduler::get().schedule(_mtx_a_reshape_kernel.get(), Window::DimY);
+ }
+ if(_mtx_b_reshape_kernel && !_reshape_b_only_on_first_run)
+ {
+ NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY);
}
- if(_asm_glue_unsigned._optimised_kernel != nullptr)
+ // Run GEMM
+ if(_asm_glue.is_configured())
{
- _asm_glue_unsigned.run();
- }
- else if(_asm_glue_signed._optimised_kernel != nullptr)
- {
- _asm_glue_signed.run();
+ _asm_glue.run();
}
else
{
@@ -284,7 +288,7 @@
}
// Run matrix B reduction kernel only if _a_offset is not equal to 0
- if(_a_offset != 0 && (_is_first_run || !_reshape_b_only_on_first_run))
+ if(_a_offset != 0 && !_reshape_b_only_on_first_run)
{
NEScheduler::get().schedule(&_mtx_b_reduction_kernel, Window::DimX);
}
@@ -293,6 +297,38 @@
NEScheduler::get().schedule(&_offset_contribution_kernel, Window::DimY);
_memory_group.release();
+}
- _is_first_run = false;
+void NEGEMMLowpMatrixMultiplyCore::prepare()
+{
+ if(!_is_prepared)
+ {
+ // Run assembly reshape
+ if(_asm_glue.is_configured() && _reshape_b_only_on_first_run)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+
+ _asm_glue.prepare();
+ _original_b->mark_as_unused();
+ }
+ // Run non-assembly reshape
+ else if(_mtx_b_reshape_kernel && _reshape_b_only_on_first_run)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+
+ // Run reshape kernel and mark original weights tensor as unused
+ _tmp_b.allocator()->allocate();
+ NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY);
+ _original_b->mark_as_unused();
+ }
+
+ // Run matrix B reduction kernel only if _a_offset is not equal to 0
+ if(_a_offset != 0 && _reshape_b_only_on_first_run)
+ {
+ _vector_sum_col.allocator()->allocate();
+ NEScheduler::get().schedule(&_mtx_b_reduction_kernel, Window::DimX);
+ }
+
+ _is_prepared = true;
+ }
}
diff --git a/src/runtime/NEON/functions/NEIm2Col.cpp b/src/runtime/NEON/functions/NEIm2Col.cpp
index 6b95cb0..4245b65 100644
--- a/src/runtime/NEON/functions/NEIm2Col.cpp
+++ b/src/runtime/NEON/functions/NEIm2Col.cpp
@@ -34,16 +34,18 @@
{
}
-void NEIm2Col::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected, bool is_flatten)
+void NEIm2Col::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation, unsigned int num_groups,
+ bool is_fully_connected, bool is_flatten)
{
_y_dim = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT);
- _kernel.configure(input, output, kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten);
+ _kernel.configure(input, output, kernel_dims, conv_info, has_bias, dilation, num_groups, is_fully_connected, is_flatten);
}
-Status NEIm2Col::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool is_fully_connected, bool is_flatten)
+Status NEIm2Col::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation,
+ unsigned int num_groups, bool is_fully_connected, bool is_flatten)
{
- return NEIm2ColKernel::validate(input, output, kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten);
+ return NEIm2ColKernel::validate(input, output, kernel_dims, conv_info, has_bias, dilation, num_groups, is_fully_connected, is_flatten);
}
void NEIm2Col::run()
diff --git a/src/runtime/NEON/functions/NELocallyConnectedLayer.cpp b/src/runtime/NEON/functions/NELocallyConnectedLayer.cpp
index 913acf8..80a2541 100644
--- a/src/runtime/NEON/functions/NELocallyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NELocallyConnectedLayer.cpp
@@ -73,7 +73,7 @@
NELocallyConnectedLayer::NELocallyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(std::move(memory_manager)), _input_im2col_kernel(), _weights_reshape_kernel(), _mm_kernel(), _output_col2im_kernel(), _input_im2col_reshaped(), _weights_reshaped(), _gemm_output(),
- _is_first_run(false), _original_weights(nullptr)
+ _is_prepared(false), _original_weights(nullptr)
{
}
@@ -113,7 +113,7 @@
TensorInfo input_im2col_reshaped_info(shape_im2col, 1, input->data_type());
TensorInfo gemm_output_info(shape_gemm, 1, input->data_type());
- ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &input_im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, has_bias, false));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &input_im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, has_bias));
ARM_COMPUTE_RETURN_ON_ERROR(NEWeightsReshapeKernel::validate(weights, biases, &weights_reshaped_info));
ARM_COMPUTE_RETURN_ON_ERROR(NELocallyConnectedMatrixMultiplyKernel::validate(&input_im2col_reshaped_info, &weights_reshaped_info, &gemm_output_info));
ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h)));
@@ -127,7 +127,7 @@
ARM_COMPUTE_ERROR_THROW_ON(NELocallyConnectedLayer::validate(input->info(), weights->info(), biases == nullptr ? nullptr : biases->info(), output->info(), conv_info));
bool _has_bias = (biases != nullptr);
- _is_first_run = true;
+ _is_prepared = false;
_original_weights = weights;
const unsigned int kernel_width = weights->info()->dimension(0);
@@ -160,24 +160,13 @@
_output_col2im_kernel.configure(&_gemm_output, output, Size2D(conv_w, conv_h));
// Allocate intermediate tensors
- _weights_reshaped.allocator()->allocate();
_input_im2col_reshaped.allocator()->allocate();
_gemm_output.allocator()->allocate();
}
void NELocallyConnectedLayer::run()
{
- // Run weights reshaping (Runs once for every configure)
- if(_is_first_run)
- {
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
- _is_first_run = false;
- NEScheduler::get().schedule(&_weights_reshape_kernel, 3);
-
- // Mark original weights tensor as unused
- _original_weights->mark_as_unused();
- }
+ prepare();
_memory_group.acquire();
@@ -192,3 +181,18 @@
_memory_group.release();
}
+
+void NELocallyConnectedLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+ // Run weights reshaping and mark original weights tensor as unused
+ _weights_reshaped.allocator()->allocate();
+ NEScheduler::get().schedule(&_weights_reshape_kernel, 3);
+ _original_weights->mark_as_unused();
+
+ _is_prepared = true;
+ }
+}
diff --git a/src/runtime/NEON/functions/NEMagnitude.cpp b/src/runtime/NEON/functions/NEMagnitude.cpp
index f865054..2738201 100644
--- a/src/runtime/NEON/functions/NEMagnitude.cpp
+++ b/src/runtime/NEON/functions/NEMagnitude.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,36 +31,18 @@
using namespace arm_compute;
-void NEMagnitude::configure(const ITensor *input1, const ITensor *input2, ITensor *output, MagnitudeType mag_type, bool use_fp16)
+void NEMagnitude::configure(const ITensor *input1, const ITensor *input2, ITensor *output, MagnitudeType mag_type)
{
- if(use_fp16)
+ if(mag_type == MagnitudeType::L1NORM)
{
- if(mag_type == MagnitudeType::L1NORM)
- {
- auto k = arm_compute::support::cpp14::make_unique<NEMagnitudePhaseFP16Kernel<MagnitudeType::L1NORM, PhaseType::SIGNED>>();
- k->configure(input1, input2, output, nullptr);
- _kernel = std::move(k);
- }
- else
- {
- auto k = arm_compute::support::cpp14::make_unique<NEMagnitudePhaseFP16Kernel<MagnitudeType::L2NORM, PhaseType::SIGNED>>();
- k->configure(input1, input2, output, nullptr);
- _kernel = std::move(k);
- }
+ auto k = arm_compute::support::cpp14::make_unique<NEMagnitudePhaseKernel<MagnitudeType::L1NORM, PhaseType::SIGNED>>();
+ k->configure(input1, input2, output, nullptr);
+ _kernel = std::move(k);
}
else
{
- if(mag_type == MagnitudeType::L1NORM)
- {
- auto k = arm_compute::support::cpp14::make_unique<NEMagnitudePhaseKernel<MagnitudeType::L1NORM, PhaseType::SIGNED>>();
- k->configure(input1, input2, output, nullptr);
- _kernel = std::move(k);
- }
- else
- {
- auto k = arm_compute::support::cpp14::make_unique<NEMagnitudePhaseKernel<MagnitudeType::L2NORM, PhaseType::SIGNED>>();
- k->configure(input1, input2, output, nullptr);
- _kernel = std::move(k);
- }
+ auto k = arm_compute::support::cpp14::make_unique<NEMagnitudePhaseKernel<MagnitudeType::L2NORM, PhaseType::SIGNED>>();
+ k->configure(input1, input2, output, nullptr);
+ _kernel = std::move(k);
}
}
diff --git a/src/runtime/NEON/functions/NENormalizationLayer.cpp b/src/runtime/NEON/functions/NENormalizationLayer.cpp
index af98ac1..f00114f 100644
--- a/src/runtime/NEON/functions/NENormalizationLayer.cpp
+++ b/src/runtime/NEON/functions/NENormalizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,7 +41,7 @@
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- TensorInfo tensor_info(input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
+ TensorInfo tensor_info(input->info()->tensor_shape(), 1, input->info()->data_type());
_input_squared.allocator()->init(tensor_info);
// Manage intermediate buffers
diff --git a/src/runtime/NEON/functions/NERNNLayer.cpp b/src/runtime/NEON/functions/NERNNLayer.cpp
new file mode 100644
index 0000000..995d5ee
--- /dev/null
+++ b/src/runtime/NEON/functions/NERNNLayer.cpp
@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/runtime/NEON/functions/NERNNLayer.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+
+namespace arm_compute
+{
+NERNNLayer::NERNNLayer(std::shared_ptr<IMemoryManager> memory_manager)
+ : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected_kernel(), _copy_kernel(), _fully_connected_out(), _gemm_output(), _add_output(),
+ _is_prepared(false)
+{
+}
+
+Status NERNNLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *recurrent_weights, const ITensorInfo *bias, const ITensorInfo *hidden_state,
+ const ITensorInfo *output, const ActivationLayerInfo &info)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, recurrent_weights, bias, hidden_state, output);
+
+ const int idx_width = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(idx_width) != weights->dimension(idx_width));
+ ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_height) != recurrent_weights->dimension(idx_width));
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_weights->dimension(idx_width) != recurrent_weights->dimension(idx_height));
+ ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(idx_width) != weights->dimension(idx_height));
+ ARM_COMPUTE_RETURN_ERROR_ON(hidden_state->dimension(idx_width) != weights->dimension(idx_height));
+ ARM_COMPUTE_RETURN_ERROR_ON(hidden_state->dimension(idx_height) != input->dimension(idx_height));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), hidden_state->tensor_shape());
+
+ auto shape_info = TensorInfo(misc::shape_calculator::compute_rnn_shape(recurrent_weights, hidden_state->dimension(idx_height)), 1, input->data_type());
+
+ ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, weights, bias, &shape_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&shape_info, &shape_info, &shape_info, ConvertPolicy::SATURATE));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&shape_info, &shape_info, info));
+
+ return Status{};
+}
+
+void NERNNLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *recurrent_weights, const ITensor *bias, ITensor *hidden_state, ITensor *output,
+ ActivationLayerInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, recurrent_weights, bias, hidden_state, output);
+ ARM_COMPUTE_ERROR_THROW_ON(NERNNLayer::validate(input->info(), weights->info(), recurrent_weights->info(), bias->info(), hidden_state->info(), output->info(), info));
+
+ const int idx_height = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT);
+ TensorShape shape = misc::shape_calculator::compute_rnn_shape(recurrent_weights->info(), hidden_state->info()->dimension(idx_height));
+
+ _is_prepared = false;
+
+ // Manage intermediate buffers and configure
+ _fully_connected_out.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
+ _gemm_output.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
+
+ // Manage intermediate buffers and configure
+ _memory_group.manage(&_fully_connected_out);
+ _fully_connected_kernel.configure(input, weights, bias, &_fully_connected_out);
+
+ _memory_group.manage(&_gemm_output);
+ _gemm_state_f.configure(hidden_state, recurrent_weights, nullptr, &_gemm_output, 1.f, 0.f);
+
+ _add_output.allocator()->init(TensorInfo(shape, 1, input->info()->data_type()));
+ _memory_group.manage(&_add_output);
+
+ _add_kernel.configure(&_fully_connected_out, &_gemm_output, &_add_output, ConvertPolicy::SATURATE);
+
+ _fully_connected_out.allocator()->allocate();
+ _gemm_output.allocator()->allocate();
+
+ _activation_kernel.configure(&_add_output, hidden_state, info);
+ _add_output.allocator()->allocate();
+
+ _copy_kernel.configure(hidden_state, output);
+}
+
+void NERNNLayer::run()
+{
+ prepare();
+
+ _memory_group.acquire();
+
+ _fully_connected_kernel.run();
+
+ _gemm_state_f.run();
+
+ NEScheduler::get().schedule(&_add_kernel, Window::DimY);
+ NEScheduler::get().schedule(&_activation_kernel, Window::DimY);
+
+ // copy hidden out to output
+ NEScheduler::get().schedule(&_copy_kernel, Window::DimY);
+
+ _memory_group.release();
+}
+
+void NERNNLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ _fully_connected_kernel.prepare();
+ _gemm_state_f.prepare();
+
+ _is_prepared = true;
+ }
+}
+} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp b/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp
new file mode 100644
index 0000000..a4b0dff
--- /dev/null
+++ b/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
+
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+
+using namespace arm_compute;
+
+NESimpleAssemblyFunction::NESimpleAssemblyFunction() // NOLINT
+ : _kernel()
+{
+}
+
+void NESimpleAssemblyFunction::run()
+{
+ NEScheduler::get().schedule(_kernel.get(), Window::DimX);
+}
+
+void NESimpleAssemblyFunction::configure(std::unique_ptr<INEGEMMWrapperKernel> kernel)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(kernel.get());
+ _kernel = std::move(kernel);
+ ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(_kernel->window(), 1);
+}
diff --git a/src/runtime/NEON/functions/NESoftmaxLayer.cpp b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
index 4fb8300..3a73f1e 100644
--- a/src/runtime/NEON/functions/NESoftmaxLayer.cpp
+++ b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
@@ -62,6 +62,7 @@
{
// Perform validation step
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() > 2, "Only 2D inputs are supported");
const TensorShape max_shape = TensorShape(input->tensor_shape()).set(0, 1);
const TensorInfo tensor_info_max_sum = TensorInfo(*input).set_tensor_shape(max_shape).reset_padding();
diff --git a/src/runtime/NEON/functions/NEWarpAffine.cpp b/src/runtime/NEON/functions/NEWarpAffine.cpp
index 889d827..105646c 100644
--- a/src/runtime/NEON/functions/NEWarpAffine.cpp
+++ b/src/runtime/NEON/functions/NEWarpAffine.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -32,11 +32,10 @@
using namespace arm_compute;
-void NEWarpAffine::configure(ITensor *input, ITensor *output, const float *matrix, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
+void NEWarpAffine::configure(ITensor *input, ITensor *output, const std::array<float, 9> &matrix, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON(nullptr == matrix);
switch(policy)
{
diff --git a/src/runtime/NEON/functions/NEWarpPerspective.cpp b/src/runtime/NEON/functions/NEWarpPerspective.cpp
index ed5d6a0..80b97ce 100644
--- a/src/runtime/NEON/functions/NEWarpPerspective.cpp
+++ b/src/runtime/NEON/functions/NEWarpPerspective.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -32,11 +32,10 @@
using namespace arm_compute;
-void NEWarpPerspective::configure(ITensor *input, ITensor *output, const float *matrix, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
+void NEWarpPerspective::configure(ITensor *input, ITensor *output, const std::array<float, 9> &matrix, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON(nullptr == matrix);
switch(policy)
{
diff --git a/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp b/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp
new file mode 100644
index 0000000..097605c
--- /dev/null
+++ b/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NEWidthConcatenateLayer.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/Tensor.h"
+#include "support/ToolchainSupport.h"
+
+using namespace arm_compute;
+
+NEWidthConcatenateLayer::NEWidthConcatenateLayer()
+ : _concat_kernels_vector(),
+ _num_inputs(0)
+{
+}
+
+Status NEWidthConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON(inputs_vector.size() < 2);
+
+ // Output auto inizialitation if not yet initialized
+ TensorInfo tmp_output_info = *output->clone();
+ TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(inputs_vector);
+ auto_init_if_empty(tmp_output_info, output_shape, 1, inputs_vector[0]->data_type());
+
+ unsigned int width_offset = 0;
+ for(const auto &input : inputs_vector)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_RETURN_ON_ERROR(NEWidthConcatenateLayerKernel::validate(input, width_offset, &tmp_output_info));
+ width_offset += input->dimension(0);
+ }
+
+ return Status{};
+}
+
+void NEWidthConcatenateLayer::configure(std::vector<ITensor *> inputs_vector, ITensor *output)
+{
+ _num_inputs = inputs_vector.size();
+
+ std::vector<ITensorInfo *> inputs_vector_info;
+ for(unsigned int i = 0; i < _num_inputs; i++)
+ {
+ inputs_vector_info.emplace_back(inputs_vector.at(i)->info());
+ }
+ TensorShape output_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(inputs_vector);
+
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output->info(), output_shape, 1, inputs_vector[0]->info()->data_type());
+ ARM_COMPUTE_ERROR_THROW_ON(NEWidthConcatenateLayer::validate(inputs_vector_info, output->info()));
+
+ unsigned int width_offset = 0;
+
+ _concat_kernels_vector = arm_compute::support::cpp14::make_unique<NEWidthConcatenateLayerKernel[]>(_num_inputs);
+
+ for(unsigned int i = 0; i < _num_inputs; i++)
+ {
+ _concat_kernels_vector[i].configure(inputs_vector.at(i), width_offset, output);
+ width_offset += inputs_vector.at(i)->info()->dimension(0);
+ }
+}
+
+void NEWidthConcatenateLayer::run()
+{
+ for(unsigned i = 0; i < _num_inputs; i++)
+ {
+ NEScheduler::get().schedule(_concat_kernels_vector.get() + i, Window::DimY);
+ }
+}
diff --git a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
index 8f2c4c4..828a593 100644
--- a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
@@ -24,16 +24,15 @@
#include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h"
#include "arm_compute/core/Error.h"
+#include "arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/runtime/NEON/AssemblyHelper.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "support/ToolchainSupport.h"
-#include "arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h"
-
#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
namespace arm_compute
@@ -60,7 +59,6 @@
ARM_COMPUTE_UNUSED(output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
- ARM_COMPUTE_RETURN_ERROR_ON(data_layout != DataLayout::NCHW); // COMPMID-1162
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(width_idx) != 3 && weights->dimension(height_idx) != 5, "Only 3 and 5 kernels are supported");
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -107,12 +105,13 @@
return std::find(fast_math_winograd.begin(), fast_math_winograd.end(), p) != fast_math_winograd.end();
}
+
} //namespace
NEWinogradConvolutionLayer::NEWinogradConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _arm_gemm(nullptr), _gemm_kernel(nullptr), _transform_input_kernel(nullptr), _transform_output_kernel(nullptr), _transform_weights_kernel(nullptr),
- _activationlayer_function(), _permute_input(), _permute_weights(), _permute_output(), _input_workspace(), _output_workspace(), _kernel_storage(), _input_nhwc(), _output_nhwc(), _weights_hwio(),
- _workspace(), _input(), _weights(), _output(), _reshaped_kernel(false), _is_activationlayer_enabled(false)
+ : _memory_group(memory_manager), _asm_glue(memory_manager), _transform_input_kernel(nullptr), _transform_output_kernel(nullptr), _transform_weights_kernel(nullptr), _activationlayer_function(),
+ _permute_input(), _permute_weights(), _permute_output(), _input_workspace(), _output_workspace(), _kernel_storage(), _input_nhwc(), _output_nhwc(), _weights_hwio(), _input(), _weights(), _output(),
+ _is_prepared(false), _is_activationlayer_enabled(false)
{
} /* arm_compute */
@@ -138,9 +137,10 @@
ARM_COMPUTE_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size), "This Winograd configuration requires enable_fast_math=true");
}
- _weights = weights;
- _input = input;
- _output = output;
+ _weights = weights;
+ _input = input;
+ _output = output;
+ _is_prepared = false;
std::unique_ptr<INEWinogradLayerTransformInputKernel<float>> transform_input_kernel;
std::unique_ptr<INEWinogradLayerTransformWeightsKernel<float>> transform_weights_kernel;
@@ -155,29 +155,32 @@
{
if(input->info()->dimension(width_idx) > 4 && input->info()->dimension(height_idx) > 4)
{
- transform_input_kernel = support::cpp14::make_unique<NEWinogradLayerTransformInputKernel<float, 4, 4, 3, 3>>();
- transform_weights_kernel = support::cpp14::make_unique<NEWinogradLayerTransformWeightsKernel<float, 4, 4, 3, 3>>();
- transform_output_kernel = support::cpp14::make_unique<NEWinogradLayerTransformOutputKernel<float, 4, 4, 3, 3>>();
- n_gemms = NEWinogradLayerBatchedGEMMKernel<float, float, 4, 4, 3, 3>::WinogradBase::N_GEMMS;
- N_BLOCK = NEWinogradLayerBatchedGEMMKernel<float, float, 4, 4, 3, 3>::WinogradConv::N_BLOCK;
+ using config = NEWinogradLayerConfiguration<float, float, 4, 4, 3, 3>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
}
else
{
- transform_input_kernel = support::cpp14::make_unique<NEWinogradLayerTransformInputKernel<float, 2, 2, 3, 3>>();
- transform_weights_kernel = support::cpp14::make_unique<NEWinogradLayerTransformWeightsKernel<float, 2, 2, 3, 3>>();
- transform_output_kernel = support::cpp14::make_unique<NEWinogradLayerTransformOutputKernel<float, 2, 2, 3, 3>>();
- n_gemms = NEWinogradLayerBatchedGEMMKernel<float, float, 2, 2, 3, 3>::WinogradBase::N_GEMMS;
- N_BLOCK = NEWinogradLayerBatchedGEMMKernel<float, float, 2, 2, 3, 3>::WinogradConv::N_BLOCK;
+ using config = NEWinogradLayerConfiguration<float, float, 2, 2, 3, 3>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
}
break;
}
case 5:
{
- transform_input_kernel = support::cpp14::make_unique<NEWinogradLayerTransformInputKernel<float, 2, 2, 5, 5>>();
- transform_weights_kernel = support::cpp14::make_unique<NEWinogradLayerTransformWeightsKernel<float, 2, 2, 5, 5>>();
- transform_output_kernel = support::cpp14::make_unique<NEWinogradLayerTransformOutputKernel<float, 2, 2, 5, 5>>();
- n_gemms = NEWinogradLayerBatchedGEMMKernel<float, float, 2, 2, 5, 5>::WinogradBase::N_GEMMS;
- N_BLOCK = NEWinogradLayerBatchedGEMMKernel<float, float, 2, 2, 5, 5>::WinogradConv::N_BLOCK;
+ using config = NEWinogradLayerConfiguration<float, float, 2, 2, 5, 5>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
break;
}
default:
@@ -195,96 +198,138 @@
const int out_channels = output->info()->dimension(channel_idx);
const Tensor4DShape in_shape(internal_get_input_shape(input));
+ const DataType data_type = input->info()->data_type();
const size_t data_type_size = input->info()->element_size();
// Get the memory required to instantiate a new Winograd operator.
- constexpr size_t storage_alignment = 64;
- const size_t kernel_storage_size = transform_weights_kernel->get_weight_storage_size(out_channels, in_channels) * data_type_size;
- _kernel_storage.allocator()->init(TensorInfo(TensorShape{ (kernel_storage_size + storage_alignment - 1) }, 1, DataType::U8));
- _kernel_storage.allocator()->allocate();
+ constexpr size_t storage_alignment = 64;
+
+ // Kernel Storage
+ const size_t kernel_storage_size = transform_weights_kernel->get_weight_storage_size(out_channels,
+ in_channels)
+ * data_type_size
+ + storage_alignment - 1;
+
// Input storage
- const size_t input_storage_size = transform_input_kernel->get_input_storage_size(in_shape.n_batches, in_shape.n_channels, in_shape.n_rows, in_shape.n_cols, use_same_padding) * data_type_size;
- _input_workspace.allocator()->init(TensorInfo(TensorShape{ (input_storage_size + storage_alignment - 1) }, 1, DataType::U8));
- _input_workspace.allocator()->allocate();
+ const size_t input_storage_size = transform_input_kernel->get_input_storage_size(in_shape.n_batches, in_shape.n_channels, in_shape.n_rows, in_shape.n_cols,
+ use_same_padding)
+ * data_type_size
+ + storage_alignment - 1;
// Output storage
- const size_t output_storage_size = transform_output_kernel->get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels, use_same_padding) * data_type_size;
- _output_workspace.allocator()->init(TensorInfo(TensorShape{ (output_storage_size + storage_alignment - 1) }, 1, DataType::U8));
- _output_workspace.allocator()->allocate();
+ const size_t output_storage_size = transform_output_kernel->get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels,
+ use_same_padding)
+ * data_type_size
+ + storage_alignment - 1;
+ ;
+ const KernelShape kernel_shape({ out_channels, static_cast<int>(kernel_size.height), static_cast<int>(kernel_size.width), in_channels });
+ const int kernel_matrix_stride = transform_weights_kernel->get_matrix_stride(kernel_shape);
+
+ const int output_matrix_stride = transform_output_kernel->get_matrix_stride(kernel_shape, in_shape, use_padding_type);
+ const auto output_shape(transform_output_kernel->get_output_shape(kernel_shape, in_shape, use_padding_type));
+
+ const int input_matrix_stride = transform_input_kernel->get_matrix_stride(kernel_shape, in_shape, use_padding_type);
+
+ // Configure GEMM
+ const int tile_rows = iceildiv(output_shape.n_rows, output_tile.height);
+ const int tile_cols = iceildiv(output_shape.n_cols, output_tile.width);
+ const int m = in_shape.n_batches * tile_rows * tile_cols;
+ const int k = in_shape.n_channels;
+ const int n = out_channels;
+ const int kernel_matrix_row_stride = roundup(out_channels, N_BLOCK);
+ const int output_matrix_row_stride = kernel_matrix_row_stride;
+
+ TensorShape a_shape(k, m, 1, n_gemms);
+ Strides a_strides(data_type_size);
+ a_strides.set(1, a_strides[0] * k);
+ a_strides.set(2, 0);
+ a_strides.set(3, data_type_size * input_matrix_stride);
+
+ TensorShape b_shape(n, k, n_gemms);
+ Strides b_strides(data_type_size);
+ b_strides.set(1, data_type_size * kernel_matrix_row_stride);
+ b_strides.set(2, data_type_size * kernel_matrix_stride);
+
+ TensorShape d_shape(n, m, 1, n_gemms);
+ Strides d_strides(data_type_size);
+ d_strides.set(1, data_type_size * output_matrix_row_stride);
+ d_strides.set(2, 0);
+ d_strides.set(3, data_type_size * output_matrix_stride);
+
+ TensorInfo a_info, b_info, d_info;
+ a_info.init(a_shape, 1, data_type, a_strides, 0, input_storage_size);
+ b_info.init(b_shape, 1, data_type, b_strides, 0, kernel_storage_size);
+ d_info.init(d_shape, 1, data_type, d_strides, 0, output_storage_size);
+
+ _input_workspace.allocator()->init(a_info, storage_alignment);
+ _kernel_storage.allocator()->init(b_info, storage_alignment);
+ _output_workspace.allocator()->init(d_info, storage_alignment);
// configure and allocate dst tensor to be used to convert from winograd domain to spatial domain when calling to reshape_output()
TensorInfo info(TensorShape(_output->info()->dimension(2), _output->info()->dimension(0),
_output->info()->dimension(1), _output->info()->dimension(3)),
1, _output->info()->data_type());
_output_nhwc.allocator()->init(info);
- _output_nhwc.allocator()->allocate();
-
- // Re-order a weight tensor from [Output feature map x Input feature map x Height x Width] to [Height x Width x Input feature map x Output feature map]
- _permute_weights.configure(weights, &_weights_hwio, PermutationVector(3U, 2U, 0U, 1U));
- _weights_hwio.allocator()->allocate();
-
- // configure the kernel to transform the input tensor from NCHW -> NHWC
- _permute_input.configure(input, &_input_nhwc, PermutationVector(2U, 0U, 1U));
- _input_nhwc.allocator()->allocate();
-
- const KernelShape kernel_shape({ out_channels, static_cast<int>(kernel_size.height), static_cast<int>(kernel_size.width), in_channels });
// Configure the InputTransform
- const int input_matrix_stride = transform_input_kernel->get_matrix_stride(kernel_shape, in_shape, use_padding_type);
- transform_input_kernel->configure(reinterpret_cast<float *>(_input_nhwc.buffer()), in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, in_shape.n_channels, use_padding_type,
- reinterpret_cast<float *>(_input_workspace.buffer()), input_matrix_stride);
+ _memory_group.manage(&_input_workspace);
+ if(data_layout == DataLayout::NCHW)
+ {
+ // configure the kernel to transform the input tensor from NCHW -> NHWC
+ _permute_input.configure(input, &_input_nhwc, PermutationVector(2U, 0U, 1U));
+ _input_nhwc.allocator()->allocate();
+ transform_input_kernel->configure(&_input_nhwc, in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, in_shape.n_channels, use_padding_type,
+ &_input_workspace, input_matrix_stride);
+ }
+ else
+ {
+ transform_input_kernel->configure(_input, in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, in_shape.n_channels, use_padding_type,
+ &_input_workspace, input_matrix_stride);
+ }
// Configure WeightsTransform
- const int kernel_matrix_stride = transform_weights_kernel->get_matrix_stride(kernel_shape);
- transform_weights_kernel->configure(&_weights_hwio, reinterpret_cast<float *>(_kernel_storage.buffer()), kernel_matrix_stride, out_channels, in_channels);
+ if(data_layout == DataLayout::NCHW)
+ {
+ // Re-order a weight tensor from [Output feature map x Input feature map x Height x Width] to [Height x Width x Input feature map x Output feature map]
+ _permute_weights.configure(weights, &_weights_hwio, PermutationVector(3U, 2U, 0U, 1U));
+
+ transform_weights_kernel->configure(&_weights_hwio, &_kernel_storage, kernel_matrix_stride, out_channels, in_channels);
+ }
+ else
+ {
+ // Re-order a weight tensor from [Output feature map x Input feature map x Height x Width] to [Height x Width x Input feature map x Output feature map]
+ _permute_weights.configure(weights, &_weights_hwio, PermutationVector(3U, 0U, 1U, 2U));
+
+ transform_weights_kernel->configure(&_weights_hwio, &_kernel_storage, kernel_matrix_stride, out_channels, in_channels);
+ }
+ _weights_hwio.allocator()->allocate();
// Configure OutputTransform
//The biases tensor has not been allocated at this point in time, the output transform will add the biases to the final result in the run() method
- const int output_matrix_stride = transform_output_kernel->get_matrix_stride(kernel_shape, in_shape, use_padding_type);
- const auto output_shape(transform_output_kernel->get_output_shape(kernel_shape, in_shape, use_padding_type));
- transform_output_kernel->configure(biases, reinterpret_cast<float *>(_output_workspace.buffer()),
- output_matrix_stride, reinterpret_cast<float *>(_output_nhwc.buffer()),
- in_shape.n_batches, output_shape.n_rows, output_shape.n_cols, out_channels);
-
- // Configure GEMM
- const int tile_rows = iceildiv(output_shape.n_rows, output_tile.height);
- const int tile_cols = iceildiv(output_shape.n_cols, output_tile.width);
- const int m = in_shape.n_batches * tile_rows * tile_cols;
- const int k = in_shape.n_channels;
- const int n = out_channels;
- const int input_matrix_row_stride = in_shape.n_channels;
- const int kernel_matrix_row_stride = roundup(out_channels, N_BLOCK);
- const int output_matrix_row_stride = kernel_matrix_row_stride;
- unsigned int num_threads = NEScheduler::get().num_threads();
-
- _arm_gemm = arm_gemm::gemm<float, float>(NEScheduler::get().cpu_info(), m, n, k, 1, n_gemms, false, false, 1.f, 0.f, num_threads, false);
- _arm_gemm->set_arrays(reinterpret_cast<float *>(_input_workspace.buffer()), input_matrix_row_stride, 0, input_matrix_stride, reinterpret_cast<float *>(_kernel_storage.buffer()),
- kernel_matrix_row_stride, kernel_matrix_stride, reinterpret_cast<float *>(_output_workspace.buffer()), output_matrix_row_stride, 0, output_matrix_stride);
-
- auto acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapper<arm_gemm::GemmCommon<float, float>>>();
- acl_gemm_wrapper->configure(_arm_gemm.get());
- const size_t workspace_size = _arm_gemm->get_working_size();
-
- // Allocate workspace
- if(workspace_size > 0)
+ _memory_group.manage(&_output_workspace);
+ if(data_layout == DataLayout::NCHW)
{
- const unsigned int alignment = 4096;
- allocate_workspace(workspace_size, _workspace, &_memory_group, alignment, 1);
- _arm_gemm->set_working_space(reinterpret_cast<float *>(_workspace.buffer()));
+ transform_output_kernel->configure(biases, &_output_workspace,
+ output_matrix_stride, &_output_nhwc,
+ in_shape.n_batches, output_shape.n_rows, output_shape.n_cols, out_channels);
+ }
+ else
+ {
+ transform_output_kernel->configure(biases, &_output_workspace,
+ output_matrix_stride, _output,
+ in_shape.n_batches, output_shape.n_rows, output_shape.n_cols, out_channels);
}
- const unsigned int window_size = _arm_gemm->get_window_size();
- if(window_size < num_threads)
- {
- num_threads = window_size;
- _arm_gemm->set_nthreads(num_threads);
- }
-
- _gemm_kernel = std::move(acl_gemm_wrapper);
+ _asm_glue.configure(&_input_workspace, &_kernel_storage, &_output_workspace, 1.0f, 0.f, false);
+ _input_workspace.allocator()->allocate();
+ _kernel_storage.allocator()->allocate();
+ _output_workspace.allocator()->allocate();
// Reorder the convoluted output to ACL's ordering NCHW
_permute_output.configure(&_output_nhwc, _output, PermutationVector(1U, 2U, 0U));
+ _output_nhwc.allocator()->allocate();
+
_transform_input_kernel = std::move(transform_input_kernel);
_transform_weights_kernel = std::move(transform_weights_kernel);
_transform_output_kernel = std::move(transform_output_kernel);
@@ -293,38 +338,43 @@
_is_activationlayer_enabled = act_info.enabled();
if(_is_activationlayer_enabled)
{
- _activationlayer_function.configure(output, nullptr, act_info);
+ _activationlayer_function.configure(_output, nullptr, act_info);
}
}
void NEWinogradConvolutionLayer::run()
{
- _memory_group.acquire();
- if(!_reshaped_kernel)
- {
- _reshaped_kernel = true;
- _permute_weights.run();
- NEScheduler::get().schedule(_transform_weights_kernel.get(), Window::DimX);
- }
- //Bring channels to the front as Winograd code expects the tensor to be in the format NHWC
- _permute_input.run();
+ const DataLayout data_layout = _input->info()->data_layout();
+ prepare();
+
+ _memory_group.acquire();
+
+ if(data_layout == DataLayout::NCHW)
+ {
+ //Bring channels to the front as Winograd code expects the tensor to be in the format NHWC
+ _permute_input.run();
+ }
// Transform input tensor to the winograd domain
NEScheduler::get().schedule(_transform_input_kernel.get(), Window::DimX);
//Run 16 GEMMs in multiple threads, each kernel runs one or more GEMMs
- NEScheduler::get().schedule(_gemm_kernel.get(), Window::DimX);
+ _asm_glue.run();
// Transform output tensor to the spatial domain
NEScheduler::get().schedule(_transform_output_kernel.get(), Window::DimX);
- // Reorder the convoluted output to ACL's ordering NCHW
- _permute_output.run();
+ if(data_layout == DataLayout::NCHW)
+ {
+ // Reorder the convoluted output to ACL's ordering NCHW
+ _permute_output.run();
+ }
if(_is_activationlayer_enabled)
{
_activationlayer_function.run();
}
+
_memory_group.release();
}
@@ -358,6 +408,7 @@
// Validate input transform
const TensorShape input0_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input, winograd_info);
const TensorInfo input0 = input->clone()->set_tensor_shape(input0_shape);
+
switch(weights->dimension(idx_width))
{
case 3:
@@ -444,7 +495,6 @@
break;
}
}
-
// Validate Activation Layer
if(act_info.enabled())
{
@@ -453,4 +503,20 @@
return Status{};
}
+void NEWinogradConvolutionLayer::prepare()
+{
+ if(!_is_prepared)
+ {
+ // Permute weights
+ _permute_weights.run();
+ _weights->mark_as_unused();
+
+ // Transform weights
+ NEScheduler::get().schedule(_transform_weights_kernel.get(), Window::DimX);
+ _weights_hwio.allocator()->free();
+
+ _is_prepared = true;
+ }
+}
+
} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
new file mode 100644
index 0000000..b52ce66
--- /dev/null
+++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
@@ -0,0 +1,260 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+
+namespace arm_compute
+{
+NEGEMMInterleavedWrapper::NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager)
+ : _memory_group(std::move(memory_manager))
+{
+}
+void NEGEMMInterleavedWrapper::run()
+{
+ prepare();
+
+ _memory_group.acquire();
+ NEScheduler::get().run_workloads(_workloads);
+ _memory_group.release();
+}
+
+void NEGEMMInterleavedWrapper::prepare()
+{
+ if(!_is_prepared)
+ {
+ if(_pretranspose_b)
+ {
+ NEScheduler::get().schedule(_prepare_b.get(), Window::DimX);
+ _b->mark_as_unused();
+ }
+ else
+ {
+ _prepare_b->create_workloads(_b_workloads);
+ }
+ _transform_a->create_workloads(_a_workloads);
+ _matrix_multiply->create_workloads(_mm_workloads);
+
+ //Maximum number of workloads to create:
+ const unsigned int num_threads = NEScheduler::get().num_threads();
+ const unsigned int max_iterations = num_threads == 1 ? 1 : num_threads;
+ //Maximum number of iterations the parameters allow:
+ const unsigned int num_iterations = _batch_window.num_iterations_total();
+ // Keep the smallest of the two:
+ const unsigned int num_windows = std::min(num_iterations, max_iterations);
+ const TensorShape window_shape = _batch_window.shape();
+
+ // Create a 1D window to dynamically split the batch window:
+ Window win_1D;
+ win_1D.set(0, Window::Dimension(0, num_iterations));
+
+ // Create one workload for each sub-window:
+ for(unsigned int w = 0; w < num_windows; w++)
+ {
+ Window win = win_1D.split_window(0, w, num_windows);
+ const Coordinates start_offset = index2coords(window_shape, win.x().start());
+ const Coordinates end_offset = index2coords(window_shape, win.x().end() - 1);
+ const unsigned int num_x_blocks = _block_walker.num_iterations(Window::DimX);
+
+ auto workload = [start_offset, end_offset, num_x_blocks, this](const ThreadInfo & info)
+ {
+ //For each block of rows in "M"
+ auto workload_mm = this->_mm_workloads.begin();
+ for(auto workload_a = this->_a_workloads.begin(); workload_a != this->_a_workloads.end(); workload_a++)
+ {
+ // Transform one k_block from A:
+ this->_transform_a->transform(*workload_a, info, this->_batch_window, start_offset, end_offset);
+ // Then perform the matrix multiplication for each x block along N:
+ for(unsigned int i = 0; i < num_x_blocks; i++)
+ {
+ ARM_COMPUTE_ERROR_ON(workload_mm == this->_mm_workloads.end());
+ this->_matrix_multiply->transform(*workload_mm++, info, this->_batch_window, start_offset, end_offset);
+ }
+ }
+ };
+ _workloads.push_back(workload);
+ }
+
+ _is_prepared = true;
+ }
+}
+
+namespace
+{
+// Factory to instantiate NEGEMMInterleavedPrepareBWrapperKernel:
+template <typename InputType, bool use_dot = false>
+std::unique_ptr<NEGEMMInterleavedPrepareBWrapperKernel> instantiate_prepareB(const ITensor *b, ITensor *transformed_b, const INEGEMMWrapperKernel::Params ¶ms)
+{
+ auto prepare_b = support::cpp14::make_unique<NEGEMMInterleavedPrepareBWrapperKernelTemplate<InputType, use_dot>>();
+ prepare_b->configure(b, transformed_b, false, NEScheduler::get().cpu_info(), params);
+ return std::move(prepare_b);
+}
+
+// Factory to instantiate NEGEMMInterleavedTransformAWrapperTemplate:
+template <typename InputType, bool use_dot = false>
+std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a, ITensor *transformed_a, const Window &block_walker, const INEGEMMWrapperKernel::Params ¶ms)
+{
+ auto transform_a = support::cpp14::make_unique<NEGEMMInterleavedTransformAWrapperTemplate<InputType, use_dot>>();
+ transform_a->configure(a, transformed_a, false, block_walker, params);
+ return std::move(transform_a);
+}
+
+// Factory to instantiate NEGEMMInterleavedTransformAWrapperTemplate:
+template <typename InputType, typename OutputType, bool use_dot = false>
+std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker,
+ const BlockSizes &block_sizes, const INEGEMMWrapperKernel::Params ¶ms, bool pretranspose_b, float alpha, float beta)
+{
+ auto matrix_multiply = support::cpp14::make_unique<NEGEMMInterleavedMatrixMultiplyWrapperTemplate<InputType, OutputType, use_dot>>();
+ matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, NEScheduler::get().num_threads());
+ return std::move(matrix_multiply);
+}
+} // namespace
+
+void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b, bool use_dot)
+{
+ _params = INEGEMMWrapperKernel::extract_parameters(a, b, c);
+ _a = a;
+ _b = b;
+ _c = c;
+ _pretranspose_b = pretranspose_b;
+
+ DataType input_type = a->info()->data_type();
+
+ // Forcing 128-byte alignment (required by 32-bit kernels)
+ const unsigned int alignment = 128;
+ _transformed_b.allocator()->init(TensorInfo{}, alignment);
+ _tmp_c.allocator()->init(TensorInfo{}, alignment);
+ if(!_pretranspose_b)
+ {
+ // If B is transposed at every iteration then transformed_B can be managed:
+ _memory_group.manage(&_transformed_b);
+ }
+ switch(input_type)
+ {
+ case DataType::F32:
+ _prepare_b = instantiate_prepareB<float>(_b, &_transformed_b, _params);
+ break;
+#ifdef __aarch64__
+ case DataType::U8:
+ case DataType::QASYMM8:
+ if(use_dot)
+ {
+ _prepare_b = instantiate_prepareB<uint8_t, true>(_b, &_transformed_b, _params);
+ }
+ else
+ {
+ _prepare_b = instantiate_prepareB<uint8_t, false>(_b, &_transformed_b, _params);
+ }
+ break;
+ case DataType::S8:
+ if(use_dot)
+ {
+ _prepare_b = instantiate_prepareB<int8_t, true>(_b, &_transformed_b, _params);
+ }
+ else
+ {
+ _prepare_b = instantiate_prepareB<int8_t, false>(_b, &_transformed_b, _params);
+ }
+ break;
+#endif /* __aarch64__ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ _prepare_b = instantiate_prepareB<__fp16>(_b, &_transformed_b, _params);
+ break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ default:
+ ARM_COMPUTE_ERROR("DataType not supported");
+ break;
+ }
+ ARM_COMPUTE_ERROR_ON(_prepare_b == nullptr);
+
+ _block_sizes = _prepare_b->block_sizes();
+
+ _block_walker.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_params.N, _block_sizes.x_block), _block_sizes.x_block));
+ _block_walker.set(Window::DimY, Window::Dimension(0, ceil_to_multiple(_params.K, _block_sizes.k_block), _block_sizes.k_block));
+ _block_walker.set(Window::DimZ, Window::Dimension(0, _params.multis));
+
+ _batch_window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_block_sizes.m_round, _block_sizes.strategy_out_height), _block_sizes.strategy_out_height));
+ _batch_window.set(Window::DimY, Window::Dimension(0, _params.batches));
+
+ _transformed_a.allocator()->init(TensorInfo(TensorShape{ _block_sizes.k_block, _block_sizes.m_round, _params.batches }, 1, input_type), alignment);
+ _memory_group.manage(&_transformed_a);
+ _memory_group.manage(&_tmp_c);
+
+ switch(input_type)
+ {
+ case DataType::F32:
+ _transform_a = instantiate_transformA<float>(_a, &_transformed_a, _block_walker, _params);
+ _matrix_multiply = instantiate_matrix_multiply<float, float>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+ break;
+#ifdef __aarch64__
+ case DataType::U8:
+ case DataType::QASYMM8:
+ if(use_dot)
+ {
+ _transform_a = instantiate_transformA<uint8_t, true>(_a, &_transformed_a, _block_walker, _params);
+ _matrix_multiply = instantiate_matrix_multiply<uint8_t, uint32_t, true>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+ }
+ else
+ {
+ _transform_a = instantiate_transformA<uint8_t, false>(_a, &_transformed_a, _block_walker, _params);
+ _matrix_multiply = instantiate_matrix_multiply<uint8_t, uint32_t, false>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+ }
+ break;
+ case DataType::S8:
+ if(use_dot)
+ {
+ _transform_a = instantiate_transformA<int8_t, true>(_a, &_transformed_a, _block_walker, _params);
+ _matrix_multiply = instantiate_matrix_multiply<int8_t, int32_t, true>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+ }
+ else
+ {
+ _transform_a = instantiate_transformA<int8_t, false>(_a, &_transformed_a, _block_walker, _params);
+ _matrix_multiply = instantiate_matrix_multiply<int8_t, int32_t, false>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+ }
+ break;
+#endif /* __aarch64__ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ _transform_a = instantiate_transformA<__fp16>(_a, &_transformed_a, _block_walker, _params);
+ _matrix_multiply = instantiate_matrix_multiply<__fp16, __fp16>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+ break;
+ break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ default:
+ break;
+ }
+ ARM_COMPUTE_ERROR_ON(_transform_a == nullptr);
+ ARM_COMPUTE_ERROR_ON(_matrix_multiply == nullptr);
+ _transformed_a.allocator()->allocate();
+ _tmp_c.allocator()->allocate();
+ _transformed_b.allocator()->allocate();
+}
+} // namespace arm_compute
diff --git a/src/runtime/OMP/OMPScheduler.cpp b/src/runtime/OMP/OMPScheduler.cpp
index 795c96c..f4253c8 100644
--- a/src/runtime/OMP/OMPScheduler.cpp
+++ b/src/runtime/OMP/OMPScheduler.cpp
@@ -56,29 +56,55 @@
_num_threads = (num_threads == 0) ? num_cores : num_threads;
}
-void OMPScheduler::schedule(ICPPKernel *kernel, unsigned int split_dimension)
+void OMPScheduler::schedule(ICPPKernel *kernel, const Hints &hints)
{
ARM_COMPUTE_ERROR_ON_MSG(!kernel, "The child class didn't set the kernel");
-
- ThreadInfo info;
- info.cpu_info = &_cpu_info;
+ ARM_COMPUTE_ERROR_ON_MSG(hints.strategy() == StrategyHint::DYNAMIC,
+ "Dynamic scheduling is not supported in OMPScheduler");
const Window &max_window = kernel->window();
- const unsigned int num_iterations = max_window.num_iterations(split_dimension);
- info.num_threads = std::min(num_iterations, _num_threads);
+ const unsigned int num_iterations = max_window.num_iterations(hints.split_dimension());
+ const unsigned int num_threads = std::min(num_iterations, _num_threads);
- if(!kernel->is_parallelisable() || info.num_threads == 1)
+ if(!kernel->is_parallelisable() || num_threads == 1)
{
+ ThreadInfo info;
+ info.cpu_info = &_cpu_info;
kernel->run(max_window, info);
}
else
{
- #pragma omp parallel firstprivate(info) num_threads(info.num_threads)
+ const unsigned int num_windows = num_threads;
+ std::vector<IScheduler::Workload> workloads(num_windows);
+ for(unsigned int t = 0; t < num_windows; t++)
{
- const int tid = omp_get_thread_num();
- Window win = max_window.split_window(split_dimension, tid, info.num_threads);
- info.thread_id = tid;
- kernel->run(win, info);
+ //Capture 't' by copy, all the other variables by reference:
+ workloads[t] = [t, &hints, &max_window, &num_windows, &kernel](const ThreadInfo & info)
+ {
+ Window win = max_window.split_window(hints.split_dimension(), t, num_windows);
+ win.validate();
+ kernel->run(win, info);
+ };
}
+ run_workloads(workloads);
+ }
+}
+
+void OMPScheduler::run_workloads(std::vector<arm_compute::IScheduler::Workload> &workloads)
+{
+ const unsigned int num_threads = std::min(_num_threads, static_cast<unsigned int>(workloads.size()));
+ if(num_threads < 1)
+ {
+ return;
+ }
+
+ ThreadInfo info;
+ info.cpu_info = &_cpu_info;
+ info.num_threads = num_threads;
+ #pragma omp parallel firstprivate(info) num_threads(num_threads)
+ {
+ const int tid = omp_get_thread_num();
+ info.thread_id = tid;
+ workloads[tid](info);
}
}
diff --git a/src/runtime/TensorAllocator.cpp b/src/runtime/TensorAllocator.cpp
index 993a95b..c84a271 100644
--- a/src/runtime/TensorAllocator.cpp
+++ b/src/runtime/TensorAllocator.cpp
@@ -138,7 +138,7 @@
if(_associated_memory_group == nullptr)
{
- _memory = Memory(std::make_shared<MemoryRegion>(info().total_size()));
+ _memory = Memory(std::make_shared<MemoryRegion>(info().total_size(), alignment()));
}
else
{