Graph rewriting for sparse inference in NCHW layout

PiperOrigin-RevId: 317200356
diff --git a/BUILD.bazel b/BUILD.bazel
index 1d04170..1a1860f 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -3044,6 +3044,15 @@
     }),
 )
 
+cc_library(
+    name = "enable_sparse",
+    defines = select({
+        ":xnn_enable_sparse_explicit_true": ["XNN_ENABLE_SPARSE=1"],
+        ":xnn_enable_sparse_explicit_false": ["XNN_ENABLE_SPARSE=0"],
+        "//conditions:default": ["XNN_ENABLE_SPARSE=1"],
+    }),
+)
+
 xnnpack_cc_library(
     name = "operators",
     srcs = OPERATOR_SRCS + [
@@ -3129,6 +3138,7 @@
     visibility = xnnpack_visibility(),
     deps = [
         ":enable_assembly",
+        ":enable_sparse",
         ":logging_utils",
         ":memory_planner",
         ":operator_run",
@@ -3170,6 +3180,7 @@
     visibility = xnnpack_visibility(),
     deps = [
         ":enable_assembly",
+        ":enable_sparse",
         ":logging_utils",
         ":memory_planner_test_mode",
         ":operator_run_test_mode",
@@ -3216,6 +3227,7 @@
     visibility = xnnpack_visibility(),
     deps = [
         ":enable_assembly",
+        ":enable_sparse",
         ":logging_utils",
         ":memory_planner",
         ":operator_run",
@@ -5030,6 +5042,18 @@
     define_values = {"xnn_enable_assembly": "false"},
 )
 
+# Enables usage of sparse inference.
+config_setting(
+    name = "xnn_enable_sparse_explicit_true",
+    define_values = {"xnn_enable_sparse": "true"},
+)
+
+# Disables usage of sparse inference.
+config_setting(
+    name = "xnn_enable_sparse_explicit_false",
+    define_values = {"xnn_enable_sparse": "false"},
+)
+
 # Disables usage of HMP-aware optimizations.
 config_setting(
     name = "xnn_enable_hmp_explicit_false",
diff --git a/src/operators/convolution-nchw.c b/src/operators/convolution-nchw.c
index 1eb5639..1198616 100644
--- a/src/operators/convolution-nchw.c
+++ b/src/operators/convolution-nchw.c
@@ -167,9 +167,9 @@
   // + 1x1 convolution (no groups)
   // + 3x3 stride-2 with 3 input channels and NHWC input layout
   // + 3x3 stride-2 depthwise convolution with horizontal padding 1 & no vertical padding
-  // - 3x3 stride-1 depthwise convolution with horizontal padding 1 & no vertical padding
-  // - 5x5 stride-2 depthwise convolution with horizontal padding 2 & no vertical padding
-  // - 5x5 stride-1 depthwise convolution with horizontal padding 2 & no vertical padding
+  // + 3x3 stride-1 depthwise convolution with horizontal padding 1 & no vertical padding
+  // + 5x5 stride-2 depthwise convolution with horizontal padding 2 & no vertical padding
+  // + 5x5 stride-1 depthwise convolution with horizontal padding 2 & no vertical padding
   const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
   const bool is_1x1 = kernel_width == 1 && kernel_height == 1 && subsampling_height == 1 && subsampling_width == 1;
   const bool is_3x3 = kernel_width == 3 && kernel_height == 3 && dilation_height == 1 && dilation_width == 1;
diff --git a/src/runtime.c b/src/runtime.c
index 3f5327d..c430ac7 100644
--- a/src/runtime.c
+++ b/src/runtime.c
@@ -112,8 +112,26 @@
         }
         runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
         runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
-        memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
-        memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
+        if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
+          assert(values[node->inputs[0]].layout == xnn_layout_type_nchw);
+          assert(values[node->inputs[1]].layout == xnn_layout_type_nchw);
+          runtime->opdata[i].shape1.dim[0] = values[node->inputs[0]].shape.dim[0];
+          runtime->opdata[i].shape1.dim[1] = values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
+          if (values[node->inputs[0]].shape.num_dims > 2) {
+            memcpy(&runtime->opdata[i].shape1.dim[2], &values[node->inputs[0]].shape.dim[1], (values[node->inputs[0]].shape.num_dims - 2) * sizeof(size_t));
+          }
+          runtime->opdata[i].shape2.dim[0] = values[node->inputs[1]].shape.dim[0];
+          runtime->opdata[i].shape2.dim[1] = values[node->inputs[1]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
+          if (values[node->inputs[0]].shape.num_dims > 2) {
+            memcpy(&runtime->opdata[i].shape2.dim[2], &values[node->inputs[1]].shape.dim[1], (values[node->inputs[1]].shape.num_dims - 2) * sizeof(size_t));
+          }
+        } else {
+          assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+          assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+          assert(values[node->inputs[1]].layout == xnn_layout_type_nhwc);
+          memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
+          memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
+        }
         runtime->opdata[i].inputs[0] = node->inputs[0];
         runtime->opdata[i].inputs[1] = node->inputs[1];
         runtime->opdata[i].outputs[0] = node->outputs[0];
@@ -212,28 +230,57 @@
         runtime->opdata[i].outputs[0] = node->outputs[0];
         break;
       case xnn_node_type_convolution_2d:
-        status = xnn_create_convolution2d_nhwc_f32(
-          node->params.convolution_2d.input_padding_top,
-          node->params.convolution_2d.input_padding_right,
-          node->params.convolution_2d.input_padding_bottom,
-          node->params.convolution_2d.input_padding_left,
-          node->params.convolution_2d.kernel_height,
-          node->params.convolution_2d.kernel_width,
-          node->params.convolution_2d.subsampling_height,
-          node->params.convolution_2d.subsampling_width,
-          node->params.convolution_2d.dilation_height,
-          node->params.convolution_2d.dilation_width,
-          node->params.convolution_2d.groups,
-          node->params.convolution_2d.group_input_channels,
-          node->params.convolution_2d.group_output_channels,
-          node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
-          node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
-          values[node->inputs[1]].data,
-          values[node->inputs[2]].data,
-          node->activation.output_min,
-          node->activation.output_max,
-          node->flags,
-          &runtime->opdata[i].operator_object);
+        assert(values[node->inputs[1]].data != NULL);
+        assert(values[node->inputs[2]].data != NULL);
+        if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
+          status = xnn_create_convolution2d_nchw_f32(
+            node->params.convolution_2d.input_padding_top,
+            node->params.convolution_2d.input_padding_right,
+            node->params.convolution_2d.input_padding_bottom,
+            node->params.convolution_2d.input_padding_left,
+            node->params.convolution_2d.kernel_height,
+            node->params.convolution_2d.kernel_width,
+            node->params.convolution_2d.subsampling_height,
+            node->params.convolution_2d.subsampling_width,
+            node->params.convolution_2d.dilation_height,
+            node->params.convolution_2d.dilation_width,
+            node->params.convolution_2d.groups,
+            node->params.convolution_2d.group_input_channels,
+            node->params.convolution_2d.group_output_channels,
+            node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
+            node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
+            values[node->inputs[1]].data,
+            values[node->inputs[2]].data,
+            node->activation.output_min,
+            node->activation.output_max,
+            node->flags | (values[node->inputs[0]].layout == xnn_layout_type_nhwc ? XNN_FLAG_INPUT_NHWC : 0),
+            &runtime->opdata[i].operator_object);
+        } else {
+          assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+          assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+          status = xnn_create_convolution2d_nhwc_f32(
+            node->params.convolution_2d.input_padding_top,
+            node->params.convolution_2d.input_padding_right,
+            node->params.convolution_2d.input_padding_bottom,
+            node->params.convolution_2d.input_padding_left,
+            node->params.convolution_2d.kernel_height,
+            node->params.convolution_2d.kernel_width,
+            node->params.convolution_2d.subsampling_height,
+            node->params.convolution_2d.subsampling_width,
+            node->params.convolution_2d.dilation_height,
+            node->params.convolution_2d.dilation_width,
+            node->params.convolution_2d.groups,
+            node->params.convolution_2d.group_input_channels,
+            node->params.convolution_2d.group_output_channels,
+            node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
+            node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
+            values[node->inputs[1]].data,
+            values[node->inputs[2]].data,
+            node->activation.output_min,
+            node->activation.output_max,
+            node->flags,
+            &runtime->opdata[i].operator_object);
+        }
         if (status != xnn_status_success) {
           goto error;
         }
@@ -260,6 +307,8 @@
         runtime->opdata[i].outputs[0] = node->outputs[0];
         break;
       case xnn_node_type_deconvolution_2d:
+        assert(values[node->inputs[1]].data != NULL);
+        assert(values[node->inputs[2]].data != NULL);
         status = xnn_create_deconvolution2d_nhwc_f32(
           node->params.deconvolution_2d.padding_top,
           node->params.deconvolution_2d.padding_right,
@@ -294,34 +343,65 @@
         runtime->opdata[i].outputs[0] = node->outputs[0];
         break;
       case xnn_node_type_depthwise_convolution_2d:
-        status = xnn_create_convolution2d_nhwc_f32(
-          node->params.depthwise_convolution_2d.input_padding_top,
-          node->params.depthwise_convolution_2d.input_padding_right,
-          node->params.depthwise_convolution_2d.input_padding_bottom,
-          node->params.depthwise_convolution_2d.input_padding_left,
-          node->params.depthwise_convolution_2d.kernel_height,
-          node->params.depthwise_convolution_2d.kernel_width,
-          node->params.depthwise_convolution_2d.subsampling_height,
-          node->params.depthwise_convolution_2d.subsampling_width,
-          node->params.depthwise_convolution_2d.dilation_height,
-          node->params.depthwise_convolution_2d.dilation_width,
-          node->params.depthwise_convolution_2d.input_channels /* groups */,
-          1 /* group_input_channels */,
-          node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
-          node->params.depthwise_convolution_2d.input_channels /* input_pixel_stride */,
-          node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_pixel_stride */,
-          values[node->inputs[1]].data,
-          values[node->inputs[2]].data,
-          node->activation.output_min,
-          node->activation.output_max,
-          node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
-          &runtime->opdata[i].operator_object);
+        assert(values[node->inputs[1]].data != NULL);
+        assert(values[node->inputs[2]].data != NULL);
+        if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
+          assert(values[node->inputs[0]].layout == xnn_layout_type_nchw);
+          status = xnn_create_convolution2d_nchw_f32(
+            node->params.depthwise_convolution_2d.input_padding_top,
+            node->params.depthwise_convolution_2d.input_padding_right,
+            node->params.depthwise_convolution_2d.input_padding_bottom,
+            node->params.depthwise_convolution_2d.input_padding_left,
+            node->params.depthwise_convolution_2d.kernel_height,
+            node->params.depthwise_convolution_2d.kernel_width,
+            node->params.depthwise_convolution_2d.subsampling_height,
+            node->params.depthwise_convolution_2d.subsampling_width,
+            node->params.depthwise_convolution_2d.dilation_height,
+            node->params.depthwise_convolution_2d.dilation_width,
+            node->params.depthwise_convolution_2d.input_channels /* groups */,
+            1 /* group_input_channels */,
+            node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
+            node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
+            node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
+            values[node->inputs[1]].data,
+            values[node->inputs[2]].data,
+            node->activation.output_min,
+            node->activation.output_max,
+            node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
+            &runtime->opdata[i].operator_object);
+        } else {
+          assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+          assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+          status = xnn_create_convolution2d_nhwc_f32(
+            node->params.depthwise_convolution_2d.input_padding_top,
+            node->params.depthwise_convolution_2d.input_padding_right,
+            node->params.depthwise_convolution_2d.input_padding_bottom,
+            node->params.depthwise_convolution_2d.input_padding_left,
+            node->params.depthwise_convolution_2d.kernel_height,
+            node->params.depthwise_convolution_2d.kernel_width,
+            node->params.depthwise_convolution_2d.subsampling_height,
+            node->params.depthwise_convolution_2d.subsampling_width,
+            node->params.depthwise_convolution_2d.dilation_height,
+            node->params.depthwise_convolution_2d.dilation_width,
+            node->params.depthwise_convolution_2d.input_channels /* groups */,
+            1 /* group_input_channels */,
+            node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
+            node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
+            node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
+            values[node->inputs[1]].data,
+            values[node->inputs[2]].data,
+            node->activation.output_min,
+            node->activation.output_max,
+            node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
+            &runtime->opdata[i].operator_object);
+        }
         if (status != xnn_status_success) {
           goto error;
         }
         runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
         runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
         runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
+
         runtime->opdata[i].inputs[0] = node->inputs[0];
         runtime->opdata[i].outputs[0] = node->outputs[0];
         break;
@@ -381,14 +461,25 @@
         runtime->opdata[i].outputs[0] = node->outputs[0];
         break;
       case xnn_node_type_global_average_pooling_2d:
-        status = xnn_create_global_average_pooling_nwc_f32(
-          values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
-          values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
-          values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
-          node->activation.output_min,
-          node->activation.output_max,
-          node->flags,
-          &runtime->opdata[i].operator_object);
+        if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
+          status = xnn_create_global_average_pooling_ncw_f32(
+            values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
+            node->activation.output_min,
+            node->activation.output_max,
+            node->flags,
+            &runtime->opdata[i].operator_object);
+        } else {
+          assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+          assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+          status = xnn_create_global_average_pooling_nwc_f32(
+            values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
+            values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
+            values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
+            node->activation.output_min,
+            node->activation.output_max,
+            node->flags,
+            &runtime->opdata[i].operator_object);
+        }
         if (status != xnn_status_success) {
           goto error;
         }
@@ -495,8 +586,26 @@
         }
         runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
         runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
-        memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
-        memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
+        if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
+          assert(values[node->inputs[0]].layout == xnn_layout_type_nchw);
+          assert(values[node->inputs[1]].layout == xnn_layout_type_nchw);
+          runtime->opdata[i].shape1.dim[0] = values[node->inputs[0]].shape.dim[0];
+          runtime->opdata[i].shape1.dim[1] = values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
+          if (values[node->inputs[0]].shape.num_dims > 2) {
+            memcpy(&runtime->opdata[i].shape1.dim[2], &values[node->inputs[0]].shape.dim[1], (values[node->inputs[0]].shape.num_dims - 2) * sizeof(size_t));
+          }
+          runtime->opdata[i].shape2.dim[0] = values[node->inputs[1]].shape.dim[0];
+          runtime->opdata[i].shape2.dim[1] = values[node->inputs[1]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
+          if (values[node->inputs[0]].shape.num_dims > 2) {
+            memcpy(&runtime->opdata[i].shape2.dim[2], &values[node->inputs[1]].shape.dim[1], (values[node->inputs[1]].shape.num_dims - 2) * sizeof(size_t));
+          }
+        } else {
+          assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
+          assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
+          assert(values[node->inputs[1]].layout == xnn_layout_type_nhwc);
+          memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
+          memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
+        }
         runtime->opdata[i].inputs[0] = node->inputs[0];
         runtime->opdata[i].inputs[1] = node->inputs[1];
         runtime->opdata[i].outputs[0] = node->outputs[0];
@@ -812,6 +921,18 @@
           runtime->blobs[opdata->outputs[0]].data,
           runtime->threadpool);
         break;
