blob: a23ab409cf684e49d2cf7253d9c8f1b8456635f1 [file] [log] [blame]
Marat Dukhan4e21b272020-06-04 18:45:01 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <limits>
16#include <random>
17#include <vector>
18
19#include <xnnpack.h>
20
21
22class CopyOperatorTester {
23 public:
24 inline CopyOperatorTester& channels(size_t channels) {
25 assert(channels != 0);
26 this->channels_ = channels;
27 return *this;
28 }
29
30 inline size_t channels() const {
31 return this->channels_;
32 }
33
34 inline CopyOperatorTester& input_stride(size_t input_stride) {
35 assert(input_stride != 0);
36 this->input_stride_ = input_stride;
37 return *this;
38 }
39
40 inline size_t input_stride() const {
41 if (this->input_stride_ == 0) {
42 return this->channels_;
43 } else {
44 assert(this->input_stride_ >= this->channels_);
45 return this->input_stride_;
46 }
47 }
48
49 inline CopyOperatorTester& output_stride(size_t output_stride) {
50 assert(output_stride != 0);
51 this->output_stride_ = output_stride;
52 return *this;
53 }
54
55 inline size_t output_stride() const {
56 if (this->output_stride_ == 0) {
57 return this->channels_;
58 } else {
59 assert(this->output_stride_ >= this->channels_);
60 return this->output_stride_;
61 }
62 }
63
64 inline CopyOperatorTester& batch_size(size_t batch_size) {
65 assert(batch_size != 0);
66 this->batch_size_ = batch_size;
67 return *this;
68 }
69
70 inline size_t batch_size() const {
71 return this->batch_size_;
72 }
73
74 inline CopyOperatorTester& iterations(size_t iterations) {
75 this->iterations_ = iterations;
76 return *this;
77 }
78
79 inline size_t iterations() const {
80 return this->iterations_;
81 }
82
Marat Dukhan2bd2bd22022-02-04 03:34:32 -080083 void TestX8() const {
84 std::random_device random_device;
85 auto rng = std::mt19937(random_device());
86 auto u8rng = std::bind(
87 std::uniform_int_distribution<uint32_t>( std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()),
88 rng);
89
90 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
91 (batch_size() - 1) * input_stride() + channels());
92 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
93 std::vector<uint8_t> output_ref(batch_size() * channels());
94 for (size_t iteration = 0; iteration < iterations(); iteration++) {
95 std::generate(input.begin(), input.end(), std::ref(u8rng));
96 std::fill(output.begin(), output.end(), UINT16_C(0xFA));
97
98 // Compute reference results.
99 for (size_t i = 0; i < batch_size(); i++) {
100 for (size_t c = 0; c < channels(); c++) {
101 output_ref[i * channels() + c] = input[i * input_stride() + c];
102 }
103 }
104
105 // Create, setup, run, and destroy Copy operator.
106 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
107 xnn_operator_t copy_op = nullptr;
108
109 ASSERT_EQ(xnn_status_success,
110 xnn_create_copy_nc_x8(
111 channels(), input_stride(), output_stride(),
112 0, &copy_op));
113 ASSERT_NE(nullptr, copy_op);
114
115 // Smart pointer to automatically delete copy_op.
116 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_copy_op(copy_op, xnn_delete_operator);
117
118 ASSERT_EQ(xnn_status_success,
119 xnn_setup_copy_nc_x8(
120 copy_op,
121 batch_size(),
122 input.data(), output.data(),
123 nullptr /* thread pool */));
124
125 ASSERT_EQ(xnn_status_success,
126 xnn_run_operator(copy_op, nullptr /* thread pool */));
127
128 // Verify results.
129 for (size_t i = 0; i < batch_size(); i++) {
130 for (size_t c = 0; c < channels(); c++) {
131 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
132 << "at batch " << i << " / " << batch_size() << ", channel = " << c << " / " << channels();
133 }
134 }
135 }
136 }
137
138 void TestX16() const {
139 std::random_device random_device;
140 auto rng = std::mt19937(random_device());
141 auto u16rng = std::bind(std::uniform_int_distribution<uint16_t>(), rng);
142
143 std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
144 (batch_size() - 1) * input_stride() + channels());
145 std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels());
146 std::vector<uint16_t> output_ref(batch_size() * channels());
147 for (size_t iteration = 0; iteration < iterations(); iteration++) {
148 std::generate(input.begin(), input.end(), std::ref(u16rng));
149 std::fill(output.begin(), output.end(), UINT16_C(0xDEAD));
150
151 // Compute reference results.
152 for (size_t i = 0; i < batch_size(); i++) {
153 for (size_t c = 0; c < channels(); c++) {
154 output_ref[i * channels() + c] = input[i * input_stride() + c];
155 }
156 }
157
158 // Create, setup, run, and destroy Copy operator.
159 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
160 xnn_operator_t copy_op = nullptr;
161
162 ASSERT_EQ(xnn_status_success,
163 xnn_create_copy_nc_x16(
164 channels(), input_stride(), output_stride(),
165 0, &copy_op));
166 ASSERT_NE(nullptr, copy_op);
167
168 // Smart pointer to automatically delete copy_op.
169 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_copy_op(copy_op, xnn_delete_operator);
170
171 ASSERT_EQ(xnn_status_success,
172 xnn_setup_copy_nc_x16(
173 copy_op,
174 batch_size(),
175 input.data(), output.data(),
176 nullptr /* thread pool */));
177
178 ASSERT_EQ(xnn_status_success,
179 xnn_run_operator(copy_op, nullptr /* thread pool */));
180
181 // Verify results.
182 for (size_t i = 0; i < batch_size(); i++) {
183 for (size_t c = 0; c < channels(); c++) {
184 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
185 << "at batch " << i << " / " << batch_size() << ", channel = " << c << " / " << channels();
186 }
187 }
188 }
189 }
190
Marat Dukhan4e21b272020-06-04 18:45:01 -0700191 void TestX32() const {
192 std::random_device random_device;
193 auto rng = std::mt19937(random_device());
194 auto u32rng = std::bind(std::uniform_int_distribution<uint32_t>(), rng);
195
Marat Dukhan2bd2bd22022-02-04 03:34:32 -0800196 std::vector<uint32_t> input(XNN_EXTRA_BYTES / sizeof(uint32_t) +
Marat Dukhan4e21b272020-06-04 18:45:01 -0700197 (batch_size() - 1) * input_stride() + channels());
198 std::vector<uint32_t> output((batch_size() - 1) * output_stride() + channels());
199 std::vector<uint32_t> output_ref(batch_size() * channels());
200 for (size_t iteration = 0; iteration < iterations(); iteration++) {
201 std::generate(input.begin(), input.end(), std::ref(u32rng));
202 std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF));
203
204 // Compute reference results.
205 for (size_t i = 0; i < batch_size(); i++) {
206 for (size_t c = 0; c < channels(); c++) {
207 output_ref[i * channels() + c] = input[i * input_stride() + c];
208 }
209 }
210
211 // Create, setup, run, and destroy Copy operator.
212 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
213 xnn_operator_t copy_op = nullptr;
214
215 ASSERT_EQ(xnn_status_success,
216 xnn_create_copy_nc_x32(
217 channels(), input_stride(), output_stride(),
218 0, &copy_op));
219 ASSERT_NE(nullptr, copy_op);
220
221 // Smart pointer to automatically delete copy_op.
222 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_copy_op(copy_op, xnn_delete_operator);
223
224 ASSERT_EQ(xnn_status_success,
225 xnn_setup_copy_nc_x32(
226 copy_op,
227 batch_size(),
228 input.data(), output.data(),
229 nullptr /* thread pool */));
230
231 ASSERT_EQ(xnn_status_success,
232 xnn_run_operator(copy_op, nullptr /* thread pool */));
233
234 // Verify results.
235 for (size_t i = 0; i < batch_size(); i++) {
236 for (size_t c = 0; c < channels(); c++) {
237 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
238 << "at batch " << i << " / " << batch_size() << ", channel = " << c << " / " << channels();
239 }
240 }
241 }
242 }
243
244 private:
245 size_t batch_size_{1};
246 size_t channels_{1};
247 size_t input_stride_{0};
248 size_t output_stride_{0};
249 size_t iterations_{15};
250};