blob: 25b747458b42a75c87c368257c49fd9d8cb930a4 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <cstddef>
14#include <cstdlib>
15#include <algorithm>
16#include <cmath>
17#include <functional>
Marat Dukhan5ce30d92020-04-14 03:31:26 -070018#include <limits>
XNNPACK Teamb455b122019-09-27 18:10:33 -070019#include <random>
20#include <vector>
21
Frank Barchard7e2cbb02020-06-12 01:22:13 -070022#include <fp16.h>
23
XNNPACK Teamb455b122019-09-27 18:10:33 -070024#include <xnnpack.h>
25
26
27class GlobalAveragePoolingOperatorTester {
28 public:
29 inline GlobalAveragePoolingOperatorTester& channels(size_t channels) {
30 assert(channels != 0);
31 this->channels_ = channels;
32 return *this;
33 }
34
35 inline size_t channels() const {
36 return this->channels_;
37 }
38
39 inline GlobalAveragePoolingOperatorTester& width(size_t width) {
40 assert(width != 0);
41 this->width_ = width;
42 return *this;
43 }
44
45 inline size_t width() const {
46 return this->width_;
47 }
48
49 inline GlobalAveragePoolingOperatorTester& input_stride(size_t input_stride) {
50 assert(input_stride != 0);
51 this->input_stride_ = input_stride;
52 return *this;
53 }
54
55 inline size_t input_stride() const {
56 if (this->input_stride_ == 0) {
57 return channels();
58 } else {
59 assert(this->input_stride_ >= channels());
60 return this->input_stride_;
61 }
62 }
63
64 inline GlobalAveragePoolingOperatorTester& output_stride(size_t output_stride) {
65 assert(output_stride != 0);
66 this->output_stride_ = output_stride;
67 return *this;
68 }
69
70 inline size_t output_stride() const {
71 if (this->output_stride_ == 0) {
72 return channels();
73 } else {
74 assert(this->output_stride_ >= channels());
75 return this->output_stride_;
76 }
77 }
78
79 inline GlobalAveragePoolingOperatorTester& batch_size(size_t batch_size) {
80 assert(batch_size != 0);
81 this->batch_size_ = batch_size;
82 return *this;
83 }
84
85 inline size_t batch_size() const {
86 return this->batch_size_;
87 }
88
89 inline GlobalAveragePoolingOperatorTester& input_scale(float input_scale) {
90 assert(input_scale > 0.0f);
91 assert(std::isnormal(input_scale));
92 this->input_scale_ = input_scale;
93 return *this;
94 }
95
96 inline float input_scale() const {
97 return this->input_scale_;
98 }
99
100 inline GlobalAveragePoolingOperatorTester& input_zero_point(uint8_t input_zero_point) {
101 this->input_zero_point_ = input_zero_point;
102 return *this;
103 }
104
105 inline uint8_t input_zero_point() const {
106 return this->input_zero_point_;
107 }
108
109 inline GlobalAveragePoolingOperatorTester& output_scale(float output_scale) {
110 assert(output_scale > 0.0f);
111 assert(std::isnormal(output_scale));
112 this->output_scale_ = output_scale;
113 return *this;
114 }
115
116 inline float output_scale() const {
117 return this->output_scale_;
118 }
119
120 inline GlobalAveragePoolingOperatorTester& output_zero_point(uint8_t output_zero_point) {
121 this->output_zero_point_ = output_zero_point;
122 return *this;
123 }
124
125 inline uint8_t output_zero_point() const {
126 return this->output_zero_point_;
127 }
128
129 inline GlobalAveragePoolingOperatorTester& qmin(uint8_t qmin) {
130 this->qmin_ = qmin;
131 return *this;
132 }
133
134 inline uint8_t qmin() const {
135 return this->qmin_;
136 }
137
138 inline GlobalAveragePoolingOperatorTester& qmax(uint8_t qmax) {
139 this->qmax_ = qmax;
140 return *this;
141 }
142
143 inline uint8_t qmax() const {
144 return this->qmax_;
145 }
146
147 inline GlobalAveragePoolingOperatorTester& iterations(size_t iterations) {
148 this->iterations_ = iterations;
149 return *this;
150 }
151
152 inline size_t iterations() const {
153 return this->iterations_;
154 }
155
Marat Dukhan08b7a972020-07-14 18:17:29 -0700156 void TestNWCxQU8() const {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700157 std::random_device random_device;
158 auto rng = std::mt19937(random_device());
Marat Dukhan5ce30d92020-04-14 03:31:26 -0700159 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700160
161 std::vector<uint8_t> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
162 std::vector<uint8_t> output(batch_size() * output_stride());
163 std::vector<float> output_ref(batch_size() * channels());
164 for (size_t iteration = 0; iteration < iterations(); iteration++) {
165 std::generate(input.begin(), input.end(), std::ref(u8rng));
166 std::fill(output.begin(), output.end(), 0xA5);
167
168 // Compute reference results.
169 const double scale = double(input_scale()) / (double(width()) * double(output_scale()));
170 for (size_t i = 0; i < batch_size(); i++) {
171 for (size_t j = 0; j < channels(); j++) {
172 double acc = 0.0f;
173 for (size_t k = 0; k < width(); k++) {
174 acc += double(int32_t(input[(i * width() + k) * input_stride() + j]) - int32_t(input_zero_point()));
175 }
176 output_ref[i * channels() + j] = float(acc * scale + double(output_zero_point()));
177 output_ref[i * channels() + j] = std::min<float>(output_ref[i * channels() + j], float(qmax()));
178 output_ref[i * channels() + j] = std::max<float>(output_ref[i * channels() + j], float(qmin()));
179 }
180 }
181
182 // Create, setup, run, and destroy Global Average Pooling operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800183 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700184 xnn_operator_t global_average_pooling_op = nullptr;
185
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700186 xnn_status status = xnn_create_global_average_pooling_nwc_qu8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700187 channels(), input_stride(), output_stride(),
188 input_zero_point(), input_scale(),
189 output_zero_point(), output_scale(),
190 qmin(), qmax(),
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700191 0, &global_average_pooling_op);
192 if (status == xnn_status_unsupported_hardware) {
193 GTEST_SKIP();
194 }
195 ASSERT_EQ(xnn_status_success, status);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700196 ASSERT_NE(nullptr, global_average_pooling_op);
197
198 // Smart pointer to automatically delete global_average_pooling_op.
199 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
200
201 ASSERT_EQ(xnn_status_success,
Marat Dukhan08b7a972020-07-14 18:17:29 -0700202 xnn_setup_global_average_pooling_nwc_qu8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700203 global_average_pooling_op,
204 batch_size(), width(),
205 input.data(), output.data(),
206 nullptr /* thread pool */));
207
208 ASSERT_EQ(xnn_status_success,
209 xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
210
211 // Verify results.
212 for (size_t i = 0; i < batch_size(); i++) {
213 for (size_t c = 0; c < channels(); c++) {
214 ASSERT_LE(uint32_t(output[i * output_stride() + c]), uint32_t(qmax()));
215 ASSERT_GE(uint32_t(output[i * output_stride() + c]), uint32_t(qmin()));
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700216 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.80f)
217 << "at batch index " << i << " / " << batch_size()
218 << ", channel " << c << " / " << channels();
219 }
220 }
221 }
222 }
223
224 void TestNWCxQS8() const {
225 std::random_device random_device;
226 auto rng = std::mt19937(random_device());
227 auto i8rng = std::bind(
228 std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()), rng);
229
230 std::vector<int8_t> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
231 std::vector<int8_t> output(batch_size() * output_stride());
232 std::vector<float> output_ref(batch_size() * channels());
233 for (size_t iteration = 0; iteration < iterations(); iteration++) {
234 std::generate(input.begin(), input.end(), std::ref(i8rng));
235 std::fill(output.begin(), output.end(), 0xA5);
236
237 // Compute reference results.
238 const double scale = double(input_scale()) / (double(width()) * double(output_scale()));
239 for (size_t i = 0; i < batch_size(); i++) {
240 for (size_t j = 0; j < channels(); j++) {
241 double acc = 0.0f;
242 for (size_t k = 0; k < width(); k++) {
243 acc += double(int32_t(input[(i * width() + k) * input_stride() + j]) - int32_t(input_zero_point() - 0x80));
244 }
245 output_ref[i * channels() + j] = float(acc * scale + double(output_zero_point() - 0x80));
246 output_ref[i * channels() + j] = std::min<float>(output_ref[i * channels() + j], float(qmax() - 0x80));
247 output_ref[i * channels() + j] = std::max<float>(output_ref[i * channels() + j], float(qmin() - 0x80));
248 }
249 }
250
251 // Create, setup, run, and destroy Global Average Pooling operator.
252 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
253 xnn_operator_t global_average_pooling_op = nullptr;
254
255 xnn_status status = xnn_create_global_average_pooling_nwc_qs8(
256 channels(), input_stride(), output_stride(),
257 int8_t(input_zero_point() - 0x80), input_scale(),
258 int8_t(output_zero_point() - 0x80), output_scale(),
259 int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
260 0, &global_average_pooling_op);
261 if (status == xnn_status_unsupported_hardware) {
262 GTEST_SKIP();
263 }
264 ASSERT_EQ(xnn_status_success, status);
265 ASSERT_NE(nullptr, global_average_pooling_op);
266
267 // Smart pointer to automatically delete global_average_pooling_op.
268 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
269
270 ASSERT_EQ(xnn_status_success,
271 xnn_setup_global_average_pooling_nwc_qs8(
272 global_average_pooling_op,
273 batch_size(), width(),
274 input.data(), output.data(),
275 nullptr /* thread pool */));
276
277 ASSERT_EQ(xnn_status_success,
278 xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
279
280 // Verify results.
281 for (size_t i = 0; i < batch_size(); i++) {
282 for (size_t c = 0; c < channels(); c++) {
283 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80));
284 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80));
285 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.80f)
286 << "at batch index " << i << " / " << batch_size()
287 << ", channel " << c << " / " << channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700288 }
289 }
290 }
291 }
292
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700293 void TestNWCxF16() const {
294 std::random_device random_device;
295 auto rng = std::mt19937(random_device());
Frank Barchardd2750b02020-10-06 12:19:03 -0700296 auto f32rng = std::bind(std::uniform_real_distribution<float>(1.0e-3f, 1.0f), rng);
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700297 auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
298
299 std::vector<uint16_t> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
300 std::vector<uint16_t> output(batch_size() * output_stride());
301 std::vector<float> output_ref(batch_size() * channels());
302 for (size_t iteration = 0; iteration < iterations(); iteration++) {
303 std::generate(input.begin(), input.end(), std::ref(f16rng));
304 std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
305
306 // Compute reference results, without clamping.
307 for (size_t i = 0; i < batch_size(); i++) {
308 for (size_t j = 0; j < channels(); j++) {
309 float acc = 0.0f;
310 for (size_t k = 0; k < width(); k++) {
311 acc += fp16_ieee_to_fp32_value(input[(i * width() + k) * input_stride() + j]);
312 }
313 output_ref[i * channels() + j] = acc / float(width());
314 }
315 }
316
317 // Compute clamping parameters.
318 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
319 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
320 const float accumulated_range = accumulated_max - accumulated_min;
Frank Barchard39133702020-06-22 13:25:10 -0700321 const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
322 const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
323 const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
324 const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700325
326 // Clamp reference results.
327 for (float& value : output_ref) {
328 value = std::max(std::min(value, output_max), output_min);
329 }
330
331 // Create, setup, run, and destroy Global Average Pooling operator.
332 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
333 xnn_operator_t global_average_pooling_op = nullptr;
334
Marat Dukhanef61d022020-06-19 13:54:49 -0700335 xnn_status status = xnn_create_global_average_pooling_nwc_f16(
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700336 channels(), input_stride(), output_stride(),
337 output_min, output_max,
Marat Dukhanef61d022020-06-19 13:54:49 -0700338 0, &global_average_pooling_op);
339 if (status == xnn_status_unsupported_hardware) {
340 GTEST_SKIP();
341 }
342 ASSERT_EQ(xnn_status_success, status);
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700343 ASSERT_NE(nullptr, global_average_pooling_op);
344
345 // Smart pointer to automatically delete global_average_pooling_op.
346 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
347
348 ASSERT_EQ(xnn_status_success,
349 xnn_setup_global_average_pooling_nwc_f16(
350 global_average_pooling_op,
351 batch_size(), width(),
352 input.data(), output.data(),
353 nullptr /* thread pool */));
354
355 ASSERT_EQ(xnn_status_success,
356 xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
357
358 // Verify results.
359 for (size_t i = 0; i < batch_size(); i++) {
360 for (size_t c = 0; c < channels(); c++) {
361 ASSERT_LE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_max);
362 ASSERT_GE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_min);
Frank Barchard2b9d29b2020-09-17 12:03:39 -0700363 ASSERT_NEAR(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_ref[i * channels() + c], std::max(1.0e-4f, std::abs(output_ref[i * channels() + c]) * 1.0e-2f))
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700364 << "at batch index " << i << " / " << batch_size()
365 << ", channel " << c << " / " << channels();
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700366 }
367 }
368 }
369 }
370
Marat Dukhanefc47b82019-11-18 09:25:38 -0800371 void TestNWCxF32() const {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700372 std::random_device random_device;
373 auto rng = std::mt19937(random_device());
374 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
375
376 std::vector<float> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
377 std::vector<float> output(batch_size() * output_stride());
378 std::vector<float> output_ref(batch_size() * channels());
379 for (size_t iteration = 0; iteration < iterations(); iteration++) {
380 std::generate(input.begin(), input.end(), std::ref(f32rng));
381 std::fill(output.begin(), output.end(), std::nanf(""));
382
383 // Compute reference results, without clamping.
384 for (size_t i = 0; i < batch_size(); i++) {
385 for (size_t j = 0; j < channels(); j++) {
386 float acc = 0.0f;
387 for (size_t k = 0; k < width(); k++) {
388 acc += input[(i * width() + k) * input_stride() + j];
389 }
390 output_ref[i * channels() + j] = acc / float(width());
391 }
392 }
393
394 // Compute clamping parameters.
395 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
396 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
397 const float accumulated_range = accumulated_max - accumulated_min;
398 const float output_min = accumulated_range == 0.0f ?
399 -std::numeric_limits<float>::infinity() :
400 accumulated_min + accumulated_range / 255.0f * float(qmin());
401 const float output_max = accumulated_range == 0.0f ?
402 +std::numeric_limits<float>::infinity() :
403 accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
404
405 // Clamp reference results.
406 for (float& value : output_ref) {
407 value = std::max(std::min(value, output_max), output_min);
408 }
409
410 // Create, setup, run, and destroy Global Average Pooling operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800411 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700412 xnn_operator_t global_average_pooling_op = nullptr;
413
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700414 xnn_status status = xnn_create_global_average_pooling_nwc_f32(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700415 channels(), input_stride(), output_stride(),
416 output_min, output_max,
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700417 0, &global_average_pooling_op);
418 if (status == xnn_status_unsupported_hardware) {
419 GTEST_SKIP();
420 }
421 ASSERT_EQ(xnn_status_success, status);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700422 ASSERT_NE(nullptr, global_average_pooling_op);
423
424 // Smart pointer to automatically delete global_average_pooling_op.
425 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
426
427 ASSERT_EQ(xnn_status_success,
428 xnn_setup_global_average_pooling_nwc_f32(
429 global_average_pooling_op,
430 batch_size(), width(),
431 input.data(), output.data(),
432 nullptr /* thread pool */));
433
434 ASSERT_EQ(xnn_status_success,
435 xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
436
437 // Verify results.
438 for (size_t i = 0; i < batch_size(); i++) {
439 for (size_t c = 0; c < channels(); c++) {
440 ASSERT_LE(output[i * output_stride() + c], output_max);
441 ASSERT_GE(output[i * output_stride() + c], output_min);
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700442 ASSERT_NEAR(output[i * output_stride() + c], output_ref[i * channels() + c], std::abs(output_ref[i * channels() + c]) * 1.0e-6f)
443 << "at batch index " << i << " / " << batch_size()
444 << ", channel " << c << " / " << channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700445 }
446 }
447 }
448 }
449
Marat Dukhanefc47b82019-11-18 09:25:38 -0800450 void TestNCWxF32() const {
451 std::random_device random_device;
452 auto rng = std::mt19937(random_device());
453 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
454
455 std::vector<float> input(batch_size() * channels() * width() + XNN_EXTRA_BYTES / sizeof(float));
456 std::vector<float> output(batch_size() * channels());
457 std::vector<float> output_ref(batch_size() * channels());
458 for (size_t iteration = 0; iteration < iterations(); iteration++) {
459 std::generate(input.begin(), input.end(), std::ref(f32rng));
460 std::fill(output.begin(), output.end(), std::nanf(""));
461
462 // Compute reference results, without clamping.
463 for (size_t i = 0; i < batch_size(); i++) {
464 for (size_t j = 0; j < channels(); j++) {
465 float acc = 0.0f;
466 for (size_t k = 0; k < width(); k++) {
467 acc += input[(i * channels() + j) * width() + k];
468 }
469 output_ref[i * channels() + j] = acc / float(width());
470 }
471 }
472
473 // Compute clamping parameters.
474 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
475 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
476 const float accumulated_range = accumulated_max - accumulated_min;
477 const float output_min = accumulated_range == 0.0f ?
478 -std::numeric_limits<float>::infinity() :
479 accumulated_min + accumulated_range / 255.0f * float(qmin());
480 const float output_max = accumulated_range == 0.0f ?
481 +std::numeric_limits<float>::infinity() :
482 accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
483
484 // Clamp reference results.
485 for (float& value : output_ref) {
486 value = std::max(std::min(value, output_max), output_min);
487 }
488
489 // Create, setup, run, and destroy Global Average Pooling operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800490 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanefc47b82019-11-18 09:25:38 -0800491 xnn_operator_t global_average_pooling_op = nullptr;
492
493 xnn_status status = xnn_create_global_average_pooling_ncw_f32(
494 channels(), output_min, output_max,
495 0, &global_average_pooling_op);
496 if (status == xnn_status_unsupported_parameter) {
497 GTEST_SKIP();
498 }
499 ASSERT_EQ(xnn_status_success, status);
500
501 // Smart pointer to automatically delete global_average_pooling_op.
502 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
503
504 ASSERT_EQ(xnn_status_success,
505 xnn_setup_global_average_pooling_ncw_f32(
506 global_average_pooling_op,
507 batch_size(), width(),
508 input.data(), output.data(),
509 nullptr /* thread pool */));
510
511 ASSERT_EQ(xnn_status_success,
512 xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
513
514 // Verify results.
515 for (size_t i = 0; i < batch_size(); i++) {
516 for (size_t c = 0; c < channels(); c++) {
517 ASSERT_LE(output[i * channels() + c], output_max);
518 ASSERT_GE(output[i * channels() + c], output_min);
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700519 ASSERT_NEAR(output[i * channels() + c], output_ref[i * channels() + c], std::abs(output_ref[i * channels() + c]) * 1.0e-5f)
520 << "at batch index " << i << " / " << batch_size()
521 << ", channel " << c << " / " << channels();
Marat Dukhanefc47b82019-11-18 09:25:38 -0800522 }
523 }
524 }
525 }
526
XNNPACK Teamb455b122019-09-27 18:10:33 -0700527 private:
528 size_t batch_size_{1};
529 size_t width_{1};
530 size_t channels_{1};
531 size_t input_stride_{0};
532 size_t output_stride_{0};
533 float input_scale_{1.0f};
534 float output_scale_{1.0f};
535 uint8_t input_zero_point_{121};
536 uint8_t output_zero_point_{133};
537 uint8_t qmin_{0};
538 uint8_t qmax_{255};
539 size_t iterations_{1};
540};