Generalize Clamp NC setup implementation across datatypes
PiperOrigin-RevId: 314695433
diff --git a/src/operators/clamp-nc.c b/src/operators/clamp-nc.c
index 22aa917..f3243a5 100644
--- a/src/operators/clamp-nc.c
+++ b/src/operators/clamp-nc.c
@@ -189,24 +189,19 @@
return status;
}
-enum xnn_status xnn_setup_clamp_nc_u8(
+static enum xnn_status setup_clamp(
xnn_operator_t clamp_op,
size_t batch_size,
- const uint8_t* input,
- uint8_t* output,
- pthreadpool_t threadpool)
+ const void* input,
+ void* output,
+ xnn_univector_ukernel_function ukernel,
+ uint32_t log2_element_size,
+ const void* params,
+ size_t params_size)
{
- if (clamp_op->type != xnn_operator_type_clamp_nc_u8) {
- xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
- xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8),
- xnn_operator_type_to_string(clamp_op->type));
- return xnn_status_invalid_parameter;
- }
- clamp_op->state = xnn_run_state_invalid;
-
if (!xnn_params.initialized) {
xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
- xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8));
+ xnn_operator_type_to_string(clamp_op->type));
return xnn_status_uninitialized;
}
@@ -222,26 +217,26 @@
const size_t block_size = 4096;
clamp_op->context.univector_contiguous = (struct univector_contiguous_context) {
.x = input,
- .x_stride = input_stride * sizeof(uint8_t),
+ .x_stride = input_stride << log2_element_size,
.y = output,
- .y_stride = output_stride * sizeof(uint8_t),
- .ukernel = xnn_params.u8.clamp,
- .params.u8_output = clamp_op->u8_minmax_params,
+ .y_stride = output_stride << log2_element_size,
+ .ukernel = ukernel,
};
+ memcpy(&clamp_op->context.univector_contiguous.params, params, params_size);
clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
- clamp_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
+ clamp_op->compute.range[0] = (batch_size * channels) << log2_element_size;
clamp_op->compute.tile[0] = block_size;
} else {
clamp_op->context.univector_strided = (struct univector_strided_context) {
- .n = channels * sizeof(uint8_t),
+ .n = channels << log2_element_size,
.x = input,
- .x_stride = input_stride * sizeof(uint8_t),
+ .x_stride = input_stride << log2_element_size,
.y = output,
- .y_stride = output_stride * sizeof(uint8_t),
- .ukernel = xnn_params.u8.clamp,
- .params.u8_output = clamp_op->u8_minmax_params,
+ .y_stride = output_stride << log2_element_size,
+ .ukernel = ukernel,
};
+ memcpy(&clamp_op->context.univector_strided.params, params, params_size);
clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
clamp_op->compute.range[0] = batch_size;
@@ -252,6 +247,30 @@
return xnn_status_success;
}
+enum xnn_status xnn_setup_clamp_nc_u8(
+ xnn_operator_t clamp_op,
+ size_t batch_size,
+ const uint8_t* input,
+ uint8_t* output,
+ pthreadpool_t threadpool)
+{
+ if (clamp_op->type != xnn_operator_type_clamp_nc_u8) {
+ xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
+ xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8),
+ xnn_operator_type_to_string(clamp_op->type));
+ return xnn_status_invalid_parameter;
+ }
+ clamp_op->state = xnn_run_state_invalid;
+
+ return setup_clamp(
+ clamp_op,
+ batch_size, input, output,
+ xnn_params.u8.clamp,
+ 0 /* log2(sizeof(uint8_t)) */,
+ &clamp_op->u8_minmax_params,
+ sizeof(clamp_op->u8_minmax_params));
+}
+
enum xnn_status xnn_setup_clamp_nc_f32(
xnn_operator_t clamp_op,
size_t batch_size,
@@ -267,50 +286,11 @@
}
clamp_op->state = xnn_run_state_invalid;
- if (!xnn_params.initialized) {
- xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
- xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8));
- return xnn_status_uninitialized;
- }
-
- if (batch_size == 0) {
- clamp_op->state = xnn_run_state_skip;
- return xnn_status_success;
- }
-
- const size_t channels = clamp_op->channels;
- const size_t input_stride = clamp_op->input_pixel_stride;
- const size_t output_stride = clamp_op->output_pixel_stride;
- if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
- const size_t block_size = 4096;
- clamp_op->context.univector_contiguous = (struct univector_contiguous_context) {
- .x = input,
- .x_stride = input_stride * sizeof(float),
- .y = output,
- .y_stride = output_stride * sizeof(float),
- .ukernel = xnn_params.f32.clamp,
- .params.f32_output = clamp_op->f32_minmax_params,
- };
- clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
- clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
- clamp_op->compute.range[0] = batch_size * channels * sizeof(float);
- clamp_op->compute.tile[0] = block_size;
- } else {
- clamp_op->context.univector_strided = (struct univector_strided_context) {
- .n = channels * sizeof(float),
- .x = input,
- .x_stride = input_stride * sizeof(float),
- .y = output,
- .y_stride = output_stride * sizeof(float),
- .ukernel = xnn_params.f32.clamp,
- .params.f32_output = clamp_op->f32_minmax_params,
- };
- clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
- clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
- clamp_op->compute.range[0] = batch_size;
- clamp_op->compute.tile[0] = 1;
- }
- clamp_op->state = xnn_run_state_ready;
-
- return xnn_status_success;
+ return setup_clamp(
+ clamp_op,
+ batch_size, input, output,
+ xnn_params.f32.clamp,
+ 2 /* log2(sizeof(float)) */,
+ &clamp_op->f32_minmax_params,
+ sizeof(clamp_op->f32_minmax_params));
}