Prepare xnn_params for variations in fused activations
- Define structures for a set of micro-kernels with the same fused activation
- Move existing micro-kernels into .minmax member
PiperOrigin-RevId: 305678097
diff --git a/src/convolution-nhwc.c b/src/convolution-nhwc.c
index e0d73b3..f8c73ce 100644
--- a/src/convolution-nhwc.c
+++ b/src/convolution-nhwc.c
@@ -50,7 +50,7 @@
size_t num_ukernels)
{
while (num_ukernels-- != 0) {
- if (ukernel->mr == kernel_size) {
+ if (ukernel->primary_tile == kernel_size) {
return ukernel;
}
ukernel++;
@@ -246,9 +246,9 @@
case xnn_ukernel_type_dwconv:
{
assert(dwconv_parameters != NULL);
- assert(dwconv_parameters->mr == kernel_size);
+ assert(dwconv_parameters->primary_tile == kernel_size);
- const uint32_t c_stride = round_up_po2(groups, dwconv_parameters->cr);
+ const uint32_t c_stride = round_up_po2(groups, dwconv_parameters->channel_tile);
const size_t packed_weights_size = (sizeof(uint8_t) * kernel_size + sizeof(int32_t)) * c_stride;
convolution_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
if (convolution_op->packed_weights == NULL) {
@@ -259,21 +259,21 @@
if (flags & XNN_FLAG_DEPTHWISE_CONVOLUTION) {
xnn_pack_q8_dwconv_hwg_w(
kernel_height, kernel_width,
- groups, dwconv_parameters->cr,
+ groups, dwconv_parameters->channel_tile,
input_zero_point, kernel_zero_point,
kernel, bias, convolution_op->packed_weights);
} else {
xnn_pack_q8_dwconv_ghw_w(
kernel_height, kernel_width,
- groups, dwconv_parameters->cr,
+ groups, dwconv_parameters->channel_tile,
input_zero_point, kernel_zero_point,
kernel, bias, convolution_op->packed_weights);
}
convolution_op->ukernel.dwconv = (struct xnn_ukernel_dwconv) {
- .unipass_function = dwconv_parameters->up,
- .mr = dwconv_parameters->mr,
- .qr = dwconv_parameters->qr,
+ .unipass_function = dwconv_parameters->minmax.unipass,
+ .primary_tile = dwconv_parameters->primary_tile,
+ .incremental_tile = dwconv_parameters->incremental_tile,
};
zero_size = sizeof(uint8_t) * c_stride + XNN_EXTRA_BYTES;
@@ -307,7 +307,7 @@
.mr = xnn_params.q8.gemm.mr,
.nr = nr,
.kr = kr,
- .general_case = xnn_params.q8.gemm.gemm,
+ .general_case = xnn_params.q8.gemm.minmax.gemm,
};
break;
case xnn_ukernel_type_igemm:
@@ -328,7 +328,7 @@
.mr = xnn_params.q8.gemm.mr,
.nr = nr,
.kr = kr,
- .general_case = xnn_params.q8.gemm.igemm,
+ .general_case = xnn_params.q8.gemm.minmax.igemm,
};
break;
default:
@@ -578,9 +578,9 @@
case xnn_ukernel_type_dwconv:
{
assert(dwconv_parameters != NULL);
- assert(dwconv_parameters->mr == kernel_size);
+ assert(dwconv_parameters->primary_tile == kernel_size);
- const uint32_t c_stride = round_up_po2(groups, dwconv_parameters->cr);
+ const uint32_t c_stride = round_up_po2(groups, dwconv_parameters->channel_tile);
const size_t packed_weights_size = (kernel_size + 1) * sizeof(float) * c_stride;
convolution_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
if (convolution_op->packed_weights == NULL) {
@@ -591,19 +591,19 @@
if (flags & XNN_FLAG_DEPTHWISE_CONVOLUTION) {
xnn_pack_f32_dwconv_hwg_w(
kernel_height, kernel_width,
- groups, dwconv_parameters->cr,
+ groups, dwconv_parameters->channel_tile,
kernel, bias, convolution_op->packed_weights);
} else {
xnn_pack_f32_dwconv_ghw_w(
kernel_height, kernel_width,
- groups, dwconv_parameters->cr,
+ groups, dwconv_parameters->channel_tile,
kernel, bias, convolution_op->packed_weights);
}
convolution_op->ukernel.dwconv = (struct xnn_ukernel_dwconv) {
- .unipass_function = dwconv_parameters->up,
- .mr = dwconv_parameters->mr,
- .qr = dwconv_parameters->qr,
+ .unipass_function = dwconv_parameters->minmax.unipass,
+ .primary_tile = dwconv_parameters->primary_tile,
+ .incremental_tile = dwconv_parameters->incremental_tile,
};
zero_size = sizeof(float) * c_stride;
@@ -636,8 +636,8 @@
.mr = xnn_params.f32.gemm.mr,
.nr = nr,
.kr = kr,
- .general_case = xnn_params.f32.gemm.gemm,
- .mr1_case = xnn_params.f32.gemm.gemm1,
+ .general_case = xnn_params.f32.gemm.minmax.gemm,
+ .mr1_case = xnn_params.f32.gemm.minmax.gemm1,
};
break;
case xnn_ukernel_type_igemm:
@@ -656,8 +656,8 @@
.mr = xnn_params.f32.gemm.mr,
.nr = nr,
.kr = kr,
- .general_case = xnn_params.f32.gemm.igemm,
- .mr1_case = xnn_params.f32.gemm.igemm1,
+ .general_case = xnn_params.f32.gemm.minmax.igemm,
+ .mr1_case = xnn_params.f32.gemm.minmax.igemm1,
};
break;
default: