Refactor MaxPool and ArgMaxPool micro-kernels

- Support input_offset argument in MaxPool and ArgMaxPool micro-kernels
- Use input_offset to make indirection buffer independent on batch size
- Simplify and auto-generate unit tests
- Use more descriptive names for micro-kernel parameters

PiperOrigin-RevId: 281447682
diff --git a/test/maxpool-microkernel-tester.h b/test/maxpool-microkernel-tester.h
index 94132ed..68d57d5 100644
--- a/test/maxpool-microkernel-tester.h
+++ b/test/maxpool-microkernel-tester.h
@@ -30,115 +30,103 @@
     Scalar,
   };
 
-  inline MaxPoolMicrokernelTester& n(size_t n) {
-    assert(n != 0);
-    this->n_ = n;
+  inline MaxPoolMicrokernelTester& output_pixels(size_t output_pixels) {
+    assert(output_pixels != 0);
+    this->output_pixels_ = output_pixels;
     return *this;
   }
 
-  inline size_t n() const {
-    return this->n_;
+  inline size_t output_pixels() const {
+    return this->output_pixels_;
   }
 
-  inline MaxPoolMicrokernelTester& s(size_t s) {
-    assert(s != 0);
-    this->s_ = s;
+  inline MaxPoolMicrokernelTester& step(size_t step) {
+    assert(step != 0);
+    this->step_ = step;
     return *this;
   }
 
-  inline size_t s() const {
-    return this->s_;
+  inline size_t step() const {
+    return this->step_;
   }
 
-  inline MaxPoolMicrokernelTester& kh(size_t kh) {
-    assert(kh != 0);
-    this->kh_ = kh;
+  inline MaxPoolMicrokernelTester& input_offset(size_t input_offset) {
+    assert(input_offset != 0);
+    this->input_offset_ = input_offset;
     return *this;
   }
 
-  inline size_t kh() const {
-    return this->kh_;
+  inline size_t input_offset() const {
+    return this->input_offset_;
   }
 
-  inline MaxPoolMicrokernelTester& kw(size_t kw) {
-    assert(kw != 0);
-    this->kw_ = kw;
+  inline MaxPoolMicrokernelTester& pooling_elements(size_t pooling_elements) {
+    assert(pooling_elements != 0);
+    this->pooling_elements_ = pooling_elements;
     return *this;
   }
 
-  inline size_t kw() const {
-    return this->kw_;
+  inline size_t pooling_elements() const {
+    return this->pooling_elements_;
   }
 
-  inline size_t ks() const {
-    return kh() * kw();
-  }
-
-  inline size_t packed_ks() const {
-    if (ks() <= mr()) {
-      return mr();
+  inline size_t packed_pooling_elements() const {
+    if (pooling_elements() <= primary_pooling_tile()) {
+      return primary_pooling_tile();
     } else {
-      return (ks() - mr()) % qr() == 0 ? ks() : ((ks() - mr()) / qr() + 1) * qr() + mr();
+      return (pooling_elements() - primary_pooling_tile()) % incremental_pooling_tile() == 0 ? pooling_elements() : ((pooling_elements() - primary_pooling_tile()) / incremental_pooling_tile() + 1) * incremental_pooling_tile() + primary_pooling_tile();
     }
   }
 
-  inline MaxPoolMicrokernelTester& mr(size_t mr) {
-    assert(mr != 0);
-    this->mr_ = mr;
+  inline MaxPoolMicrokernelTester& pooling_tile(size_t primary_tile, size_t incremental_tile) {
+    assert(primary_tile != 0);
+    this->primary_pooling_tile_ = primary_tile;
+    this->incremental_pooling_tile_ = incremental_tile;
     return *this;
   }
 
-  inline size_t mr() const {
-    return this->mr_;
-  }
-
-  inline MaxPoolMicrokernelTester& qr(size_t qr) {
-    assert(qr != 0);
-    this->qr_ = qr;
+  inline MaxPoolMicrokernelTester& primary_pooling_tile(size_t primary_pooling_tile) {
+    assert(primary_pooling_tile != 0);
+    this->primary_pooling_tile_ = primary_pooling_tile;
     return *this;
   }
 
-  inline size_t qr() const {
-    return this->qr_;
+  inline size_t primary_pooling_tile() const {
+    return this->primary_pooling_tile_;
   }
 
-  inline MaxPoolMicrokernelTester& kc(size_t kc) {
-    assert(kc != 0);
-    this->kc_ = kc;
+  inline MaxPoolMicrokernelTester& incremental_pooling_tile(size_t incremental_pooling_tile) {
+    assert(incremental_pooling_tile != 0);
+    this->incremental_pooling_tile_ = incremental_pooling_tile;
     return *this;
   }
 
-  inline size_t kc() const {
-    return this->kc_;
+  inline size_t incremental_pooling_tile() const {
+    return this->incremental_pooling_tile_;
   }
 
-  inline MaxPoolMicrokernelTester& x_stride(size_t x_stride) {
-    assert(x_stride != 0);
-    this->x_stride_ = x_stride;
+  inline MaxPoolMicrokernelTester& channels(size_t channels) {
+    assert(channels != 0);
+    this->channels_ = channels;
     return *this;
   }
 
-  inline size_t x_stride() const {
-    if (this->x_stride_ == 0) {
-      return kc();
+  inline size_t channels() const {
+    return this->channels_;
+  }
+
+  inline MaxPoolMicrokernelTester& output_stride(size_t output_stride) {
+    assert(output_stride != 0);
+    this->output_stride_ = output_stride;
+    return *this;
+  }
+
+  inline size_t output_stride() const {
+    if (this->output_stride_ == 0) {
+      return channels();
     } else {
-      assert(this->x_stride_ >= kc());
-      return this->x_stride_;
-    }
-  }
-
-  inline MaxPoolMicrokernelTester& y_stride(size_t y_stride) {
-    assert(y_stride != 0);
-    this->y_stride_ = y_stride;
-    return *this;
-  }
-
-  inline size_t y_stride() const {
-    if (this->y_stride_ == 0) {
-      return kc();
-    } else {
-      assert(this->y_stride_ >= kc());
-      return this->y_stride_;
+      assert(this->output_stride_ >= channels());
+      return this->output_stride_;
     }
   }
 
@@ -174,19 +162,23 @@
     auto rng = std::mt19937(random_device());
     auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
 
-    std::vector<const uint8_t*> indirect_x(packed_ks() + (n() * s() - 1) * kh());
-    std::vector<uint8_t> x((indirect_x.size() - 1) * x_stride() + kc() + XNN_EXTRA_BYTES / sizeof(uint8_t));
-
-    std::vector<uint8_t> y((n() - 1) * y_stride() + kc() + XNN_EXTRA_BYTES / sizeof(uint8_t));
-    std::vector<uint8_t> y_ref(n() * kc());
+    std::vector<const uint8_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
+    std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
+      indirect_input.size() * channels());
+    std::vector<uint8_t> output(XNN_EXTRA_BYTES / sizeof(uint8_t) +
+      (output_pixels() - 1) * output_stride() + channels());
+    std::vector<uint8_t> output_ref(output_pixels() * channels());
     for (size_t iteration = 0; iteration < iterations(); iteration++) {
-      std::generate(x.begin(), x.end(), std::ref(u8rng));
-      std::fill(y.begin(), y.end(), 0xA5);
+      do {
+        std::generate(input.begin(), input.end(), std::ref(u8rng));
+      } while (input.size() > 1 && *std::max_element(input.cbegin(), input.cend()) == *std::min_element(input.cbegin(), input.cend()));
+      std::fill(output.begin(), output.end(), 0xA5);
 
-      for (size_t i = 0; i < indirect_x.size(); i++) {
-        indirect_x[i] = x.data() + i * x_stride();
+      for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
+        indirect_input[i] = input.data() + i * channels() - input_offset();
       }
-      std::shuffle(indirect_x.begin(), indirect_x.end(), rng);
+      std::shuffle(indirect_input.begin(),
+        indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
 
       // Prepare output parameters.
       xnn_u8_output_params output_params = { };
@@ -200,32 +192,40 @@
       }
 
       // Compute reference results.
-      for (size_t i = 0; i < n(); i++) {
-        for (size_t k = 0; k < kc(); k++) {
+      for (size_t x = 0; x < output_pixels(); x++) {
+        for (size_t c = 0; c < channels(); c++) {
           uint8_t max_value = 0;
-          for (size_t j = 0; j < ks(); j++) {
-            max_value = std::max(max_value,
-              indirect_x[i * s() * kh() + j][k]);
+          for (size_t p = 0; p < pooling_elements(); p++) {
+            max_value = std::max(max_value, indirect_input[x * step() + p][c + input_offset()]);
           }
           max_value = std::min(max_value, qmax());
           max_value = std::max(max_value, qmin());
-          y_ref[i * kc() + k] = max_value;
+          output_ref[x * channels() + c] = max_value;
         }
       }
 
       // Call optimized micro-kernel.
-      maxpool(n(), ks(), kc(),
-        indirect_x.data(), y.data(),
-        (kh() * s() - packed_ks()) * sizeof(void*),
-        (y_stride() - kc()) * sizeof(uint8_t),
+      maxpool(output_pixels(), pooling_elements(), channels(),
+        indirect_input.data(), input_offset() * sizeof(uint8_t), output.data(),
+        (step() - packed_pooling_elements()) * sizeof(void*),
+        (output_stride() - channels()) * sizeof(uint8_t),
         &output_params);
 
       // Verify results.
-      for (size_t i = 0; i < n(); i++) {
-        for (size_t k = 0; k < kc(); k++) {
-          ASSERT_EQ(uint32_t(y_ref[i * kc() + k]), uint32_t(y[i * y_stride() + k]))
-            << "at pixel " << i << ", channel " << k << ", n = " << n()
-            << ", ks = " << kh() << "x" << kw() << " (" << ks() << "), kc = " << kc();
+      for (size_t x = 0; x < output_pixels(); x++) {
+        for (size_t c = 0; c < channels(); c++) {
+          ASSERT_GE(uint32_t(output[x * output_stride() + c]), uint32_t(qmin()))
+            << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
+            << ", pooling elements = " << pooling_elements() << ", step = " << step()
+            << ", input offset = " << input_offset();
+          ASSERT_LE(uint32_t(output[x * output_stride() + c]), uint32_t(qmax()))
+            << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
+            << ", pooling elements = " << pooling_elements() << ", step = " << step()
+            << ", input offset = " << input_offset();
+          ASSERT_EQ(uint32_t(output_ref[x * channels() + c]), uint32_t(output[x * output_stride() + c]))
+            << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
+            << ", pooling elements = " << pooling_elements() << ", step = " << step()
+            << ", input offset = " << input_offset();
         }
       }
     }
@@ -236,87 +236,94 @@
     auto rng = std::mt19937(random_device());
     auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 1.0f), rng);
 
-    std::vector<const float*> indirect_x(packed_ks() + (n() * s() - 1) * kh());
-    std::vector<float> x((indirect_x.size() - 1) * x_stride() + kc() + XNN_EXTRA_BYTES / sizeof(float));
-
-    std::vector<float> y((n() - 1) * y_stride() + kc() + XNN_EXTRA_BYTES / sizeof(float));
-    std::vector<float> y_ref(n() * kc());
+    std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
+    std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
+      ((output_pixels() - 1) * step() + pooling_elements()) * channels());
+    std::vector<float> output(XNN_EXTRA_BYTES / sizeof(float) +
+      (output_pixels() - 1) * output_stride() + channels());
+    std::vector<float> output_ref(output_pixels() * channels());
     for (size_t iteration = 0; iteration < iterations(); iteration++) {
-      std::generate(x.begin(), x.end(), std::ref(f32rng));
-      std::fill(y.begin(), y.end(), nanf(""));
+      std::generate(input.begin(), input.end(), std::ref(f32rng));
+      std::fill(output.begin(), output.end(), nanf(""));
 
-      for (size_t i = 0; i < indirect_x.size(); i++) {
-        indirect_x[i] = x.data() + i * x_stride();
+      for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
+        indirect_input[i] = input.data() + i * channels() - input_offset();
       }
-      std::shuffle(indirect_x.begin(), indirect_x.end(), rng);
+      std::shuffle(indirect_input.begin(),
+        indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
 
       // Compute reference results, without clamping.
-      for (size_t i = 0; i < n(); i++) {
-        for (size_t k = 0; k < kc(); k++) {
+      for (size_t x = 0; x < output_pixels(); x++) {
+        for (size_t c = 0; c < channels(); c++) {
           float max_value = -std::numeric_limits<float>::infinity();
-          for (size_t j = 0; j < ks(); j++) {
-            max_value = std::max(max_value,
-              indirect_x[i * s() * kh() + j][k]);
+          for (size_t p = 0; p < pooling_elements(); p++) {
+            max_value = std::max(max_value, indirect_input[x * step() + p][c + input_offset()]);
           }
-          y_ref[i * kc() + k] = max_value;
+          output_ref[x * channels() + c] = max_value;
         }
       }
 
       // Compute clamping parameters.
-      const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend());
-      const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend());
+      const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
+      const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
       const float accumulated_range = accumulated_max - accumulated_min;
-      const float y_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
-      const float y_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
+      const float output_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
+      const float output_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
 
 
       // Prepare output parameters.
       xnn_f32_output_params output_params = { };
       switch (variant) {
         case Variant::Native:
-          output_params = xnn_init_f32_output_params(y_min, y_max);
+          output_params = xnn_init_f32_output_params(output_min, output_max);
           break;
         case Variant::Scalar:
-          output_params = xnn_init_scalar_f32_output_params(y_min, y_max);
+          output_params = xnn_init_scalar_f32_output_params(output_min, output_max);
           break;
       }
 
       // Clamp reference results.
-      for (size_t i = 0; i < n(); i++) {
-        for (size_t k = 0; k < kc(); k++) {
-          y_ref[i * kc() + k] = std::max(std::min(y_ref[i * kc() + k], y_max), y_min);
-        }
+      for (float& output_value : output_ref) {
+        output_value = std::max(std::min(output_value, output_max), output_min);
       }
 
       // Call optimized micro-kernel.
-      maxpool(n(), ks(), kc(),
-        indirect_x.data(), y.data(),
-        (kh() * s() - packed_ks()) * sizeof(void*),
-        (y_stride() - kc()) * sizeof(float),
+      maxpool(output_pixels(), pooling_elements(), channels(),
+        indirect_input.data(), input_offset() * sizeof(float), output.data(),
+        (step() - packed_pooling_elements()) * sizeof(void*),
+        (output_stride() - channels()) * sizeof(float),
         &output_params);
 
       // Verify results.
-      for (size_t i = 0; i < n(); i++) {
-        for (size_t k = 0; k < kc(); k++) {
-          ASSERT_EQ(y_ref[i * kc() + k], y[i * y_stride() + k])
-            << "at pixel " << i << ", channel " << k << ", n = " << n()
-            << ", ks = " << kh() << "x" << kw() << " (" << ks() << "), kc = " << kc();
+      for (size_t x = 0; x < output_pixels(); x++) {
+        for (size_t c = 0; c < channels(); c++) {
+          ASSERT_GE(output[x * output_stride() + c], output_min)
+            << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
+            << ", pooling elements = " << pooling_elements() << ", step = " << step()
+            << ", input offset = " << input_offset();
+          ASSERT_LE(output[x * output_stride() + c], output_max)
+            << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
+            << ", pooling elements = " << pooling_elements() << ", step = " << step()
+            << ", input offset = " << input_offset();
+          ASSERT_EQ(output_ref[x * channels() + c], output[x * output_stride() + c])
+            << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
+            << ", pooling elements = " << pooling_elements() << ", step = " << step()
+            << ", input offset = " << input_offset();
         }
       }
     }
   }
 
  private:
-  size_t n_{1};
-  size_t s_{1};
-  size_t kh_{1};
-  size_t kw_{1};
-  size_t mr_{1};
-  size_t qr_{1};
-  size_t kc_{1};
-  size_t x_stride_{0};
-  size_t y_stride_{0};
+  size_t output_pixels_{1};
+  size_t pooling_elements_{1};
+  size_t channels_{1};
+  size_t input_offset_{0};
+  size_t step_{1};
+  size_t primary_pooling_tile_{1};
+  size_t incremental_pooling_tile_{1};
+  size_t output_stride_{0};
   uint8_t qmin_{0};
   uint8_t qmax_{255};
-  size_t iterations_{15};
+  size_t iterations_{3};
 };