Clamp NC operator for S8 data type
- New API functions: xnn_create_clamp_nc_s8 and xnn_setup_clamp_nc_s8
- Unit tests
PiperOrigin-RevId: 391216240
diff --git a/src/init.c b/src/init.c
index 003d89b..213797f 100644
--- a/src/init.c
+++ b/src/init.c
@@ -270,6 +270,11 @@
#ifndef XNN_NO_S8_OPERATORS
init_flags |= XNN_INIT_FLAG_S8;
+ xnn_params.s8.clamp = (struct vunary_parameters) {
+ .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__neon_x64,
+ .init.s8_minmax = xnn_init_s8_minmax_neon_params,
+ .element_tile = 64,
+ };
xnn_params.s8.maxpool = (struct maxpool_parameters) {
.ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__neon_c16,
.init.s8 = xnn_init_s8_minmax_neon_params,
@@ -740,6 +745,11 @@
#ifndef XNN_NO_S8_OPERATORS
init_flags |= XNN_INIT_FLAG_S8;
+ xnn_params.s8.clamp = (struct vunary_parameters) {
+ .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__scalar_x4,
+ .init.s8_minmax = xnn_init_s8_minmax_neon_params,
+ .element_tile = 4,
+ };
xnn_params.s8.maxpool = (struct maxpool_parameters) {
.ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__scalar_c1,
.init.s8 = xnn_init_s8_minmax_scalar_params,
@@ -1540,6 +1550,11 @@
#ifndef XNN_NO_S8_OPERATORS
init_flags |= XNN_INIT_FLAG_S8;
+ xnn_params.s8.clamp = (struct vunary_parameters) {
+ .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__neon_x64,
+ .init.s8_minmax = xnn_init_s8_minmax_neon_params,
+ .element_tile = 64,
+ };
xnn_params.s8.maxpool = (struct maxpool_parameters) {
.ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__neon_c16,
.init.s8 = xnn_init_s8_minmax_neon_params,
@@ -2639,6 +2654,11 @@
init_flags |= XNN_INIT_FLAG_S8;
if (cpuinfo_has_x86_sse4_1()) {
+ xnn_params.s8.clamp = (struct vunary_parameters) {
+ .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__sse41_x64,
+ .init.s8_minmax = xnn_init_s8_minmax_sse4_params,
+ .element_tile = 64,
+ };
xnn_params.s8.maxpool = (struct maxpool_parameters) {
.ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__sse41_c16,
.init.s8 = xnn_init_s8_minmax_sse4_params,
@@ -2646,6 +2666,11 @@
.qr = 8,
};
} else {
+ xnn_params.s8.clamp = (struct vunary_parameters) {
+ .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__sse2_x64,
+ .init.s8_minmax = xnn_init_s8_minmax_sse2_params,
+ .element_tile = 64,
+ };
xnn_params.s8.maxpool = (struct maxpool_parameters) {
.ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__sse2_c16,
.init.s8 = xnn_init_s8_minmax_sse2_params,
@@ -3332,6 +3357,11 @@
#ifndef XNN_NO_S8_OPERATORS
init_flags |= XNN_INIT_FLAG_S8;
+ xnn_params.s8.clamp = (struct vunary_parameters) {
+ .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__wasmsimd_x64,
+ .init.s8_minmax = xnn_init_s8_minmax_wasmsimd_params,
+ .element_tile = 64,
+ };
xnn_params.s8.maxpool = (struct maxpool_parameters) {
.ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__wasmsimd_c16,
.init.s8 = xnn_init_s8_minmax_wasmsimd_params,
@@ -3345,9 +3375,9 @@
init_flags |= XNN_INIT_FLAG_U8;
xnn_params.u8.clamp = (struct vunary_parameters) {
- .ukernel = (xnn_univector_ukernel_function) xnn_u8_vclamp_ukernel__scalar_x4,
- .init.u8_minmax = xnn_init_u8_minmax_scalar_params,
- .element_tile = 4,
+ .ukernel = (xnn_univector_ukernel_function) xnn_u8_vclamp_ukernel__wasmsimd_x64,
+ .init.u8_minmax = xnn_init_u8_minmax_wasmsimd_params,
+ .element_tile = 64,
};
xnn_params.u8.maxpool = (struct maxpool_parameters) {
.ukernel = (xnn_maxpool_ukernel_function) xnn_u8_maxpool_minmax_ukernel_9p8x__wasmsimd_c16,
@@ -3974,6 +4004,11 @@
#ifndef XNN_NO_S8_OPERATORS
init_flags |= XNN_INIT_FLAG_S8;
+ xnn_params.s8.clamp = (struct vunary_parameters) {
+ .ukernel = (xnn_univector_ukernel_function) xnn_s8_vclamp_ukernel__scalar_x4,
+ .init.s8_minmax = xnn_init_s8_minmax_scalar_params,
+ .element_tile = 4,
+ };
xnn_params.s8.maxpool = (struct maxpool_parameters) {
.ukernel = (xnn_maxpool_ukernel_function) xnn_s8_maxpool_minmax_ukernel_9p8x__scalar_c1,
.init.s8 = xnn_init_s8_minmax_scalar_params,
diff --git a/src/operator-strings.c b/src/operator-strings.c
index fcab61d..30e8fe9 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -44,6 +44,8 @@
return "Channel Shuffle (NC, X32)";
case xnn_operator_type_clamp_nc_f32:
return "Clamp (NC, F32)";
+ case xnn_operator_type_clamp_nc_s8:
+ return "Clamp (NC, S8)";
case xnn_operator_type_clamp_nc_u8:
return "Clamp (NC, U8)";
case xnn_operator_type_constant_pad_nd_x8:
diff --git a/src/operators/unary-elementwise-nc.c b/src/operators/unary-elementwise-nc.c
index bfbd5d8..f76fe58 100644
--- a/src/operators/unary-elementwise-nc.c
+++ b/src/operators/unary-elementwise-nc.c
@@ -148,6 +148,34 @@
return xnn_status_success;
}
+enum xnn_status xnn_create_clamp_nc_s8(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ int8_t output_min,
+ int8_t output_max,
+ uint32_t flags,
+ xnn_operator_t* clamp_op_out)
+{
+ if (output_min >= output_max) {
+ xnn_log_error(
+ "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
+ xnn_operator_type_to_string(xnn_operator_type_clamp_nc_s8), output_min, output_max);
+ return xnn_status_invalid_parameter;
+ }
+
+ union xnn_s8_minmax_params params;
+ if (xnn_params.s8.clamp.init.s8_minmax != NULL) {
+ xnn_params.s8.clamp.init.s8_minmax(¶ms, output_min, output_max);
+ }
+ return create_unary_elementwise_nc(
+ channels, input_stride, output_stride, flags,
+ ¶ms, sizeof(params),
+ xnn_operator_type_clamp_nc_s8,
+ xnn_params.s8.clamp.ukernel,
+ clamp_op_out);
+}
+
enum xnn_status xnn_create_clamp_nc_u8(
size_t channels,
size_t input_stride,
@@ -549,6 +577,28 @@
&ceiling_op->params.f32_rnd, sizeof(ceiling_op->params.f32_rnd));
}
+enum xnn_status xnn_setup_clamp_nc_s8(
+ xnn_operator_t clamp_op,
+ size_t batch_size,
+ const int8_t* input,
+ int8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (clamp_op->type != xnn_operator_type_clamp_nc_s8) {
+ xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
+ xnn_operator_type_to_string(xnn_operator_type_clamp_nc_s8),
+ xnn_operator_type_to_string(clamp_op->type));
+ return xnn_status_invalid_parameter;
+ }
+ clamp_op->state = xnn_run_state_invalid;
+
+ return setup_unary_elementwise_nc(
+ clamp_op,
+ batch_size, input, output,
+ 0 /* log2(sizeof(int8_t)) */,
+ &clamp_op->params.s8_minmax, sizeof(clamp_op->params.s8_minmax));
+}
+
enum xnn_status xnn_setup_clamp_nc_u8(
xnn_operator_t clamp_op,
size_t batch_size,
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index e579e68..60738e2 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -44,6 +44,7 @@
xnn_operator_type_channel_shuffle_nc_x8,
xnn_operator_type_channel_shuffle_nc_x32,
xnn_operator_type_clamp_nc_f32,
+ xnn_operator_type_clamp_nc_s8,
xnn_operator_type_clamp_nc_u8,
xnn_operator_type_ceiling_nc_f32,
xnn_operator_type_constant_pad_nd_x8,