Refactor Depth-To-Space operator

PiperOrigin-RevId: 342119823
diff --git a/include/xnnpack.h b/include/xnnpack.h
index 86ad344..fae1274 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -1707,7 +1707,7 @@
   pthreadpool_t threadpool);
 
 enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x32(
-  size_t channels,
+  size_t output_channels,
   size_t input_pixel_stride,
   size_t output_pixel_stride,
   uint32_t block_size,
@@ -1719,8 +1719,6 @@
   size_t batch_size,
   size_t input_height,
   size_t input_width,
-  size_t output_height,
-  size_t output_width,
   const void* input,
   void* output,
   pthreadpool_t threadpool);
diff --git a/src/operator-strings.c b/src/operator-strings.c
index 5971126..050f29d 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -63,7 +63,7 @@
     case xnn_operator_type_deconvolution_nhwc_qu8:
       return "Deconvolution (NHWC, QU8)";
     case xnn_operator_type_depth_to_space_nchw2nhwc_x32:
-      return "DepthToSpace (NCHW2NHWC, X32)";
+      return "Depth To Space (NCHW2NHWC, X32)";
     case xnn_operator_type_divide_nd_f32:
       return "Divide (ND, F32)";
     case xnn_operator_type_floor_nc_f32:
diff --git a/src/operators/depth-to-space-nchw2nhwc.c b/src/operators/depth-to-space-nchw2nhwc.c
index 3b6fd82..b3e3fab 100644
--- a/src/operators/depth-to-space-nchw2nhwc.c
+++ b/src/operators/depth-to-space-nchw2nhwc.c
@@ -14,7 +14,7 @@
 #include <xnnpack/params.h>
 
 enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x32(
-    size_t channels,
+    size_t output_channels,
     size_t input_pixel_stride,
     size_t output_pixel_stride,
     uint32_t block_size,
@@ -32,14 +32,14 @@
 
   status = xnn_status_invalid_parameter;
 
-  if (channels == 0) {
+  if (output_channels == 0) {
     xnn_log_error(
-        "failed to create %s operator with %zu channels: number of channels must be non-zero",
-        xnn_operator_type_to_string(xnn_operator_type_depth_to_space_nchw2nhwc_x32), channels);
+        "failed to create %s operator with %zu output channels: number of channels must be non-zero",
+        xnn_operator_type_to_string(xnn_operator_type_depth_to_space_nchw2nhwc_x32), output_channels);
     goto error;
   }
 
-  if (block_size < 2) {
+  if (block_size <= 1) {
     xnn_log_error(
         "failed to create %s operator with %u block size: block size must be greater than 1",
         xnn_operator_type_to_string(xnn_operator_type_depth_to_space_nchw2nhwc_x32), block_size);
@@ -56,7 +56,7 @@
     goto error;
   }
 
-  depth_to_space_op->channels = channels;
+  depth_to_space_op->channels = output_channels;
   depth_to_space_op->input_pixel_stride = input_pixel_stride;
   depth_to_space_op->output_pixel_stride = output_pixel_stride;
   depth_to_space_op->block_size = block_size;
@@ -80,8 +80,6 @@
     size_t batch_size,
     size_t input_height,
     size_t input_width,
-    size_t output_height,
-    size_t output_width,
     const void* input,
     void* output,
     pthreadpool_t threadpool)
@@ -107,26 +105,21 @@
     return xnn_status_invalid_parameter;
   }
 
-  if (output_width == 0 || output_height == 0) {
-    xnn_log_error(
-        "failed to setup %s operator with %zux%zu output: output dimensions must be non-zero",
-        xnn_operator_type_to_string(xnn_operator_type_depth_to_space_nchw2nhwc_x32), output_width, output_height);
-    return xnn_status_invalid_parameter;
-  }
-
   if (batch_size == 0) {
     depth_to_space_op->state = xnn_run_state_skip;
     return xnn_status_success;
   }
 