+      case xnn_operator_type_convolution_nchw_f32:
+        assert(runtime->blobs[opdata->inputs[0]].data != NULL);
+        assert(runtime->blobs[opdata->outputs[0]].data != NULL);
+        status = xnn_setup_convolution2d_nchw_f32(
+          opdata->operator_object,
+          opdata->batch_size,
+          opdata->input_height,
+          opdata->input_width,
+          runtime->blobs[opdata->inputs[0]].data,
+          runtime->blobs[opdata->outputs[0]].data,
+          runtime->threadpool);
+        break;
       case xnn_operator_type_convolution_nhwc_f32:
         assert(runtime->blobs[opdata->inputs[0]].data != NULL);
         assert(runtime->blobs[opdata->outputs[0]].data != NULL);
@@ -883,6 +1004,17 @@
           runtime->blobs[opdata->outputs[0]].data,
           runtime->threadpool);
         break;
+      case xnn_operator_type_global_average_pooling_ncw_f32:
+        assert(runtime->blobs[opdata->inputs[0]].data != NULL);
+        assert(runtime->blobs[opdata->outputs[0]].data != NULL);
+        status = xnn_setup_global_average_pooling_ncw_f32(
+          opdata->operator_object,
+          opdata->batch_size,
+          opdata->input_width,
+          runtime->blobs[opdata->inputs[0]].data,
+          runtime->blobs[opdata->outputs[0]].data,
+          runtime->threadpool);
+        break;
       case xnn_operator_type_global_average_pooling_nwc_f32:
         assert(runtime->blobs[opdata->inputs[0]].data != NULL);
         assert(runtime->blobs[opdata->outputs[0]].data != NULL);
