Copy NC operator
PiperOrigin-RevId: 314845678
diff --git a/src/operator-strings.c b/src/operator-strings.c
index 358af55..31300c6 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -48,6 +48,8 @@
return "Convolution (NHWC, Q8)";
case xnn_operator_type_convolution_nchw_f32:
return "Convolution (NCHW, F32)";
+ case xnn_operator_type_copy_nc_x32:
+ return "Copy (NC, X32)";
case xnn_operator_type_deconvolution_nhwc_f32:
return "Deconvolution (NHWC, F32)";
case xnn_operator_type_deconvolution_nhwc_q8:
diff --git a/src/operators/unary-elementwise-nc.c b/src/operators/unary-elementwise-nc.c
index 9a41088..9521bbf 100644
--- a/src/operators/unary-elementwise-nc.c
+++ b/src/operators/unary-elementwise-nc.c
@@ -206,6 +206,20 @@
clamp_op_out);
}
+enum xnn_status xnn_create_copy_nc_x32(
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ uint32_t flags,
+ xnn_operator_t* copy_op_out)
+{
+ return create_unary_elementwise_nc(
+ channels, input_stride, output_stride, flags,
+ NULL, 0,
+ xnn_operator_type_copy_nc_x32,
+ copy_op_out);
+}
+
enum xnn_status xnn_create_hardswish_nc_f32(
size_t channels,
size_t input_stride,
@@ -281,6 +295,33 @@
&clamp_op->params.f32_minmax, sizeof(clamp_op->params.f32_minmax));
}
+static void memcpy_ukernel(size_t size, const void* input, void* output, const void* params) {
+ memcpy(output, input, size);
+}
+
+enum xnn_status xnn_setup_copy_nc_x32(
+ xnn_operator_t copy_op,
+ size_t batch_size,
+ const void* input,
+ void* output,
+ pthreadpool_t threadpool)
+{
+ if (copy_op->type != xnn_operator_type_copy_nc_x32) {
+ xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
+ xnn_operator_type_to_string(xnn_operator_type_copy_nc_x32),
+ xnn_operator_type_to_string(copy_op->type));
+ return xnn_status_invalid_parameter;
+ }
+ copy_op->state = xnn_run_state_invalid;
+
+ return setup_unary_elementwise_nc(
+ copy_op,
+ batch_size, input, output,
+ memcpy_ukernel,
+ 2 /* log2(sizeof(uint32_t)) */,
+ NULL, 0);
+}
+
enum xnn_status xnn_setup_hardswish_nc_f32(
xnn_operator_t hardswish_op,
size_t batch_size,
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index 4c20e8b..c876fd0 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -56,9 +56,10 @@
xnn_operator_type_clamp_nc_f32,
xnn_operator_type_clamp_nc_u8,
xnn_operator_type_constant_pad_nd_x32,
+ xnn_operator_type_convolution_nchw_f32,
xnn_operator_type_convolution_nhwc_f32,
xnn_operator_type_convolution_nhwc_q8,
- xnn_operator_type_convolution_nchw_f32,
+ xnn_operator_type_copy_nc_x32,
xnn_operator_type_deconvolution_nhwc_f32,
xnn_operator_type_deconvolution_nhwc_q8,
xnn_operator_type_divide_nd_f32,