+  const size_t block_size = depth_to_space_op->block_size;
+  const size_t output_height = input_height * block_size;
+  const size_t output_width = input_width * block_size;
   depth_to_space_op->context.depth_to_space_chw = (struct depth_to_space_chw2hwc_context) {
-    .output_channels = depth_to_space_op->output_pixel_stride,
+    .output_channels = depth_to_space_op->channels,
     .input_height = input_height,
     .input_width = input_width,
-    .block_size = depth_to_space_op->block_size,
+    .block_size = block_size,
     .input = input,
     .output = output,
-    // TODO(artsiom,kartynnik): Check with maratek@ for additional padding at the end of the image
     .input_batch_stride = depth_to_space_op->input_pixel_stride * input_height * input_width * sizeof(float),
     .output_batch_stride = depth_to_space_op->output_pixel_stride * output_height * output_width * sizeof(float),
     .input_channel_stride = input_height * input_width * sizeof(float),
diff --git a/src/runtime.c b/src/runtime.c
index 462904a..1a4abb2 100644
--- a/src/runtime.c
+++ b/src/runtime.c
@@ -393,9 +393,9 @@
         status = xnn_status_unsupported_parameter;
         if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
           status = xnn_create_depth_to_space_nchw2nhwc_x32(
-              values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
+              values[node->outputs[0]].shape.dim[values[node->outputs[0]].shape.num_dims - 1] /* output channels */,
               values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
-              values[node->outputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
+              values[node->outputs[0]].shape.dim[values[node->outputs[0]].shape.num_dims - 1] /* output stride */,
               node->params.depth_to_space.block_size,
               node->flags,
               &runtime->opdata[i].operator_object);
@@ -1068,8 +1068,6 @@
             opdata->batch_size,
             opdata->input_height,
             opdata->input_width,
-            opdata->output_height,
-            opdata->output_width,
             runtime->blobs[opdata->inputs[0]].data,
             runtime->blobs[opdata->outputs[0]].data,
             runtime->threadpool);
diff --git a/src/subgraph-strings.c b/src/subgraph-strings.c
index eb423e6..0b8d2e9 100644
--- a/src/subgraph-strings.c
+++ b/src/subgraph-strings.c
@@ -39,7 +39,7 @@
     case xnn_node_type_depthwise_convolution_2d:
       return "Depthwise Convolution 2D";
     case xnn_node_type_depth_to_space:
-      return "Depth to Space";
+      return "Depth To Space";
     case xnn_node_type_divide:
       return "Divide";
     case xnn_node_type_fully_connected:
diff --git a/test/depth-to-space-operator-tester.h b/test/depth-to-space-operator-tester.h
index 1f575d2..b660b54 100644
--- a/test/depth-to-space-operator-tester.h
+++ b/test/depth-to-space-operator-tester.h
@@ -49,6 +49,14 @@
     return this->input_width_;
   }
 
+  inline size_t output_height() const {
+    return input_height() * block_size();
+  }
+
+  inline size_t output_width() const {
+    return input_width() * block_size();
+  }
+
   inline DepthToSpaceOperatorTester& block_size(size_t block_size) {
     assert(block_size >= 2);
     this->block_size_ = block_size;
@@ -59,14 +67,18 @@
     return this->block_size_;
   }
 
