Generalize Clamp NC setup implementation across datatypes

PiperOrigin-RevId: 314695433
diff --git a/src/operators/clamp-nc.c b/src/operators/clamp-nc.c
index 22aa917..f3243a5 100644
--- a/src/operators/clamp-nc.c
+++ b/src/operators/clamp-nc.c
@@ -189,24 +189,19 @@
   return status;
 }
 
-enum xnn_status xnn_setup_clamp_nc_u8(
+static enum xnn_status setup_clamp(
     xnn_operator_t clamp_op,
     size_t batch_size,
-    const uint8_t* input,
-    uint8_t* output,
-    pthreadpool_t threadpool)
+    const void* input,
+    void* output,
+    xnn_univector_ukernel_function ukernel,
+    uint32_t log2_element_size,
+    const void* params,
+    size_t params_size)
 {
-  if (clamp_op->type != xnn_operator_type_clamp_nc_u8) {
-    xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
-      xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8),
-      xnn_operator_type_to_string(clamp_op->type));
-    return xnn_status_invalid_parameter;
-  }
-  clamp_op->state = xnn_run_state_invalid;
-
   if (!xnn_params.initialized) {
     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
-      xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8));
+      xnn_operator_type_to_string(clamp_op->type));
     return xnn_status_uninitialized;
   }
 
@@ -222,26 +217,26 @@
     const size_t block_size = 4096;
     clamp_op->context.univector_contiguous = (struct univector_contiguous_context) {
       .x = input,
-      .x_stride = input_stride * sizeof(uint8_t),
+      .x_stride = input_stride << log2_element_size,
       .y = output,
-      .y_stride = output_stride * sizeof(uint8_t),
-      .ukernel = xnn_params.u8.clamp,
-      .params.u8_output = clamp_op->u8_minmax_params,
+      .y_stride = output_stride << log2_element_size,
+      .ukernel = ukernel,
     };
+    memcpy(&clamp_op->context.univector_contiguous.params, params, params_size);
     clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
     clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
-    clamp_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
+    clamp_op->compute.range[0] = (batch_size * channels) << log2_element_size;
     clamp_op->compute.tile[0] = block_size;
   } else {
     clamp_op->context.univector_strided = (struct univector_strided_context) {
-      .n = channels * sizeof(uint8_t),
+      .n = channels << log2_element_size,
       .x = input,
-      .x_stride = input_stride * sizeof(uint8_t),
+      .x_stride = input_stride << log2_element_size,
       .y = output,
-      .y_stride = output_stride * sizeof(uint8_t),
-      .ukernel = xnn_params.u8.clamp,
-      .params.u8_output = clamp_op->u8_minmax_params,
+      .y_stride = output_stride << log2_element_size,
+      .ukernel = ukernel,
     };
+    memcpy(&clamp_op->context.univector_strided.params, params, params_size);
     clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
     clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
     clamp_op->compute.range[0] = batch_size;
@@ -252,6 +247,30 @@
   return xnn_status_success;
 }
 
+enum xnn_status xnn_setup_clamp_nc_u8(
+    xnn_operator_t clamp_op,
+    size_t batch_size,
+    const uint8_t* input,
+    uint8_t* output,
+    pthreadpool_t threadpool)
+{
+  if (clamp_op->type != xnn_operator_type_clamp_nc_u8) {
+    xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
+      xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8),
+      xnn_operator_type_to_string(clamp_op->type));
+    return xnn_status_invalid_parameter;
+  }
+  clamp_op->state = xnn_run_state_invalid;
+
+  return setup_clamp(
+    clamp_op,
+    batch_size, input, output,
+    xnn_params.u8.clamp,
+    0 /* log2(sizeof(uint8_t)) */,
+    &clamp_op->u8_minmax_params,
+    sizeof(clamp_op->u8_minmax_params));
+}
+
 enum xnn_status xnn_setup_clamp_nc_f32(
     xnn_operator_t clamp_op,
     size_t batch_size,
@@ -267,50 +286,11 @@
   }
   clamp_op->state = xnn_run_state_invalid;
 
-  if (!xnn_params.initialized) {
-    xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
-      xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8));
-    return xnn_status_uninitialized;
-  }
-
-  if (batch_size == 0) {
-    clamp_op->state = xnn_run_state_skip;
-    return xnn_status_success;
-  }
-
-  const size_t channels = clamp_op->channels;
-  const size_t input_stride = clamp_op->input_pixel_stride;
-  const size_t output_stride = clamp_op->output_pixel_stride;
-  if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
-    const size_t block_size = 4096;
-    clamp_op->context.univector_contiguous = (struct univector_contiguous_context) {
-      .x = input,
-      .x_stride = input_stride * sizeof(float),
-      .y = output,
-      .y_stride = output_stride * sizeof(float),
-      .ukernel = xnn_params.f32.clamp,
-      .params.f32_output = clamp_op->f32_minmax_params,
-    };
-    clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
-    clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
-    clamp_op->compute.range[0] = batch_size * channels * sizeof(float);
-    clamp_op->compute.tile[0] = block_size;
-  } else {
-    clamp_op->context.univector_strided = (struct univector_strided_context) {
-      .n = channels * sizeof(float),
-      .x = input,
-      .x_stride = input_stride * sizeof(float),
-      .y = output,
-      .y_stride = output_stride * sizeof(float),
-      .ukernel = xnn_params.f32.clamp,
-      .params.f32_output = clamp_op->f32_minmax_params,
-    };
-    clamp_op->compute.type = xnn_parallelization_type_1d_tile_1d;
-    clamp_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
-    clamp_op->compute.range[0] = batch_size;
-    clamp_op->compute.tile[0] = 1;
-  }
-  clamp_op->state = xnn_run_state_ready;
-
-  return xnn_status_success;
+  return setup_clamp(
+    clamp_op,
+    batch_size, input, output,
+    xnn_params.f32.clamp,
+    2 /* log2(sizeof(float)) */,
+    &clamp_op->f32_minmax_params,
+    sizeof(clamp_op->f32_minmax_params));
 }