@@ -1066,7 +1198,8 @@
           runtime->threadpool);
         break;
       default:
-        xnn_log_fatal("unexpected operator type %d in operator #%zu", opdata->operator_object->type, i);
+        xnn_log_fatal("unexpected operator type %s in operator #%zu",
+          xnn_operator_type_to_string(opdata->operator_object->type), i);
         XNN_UNREACHABLE;
     }
     if (status != xnn_status_success) {
diff --git a/src/subgraph.c b/src/subgraph.c
index 99be2a4..65a0de7 100644
--- a/src/subgraph.c
+++ b/src/subgraph.c
@@ -122,6 +122,320 @@
   return new_node;
 }
 
+#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW      1
+#define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
+#define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
+#define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
+
+#if XNN_ENABLE_SPARSE
+uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
+  switch (node->type) {
+    case xnn_node_type_convolution_2d:
+      // Supported cases:
+      // - 1x1 convolution (no stride, no dilation, no padding, no groups)
+      // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
+      if (node->params.convolution_2d.groups != 1) {
+        return 0;
+      }
+      if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
+        return 0;
+      }
+      if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
+        if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
+             node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0)
+        {
+          return 0;
+        }
+        if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
+          return 0;
+        }
+        return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
+      } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
+        if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
+            node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1)
+        {
+          return 0;
+        }
+        if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
+          return 0;
+        }
+        if (node->params.convolution_2d.group_input_channels != 3) {
+          return 0;
+        }
+        return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
+      }
+      return 0;
+    case xnn_node_type_depthwise_convolution_2d:
+      // Supported cases:
+      // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
+      // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
+      // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
+      // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
+      if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
+        return 0;
+      }
+      if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
+        return 0;
+      }
+      if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
+        return 0;
+      }
+      if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
+        return 0;
+      }
+      switch (node->params.depthwise_convolution_2d.subsampling_height) {
+        case 1:
+        case 2:
+          break;
+        default:
+          return 0;
+      }
+      if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
+        return 0;
+      }
+      switch (node->params.depthwise_convolution_2d.kernel_height) {
+        case 3:
+          return node->params.depthwise_convolution_2d.input_padding_top == 1 &&
+                 node->params.depthwise_convolution_2d.input_padding_right == 1 &&
+                 node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
+                 node->params.depthwise_convolution_2d.input_padding_left == 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
+        case 5:
+          return node->params.depthwise_convolution_2d.input_padding_top == 2 &&
+                 node->params.depthwise_convolution_2d.input_padding_right == 2 &&
+                 node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
+                 node->params.depthwise_convolution_2d.input_padding_left == 2 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
+        default:
+          return 0;
+      }
+    case xnn_node_type_global_average_pooling_2d:
+      return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
+    case xnn_node_type_add2:
+    case xnn_node_type_multiply2:
+      assert(node->num_inputs == 2);
+      assert(node->num_outputs == 1);
+      if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
+          subgraph->values[node->inputs[1]].shape.num_dims != 4)
+      {
+        return 0;
+      }
+
+      if (subgraph->values[node->inputs[0]].data != NULL) {
+        // Check that the first input is representable as either a scalar, or a vector
+        size_t num_nonunit_dims = 0;
+        for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
+          if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
+            num_nonunit_dims += 1;
+          }
+        }
+        if (num_nonunit_dims > 1) {
+          return 0;
+        }
+      }
+
+      if (subgraph->values[node->inputs[1]].data != NULL) {
+        // Check that the second input is representable as either a scalar, or a vector
+        size_t num_nonunit_dims = 0;
+        for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
+          if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
+            num_nonunit_dims += 1;
+          }
+        }
+        if (num_nonunit_dims > 1) {
+          return 0;
+        }
+      }
+
+      return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
+    case xnn_node_type_abs:
+    case xnn_node_type_bankers_rounding:
+    case xnn_node_type_ceiling:
+    case xnn_node_type_clamp:
+    case xnn_node_type_floor:
+    case xnn_node_type_hardswish:
+    case xnn_node_type_leaky_relu:
+    case xnn_node_type_negate:
+    case xnn_node_type_sigmoid:
+    case xnn_node_type_square:
+      assert(node->num_inputs == 1);
+      assert(node->num_outputs == 1);
+      return subgraph->values[node->inputs[0]].shape.num_dims == 4 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
+    default:
+      return false;
+  }
+}
+
+static void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
+{
+  // Convert parts of the subgraph to NCHW for sparse inference
+  // Step 1: detect NCHW-compatible Nodes
+  // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
+  // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
+  // Step 4: switch Values' layout to NCHW
+  for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+    struct xnn_node* node = &subgraph->nodes[n];
+    node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
+    xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
+      n, xnn_node_type_to_string(node->type),
+      node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
+      node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
+      node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
+  }
+
+  // Run Shiloach-Vishkin connected components algorithm
+  bool update = false;
+  for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+    struct xnn_node* node = &subgraph->nodes[n];
+    node->cluster_leader = n;
+    if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
+      for (uint32_t i = 0; i < node->num_inputs; i++) {
+        const struct xnn_value* value = &subgraph->values[node->inputs[i]];
+        if (value->data != NULL) {
+          // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
+          // during the initial NCHW compatibility check for the Node.
+          continue;
+        }
+        if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
+          // External value, invalid cluster
+          node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+          continue;
+        }
+        const uint32_t producer_id = value->producer;
+        assert(producer_id != XNN_INVALID_NODE_ID);
+        assert(producer_id < n);
+        struct xnn_node* producer_node = &subgraph->nodes[producer_id];
+        if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
+            (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
+        {
+          producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
+          if (producer_node->cluster_leader != node->cluster_leader) {
+            producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
+            update = true;
+          }
+        } else {
+          node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+        }
+      }
+    }
+  }
+  while (update) {
+    update = false;
+    for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+      struct xnn_node* node = &subgraph->nodes[n];
+      if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
+        continue;
+      }
+
+      if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
+        continue;
+      }
+
+      for (uint32_t i = 0; i < node->num_inputs; i++) {
+        const struct xnn_value* value = &subgraph->values[node->inputs[i]];
+        if (value->data != NULL) {
+          // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
+          // during the initial NCHW compatibility check for the Node.
+          continue;
+        }
+        if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
+          // External value, invalid cluster
+          node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+          continue;
+        }
+        const uint32_t producer_id = value->producer;
+        assert(producer_id != XNN_INVALID_NODE_ID);
+        assert(producer_id < n);
+        struct xnn_node* producer_node = &subgraph->nodes[producer_id];
+        if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
+            (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
+        {
+          producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
+          if (producer_node->cluster_leader != node->cluster_leader) {
+            producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
+            update = true;
+          }
+        } else {
+          node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+        }
+      }
+    }
+  }
+  // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
+  for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+    struct xnn_node* node = &subgraph->nodes[n];
+    subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+  }
+  // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
+  for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+    struct xnn_node* node = &subgraph->nodes[n];
+    if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
+      continue;
+    }
+
+    if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
+      continue;
+    }
+
+    for (uint32_t i = 0; i < node->num_inputs; i++) {
+      struct xnn_value* value = &subgraph->values[node->inputs[i]];
+      if (value->data != NULL) {
+        // Static data, skip this input value because it doesn't have a producer Node.
+        continue;
+      }
+      assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
+      value->num_nchw_compatible_consumers += 1;
+    }
+  }
+  for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+    struct xnn_node* node = &subgraph->nodes[n];
+    if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
+      continue;
+    }
+
+    if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
+      continue;
+    }
+
+    for (uint32_t i = 0; i < node->num_inputs; i++) {
+      const struct xnn_value* value = &subgraph->values[node->inputs[i]];
+      if (value->data != NULL) {
+        // Static data, skip this input value because it doesn't have a producer Node.
+        continue;
+      }
+      assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
+      assert(value->num_nchw_compatible_consumers > 0);
+      if (value->num_nchw_compatible_consumers != value->num_consumers) {
+        subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
+      }
+    }
+  }
+  for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
+    struct xnn_node* node = &subgraph->nodes[n];
+    if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
+      continue;
+    }
+
+    if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
+      continue;
+    }
+
+    for (uint32_t i = 0; i < node->num_inputs; i++) {
+      struct xnn_value* value = &subgraph->values[node->inputs[i]];
+      if (value->data != NULL) {
+        // Static data, skip this input value because it doesn't have a producer Node.
+        continue;
+      }
+      assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
+      assert(value->num_nchw_compatible_consumers > 0);
+      assert(value->num_nchw_compatible_consumers == value->num_consumers);
+      if (value->layout != xnn_layout_type_nchw) {
+        value->layout = xnn_layout_type_nchw;
+        xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
+      }
+    }
+  }
+}
+#endif  // XNN_ENABLE_SPARSE
+
 enum xnn_status xnn_subgraph_optimize(
   xnn_subgraph_t subgraph,
   uint32_t flags)
