blob: 7e3e333fd7ae2a9af362da106df1e3d1ccc4e68e [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
18#include <xnnpack.h>
19#include <xnnpack/AlignedAllocator.h>
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070020#include <xnnpack/params-init.h>
Frank Barcharde0601b52019-10-25 17:43:34 -070021#include <xnnpack/params.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070022
23
Marat Dukhan329da642019-11-19 21:44:39 -080024class ArgMaxPoolMicrokernelTester {
XNNPACK Teamb455b122019-09-27 18:10:33 -070025 public:
26 enum class Variant {
27 Native,
28 Scalar,
29 };
30
Marat Dukhan329da642019-11-19 21:44:39 -080031 inline ArgMaxPoolMicrokernelTester& output_pixels(size_t output_pixels) {
32 assert(output_pixels != 0);
33 this->output_pixels_ = output_pixels;
XNNPACK Teamb455b122019-09-27 18:10:33 -070034 return *this;
35 }
36
Marat Dukhan329da642019-11-19 21:44:39 -080037 inline size_t output_pixels() const {
38 return this->output_pixels_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070039 }
40
Marat Dukhan329da642019-11-19 21:44:39 -080041 inline ArgMaxPoolMicrokernelTester& step(size_t step) {
42 assert(step != 0);
43 this->step_ = step;
XNNPACK Teamb455b122019-09-27 18:10:33 -070044 return *this;
45 }
46
Marat Dukhan329da642019-11-19 21:44:39 -080047 inline size_t step() const {
48 return this->step_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070049 }
50
Marat Dukhan329da642019-11-19 21:44:39 -080051 inline ArgMaxPoolMicrokernelTester& input_offset(size_t input_offset) {
52 assert(input_offset != 0);
53 this->input_offset_ = input_offset;
XNNPACK Teamb455b122019-09-27 18:10:33 -070054 return *this;
55 }
56
Marat Dukhan329da642019-11-19 21:44:39 -080057 inline size_t input_offset() const {
58 return this->input_offset_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070059 }
60
Marat Dukhan329da642019-11-19 21:44:39 -080061 inline ArgMaxPoolMicrokernelTester& pooling_elements(size_t pooling_elements) {
62 assert(pooling_elements != 0);
63 this->pooling_elements_ = pooling_elements;
XNNPACK Teamb455b122019-09-27 18:10:33 -070064 return *this;
65 }
66
Marat Dukhan329da642019-11-19 21:44:39 -080067 inline size_t pooling_elements() const {
68 return this->pooling_elements_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070069 }
70
Marat Dukhan329da642019-11-19 21:44:39 -080071 inline size_t packed_pooling_elements() const {
72 if (pooling_elements() <= primary_pooling_tile()) {
73 return primary_pooling_tile();
XNNPACK Teamb455b122019-09-27 18:10:33 -070074 } else {
Marat Dukhan329da642019-11-19 21:44:39 -080075 return (pooling_elements() - primary_pooling_tile()) % incremental_pooling_tile() == 0 ? pooling_elements() : ((pooling_elements() - primary_pooling_tile()) / incremental_pooling_tile() + 1) * incremental_pooling_tile() + primary_pooling_tile();
XNNPACK Teamb455b122019-09-27 18:10:33 -070076 }
77 }
78
Marat Dukhan329da642019-11-19 21:44:39 -080079 inline ArgMaxPoolMicrokernelTester& pooling_tile(size_t primary_tile) {
80 assert(primary_tile != 0);
81 this->primary_pooling_tile_ = primary_tile;
82 this->incremental_pooling_tile_ = 0;
XNNPACK Teamb455b122019-09-27 18:10:33 -070083 return *this;
84 }
85
Marat Dukhan329da642019-11-19 21:44:39 -080086 inline ArgMaxPoolMicrokernelTester& pooling_tile(size_t primary_tile, size_t incremental_tile) {
87 assert(primary_tile != 0);
88 this->primary_pooling_tile_ = primary_tile;
89 this->incremental_pooling_tile_ = incremental_tile;
XNNPACK Teamb455b122019-09-27 18:10:33 -070090 return *this;
91 }
92
Marat Dukhan329da642019-11-19 21:44:39 -080093 inline ArgMaxPoolMicrokernelTester& primary_pooling_tile(size_t primary_pooling_tile) {
94 assert(primary_pooling_tile != 0);
95 this->primary_pooling_tile_ = primary_pooling_tile;
XNNPACK Teamb455b122019-09-27 18:10:33 -070096 return *this;
97 }
98
Marat Dukhan329da642019-11-19 21:44:39 -080099 inline size_t primary_pooling_tile() const {
100 return this->primary_pooling_tile_;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700101 }
102
Marat Dukhan329da642019-11-19 21:44:39 -0800103 inline ArgMaxPoolMicrokernelTester& incremental_pooling_tile(size_t incremental_pooling_tile) {
104 assert(incremental_pooling_tile != 0);
105 this->incremental_pooling_tile_ = incremental_pooling_tile;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700106 return *this;
107 }
108
Marat Dukhan329da642019-11-19 21:44:39 -0800109 inline size_t incremental_pooling_tile() const {
110 return this->incremental_pooling_tile_;
111 }
112
113 inline ArgMaxPoolMicrokernelTester& channels(size_t channels) {
114 assert(channels != 0);
115 this->channels_ = channels;
116 return *this;
117 }
118
119 inline size_t channels() const {
120 return this->channels_;
121 }
122
123 inline ArgMaxPoolMicrokernelTester& output_stride(size_t output_stride) {
124 assert(output_stride != 0);
125 this->output_stride_ = output_stride;
126 return *this;
127 }
128
129 inline size_t output_stride() const {
130 if (this->output_stride_ == 0) {
131 return channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700132 } else {
Marat Dukhan329da642019-11-19 21:44:39 -0800133 assert(this->output_stride_ >= channels());
134 return this->output_stride_;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700135 }
136 }
137
Marat Dukhan329da642019-11-19 21:44:39 -0800138 inline ArgMaxPoolMicrokernelTester& iterations(size_t iterations) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700139 this->iterations_ = iterations;
140 return *this;
141 }
142
143 inline size_t iterations() const {
144 return this->iterations_;
145 }
146
Marat Dukhan99936602020-04-11 16:47:01 -0700147 void Test(xnn_f32_argmaxpool_unipass_ukernel_function argmaxpool, Variant variant = Variant::Native) const {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700148 std::random_device random_device;
149 auto rng = std::mt19937(random_device());
150 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 1.0f), rng);
151
Marat Dukhan329da642019-11-19 21:44:39 -0800152 std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
153 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
154 ((output_pixels() - 1) * step() + pooling_elements()) * channels());
155 std::vector<float> output((output_pixels() - 1) * output_stride() + channels());
156 std::vector<uint32_t> index(output_pixels() * channels());
157 std::vector<float> output_ref(output_pixels() * channels());
158 std::vector<uint32_t> index_ref(output_pixels() * channels());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700159 for (size_t iteration = 0; iteration < iterations(); iteration++) {
Marat Dukhan329da642019-11-19 21:44:39 -0800160 std::generate(input.begin(), input.end(), std::ref(f32rng));
161 std::fill(output.begin(), output.end(), nanf(""));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700162
Marat Dukhan329da642019-11-19 21:44:39 -0800163 for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
164 indirect_input[i] = input.data() + i * channels() - input_offset();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700165 }
Marat Dukhan329da642019-11-19 21:44:39 -0800166 std::shuffle(indirect_input.begin(),
167 indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700168
169 // Compute reference results, without clamping.
Marat Dukhan329da642019-11-19 21:44:39 -0800170 for (size_t x = 0; x < output_pixels(); x++) {
171 for (size_t c = 0; c < channels(); c++) {
172 float max_value = indirect_input[x * step()][c + input_offset()];
XNNPACK Teamb455b122019-09-27 18:10:33 -0700173 uint32_t max_index = 0;
Marat Dukhan329da642019-11-19 21:44:39 -0800174 for (size_t p = 0; p < pooling_elements(); p++) {
175 const float value = indirect_input[x * step() + p][c + input_offset()];
XNNPACK Teamb455b122019-09-27 18:10:33 -0700176 if (value > max_value) {
177 max_value = value;
Marat Dukhan329da642019-11-19 21:44:39 -0800178 max_index = p;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700179 }
180 }
Marat Dukhan329da642019-11-19 21:44:39 -0800181 output_ref[x * channels() + c] = max_value;
182 index_ref[x * channels() + c] = max_index;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700183 }
184 }
185
XNNPACK Teamb455b122019-09-27 18:10:33 -0700186 // Call optimized micro-kernel.
Marat Dukhan329da642019-11-19 21:44:39 -0800187 argmaxpool(output_pixels(), pooling_elements(), channels(),
188 indirect_input.data(), input_offset() * sizeof(float), output.data(), index.data(),
189 step() * sizeof(void*),
Marat Dukhan447c4f52020-07-17 01:07:28 -0700190 (output_stride() - channels()) * sizeof(float));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700191
192 // Verify results.
Marat Dukhan329da642019-11-19 21:44:39 -0800193 for (size_t x = 0; x < output_pixels(); x++) {
194 for (size_t c = 0; c < channels(); c++) {
Marat Dukhan329da642019-11-19 21:44:39 -0800195 ASSERT_EQ(output_ref[x * channels() + c], output[x * output_stride() + c])
196 << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
197 << ", pooling elements = " << pooling_elements() << ", step = " << step()
198 << ", input offset = " << input_offset();
199 ASSERT_EQ(
200 indirect_input[x * step() + index_ref[x * channels() + c]][c + input_offset()],
201 indirect_input[x * step() + index[x * channels() + c]][c + input_offset()])
202 << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
203 << ", pooling elements = " << pooling_elements() << ", step = " << step()
204 << ", input offset = " << input_offset();
205 ASSERT_EQ(index_ref[x * channels() + c], index[x * channels() + c])
206 << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
207 << ", pooling elements = " << pooling_elements() << ", step = " << step()
208 << ", input offset = " << input_offset();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700209 }
210 }
211 }
212 }
213
Marat Dukhan99936602020-04-11 16:47:01 -0700214 void Test(xnn_f32_argmaxpool_multipass_ukernel_function argmaxpool, Variant variant = Variant::Native) const {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700215 std::random_device random_device;
216 auto rng = std::mt19937(random_device());
217 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 1.0f), rng);
218
Marat Dukhan329da642019-11-19 21:44:39 -0800219 std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
220 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
221 ((output_pixels() - 1) * step() + pooling_elements()) * channels());
222 std::vector<float> output((output_pixels() - 1) * output_stride() + channels());
223 std::vector<uint32_t> index(output_pixels() * channels());
Marat Dukhan9594db02019-12-05 14:32:37 -0800224 std::vector<uint32_t, AlignedAllocator<uint32_t, 64>> index_buffer(
Marat Dukhan329da642019-11-19 21:44:39 -0800225 channels() + XNN_EXTRA_BYTES / sizeof(uint32_t));
Marat Dukhan9594db02019-12-05 14:32:37 -0800226 std::vector<float, AlignedAllocator<float, 64>> output_buffer(
Marat Dukhan329da642019-11-19 21:44:39 -0800227 channels() + XNN_EXTRA_BYTES / sizeof(float));
228 std::vector<float> output_ref(output_pixels() * channels());
229 std::vector<uint32_t> index_ref(output_pixels() * channels());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700230 for (size_t iteration = 0; iteration < iterations(); iteration++) {
Marat Dukhan329da642019-11-19 21:44:39 -0800231 std::generate(input.begin(), input.end(), std::ref(f32rng));
232 std::fill(output.begin(), output.end(), nanf(""));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700233
Marat Dukhan329da642019-11-19 21:44:39 -0800234 for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
235 indirect_input[i] = input.data() + i * channels() - input_offset();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700236 }
Marat Dukhan329da642019-11-19 21:44:39 -0800237 std::shuffle(indirect_input.begin(),
238 indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700239
240 // Compute reference results, without clamping.
Marat Dukhan329da642019-11-19 21:44:39 -0800241 for (size_t x = 0; x < output_pixels(); x++) {
242 for (size_t c = 0; c < channels(); c++) {
243 float max_value = indirect_input[x * step()][c + input_offset()];
XNNPACK Teamb455b122019-09-27 18:10:33 -0700244 uint32_t max_index = 0;
Marat Dukhan329da642019-11-19 21:44:39 -0800245 for (size_t p = 0; p < pooling_elements(); p++) {
246 const float value = indirect_input[x * step() + p][c + input_offset()];
XNNPACK Teamb455b122019-09-27 18:10:33 -0700247 if (value > max_value) {
248 max_value = value;
Marat Dukhan329da642019-11-19 21:44:39 -0800249 max_index = p;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700250 }
251 }
Marat Dukhan329da642019-11-19 21:44:39 -0800252 output_ref[x * channels() + c] = max_value;
253 index_ref[x * channels() + c] = max_index;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700254 }
255 }
256
XNNPACK Teamb455b122019-09-27 18:10:33 -0700257 // Call optimized micro-kernel.
Marat Dukhan329da642019-11-19 21:44:39 -0800258 argmaxpool(output_pixels(), pooling_elements(), channels(),
259 indirect_input.data(), input_offset() * sizeof(float),
260 output_buffer.data(), index_buffer.data(),
261 output.data(), index.data(),
262 (step() - (packed_pooling_elements() - incremental_pooling_tile())) * sizeof(void*),
Marat Dukhan447c4f52020-07-17 01:07:28 -0700263 (output_stride() - channels()) * sizeof(float));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700264
265 // Verify results.
Marat Dukhan329da642019-11-19 21:44:39 -0800266 for (size_t x = 0; x < output_pixels(); x++) {
267 for (size_t c = 0; c < channels(); c++) {
Marat Dukhan329da642019-11-19 21:44:39 -0800268 ASSERT_EQ(output_ref[x * channels() + c], output[x * output_stride() + c])
269 << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
270 << ", pooling elements = " << pooling_elements() << ", step = " << step()
271 << ", input offset = " << input_offset();
272 ASSERT_EQ(
273 indirect_input[x * step() + index_ref[x * channels() + c]][c + input_offset()],
274 indirect_input[x * step() + index[x * channels() + c]][c + input_offset()])
275 << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
276 << ", pooling elements = " << pooling_elements() << ", step = " << step()
277 << ", input offset = " << input_offset();
278 ASSERT_EQ(index_ref[x * channels() + c], index[x * channels() + c])
279 << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
280 << ", pooling elements = " << pooling_elements() << ", step = " << step()
281 << ", input offset = " << input_offset();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700282 }
283 }
284 }
285 }
286
287 private:
Marat Dukhan329da642019-11-19 21:44:39 -0800288 size_t output_pixels_{1};
289 size_t pooling_elements_{1};
290 size_t channels_{1};
291 size_t input_offset_{0};
292 size_t step_{1};
293 size_t primary_pooling_tile_{1};
294 size_t incremental_pooling_tile_{1};
295 size_t output_stride_{0};
Marat Dukhan329da642019-11-19 21:44:39 -0800296 size_t iterations_{3};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700297};