blob: 168a0bb04c24befa710a7dc5dbce91c0a43016f9 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <limits>
16#include <random>
17#include <vector>
18
19#include <xnnpack.h>
20
21
22class ArgmaxPoolingOperatorTester {
23 public:
Marat Dukhane4b8e572020-05-05 11:35:02 -070024 inline ArgmaxPoolingOperatorTester& padding_tf_same(bool padding_same) {
25 if (padding_same) {
26 assert(padding_top() == 0);
27 assert(padding_left() == 0);
28 assert(padding_bottom() == 0);
29 assert(padding_right() == 0);
30 }
31 this->padding_tf_same_ = padding_same;
32 return *this;
33 }
34
35 inline bool padding_tf_same() const {
36 return this->padding_tf_same_;
37 }
38
XNNPACK Teamb455b122019-09-27 18:10:33 -070039 inline ArgmaxPoolingOperatorTester& padding(uint32_t padding) {
Marat Dukhane4b8e572020-05-05 11:35:02 -070040 assert(!padding_tf_same());
XNNPACK Teamb455b122019-09-27 18:10:33 -070041 this->padding_top_ = padding;
42 this->padding_right_ = padding;
43 this->padding_bottom_ = padding;
44 this->padding_left_ = padding;
45 return *this;
46 }
47
Marat Dukhane4b8e572020-05-05 11:35:02 -070048 inline ArgmaxPoolingOperatorTester& padding(uint32_t padding_height, uint32_t padding_width) {
49 assert(!padding_tf_same());
50 this->padding_top_ = padding_height;
51 this->padding_right_ = padding_width;
52 this->padding_bottom_ = padding_height;
53 this->padding_left_ = padding_width;
54 return *this;
55 }
56
XNNPACK Teamb455b122019-09-27 18:10:33 -070057 inline ArgmaxPoolingOperatorTester& padding_height(uint32_t padding_height) {
Marat Dukhane4b8e572020-05-05 11:35:02 -070058 assert(!padding_tf_same());
XNNPACK Teamb455b122019-09-27 18:10:33 -070059 this->padding_top_ = padding_height;
60 this->padding_bottom_ = padding_height;
61 return *this;
62 }
63
64 inline ArgmaxPoolingOperatorTester& padding_width(uint32_t padding_width) {
Marat Dukhane4b8e572020-05-05 11:35:02 -070065 assert(!padding_tf_same());
XNNPACK Teamb455b122019-09-27 18:10:33 -070066 this->padding_right_ = padding_width;
67 this->padding_left_ = padding_width;
68 return *this;
69 }
70
71 inline ArgmaxPoolingOperatorTester& padding_top(uint32_t padding_top) {
Marat Dukhane4b8e572020-05-05 11:35:02 -070072 assert(!padding_tf_same());
XNNPACK Teamb455b122019-09-27 18:10:33 -070073 this->padding_top_ = padding_top;
74 return *this;
75 }
76
77 inline uint32_t padding_top() const {
Marat Dukhane4b8e572020-05-05 11:35:02 -070078 if (padding_tf_same()) {
79 const uint32_t total_padding_height = output_height() * pooling_height() - input_height();
80 return total_padding_height / 2;
81 } else {
82 return this->padding_top_;
83 }
XNNPACK Teamb455b122019-09-27 18:10:33 -070084 }
85
86 inline ArgmaxPoolingOperatorTester& padding_left(uint32_t padding_left) {
Marat Dukhane4b8e572020-05-05 11:35:02 -070087 assert(!padding_tf_same());
XNNPACK Teamb455b122019-09-27 18:10:33 -070088 this->padding_left_ = padding_left;
89 return *this;
90 }
91
92 inline uint32_t padding_left() const {
Marat Dukhane4b8e572020-05-05 11:35:02 -070093 if (padding_tf_same()) {
94 const uint32_t total_padding_width = output_width() * pooling_width() - input_width();
95 return total_padding_width / 2;
96 } else {
97 return this->padding_left_;
98 }
99 }
100
101 inline ArgmaxPoolingOperatorTester& padding_bottom(uint32_t padding_bottom) {
102 assert(!padding_tf_same());
103 this->padding_bottom_ = padding_bottom;
104 return *this;
105 }
106
107 inline uint32_t padding_bottom() const {
108 if (padding_tf_same()) {
109 const uint32_t total_padding_height = output_height() * pooling_height() - input_height();
110 return total_padding_height - total_padding_height / 2;
111 } else {
112 return this->padding_bottom_;
113 }
114 }
115
116 inline ArgmaxPoolingOperatorTester& padding_right(uint32_t padding_right) {
117 assert(!padding_tf_same());
118 this->padding_right_ = padding_right;
119 return *this;
120 }
121
122 inline uint32_t padding_right() const {
123 if (padding_tf_same()) {
124 const uint32_t total_padding_width = output_width() * pooling_width() - input_width();
125 return total_padding_width - total_padding_width / 2;
126 } else {
127 return this->padding_right_;
128 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700129 }
130
131 inline ArgmaxPoolingOperatorTester& input_size(size_t input_height, size_t input_width) {
132 assert(input_height >= 1);
133 assert(input_width >= 1);
134 this->input_height_ = input_height;
135 this->input_width_ = input_width;
136 return *this;
137 }
138
139 inline ArgmaxPoolingOperatorTester& input_height(size_t input_height) {
140 assert(input_height >= 1);
141 this->input_height_ = input_height;
142 return *this;
143 }
144
145 inline size_t input_height() const {
146 return this->input_height_;
147 }
148
149 inline ArgmaxPoolingOperatorTester& input_width(size_t input_width) {
150 assert(input_width >= 1);
151 this->input_width_ = input_width;
152 return *this;
153 }
154
155 inline size_t input_width() const {
156 return this->input_width_;
157 }
158
159 inline ArgmaxPoolingOperatorTester& channels(size_t channels) {
160 assert(channels != 0);
161 this->channels_ = channels;
162 return *this;
163 }
164
165 inline size_t channels() const {
166 return this->channels_;
167 }
168
169 inline ArgmaxPoolingOperatorTester& batch_size(size_t batch_size) {
170 assert(batch_size != 0);
171 this->batch_size_ = batch_size;
172 return *this;
173 }
174
175 inline size_t batch_size() const {
176 return this->batch_size_;
177 }
178
179 inline ArgmaxPoolingOperatorTester& pooling_size(uint32_t pooling_size) {
180 assert(pooling_size >= 1);
181 this->pooling_height_ = pooling_size;
182 this->pooling_width_ = pooling_size;
183 return *this;
184 }
185
186 inline ArgmaxPoolingOperatorTester& pooling_size(uint32_t pooling_height, uint32_t pooling_width) {
187 assert(pooling_height >= 1);
188 assert(pooling_width >= 1);
189 this->pooling_height_ = pooling_height;
190 this->pooling_width_ = pooling_width;
191 return *this;
192 }
193
194 inline ArgmaxPoolingOperatorTester& pooling_height(uint32_t pooling_height) {
195 assert(pooling_height >= 1);
196 this->pooling_height_ = pooling_height;
197 return *this;
198 }
199
200 inline uint32_t pooling_height() const {
201 return this->pooling_height_;
202 }
203
204 inline ArgmaxPoolingOperatorTester& pooling_width(uint32_t pooling_width) {
205 assert(pooling_width >= 1);
206 this->pooling_width_ = pooling_width;
207 return *this;
208 }
209
210 inline uint32_t pooling_width() const {
211 return this->pooling_width_;
212 }
213
214 inline size_t output_height() const {
Marat Dukhane4b8e572020-05-05 11:35:02 -0700215 if (padding_tf_same()) {
216 return (input_height() + pooling_height() - 1) / pooling_height();
217 } else {
218 const size_t padded_input_height = padding_top() + input_height() + padding_bottom();
219 return padded_input_height / pooling_height();
220 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700221 }
222
223 inline size_t output_width() const {
Marat Dukhane4b8e572020-05-05 11:35:02 -0700224 if (padding_tf_same()) {
225 return (input_width() + pooling_width() - 1) / pooling_width();
226 } else {
227 const size_t padded_input_width = padding_left() + input_width() + padding_right();
228 return padded_input_width / pooling_width();
229 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700230 }
231
232 inline ArgmaxPoolingOperatorTester& input_pixel_stride(size_t input_pixel_stride) {
233 assert(input_pixel_stride != 0);
234 this->input_pixel_stride_ = input_pixel_stride;
235 return *this;
236 }
237
238 inline size_t input_pixel_stride() const {
239 if (this->input_pixel_stride_ == 0) {
240 return channels();
241 } else {
242 assert(this->input_pixel_stride_ >= channels());
243 return this->input_pixel_stride_;
244 }
245 }
246
247 inline ArgmaxPoolingOperatorTester& output_pixel_stride(size_t output_pixel_stride) {
248 assert(output_pixel_stride != 0);
249 this->output_pixel_stride_ = output_pixel_stride;
250 return *this;
251 }
252
253 inline size_t output_pixel_stride() const {
254 if (this->output_pixel_stride_ == 0) {
255 return channels();
256 } else {
257 assert(this->output_pixel_stride_ >= channels());
258 return this->output_pixel_stride_;
259 }
260 }
261
262 inline ArgmaxPoolingOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) {
263 assert(next_input_height >= 1);
264 assert(next_input_width >= 1);
265 this->next_input_height_ = next_input_height;
266 this->next_input_width_ = next_input_width;
267 return *this;
268 }
269
270 inline ArgmaxPoolingOperatorTester& next_input_height(uint32_t next_input_height) {
271 assert(next_input_height >= 1);
272 this->next_input_height_ = next_input_height;
273 return *this;
274 }
275
276 inline uint32_t next_input_height() const {
277 if (this->next_input_height_ == 0) {
278 return input_height();
279 } else {
280 return this->next_input_height_;
281 }
282 }
283
284 inline ArgmaxPoolingOperatorTester& next_input_width(uint32_t next_input_width) {
285 assert(next_input_width >= 1);
286 this->next_input_width_ = next_input_width;
287 return *this;
288 }
289
290 inline uint32_t next_input_width() const {
291 if (this->next_input_width_ == 0) {
292 return input_width();
293 } else {
294 return this->next_input_width_;
295 }
296 }
297
298 inline size_t next_output_height() const {
299 const size_t padded_next_input_height = padding_top() + next_input_height() + padding_bottom();
300 return padded_next_input_height / pooling_height();
301 }
302
303 inline size_t next_output_width() const {
304 const size_t padded_next_input_width = padding_left() + next_input_width() + padding_right();
305 return padded_next_input_width / pooling_width();
306 }
307
308 inline ArgmaxPoolingOperatorTester& next_batch_size(size_t next_batch_size) {
309 assert(next_batch_size >= 1);
310 this->next_batch_size_ = next_batch_size;
311 return *this;
312 }
313
314 inline size_t next_batch_size() const {
315 if (this->next_batch_size_ == 0) {
316 return batch_size();
317 } else {
318 return this->next_batch_size_;
319 }
320 }
321
XNNPACK Teamb455b122019-09-27 18:10:33 -0700322 inline ArgmaxPoolingOperatorTester& iterations(size_t iterations) {
323 this->iterations_ = iterations;
324 return *this;
325 }
326
327 inline size_t iterations() const {
328 return this->iterations_;
329 }
330
331 void TestF32() const {
332 std::random_device random_device;
333 auto rng = std::mt19937(random_device());
334 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 1.0f), rng);
335
336 std::vector<float> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
337 std::vector<float> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels());
338 std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels());
339 std::vector<uint32_t> index(batch_size() * output_height() * output_width() * channels());
340 std::vector<uint32_t> index_ref(batch_size() * output_height() * output_width() * channels());
341 for (size_t iteration = 0; iteration < iterations(); iteration++) {
342 std::generate(input.begin(), input.end(), std::ref(f32rng));
343 std::fill(output.begin(), output.end(), nanf(""));
344
345 // Compute reference results, without clamping.
346 for (size_t i = 0; i < batch_size(); i++) {
347 for (size_t oy = 0; oy < output_height(); oy++) {
348 for (size_t ox = 0; ox < output_width(); ox++) {
349 for (size_t c = 0; c < channels(); c++) {
350 const size_t iy_top_left = std::max<size_t>(oy * pooling_height(), padding_top()) - padding_top();
351 const size_t ix_top_left = std::max<size_t>(ox * pooling_width(), padding_left()) - padding_left();
352 float max_value =
353 input[((i * input_height() + iy_top_left) * input_width() + ix_top_left) * input_pixel_stride() + c];
354 uint32_t max_index = 0;
355 for (size_t py = 0; py < pooling_height(); py++) {
356 const size_t iy = oy * pooling_height() + py - padding_top();
357 for (size_t px = 0; px < pooling_width(); px++) {
358 const size_t ix = ox * pooling_width() + px - padding_left();
359 if (ix < input_width() && iy < input_height()) {
360 const float value = input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c];
361 if (value > max_value) {
362 max_value = value;
363 max_index = uint32_t(px * pooling_height() + py);
364 }
365 }
366 }
367 }
368 output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value;
369 index_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_index;
370 }
371 }
372 }
373 }
374
XNNPACK Teamb455b122019-09-27 18:10:33 -0700375 // Create, setup, run, and destroy Argmax Pooling operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800376 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700377 xnn_operator_t argmax_pooling_op = nullptr;
378
379 ASSERT_EQ(xnn_status_success,
380 xnn_create_argmax_pooling2d_nhwc_f32(
Marat Dukhane4b8e572020-05-05 11:35:02 -0700381 padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(),
382 padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700383 pooling_height(), pooling_width(),
384 channels(), input_pixel_stride(), output_pixel_stride(),
Marat Dukhane4b8e572020-05-05 11:35:02 -0700385 padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0,
386 &argmax_pooling_op));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700387 ASSERT_NE(nullptr, argmax_pooling_op);
388
389 // Smart pointer to automatically delete argmax_pooling_op.
390 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_argmax_pooling_op(argmax_pooling_op, xnn_delete_operator);
391
392 ASSERT_EQ(xnn_status_success,
393 xnn_setup_argmax_pooling2d_nhwc_f32(
394 argmax_pooling_op,
395 batch_size(), input_height(), input_width(),
396 input.data(), output.data(), index.data(),
397 nullptr /* thread pool */));
398
399 ASSERT_EQ(xnn_status_success,
400 xnn_run_operator(argmax_pooling_op, nullptr /* thread pool */));
401
402 // Verify results.
403 for (size_t i = 0; i < batch_size(); i++) {
404 for (size_t y = 0; y < output_height(); y++) {
405 for (size_t x = 0; x < output_width(); x++) {
406 for (size_t c = 0; c < channels(); c++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700407 ASSERT_EQ(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c],
408 output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]) <<
409 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
410 ASSERT_EQ(index_ref[((i * output_height() + y) * output_width() + x) * channels() + c],
411 index[((i * output_height() + y) * output_width() + x) * channels() + c]) <<
412 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
413 }
414 }
415 }
416 }
417 }
418 }
419
420 void TestSetupF32() const {
421 std::random_device random_device;
422 auto rng = std::mt19937(random_device());
423 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 1.0f), rng);
424
425 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + std::max(
426 (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(),
427 (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels()));
428 std::vector<float> output(std::max(
429 (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(),
430 (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels()));
431 std::vector<uint32_t> index(std::max(
432 batch_size() * output_height() * output_width() * channels(),
433 next_batch_size() * next_output_height() * next_output_width() * channels()));
434 std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels());
435 std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels());
436 std::vector<uint32_t> index_ref(batch_size() * output_height() * output_width() * channels());
437 std::vector<uint32_t> next_index_ref(next_batch_size() * next_output_height() * next_output_width() * channels());
438 for (size_t iteration = 0; iteration < iterations(); iteration++) {
439 std::generate(input.begin(), input.end(), std::ref(f32rng));
440 std::fill(output.begin(), output.end(), nanf(""));
441
442 // Compute reference results, without clamping.
443 for (size_t i = 0; i < batch_size(); i++) {
444 for (size_t oy = 0; oy < output_height(); oy++) {
445 for (size_t ox = 0; ox < output_width(); ox++) {
446 for (size_t c = 0; c < channels(); c++) {
447 const size_t iy_top_left = std::max<size_t>(oy * pooling_height(), padding_top()) - padding_top();
448 const size_t ix_top_left = std::max<size_t>(ox * pooling_width(), padding_left()) - padding_left();
449 float max_value =
450 input[((i * input_height() + iy_top_left) * input_width() + ix_top_left) * input_pixel_stride() + c];
451 uint32_t max_index = 0;
452 for (size_t py = 0; py < pooling_height(); py++) {
453 const size_t iy = oy * pooling_height() + py - padding_top();
454 for (size_t px = 0; px < pooling_width(); px++) {
455 const size_t ix = ox * pooling_width() + px - padding_left();
456 if (ix < input_width() && iy < input_height()) {
457 const float value = input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c];
458 if (value > max_value) {
459 max_value = value;
460 max_index = uint32_t(px * pooling_height() + py);
461 }
462 }
463 }
464 }
465 output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value;
466 index_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_index;
467 }
468 }
469 }
470 }
471
XNNPACK Teamb455b122019-09-27 18:10:33 -0700472 // Create, setup, and run Argmax Pooling operator once.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800473 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700474 xnn_operator_t argmax_pooling_op = nullptr;
475
476 ASSERT_EQ(xnn_status_success,
477 xnn_create_argmax_pooling2d_nhwc_f32(
478 padding_top(), padding_right(), padding_bottom(), padding_left(),
479 pooling_height(), pooling_width(),
480 channels(), input_pixel_stride(), output_pixel_stride(),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700481 0, &argmax_pooling_op));
482 ASSERT_NE(nullptr, argmax_pooling_op);
483
484 ASSERT_EQ(xnn_status_success,
485 xnn_setup_argmax_pooling2d_nhwc_f32(
486 argmax_pooling_op,
487 batch_size(), input_height(), input_width(),
488 input.data(), output.data(), index.data(),
489 nullptr /* thread pool */));
490
491 ASSERT_EQ(xnn_status_success,
492 xnn_run_operator(argmax_pooling_op, nullptr /* thread pool */));
493
494 // Verify results of the first run.
495 for (size_t i = 0; i < batch_size(); i++) {
496 for (size_t y = 0; y < output_height(); y++) {
497 for (size_t x = 0; x < output_width(); x++) {
498 for (size_t c = 0; c < channels(); c++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700499 ASSERT_EQ(
500 output_ref[((i * output_height() + y) * output_width() + x) * channels() + c],
501 output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])
502 << "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
503 ASSERT_EQ(
504 index_ref[((i * output_height() + y) * output_width() + x) * channels() + c],
505 index[((i * output_height() + y) * output_width() + x) * channels() + c])
506 << "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
507 }
508 }
509 }
510 }
511
512 // Re-generate data for the second run.
513 std::generate(input.begin(), input.end(), std::ref(f32rng));
514 std::fill(output.begin(), output.end(), 0xA5);
515
516 // Compute reference results for the second run, including clamping.
517 for (size_t i = 0; i < next_batch_size(); i++) {
518 for (size_t oy = 0; oy < next_output_height(); oy++) {
519 for (size_t ox = 0; ox < next_output_width(); ox++) {
520 for (size_t c = 0; c < channels(); c++) {
521 const size_t iy_top_left = std::max<size_t>(oy * pooling_height(), padding_top()) - padding_top();
522 const size_t ix_top_left = std::max<size_t>(ox * pooling_width(), padding_left()) - padding_left();
523 float max_value =
524 input[((i * next_input_height() + iy_top_left) * next_input_width() + ix_top_left) * input_pixel_stride() + c];
525 uint32_t max_index = 0;
526 for (size_t py = 0; py < pooling_height(); py++) {
527 const size_t iy = oy * pooling_height() + py - padding_top();
528 for (size_t px = 0; px < pooling_width(); px++) {
529 const size_t ix = ox * pooling_width() + px - padding_left();
530 if (ix < next_input_width() && iy < next_input_height()) {
531 const float value = input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c];
532 if (value > max_value) {
533 max_value = value;
534 max_index = uint32_t(px * pooling_height() + py);
535 }
536 }
537 }
538 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700539 next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = max_value;
540 next_index_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = max_index;
541 }
542 }
543 }
544 }
545
546 // Setup and run Argmax Pooling operator the second time, and destroy the operator.
547 ASSERT_EQ(xnn_status_success,
548 xnn_setup_argmax_pooling2d_nhwc_f32(
549 argmax_pooling_op,
550 next_batch_size(), next_input_height(), next_input_width(),
551 input.data(), output.data(), index.data(),
552 nullptr /* thread pool */));
553
554 ASSERT_EQ(xnn_status_success,
555 xnn_run_operator(argmax_pooling_op, nullptr /* thread pool */));
556
557 ASSERT_EQ(xnn_status_success,
558 xnn_delete_operator(argmax_pooling_op));
559 argmax_pooling_op = nullptr;
560
561 // Verify results of the second run.
562 for (size_t i = 0; i < next_batch_size(); i++) {
563 for (size_t y = 0; y < next_output_height(); y++) {
564 for (size_t x = 0; x < next_output_width(); x++) {
565 for (size_t c = 0; c < channels(); c++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700566 ASSERT_EQ(
567 next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c],
568 output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c])
569 << "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
570 ASSERT_EQ(
571 next_index_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c],
572 index[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c])
573 << "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
574 }
575 }
576 }
577 }
578 }
579 }
580
581 private:
582 uint32_t padding_top_{0};
583 uint32_t padding_right_{0};
584 uint32_t padding_bottom_{0};
585 uint32_t padding_left_{0};
Marat Dukhane4b8e572020-05-05 11:35:02 -0700586 bool padding_tf_same_{false};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700587 size_t input_height_{1};
588 size_t input_width_{1};
589 size_t channels_{1};
590 size_t batch_size_{1};
591 size_t input_pixel_stride_{0};
592 size_t output_pixel_stride_{0};
593 uint32_t pooling_height_{1};
594 uint32_t pooling_width_{1};
595 size_t next_input_height_{0};
596 size_t next_input_width_{0};
597 size_t next_batch_size_{0};
598 uint8_t qmin_{0};
599 uint8_t qmax_{255};
600 size_t iterations_{1};
601};