-  inline DepthToSpaceOperatorTester& input_channels(size_t input_channels) {
-    assert(input_channels != 0);
-    this->input_channels_ = input_channels;
+  inline size_t input_channels() const {
+    return output_channels() * block_size() * block_size();
+  }
+
+  inline DepthToSpaceOperatorTester& output_channels(size_t output_channels) {
+    assert(output_channels != 0);
+    this->output_channels_ = output_channels;
     return *this;
   }
 
-  inline size_t input_channels() const {
-    return this->input_channels_;
+  inline size_t output_channels() const {
+    return this->output_channels_;
   }
 
   inline DepthToSpaceOperatorTester& batch_size(size_t batch_size) {
@@ -99,21 +111,16 @@
   void TestNCHW2NHWCxF32() const {
     std::random_device random_device;
     auto rng = std::mt19937(random_device());
-    auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
+    auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(), rng);
 
-    ASSERT_EQ(0, input_channels() %  (block_size() * block_size()));
+    size_t output_height_stride = output_width() * output_channels();
+    size_t output_width_stride = output_channels();
 
-    size_t output_height = input_height() * block_size();
-    size_t output_width = input_width() * block_size();
-    size_t output_channels =  input_channels() / block_size() / block_size();
-    size_t output_height_stride = output_width * output_channels;
-    size_t output_width_stride = output_channels;
-
-    std::vector<float> input((batch_size() * input_height() * input_width() - 1) * input_channels() + input_channels() + XNN_EXTRA_BYTES / sizeof(float));
-    std::vector<float> output((batch_size() * output_height * output_width - 1) * output_channels + output_channels);
+    std::vector<int32_t> input(batch_size() * input_height() * input_width() * input_channels() + XNN_EXTRA_BYTES / sizeof(uint32_t));
+    std::vector<int32_t> output(batch_size() * output_height() * output_width() * output_channels());
     for (size_t iteration = 0; iteration < iterations(); iteration++) {
-      std::generate(input.begin(), input.end(), std::ref(f32rng));
-      std::fill(output.begin(), output.end(), std::nanf(""));
+      std::generate(input.begin(), input.end(), std::ref(i32rng));
+      std::fill(output.begin(), output.end(), INT32_C(0xDEADBEAF));
 
       // Create, setup, run, and destroy Depth To Space operator.
       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
@@ -121,7 +128,7 @@
 
       ASSERT_EQ(xnn_status_success,
                 xnn_create_depth_to_space_nchw2nhwc_x32(
-                    input_channels(), input_channels(), output_channels,
+                    output_channels(), input_channels(), output_channels(),
                     block_size(), 0, &depth_to_space_op));
       ASSERT_NE(nullptr, depth_to_space_op);
 
@@ -130,32 +137,34 @@
 
       ASSERT_EQ(xnn_status_success,
                 xnn_setup_depth_to_space_nchw2nhwc_x32(
-                    depth_to_space_op, batch_size(), input_height(),
-                    input_width(), output_height, output_width,
+                    depth_to_space_op,
+                    batch_size(), input_height(), input_width(),
                     input.data(), output.data(), nullptr /* thread pool */));
 
       ASSERT_EQ(xnn_status_success,
         xnn_run_operator(depth_to_space_op, nullptr /* thread pool */));
 
       // Verify results.
-      for (size_t batch_index = 0; batch_index < batch_size(); batch_index++) {
-        for (size_t iy = 0; iy < input_height(); ++iy) {
-          for (size_t by = 0; by < block_size(); ++by) {
-            for (size_t ix = 0; ix < input_width(); ++ix) {
-              for (size_t bx = 0; bx < block_size(); ++bx) {
-                for (size_t c = 0; c < output_channels; ++c) {
-                  size_t input_batch_offset = batch_index * input_height() * input_width() * input_channels();
-                  size_t input_offset = input_batch_offset + (c * block_size() * block_size() + by * block_size() + bx) * input_channel_stride() + iy * input_height_stride() + ix;
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t iy = 0; iy < input_height(); iy++) {
+          for (size_t by = 0; by < block_size(); by++) {
+            for (size_t ix = 0; ix < input_width(); ix++) {
+              for (size_t bx = 0; bx < block_size(); bx++) {
+                for (size_t oc = 0; oc < output_channels(); oc++) {
+                  const size_t input_offset = i * input_height() * input_width() * input_channels() +
+                    (oc * block_size() * block_size() + by * block_size() + bx) * input_channel_stride() + iy * input_height_stride() + ix;
                   ASSERT_LT(input_offset, input.size());
 
-                  size_t output_batch_offset = batch_index * output_height * output_width * output_channels;
-                  size_t output_offset = output_batch_offset + (iy * block_size() + by) * output_height_stride + (ix * block_size() + bx) * output_width_stride + c;
+                  const size_t output_offset = i * output_height() * output_width() * output_channels() + (iy * block_size() + by) * output_height_stride + (ix * block_size() + bx) * output_width_stride + oc;
                   ASSERT_LT(output_offset, output.size());
 
                   ASSERT_EQ(output[output_offset], input[input_offset])
-                                << "iy = " << iy << ", " << "by = " << by << ", "
-                                << "ix = " << ix << ", " << "bx = " << bx << ", "
-                                << "c = " << c;
+                    << "batch " << i << " / " << batch_size()
+                    << "input x " << ix << " / " << input_width()
+                    << ", input y " << iy << " / " << input_height()
+                    << ", block x " << bx << " / " << block_size()
+                    << ", block y " << by << " / " << block_size()
+                    << ", output channel " << oc << " / " << output_channels();
                 }
               }
             }
@@ -168,7 +177,7 @@
  private:
   size_t input_height_{1};
   size_t input_width_{1};
-  size_t input_channels_{1};
+  size_t output_channels_{1};
   size_t block_size_{2};
   size_t batch_size_{1};
   size_t iterations_{1};
diff --git a/test/depth-to-space.cc b/test/depth-to-space.cc
index 08089d5..c85255e 100644
--- a/test/depth-to-space.cc
+++ b/test/depth-to-space.cc
@@ -7,14 +7,14 @@
 
 #include <gtest/gtest.h>
 
-TEST(RESIZE_DEPTH_TO_SPACE_NCHW2NHWC_F32, one_column) {
+TEST(RESIZE_DEPTH_TO_SPACE_NCHW2NHWC_X32, one_column) {
     for (size_t input_height = 1; input_height <= 3; input_height++) {
       for (size_t block_size = 2; block_size <= 5; block_size++) {
-        for (size_t c = 1; c <= 7; c++) {
+        for (size_t output_channels = 1; output_channels <= 7; output_channels++) {
           DepthToSpaceOperatorTester()
             .input_size(input_height, 1)
             .block_size(block_size)
-            .input_channels(c * block_size * block_size)
+            .output_channels(output_channels)
             .iterations(3)
             .TestNCHW2NHWCxF32();
         }
@@ -22,14 +22,14 @@
     }
 }
 
-TEST(RESIZE_DEPTH_TO_SPACE_NCHW2NHWC_F32, one_row) {
+TEST(RESIZE_DEPTH_TO_SPACE_NCHW2NHWC_X32, one_row) {
     for (size_t input_width = 1; input_width <= 3; input_width++) {
       for (size_t block_size = 2; block_size <= 5; block_size++) {
-        for (size_t c = 1; c <= 3; c++) {
+        for (size_t output_channels = 1; output_channels <= 3; output_channels++) {
           DepthToSpaceOperatorTester()
             .input_size(1, input_width)
             .block_size(block_size)
-            .input_channels(c * block_size * block_size)
+            .output_channels(output_channels)
             .iterations(3)
             .TestNCHW2NHWCxF32();
         }
@@ -37,15 +37,15 @@
     }
 }
 
-TEST(RESIZE_DEPTH_TO_SPACE_NCHW2NHWC_F32, varying_input_size) {
+TEST(RESIZE_DEPTH_TO_SPACE_NCHW2NHWC_X32, varying_input_size) {
     for (size_t input_height = 1; input_height <= 7; input_height++) {
        for (size_t input_width = 1; input_width <= 7; input_width++) {
          for (size_t block_size = 2; block_size <= 5; block_size++) {
-           for (size_t c = 1; c <= 3; c++) {
+           for (size_t output_channels = 1; output_channels <= 3; output_channels++) {
              DepthToSpaceOperatorTester()
                .input_size(input_height, input_width)
                .block_size(block_size)
-               .input_channels(c * block_size * block_size)
+               .output_channels(output_channels)
                .iterations(3)
                .TestNCHW2NHWCxF32();
            }
@@ -54,16 +54,16 @@
     }
 }
 
-TEST(RESIZE_BILINEAR_NHWC_F32, varying_batch_size) {
-  for (size_t input_size = 2; input_size <= 6; input_size += 2) {
-    for (size_t block_size = 2; block_size <= 6; block_size += 2) {
-      for (size_t batch_size = 2; batch_size <= 3; batch_size++) {
-        for (size_t c = 1; c <= 3; c++) {
+TEST(RESIZE_DEPTH_TO_SPACE_NCHW2NHWC_X32, varying_batch_size) {
+  for (size_t batch_size = 2; batch_size <= 3; batch_size++) {
+    for (size_t input_size = 2; input_size <= 6; input_size += 2) {
+      for (size_t block_size = 2; block_size <= 6; block_size += 2) {
+        for (size_t output_channels = 1; output_channels <= 3; output_channels++) {
           DepthToSpaceOperatorTester()
             .batch_size(batch_size)
             .input_size(input_size, input_size)
             .block_size(block_size)
-            .input_channels(c * block_size * block_size)
+            .output_channels(output_channels)
             .iterations(3)
             .TestNCHW2NHWCxF32();
         }