blob: 8cd80330f376120b65e30a5633fb1f5f118bd93b [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <cstddef>
14#include <cstdlib>
15#include <algorithm>
16#include <cmath>
17#include <functional>
Marat Dukhan5ce30d92020-04-14 03:31:26 -070018#include <limits>
XNNPACK Teamb455b122019-09-27 18:10:33 -070019#include <random>
20#include <vector>
21
22#include <xnnpack.h>
23
24
25class FullyConnectedOperatorTester {
26 public:
27 inline FullyConnectedOperatorTester& input_channels(size_t input_channels) {
28 assert(input_channels >= 1);
29 this->input_channels_ = input_channels;
30 return *this;
31 }
32
33 inline size_t input_channels() const {
34 return this->input_channels_;
35 }
36
37 inline FullyConnectedOperatorTester& output_channels(size_t output_channels) {
38 assert(output_channels >= 1);
39 this->output_channels_ = output_channels;
40 return *this;
41 }
42
43 inline size_t output_channels() const {
44 return this->output_channels_;
45 }
46
47 inline FullyConnectedOperatorTester& batch_size(size_t batch_size) {
48 assert(batch_size >= 1);
49 this->batch_size_ = batch_size;
50 return *this;
51 }
52
53 inline size_t batch_size() const {
54 return this->batch_size_;
55 }
56
57 inline FullyConnectedOperatorTester& input_stride(size_t input_stride) {
58 assert(input_stride >= 1);
59 this->input_stride_ = input_stride;
60 return *this;
61 }
62
63 inline size_t input_stride() const {
64 if (this->input_stride_ == 0) {
65 return input_channels();
66 } else {
67 assert(this->input_stride_ >= input_channels());
68 return this->input_stride_;
69 }
70 }
71
72 inline FullyConnectedOperatorTester& output_stride(size_t output_stride) {
73 assert(output_stride >= 1);
74 this->output_stride_ = output_stride;
75 return *this;
76 }
77
78 inline size_t output_stride() const {
79 if (this->output_stride_ == 0) {
80 return output_channels();
81 } else {
82 assert(this->output_stride_ >= output_channels());
83 return this->output_stride_;
84 }
85 }
86
87 inline FullyConnectedOperatorTester& qmin(uint8_t qmin) {
88 this->qmin_ = qmin;
89 return *this;
90 }
91
92 inline uint8_t qmin() const {
93 return this->qmin_;
94 }
95
96 inline FullyConnectedOperatorTester& qmax(uint8_t qmax) {
97 this->qmax_ = qmax;
98 return *this;
99 }
100
101 inline uint8_t qmax() const {
102 return this->qmax_;
103 }
104
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800105 inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
106 this->transpose_weights_ = transpose_weights;
107 return *this;
108 }
109
110 inline bool transpose_weights() const {
111 return this->transpose_weights_;
112 }
113
Marat Dukhanf568f082019-10-30 09:47:07 -0700114 inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
115 this->has_bias_ = has_bias;
116 return *this;
117 }
118
119 inline bool has_bias() const {
120 return this->has_bias_;
121 }
122
XNNPACK Teamb455b122019-09-27 18:10:33 -0700123 inline FullyConnectedOperatorTester& iterations(size_t iterations) {
124 this->iterations_ = iterations;
125 return *this;
126 }
127
128 inline size_t iterations() const {
129 return this->iterations_;
130 }
131
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700132 void TestQS8() const {
133 std::random_device random_device;
134 auto rng = std::mt19937(random_device());
135 auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
136 auto i8rng = std::bind(std::uniform_int_distribution<int32_t>(
137 -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()), rng);
138
139 std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
140 (batch_size() - 1) * input_stride() + input_channels());
141 std::vector<int8_t> kernel(output_channels() * input_channels());
142 std::vector<int32_t> bias(output_channels());
143 std::vector<int8_t> output((batch_size() - 1) * output_stride() + output_channels());
144 std::vector<int32_t> accumulators(batch_size() * output_channels());
145 std::vector<double> output_ref(batch_size() * output_channels());
146
147 const int8_t input_zero_point = 127;
148
149 for (size_t iteration = 0; iteration < iterations(); iteration++) {
150 std::generate(input.begin(), input.end(), std::ref(i8rng));
151 std::generate(kernel.begin(), kernel.end(), std::ref(i8rng));
152 std::generate(bias.begin(), bias.end(), std::ref(i32rng));
153 std::fill(output.begin(), output.end(), 0xA5);
154
155 // Compute reference results, without renormalization.
156 if (has_bias()) {
157 for (size_t i = 0; i < batch_size(); i++) {
158 for (size_t oc = 0; oc < output_channels(); oc++) {
159 accumulators[i * output_channels() + oc] = bias[oc];
160 }
161 }
162 } else {
163 std::fill(accumulators.begin(), accumulators.end(), 0);
164 }
165 if (transpose_weights()) {
166 for (size_t i = 0; i < batch_size(); i++) {
167 for (size_t oc = 0; oc < output_channels(); oc++) {
168 for (size_t ic = 0; ic < input_channels(); ic++) {
169 accumulators[i * output_channels() + oc] +=
170 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
171 int32_t(kernel[ic * output_channels() + oc]);
172 }
173 }
174 }
175 } else {
176 for (size_t i = 0; i < batch_size(); i++) {
177 for (size_t oc = 0; oc < output_channels(); oc++) {
178 for (size_t ic = 0; ic < input_channels(); ic++) {
179 accumulators[i * output_channels() + oc] +=
180 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
181 int32_t(kernel[oc * input_channels() + ic]);
182 }
183 }
184 }
185 }
186
187 // Compute renormalization parameters.
188 const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
189 const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
190
191 const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
192 const int8_t output_zero_point = int8_t(std::max(std::min(
193 lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
194 long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
195
196 // Renormalize reference results.
197 std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
198 [this, output_scale, output_zero_point](int32_t x) -> double {
199 return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
200 });
201
202 // Create, setup, run, and destroy Fully Connected operator.
203 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
204 xnn_operator_t fully_connected_op = nullptr;
205
206 const xnn_status status = xnn_create_fully_connected_nc_qs8(
207 input_channels(), output_channels(),
208 input_stride(), output_stride(),
209 input_zero_point, 1.0f /* input scale */,
210 1.0f /* kernel scale */,
211 kernel.data(), has_bias() ? bias.data() : nullptr,
212 output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
213 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
214 &fully_connected_op);
215 if (status == xnn_status_unsupported_hardware) {
216 GTEST_SKIP();
217 }
218 ASSERT_EQ(xnn_status_success, status);
219 ASSERT_NE(nullptr, fully_connected_op);
220
221 // Smart pointer to automatically delete fully_connected_op.
222 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
223
224 ASSERT_EQ(xnn_status_success,
225 xnn_setup_fully_connected_nc_qs8(
226 fully_connected_op,
227 batch_size(),
228 input.data(), output.data(),
229 nullptr /* thread pool */));
230
231 ASSERT_EQ(xnn_status_success,
232 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
233
234 // Verify results.
235 for (size_t i = 0; i < batch_size(); i++) {
236 for (size_t c = 0; c < output_channels(); c++) {
237 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80))
238 << "batch index = " << i << ", channel = " << c;
239 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80))
240 << "batch index = " << i << ", channel = " << c;
241 ASSERT_NEAR(
242 output_ref[i * output_channels() + c],
243 double(output[i * output_stride() + c]) - double(output_zero_point),
244 0.9)
245 << "batch index = " << i << ", channel = " << c;
246 }
247 }
248 }
249 }
250
Marat Dukhan08b7a972020-07-14 18:17:29 -0700251 void TestQU8() const {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700252 std::random_device random_device;
253 auto rng = std::mt19937(random_device());
Marat Dukhanecd83112020-08-03 21:50:28 -0700254 auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
Marat Dukhan5ce30d92020-04-14 03:31:26 -0700255 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700256
257 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
258 (batch_size() - 1) * input_stride() + input_channels());
259 std::vector<uint8_t> kernel(output_channels() * input_channels());
260 std::vector<int32_t> bias(output_channels());
261 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + output_channels());
262 std::vector<int32_t> accumulators(batch_size() * output_channels());
263 std::vector<double> output_ref(batch_size() * output_channels());
264
265 const uint8_t input_zero_point = 127;
266 const uint8_t kernel_zero_point = 127;
267
268 for (size_t iteration = 0; iteration < iterations(); iteration++) {
269 std::generate(input.begin(), input.end(), std::ref(u8rng));
270 std::generate(kernel.begin(), kernel.end(), std::ref(u8rng));
Marat Dukhanecd83112020-08-03 21:50:28 -0700271 std::generate(bias.begin(), bias.end(), std::ref(i32rng));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700272 std::fill(output.begin(), output.end(), 0xA5);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700273
274 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700275 if (has_bias()) {
276 for (size_t i = 0; i < batch_size(); i++) {
277 for (size_t oc = 0; oc < output_channels(); oc++) {
278 accumulators[i * output_channels() + oc] = bias[oc];
279 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700280 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700281 } else {
282 std::fill(accumulators.begin(), accumulators.end(), 0);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700283 }
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800284 if (transpose_weights()) {
285 for (size_t i = 0; i < batch_size(); i++) {
286 for (size_t oc = 0; oc < output_channels(); oc++) {
287 for (size_t ic = 0; ic < input_channels(); ic++) {
288 accumulators[i * output_channels() + oc] +=
289 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
290 (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
291 }
292 }
293 }
294 } else {
295 for (size_t i = 0; i < batch_size(); i++) {
296 for (size_t oc = 0; oc < output_channels(); oc++) {
297 for (size_t ic = 0; ic < input_channels(); ic++) {
298 accumulators[i * output_channels() + oc] +=
299 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
300 (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
301 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700302 }
303 }
304 }
305
306 // Compute renormalization parameters.
307 const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
308 const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
309
310 const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
311 const uint8_t output_zero_point = uint8_t(std::max(std::min(
312 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
313 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
314
315 // Renormalize reference results.
316 std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
317 [this, output_scale, output_zero_point](int32_t x) -> double {
318 return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
319 });
320
321 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800322 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700323 xnn_operator_t fully_connected_op = nullptr;
324
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700325 const xnn_status status = xnn_create_fully_connected_nc_qu8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700326 input_channels(), output_channels(),
327 input_stride(), output_stride(),
328 input_zero_point, 1.0f /* input scale */,
329 kernel_zero_point, 1.0f /* kernel scale */,
Marat Dukhanf568f082019-10-30 09:47:07 -0700330 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700331 output_zero_point, output_scale, qmin(), qmax(),
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800332 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700333 &fully_connected_op);
334 if (status == xnn_status_unsupported_hardware) {
335 GTEST_SKIP();
336 }
337 ASSERT_EQ(xnn_status_success, status);
338 ASSERT_NE(nullptr, fully_connected_op);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700339
340 // Smart pointer to automatically delete fully_connected_op.
341 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
342
343 ASSERT_EQ(xnn_status_success,
Marat Dukhan08b7a972020-07-14 18:17:29 -0700344 xnn_setup_fully_connected_nc_qu8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700345 fully_connected_op,
346 batch_size(),
347 input.data(), output.data(),
348 nullptr /* thread pool */));
349
350 ASSERT_EQ(xnn_status_success,
351 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
352
353 // Verify results.
354 for (size_t i = 0; i < batch_size(); i++) {
355 for (size_t c = 0; c < output_channels(); c++) {
356 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax()))
357 << "batch index = " << i << ", channel = " << c;
358 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin()))
359 << "batch index = " << i << ", channel = " << c;
360 ASSERT_NEAR(
361 output_ref[i * output_channels() + c],
362 double(output[i * output_stride() + c]) - double(output_zero_point),
363 0.9)
364 << "batch index = " << i << ", channel = " << c;
365 }
366 }
367 }
368 }
369
370 void TestF32() const {
371 std::random_device random_device;
372 auto rng = std::mt19937(random_device());
373 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.1f, 1.0f), rng);
374
375 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
376 (batch_size() - 1) * input_stride() + input_channels());
377 std::vector<float> kernel(output_channels() * input_channels());
378 std::vector<float> bias(output_channels());
379 std::vector<float> output((batch_size() - 1) * output_stride() + output_channels());
380 std::vector<float> output_ref(batch_size() * output_channels());
381
382 for (size_t iteration = 0; iteration < iterations(); iteration++) {
383 std::generate(input.begin(), input.end(), std::ref(f32rng));
384 std::generate(kernel.begin(), kernel.end(), std::ref(f32rng));
385 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
386 std::fill(output.begin(), output.end(), nanf(""));
387
388 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700389 if (has_bias()) {
390 for (size_t i = 0; i < batch_size(); i++) {
391 for (size_t oc = 0; oc < output_channels(); oc++) {
392 output_ref[i * output_channels() + oc] = bias[oc];
393 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700394 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700395 } else {
396 std::fill(output_ref.begin(), output_ref.end(), 0.0f);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700397 }
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800398 if (transpose_weights()) {
399 for (size_t i = 0; i < batch_size(); i++) {
400 for (size_t oc = 0; oc < output_channels(); oc++) {
401 for (size_t ic = 0; ic < input_channels(); ic++) {
402 output_ref[i * output_channels() + oc] +=
403 input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
404 }
405 }
406 }
407 } else {
408 for (size_t i = 0; i < batch_size(); i++) {
409 for (size_t oc = 0; oc < output_channels(); oc++) {
410 for (size_t ic = 0; ic < input_channels(); ic++) {
411 output_ref[i * output_channels() + oc] +=
412 input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
413 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700414 }
415 }
416 }
417
418 // Compute clamping parameters.
Marat Dukhanc6edf922019-10-03 15:08:04 -0700419 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
420 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700421
Marat Dukhan869c62d2020-04-09 17:17:55 -0700422 const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
423 accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
424 const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
425 accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700426
427 // Clamp reference results.
428 for (float& value : output_ref) {
429 value = std::max(std::min(value, output_max), output_min);
430 }
431
432 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800433 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700434 xnn_operator_t fully_connected_op = nullptr;
435
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700436 const xnn_status status = xnn_create_fully_connected_nc_f32(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700437 input_channels(), output_channels(),
438 input_stride(), output_stride(),
Marat Dukhanf568f082019-10-30 09:47:07 -0700439 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700440 output_min, output_max,
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800441 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700442 &fully_connected_op);
443 if (status == xnn_status_unsupported_hardware) {
444 GTEST_SKIP();
445 }
446 ASSERT_EQ(xnn_status_success, status);
447 ASSERT_NE(nullptr, fully_connected_op);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700448
449 // Smart pointer to automatically delete fully_connected_op.
450 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
451
452 ASSERT_EQ(xnn_status_success,
453 xnn_setup_fully_connected_nc_f32(
454 fully_connected_op,
455 batch_size(),
456 input.data(), output.data(),
457 nullptr /* thread pool */));
458
459 ASSERT_EQ(xnn_status_success,
460 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
461
462 // Verify results.
463 for (size_t i = 0; i < batch_size(); i++) {
464 for (size_t c = 0; c < output_channels(); c++) {
465 ASSERT_LE(output[i * output_stride() + c], output_max)
466 << "batch index = " << i << ", channel = " << c;
467 ASSERT_GE(output[i * output_stride() + c], output_min)
468 << "batch index = " << i << ", channel = " << c;
469 ASSERT_NEAR(
470 output_ref[i * output_channels() + c],
471 output[i * output_stride() + c],
472 1.0e-4 * std::abs(output_ref[i * output_channels() + c]))
473 << "batch index = " << i << ", channel = " << c;
474 }
475 }
476 }
477 }
478
479 private:
480 size_t input_channels_{1};
481 size_t input_stride_{0};
482 size_t output_channels_{1};
483 size_t output_stride_{0};
484 size_t batch_size_{1};
485 uint8_t qmin_{0};
486 uint8_t qmax_{255};
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800487 bool transpose_weights_{false};
Marat Dukhanf568f082019-10-30 09:47:07 -0700488 bool has_bias_{true};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700489 size_t iterations_{1};
490};