@@ -294,6 +608,11 @@
       }
     }
   }
+
+  #if XNN_ENABLE_SPARSE
+    xnn_subgraph_rewrite_for_nchw(subgraph);
+  #endif
+
   return xnn_status_success;
 }
 
diff --git a/src/xnnpack/subgraph.h b/src/xnnpack/subgraph.h
index dceef12..4d9475d 100644
--- a/src/xnnpack/subgraph.h
+++ b/src/xnnpack/subgraph.h
@@ -28,6 +28,11 @@
   xnn_value_type_dense_tensor = 1,
 };
 
+enum xnn_layout_type {
+  xnn_layout_type_nhwc = 0,
+  xnn_layout_type_nchw = 1,
+};
+
 /// Abstraction for a collections of elements produced and consumed by nodes.
 struct xnn_value {
   /// Unique ID for the value.
@@ -56,6 +61,8 @@
   /// If multiple inputs in a Node refer to this Value as input, the Node is counted as consumer multiple times.
   /// If the Value is an external output, it counts as having an extra consumer.
   uint32_t num_consumers;
+  uint32_t num_nchw_compatible_consumers;
+  enum xnn_layout_type layout;
 };
 
 struct xnn_blob {
@@ -182,6 +189,8 @@
   uint32_t outputs[XNN_MAX_OUTPUTS];
   uint32_t num_outputs;
   uint32_t flags;
+  uint32_t layout_flags;
+  uint32_t cluster_leader;
 };
 
 struct xnn_operator_data {