X8 version of Constand Pad ND operator
Generalize implementation of Constant Pad operator to 8-bit element types.
PiperOrigin-RevId: 389646484
diff --git a/src/operator-strings.c b/src/operator-strings.c
index 4f706d4..8f92ce1 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -38,14 +38,16 @@
return "Bankers Rounding (NC, F32)";
case xnn_operator_type_ceiling_nc_f32:
return "Ceiling (NC, F32)";
- case xnn_operator_type_channel_shuffle_nc_x32:
- return "Channel Shuffle (NC, X32)";
case xnn_operator_type_channel_shuffle_nc_x8:
return "Channel Shuffle (NC, X8)";
+ case xnn_operator_type_channel_shuffle_nc_x32:
+ return "Channel Shuffle (NC, X32)";
case xnn_operator_type_clamp_nc_f32:
return "Clamp (NC, F32)";
case xnn_operator_type_clamp_nc_u8:
return "Clamp (NC, U8)";
+ case xnn_operator_type_constant_pad_nd_x8:
+ return "Constant Pad (ND, X8)";
case xnn_operator_type_constant_pad_nd_x32:
return "Constant Pad (ND, X32)";
case xnn_operator_type_convolution_nhwc_f16:
diff --git a/src/operators/constant-pad-nd.c b/src/operators/constant-pad-nd.c
index a4dd524..9e4e9d0 100644
--- a/src/operators/constant-pad-nd.c
+++ b/src/operators/constant-pad-nd.c
@@ -58,6 +58,16 @@
return status;
}
+enum xnn_status xnn_create_constant_pad_nd_x8(
+ const void* padding_value,
+ uint32_t flags,
+ xnn_operator_t* constant_pad_op_out)
+{
+ const uint32_t padding_pattern = *((uint8_t*) padding_value);
+ return create_constant_pad_nd(
+ padding_pattern * UINT32_C(0x01010101), flags, xnn_operator_type_constant_pad_nd_x8, constant_pad_op_out);
+}
+
enum xnn_status xnn_create_constant_pad_nd_x32(
const void* padding_value,
uint32_t flags,
@@ -76,6 +86,7 @@
const size_t* post_paddings,
const void* input,
void* output,
+ uint32_t log2_element_size,
size_t num_threads)
{
if (constant_pad_op->type != expected_operator_type) {
@@ -160,15 +171,15 @@
size_t output_stride = normalized_output_shape[XNN_MAX_TENSOR_DIMS - 1];
for (size_t i = 1; i < XNN_MAX_TENSOR_DIMS; i++) {
constant_pad_op->context.pad.input = (const void*)
- ((uintptr_t) constant_pad_op->context.pad.input - constant_pad_op->context.pad.pre_paddings[i] * input_stride * sizeof(float));
- constant_pad_op->context.pad.input_stride[i - 1] = input_stride * sizeof(float);
- constant_pad_op->context.pad.output_stride[i - 1] = output_stride * sizeof(float);
+ ((uintptr_t) constant_pad_op->context.pad.input - (constant_pad_op->context.pad.pre_paddings[i] * input_stride << log2_element_size));
+ constant_pad_op->context.pad.input_stride[i - 1] = input_stride << log2_element_size;
+ constant_pad_op->context.pad.output_stride[i - 1] = output_stride << log2_element_size;
input_stride *= normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1 - i];
output_stride *= normalized_output_shape[XNN_MAX_TENSOR_DIMS - 1 - i];
}
- constant_pad_op->context.pad.input_size[0] *= sizeof(float);
- constant_pad_op->context.pad.output_size[0] = normalized_output_shape[XNN_MAX_TENSOR_DIMS - 1] * sizeof(float);
- constant_pad_op->context.pad.pre_paddings[0] *= sizeof(float);
+ constant_pad_op->context.pad.input_size[0] <<= log2_element_size;
+ constant_pad_op->context.pad.output_size[0] = normalized_output_shape[XNN_MAX_TENSOR_DIMS - 1] << log2_element_size;
+ constant_pad_op->context.pad.pre_paddings[0] <<= log2_element_size;
constant_pad_op->context.pad.post_paddings[0] =
constant_pad_op->context.pad.output_size[0] - constant_pad_op->context.pad.pre_paddings[0] - constant_pad_op->context.pad.input_size[0];
@@ -184,6 +195,23 @@
return xnn_status_success;
}
+enum xnn_status xnn_setup_constant_pad_nd_x8(
+ xnn_operator_t constant_pad_op,
+ size_t num_dims,
+ const size_t* input_shape,
+ const size_t* pre_padding,
+ const size_t* post_padding,
+ const void* input,
+ void* output,
+ pthreadpool_t threadpool)
+{
+ return setup_constant_pad_nd(
+ constant_pad_op, xnn_operator_type_constant_pad_nd_x8,
+ num_dims, input_shape, pre_padding, post_padding,
+ input, output, 0 /* log2(element size) */,
+ pthreadpool_get_threads_count(threadpool));
+}
+
enum xnn_status xnn_setup_constant_pad_nd_x32(
xnn_operator_t constant_pad_op,
size_t num_dims,
@@ -197,6 +225,6 @@
return setup_constant_pad_nd(
constant_pad_op, xnn_operator_type_constant_pad_nd_x32,
num_dims, input_shape, pre_padding, post_padding,
- input, output,
+ input, output, 2 /* log2(element size) */,
pthreadpool_get_threads_count(threadpool));
}
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index 8a08e28..328dbaf 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -41,11 +41,12 @@
xnn_operator_type_average_pooling_nhwc_f32,
xnn_operator_type_average_pooling_nhwc_qu8,
xnn_operator_type_bankers_rounding_nc_f32,
- xnn_operator_type_channel_shuffle_nc_x32,
xnn_operator_type_channel_shuffle_nc_x8,
+ xnn_operator_type_channel_shuffle_nc_x32,
xnn_operator_type_clamp_nc_f32,
xnn_operator_type_clamp_nc_u8,
xnn_operator_type_ceiling_nc_f32,
+ xnn_operator_type_constant_pad_nd_x8,
xnn_operator_type_constant_pad_nd_x32,
xnn_operator_type_convolution_nchw_f32,
xnn_operator_type_convolution_nhwc_f16,