Graph rewriting for sparse inference in NCHW layout
PiperOrigin-RevId: 317200356
diff --git a/BUILD.bazel b/BUILD.bazel
index 1d04170..1a1860f 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -3044,6 +3044,15 @@
}),
)
+cc_library(
+ name = "enable_sparse",
+ defines = select({
+ ":xnn_enable_sparse_explicit_true": ["XNN_ENABLE_SPARSE=1"],
+ ":xnn_enable_sparse_explicit_false": ["XNN_ENABLE_SPARSE=0"],
+ "//conditions:default": ["XNN_ENABLE_SPARSE=1"],
+ }),
+)
+
xnnpack_cc_library(
name = "operators",
srcs = OPERATOR_SRCS + [
@@ -3129,6 +3138,7 @@
visibility = xnnpack_visibility(),
deps = [
":enable_assembly",
+ ":enable_sparse",
":logging_utils",
":memory_planner",
":operator_run",
@@ -3170,6 +3180,7 @@
visibility = xnnpack_visibility(),
deps = [
":enable_assembly",
+ ":enable_sparse",
":logging_utils",
":memory_planner_test_mode",
":operator_run_test_mode",
@@ -3216,6 +3227,7 @@
visibility = xnnpack_visibility(),
deps = [
":enable_assembly",
+ ":enable_sparse",
":logging_utils",
":memory_planner",
":operator_run",
@@ -5030,6 +5042,18 @@
define_values = {"xnn_enable_assembly": "false"},
)
+# Enables usage of sparse inference.
+config_setting(
+ name = "xnn_enable_sparse_explicit_true",
+ define_values = {"xnn_enable_sparse": "true"},
+)
+
+# Disables usage of sparse inference.
+config_setting(
+ name = "xnn_enable_sparse_explicit_false",
+ define_values = {"xnn_enable_sparse": "false"},
+)
+
# Disables usage of HMP-aware optimizations.
config_setting(
name = "xnn_enable_hmp_explicit_false",
diff --git a/src/operators/convolution-nchw.c b/src/operators/convolution-nchw.c
index 1eb5639..1198616 100644
--- a/src/operators/convolution-nchw.c
+++ b/src/operators/convolution-nchw.c
@@ -167,9 +167,9 @@
// + 1x1 convolution (no groups)
// + 3x3 stride-2 with 3 input channels and NHWC input layout
// + 3x3 stride-2 depthwise convolution with horizontal padding 1 & no vertical padding
- // - 3x3 stride-1 depthwise convolution with horizontal padding 1 & no vertical padding
- // - 5x5 stride-2 depthwise convolution with horizontal padding 2 & no vertical padding
- // - 5x5 stride-1 depthwise convolution with horizontal padding 2 & no vertical padding
+ // + 3x3 stride-1 depthwise convolution with horizontal padding 1 & no vertical padding
+ // + 5x5 stride-2 depthwise convolution with horizontal padding 2 & no vertical padding
+ // + 5x5 stride-1 depthwise convolution with horizontal padding 2 & no vertical padding
const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
const bool is_1x1 = kernel_width == 1 && kernel_height == 1 && subsampling_height == 1 && subsampling_width == 1;
const bool is_3x3 = kernel_width == 3 && kernel_height == 3 && dilation_height == 1 && dilation_width == 1;
diff --git a/src/runtime.c b/src/runtime.c
index 3f5327d..c430ac7 100644
--- a/src/runtime.c
+++ b/src/runtime.c
@@ -112,8 +112,26 @@
}
runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
- memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
- memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
+ if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
+ assert(values[node->inputs[0]].layout == xnn_layout_type_nchw);
+ assert(values[node->inputs[1]].layout == xnn_layout_type_nchw);
+ runtime->opdata[i].shape1.dim[0] = values[node->inputs[0]].shape.dim[0];
+ runtime->opdata[i].shape1.dim[1] = values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
+ if (values[node->inputs[0]].shape.num_dims > 2) {
+ memcpy(&runtime->opdata[i].shape1.dim[2], &values[node->inputs[0]].shape.dim[1], (values[node->inputs[0]].shape.num_dims - 2) * sizeof(size_t));
+ }
+ runtime->opdata[i].shape2.dim[0] = values[node->inputs[1]].shape.dim[0];
+ runtime->opdata[i].shape2.dim[1] = values[node->inputs[1]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
+ if (values[node->inputs[0]].shape.num_dims > 2) {
+ memcpy(&runtime->opdata[i].shape2.dim[2], &values[node->inputs[1]].shape.dim[1], (values[node->inputs[1]].shape.num_dims - 2) * sizeof(size_t));
+ }
+ } else {
+ assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+ assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+ assert(values[node->inputs[1]].layout == xnn_layout_type_nhwc);
+ memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
+ memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
+ }
runtime->opdata[i].inputs[0] = node->inputs[0];
runtime->opdata[i].inputs[1] = node->inputs[1];
runtime->opdata[i].outputs[0] = node->outputs[0];
@@ -212,28 +230,57 @@
runtime->opdata[i].outputs[0] = node->outputs[0];
break;
case xnn_node_type_convolution_2d:
- status = xnn_create_convolution2d_nhwc_f32(
- node->params.convolution_2d.input_padding_top,
- node->params.convolution_2d.input_padding_right,
- node->params.convolution_2d.input_padding_bottom,
- node->params.convolution_2d.input_padding_left,
- node->params.convolution_2d.kernel_height,
- node->params.convolution_2d.kernel_width,
- node->params.convolution_2d.subsampling_height,
- node->params.convolution_2d.subsampling_width,
- node->params.convolution_2d.dilation_height,
- node->params.convolution_2d.dilation_width,
- node->params.convolution_2d.groups,
- node->params.convolution_2d.group_input_channels,
- node->params.convolution_2d.group_output_channels,
- node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
- node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
- values[node->inputs[1]].data,
- values[node->inputs[2]].data,
- node->activation.output_min,
- node->activation.output_max,
- node->flags,
- &runtime->opdata[i].operator_object);
+ assert(values[node->inputs[1]].data != NULL);
+ assert(values[node->inputs[2]].data != NULL);
+ if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
+ status = xnn_create_convolution2d_nchw_f32(
+ node->params.convolution_2d.input_padding_top,
+ node->params.convolution_2d.input_padding_right,
+ node->params.convolution_2d.input_padding_bottom,
+ node->params.convolution_2d.input_padding_left,
+ node->params.convolution_2d.kernel_height,
+ node->params.convolution_2d.kernel_width,
+ node->params.convolution_2d.subsampling_height,
+ node->params.convolution_2d.subsampling_width,
+ node->params.convolution_2d.dilation_height,
+ node->params.convolution_2d.dilation_width,
+ node->params.convolution_2d.groups,
+ node->params.convolution_2d.group_input_channels,
+ node->params.convolution_2d.group_output_channels,
+ node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
+ node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
+ values[node->inputs[1]].data,
+ values[node->inputs[2]].data,
+ node->activation.output_min,
+ node->activation.output_max,
+ node->flags | (values[node->inputs[0]].layout == xnn_layout_type_nhwc ? XNN_FLAG_INPUT_NHWC : 0),
+ &runtime->opdata[i].operator_object);
+ } else {
+ assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+ assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+ status = xnn_create_convolution2d_nhwc_f32(
+ node->params.convolution_2d.input_padding_top,
+ node->params.convolution_2d.input_padding_right,
+ node->params.convolution_2d.input_padding_bottom,
+ node->params.convolution_2d.input_padding_left,
+ node->params.convolution_2d.kernel_height,
+ node->params.convolution_2d.kernel_width,
+ node->params.convolution_2d.subsampling_height,
+ node->params.convolution_2d.subsampling_width,
+ node->params.convolution_2d.dilation_height,
+ node->params.convolution_2d.dilation_width,
+ node->params.convolution_2d.groups,
+ node->params.convolution_2d.group_input_channels,
+ node->params.convolution_2d.group_output_channels,
+ node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
+ node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
+ values[node->inputs[1]].data,
+ values[node->inputs[2]].data,
+ node->activation.output_min,
+ node->activation.output_max,
+ node->flags,
+ &runtime->opdata[i].operator_object);
+ }
if (status != xnn_status_success) {
goto error;
}
@@ -260,6 +307,8 @@
runtime->opdata[i].outputs[0] = node->outputs[0];
break;
case xnn_node_type_deconvolution_2d:
+ assert(values[node->inputs[1]].data != NULL);
+ assert(values[node->inputs[2]].data != NULL);
status = xnn_create_deconvolution2d_nhwc_f32(
node->params.deconvolution_2d.padding_top,
node->params.deconvolution_2d.padding_right,
@@ -294,34 +343,65 @@
runtime->opdata[i].outputs[0] = node->outputs[0];
break;
case xnn_node_type_depthwise_convolution_2d:
- status = xnn_create_convolution2d_nhwc_f32(
- node->params.depthwise_convolution_2d.input_padding_top,
- node->params.depthwise_convolution_2d.input_padding_right,
- node->params.depthwise_convolution_2d.input_padding_bottom,
- node->params.depthwise_convolution_2d.input_padding_left,
- node->params.depthwise_convolution_2d.kernel_height,
- node->params.depthwise_convolution_2d.kernel_width,
- node->params.depthwise_convolution_2d.subsampling_height,
- node->params.depthwise_convolution_2d.subsampling_width,
- node->params.depthwise_convolution_2d.dilation_height,
- node->params.depthwise_convolution_2d.dilation_width,
- node->params.depthwise_convolution_2d.input_channels /* groups */,
- 1 /* group_input_channels */,
- node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
- node->params.depthwise_convolution_2d.input_channels /* input_pixel_stride */,
- node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_pixel_stride */,
- values[node->inputs[1]].data,
- values[node->inputs[2]].data,
- node->activation.output_min,
- node->activation.output_max,
- node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
- &runtime->opdata[i].operator_object);
+ assert(values[node->inputs[1]].data != NULL);
+ assert(values[node->inputs[2]].data != NULL);
+ if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
+ assert(values[node->inputs[0]].layout == xnn_layout_type_nchw);
+ status = xnn_create_convolution2d_nchw_f32(
+ node->params.depthwise_convolution_2d.input_padding_top,
+ node->params.depthwise_convolution_2d.input_padding_right,
+ node->params.depthwise_convolution_2d.input_padding_bottom,
+ node->params.depthwise_convolution_2d.input_padding_left,
+ node->params.depthwise_convolution_2d.kernel_height,
+ node->params.depthwise_convolution_2d.kernel_width,
+ node->params.depthwise_convolution_2d.subsampling_height,
+ node->params.depthwise_convolution_2d.subsampling_width,
+ node->params.depthwise_convolution_2d.dilation_height,
+ node->params.depthwise_convolution_2d.dilation_width,
+ node->params.depthwise_convolution_2d.input_channels /* groups */,
+ 1 /* group_input_channels */,
+ node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
+ node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
+ node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
+ values[node->inputs[1]].data,
+ values[node->inputs[2]].data,
+ node->activation.output_min,
+ node->activation.output_max,
+ node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
+ &runtime->opdata[i].operator_object);
+ } else {
+ assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+ assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+ status = xnn_create_convolution2d_nhwc_f32(
+ node->params.depthwise_convolution_2d.input_padding_top,
+ node->params.depthwise_convolution_2d.input_padding_right,
+ node->params.depthwise_convolution_2d.input_padding_bottom,
+ node->params.depthwise_convolution_2d.input_padding_left,
+ node->params.depthwise_convolution_2d.kernel_height,
+ node->params.depthwise_convolution_2d.kernel_width,
+ node->params.depthwise_convolution_2d.subsampling_height,
+ node->params.depthwise_convolution_2d.subsampling_width,
+ node->params.depthwise_convolution_2d.dilation_height,
+ node->params.depthwise_convolution_2d.dilation_width,
+ node->params.depthwise_convolution_2d.input_channels /* groups */,
+ 1 /* group_input_channels */,
+ node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
+ node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
+ node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
+ values[node->inputs[1]].data,
+ values[node->inputs[2]].data,
+ node->activation.output_min,
+ node->activation.output_max,
+ node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
+ &runtime->opdata[i].operator_object);
+ }
if (status != xnn_status_success) {
goto error;
}
runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
+
runtime->opdata[i].inputs[0] = node->inputs[0];
runtime->opdata[i].outputs[0] = node->outputs[0];
break;
@@ -381,14 +461,25 @@
runtime->opdata[i].outputs[0] = node->outputs[0];
break;
case xnn_node_type_global_average_pooling_2d:
- status = xnn_create_global_average_pooling_nwc_f32(
- values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
- values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
- values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
- node->activation.output_min,
- node->activation.output_max,
- node->flags,
- &runtime->opdata[i].operator_object);
+ if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
+ status = xnn_create_global_average_pooling_ncw_f32(
+ values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
+ node->activation.output_min,
+ node->activation.output_max,
+ node->flags,
+ &runtime->opdata[i].operator_object);
+ } else {
+ assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+ assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+ status = xnn_create_global_average_pooling_nwc_f32(
+ values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
+ values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
+ values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
+ node->activation.output_min,
+ node->activation.output_max,
+ node->flags,
+ &runtime->opdata[i].operator_object);
+ }
if (status != xnn_status_success) {
goto error;
}
@@ -495,8 +586,26 @@
}
runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
- memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
- memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
+ if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
+ assert(values[node->inputs[0]].layout == xnn_layout_type_nchw);
+ assert(values[node->inputs[1]].layout == xnn_layout_type_nchw);
+ runtime->opdata[i].shape1.dim[0] = values[node->inputs[0]].shape.dim[0];
+ runtime->opdata[i].shape1.dim[1] = values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
+ if (values[node->inputs[0]].shape.num_dims > 2) {
+ memcpy(&runtime->opdata[i].shape1.dim[2], &values[node->inputs[0]].shape.dim[1], (values[node->inputs[0]].shape.num_dims - 2) * sizeof(size_t));
+ }
+ runtime->opdata[i].shape2.dim[0] = values[node->inputs[1]].shape.dim[0];
+ runtime->opdata[i].shape2.dim[1] = values[node->inputs[1]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
+ if (values[node->inputs[0]].shape.num_dims > 2) {
+ memcpy(&runtime->opdata[i].shape2.dim[2], &values[node->inputs[1]].shape.dim[1], (values[node->inputs[1]].shape.num_dims - 2) * sizeof(size_t));
+ }
+ } else {
+ assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+ assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+ assert(values[node->inputs[1]].layout == xnn_layout_type_nhwc);
+ memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
+ memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
+ }
runtime->opdata[i].inputs[0] = node->inputs[0];
runtime->opdata[i].inputs[1] = node->inputs[1];
runtime->opdata[i].outputs[0] = node->outputs[0];
@@ -812,6 +921,18 @@
runtime->blobs[opdata->outputs[0]].data,
runtime->threadpool);
break;
+ case xnn_operator_type_convolution_nchw_f32:
+ assert(runtime->blobs[opdata->inputs[0]].data != NULL);
+ assert(runtime->blobs[opdata->outputs[0]].data != NULL);
+ status = xnn_setup_convolution2d_nchw_f32(
+ opdata->operator_object,
+ opdata->batch_size,
+ opdata->input_height,
+ opdata->input_width,
+ runtime->blobs[opdata->inputs[0]].data,
+ runtime->blobs[opdata->outputs[0]].data,
+ runtime->threadpool);
+ break;
case xnn_operator_type_convolution_nhwc_f32:
assert(runtime->blobs[opdata->inputs[0]].data != NULL);
assert(runtime->blobs[opdata->outputs[0]].data != NULL);
@@ -883,6 +1004,17 @@
runtime->blobs[opdata->outputs[0]].data,
runtime->threadpool);
break;
+ case xnn_operator_type_global_average_pooling_ncw_f32:
+ assert(runtime->blobs[opdata->inputs[0]].data != NULL);
+ assert(runtime->blobs[opdata->outputs[0]].data != NULL);
+ status = xnn_setup_global_average_pooling_ncw_f32(
+ opdata->operator_object,
+ opdata->batch_size,
+ opdata->input_width,
+ runtime->blobs[opdata->inputs[0]].data,
+ runtime->blobs[opdata->outputs[0]].data,
+ runtime->threadpool);
+ break;
case xnn_operator_type_global_average_pooling_nwc_f32:
assert(runtime->blobs[opdata->inputs[0]].data != NULL);
assert(runtime->blobs[opdata->outputs[0]].data != NULL);
@@ -1066,7 +1198,8 @@
runtime->threadpool);
break;
default:
- xnn_log_fatal("unexpected operator type %d in operator #%zu", opdata->operator_object->type, i);
+ xnn_log_fatal("unexpected operator type %s in operator #%zu",
+ xnn_operator_type_to_string(opdata->operator_object->type), i);
XNN_UNREACHABLE;
}
if (status != xnn_status_success) {
diff --git a/src/subgraph.c b/src/subgraph.c
index 99be2a4..65a0de7 100644
--- a/src/subgraph.c
+++ b/src/subgraph.c
@@ -122,6 +122,320 @@
return new_node;
}
+#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW 1
+#define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
+#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
+#define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
+
+#if XNN_ENABLE_SPARSE
+uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
+ switch (node->type) {
+ case xnn_node_type_convolution_2d:
+ // Supported cases:
+ // - 1x1 convolution (no stride, no dilation, no padding, no groups)
+ // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
+ if (node->params.convolution_2d.groups != 1) {
+ return 0;
+ }
+ if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
+ return 0;
+ }
+ if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
+ if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
+ node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0)
+ {
+ return 0;
+ }
+ if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
+ return 0;
+ }
+ return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
+ } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
+ if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
+ node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1)
+ {
+ return 0;
+ }
+ if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
+ return 0;
+ }
+ if (node->params.convolution_2d.group_input_channels != 3) {
+ return 0;
+ }
+ return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
+ }
+ return 0;
+ case xnn_node_type_depthwise_convolution_2d:
+ // Supported cases:
+ // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
+ // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
+ // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
+ // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
+ if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
+ return 0;
+ }
+ if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
+ return 0;
+ }
+ if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
+ return 0;
+ }
+ if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
+ return 0;
+ }
+ switch (node->params.depthwise_convolution_2d.subsampling_height) {
+ case 1:
+ case 2:
+ break;
+ default:
+ return 0;
+ }
+ if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
+ return 0;
+ }
+ switch (node->params.depthwise_convolution_2d.kernel_height) {
+ case 3:
+ return node->params.depthwise_convolution_2d.input_padding_top == 1 &&
+ node->params.depthwise_convolution_2d.input_padding_right == 1 &&
+ node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
+ node->params.depthwise_convolution_2d.input_padding_left == 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
+ case 5:
+ return node->params.depthwise_convolution_2d.input_padding_top == 2 &&
+ node->params.depthwise_convolution_2d.input_padding_right == 2 &&
+ node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
+ node->params.depthwise_convolution_2d.input_padding_left == 2 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
+ default:
+ return 0;
+ }
+ case xnn_node_type_global_average_pooling_2d:
+ return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
+ case xnn_node_type_add2:
+ case xnn_node_type_multiply2:
+ assert(node->num_inputs == 2);
+ assert(node->num_outputs == 1);
+ if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
+ subgraph->values[node->inputs[1]].shape.num_dims != 4)
+ {
+ return 0;
+ }
+
+ if (subgraph->values[node->inputs[0]].data != NULL) {
+ // Check that the first input is representable as either a scalar, or a vector
+ size_t num_nonunit_dims = 0;
+ for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
+ if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
+ num_nonunit_dims += 1;
+ }
+ }
+ if (num_nonunit_dims > 1) {
+ return 0;
+ }
+ }
+
+ if (subgraph->values[node->inputs[1]].data != NULL) {
+ // Check that the second input is representable as either a scalar, or a vector
+ size_t num_nonunit_dims = 0;
+ for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
+ if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
+ num_nonunit_dims += 1;
+ }
+ }
+ if (num_nonunit_dims > 1) {
+ return 0;
+ }
+ }
+
+ return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
+ case xnn_node_type_abs:
+ case xnn_node_type_bankers_rounding:
+ case xnn_node_type_ceiling:
+ case xnn_node_type_clamp:
+ case xnn_node_type_floor:
+ case xnn_node_type_hardswish:
+ case xnn_node_type_leaky_relu:
+ case xnn_node_type_negate:
+ case xnn_node_type_sigmoid:
+ case xnn_node_type_square:
+ assert(node->num_inputs == 1);
+ assert(node->num_outputs == 1);
+ return subgraph->values[node->inputs[0]].shape.num_dims == 4 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
+ default:
+ return false;
+ }
+}
+
+static void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
+{
+ // Convert parts of the subgraph to NCHW for sparse inference
+ // Step 1: detect NCHW-compatible Nodes
+ // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
+ // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
+ // Step 4: switch Values' layout to NCHW
+ for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+ struct xnn_node* node = &subgraph->nodes[n];
+ node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
+ xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
+ n, xnn_node_type_to_string(node->type),
+ node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
+ node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
+ node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
+ }
+
+ // Run Shiloach-Vishkin connected components algorithm
+ bool update = false;
+ for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+ struct xnn_node* node = &subgraph->nodes[n];
+ node->cluster_leader = n;
+ if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
+ for (uint32_t i = 0; i < node->num_inputs; i++) {
+ const struct xnn_value* value = &subgraph->values[node->inputs[i]];
+ if (value->data != NULL) {
+ // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
+ // during the initial NCHW compatibility check for the Node.
+ continue;
+ }
+ if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
+ // External value, invalid cluster
+ node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+ continue;
+ }
+ const uint32_t producer_id = value->producer;
+ assert(producer_id != XNN_INVALID_NODE_ID);
+ assert(producer_id < n);
+ struct xnn_node* producer_node = &subgraph->nodes[producer_id];
+ if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
+ (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
+ {
+ producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
+ if (producer_node->cluster_leader != node->cluster_leader) {
+ producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
+ update = true;
+ }
+ } else {
+ node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+ }
+ }
+ }
+ }
+ while (update) {
+ update = false;
+ for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+ struct xnn_node* node = &subgraph->nodes[n];
+ if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
+ continue;
+ }
+
+ if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
+ continue;
+ }
+
+ for (uint32_t i = 0; i < node->num_inputs; i++) {
+ const struct xnn_value* value = &subgraph->values[node->inputs[i]];
+ if (value->data != NULL) {
+ // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
+ // during the initial NCHW compatibility check for the Node.
+ continue;
+ }
+ if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
+ // External value, invalid cluster
+ node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+ continue;
+ }
+ const uint32_t producer_id = value->producer;
+ assert(producer_id != XNN_INVALID_NODE_ID);
+ assert(producer_id < n);
+ struct xnn_node* producer_node = &subgraph->nodes[producer_id];
+ if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
+ (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
+ {
+ producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
+ if (producer_node->cluster_leader != node->cluster_leader) {
+ producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
+ update = true;
+ }
+ } else {
+ node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+ }
+ }
+ }
+ }
+ // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
+ for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+ struct xnn_node* node = &subgraph->nodes[n];
+ subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+ }
+ // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
+ for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+ struct xnn_node* node = &subgraph->nodes[n];
+ if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
+ continue;
+ }
+
+ if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
+ continue;
+ }
+
+ for (uint32_t i = 0; i < node->num_inputs; i++) {
+ struct xnn_value* value = &subgraph->values[node->inputs[i]];
+ if (value->data != NULL) {
+ // Static data, skip this input value because it doesn't have a producer Node.
+ continue;
+ }
+ assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
+ value->num_nchw_compatible_consumers += 1;
+ }
+ }
+ for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+ struct xnn_node* node = &subgraph->nodes[n];
+ if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
+ continue;
+ }
+
+ if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
+ continue;
+ }
+
+ for (uint32_t i = 0; i < node->num_inputs; i++) {
+ const struct xnn_value* value = &subgraph->values[node->inputs[i]];
+ if (value->data != NULL) {
+ // Static data, skip this input value because it doesn't have a producer Node.
+ continue;
+ }
+ assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
+ assert(value->num_nchw_compatible_consumers > 0);
+ if (value->num_nchw_compatible_consumers != value->num_consumers) {
+ subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+ }
+ }
+ }
+ for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+ struct xnn_node* node = &subgraph->nodes[n];
+ if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
+ continue;
+ }
+
+ if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
+ continue;
+ }
+
+ for (uint32_t i = 0; i < node->num_inputs; i++) {
+ struct xnn_value* value = &subgraph->values[node->inputs[i]];
+ if (value->data != NULL) {
+ // Static data, skip this input value because it doesn't have a producer Node.
+ continue;
+ }
+ assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
+ assert(value->num_nchw_compatible_consumers > 0);
+ assert(value->num_nchw_compatible_consumers == value->num_consumers);
+ if (value->layout != xnn_layout_type_nchw) {
+ value->layout = xnn_layout_type_nchw;
+ xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
+ }
+ }
+ }
+}
+#endif // XNN_ENABLE_SPARSE
+
enum xnn_status xnn_subgraph_optimize(
xnn_subgraph_t subgraph,
uint32_t flags)
@@ -294,6 +608,11 @@
}
}
}
+
+ #if XNN_ENABLE_SPARSE
+ xnn_subgraph_rewrite_for_nchw(subgraph);
+ #endif
+
return xnn_status_success;
}
diff --git a/src/xnnpack/subgraph.h b/src/xnnpack/subgraph.h
index dceef12..4d9475d 100644
--- a/src/xnnpack/subgraph.h
+++ b/src/xnnpack/subgraph.h
@@ -28,6 +28,11 @@
xnn_value_type_dense_tensor = 1,
};
+enum xnn_layout_type {
+ xnn_layout_type_nhwc = 0,
+ xnn_layout_type_nchw = 1,
+};
+
/// Abstraction for a collections of elements produced and consumed by nodes.
struct xnn_value {
/// Unique ID for the value.
@@ -56,6 +61,8 @@
/// If multiple inputs in a Node refer to this Value as input, the Node is counted as consumer multiple times.
/// If the Value is an external output, it counts as having an extra consumer.
uint32_t num_consumers;
+ uint32_t num_nchw_compatible_consumers;
+ enum xnn_layout_type layout;
};
struct xnn_blob {
@@ -182,6 +189,8 @@
uint32_t outputs[XNN_MAX_OUTPUTS];
uint32_t num_outputs;
uint32_t flags;
+ uint32_t layout_flags;
+ uint32_t cluster_leader;
};
struct xnn_operator_data {