blob: 0dd7681965661ebc74b30a14aaa677fffc4a21aa [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <cstddef>
14#include <cstdlib>
15#include <algorithm>
16#include <cmath>
17#include <functional>
Marat Dukhan5ce30d92020-04-14 03:31:26 -070018#include <limits>
XNNPACK Teamb455b122019-09-27 18:10:33 -070019#include <random>
20#include <vector>
21
Marat Dukhanddb3d162021-10-25 17:05:51 -070022#include <fp16.h>
23
XNNPACK Teamb455b122019-09-27 18:10:33 -070024#include <xnnpack.h>
25
26
27class FullyConnectedOperatorTester {
28 public:
Marat Dukhan1d6b7c92022-01-14 21:18:44 -080029 enum class WeightsType {
30 Default,
31 FP32,
32 };
33
XNNPACK Teamb455b122019-09-27 18:10:33 -070034 inline FullyConnectedOperatorTester& input_channels(size_t input_channels) {
35 assert(input_channels >= 1);
36 this->input_channels_ = input_channels;
37 return *this;
38 }
39
40 inline size_t input_channels() const {
41 return this->input_channels_;
42 }
43
44 inline FullyConnectedOperatorTester& output_channels(size_t output_channels) {
45 assert(output_channels >= 1);
46 this->output_channels_ = output_channels;
47 return *this;
48 }
49
50 inline size_t output_channels() const {
51 return this->output_channels_;
52 }
53
54 inline FullyConnectedOperatorTester& batch_size(size_t batch_size) {
55 assert(batch_size >= 1);
56 this->batch_size_ = batch_size;
57 return *this;
58 }
59
60 inline size_t batch_size() const {
61 return this->batch_size_;
62 }
63
64 inline FullyConnectedOperatorTester& input_stride(size_t input_stride) {
65 assert(input_stride >= 1);
66 this->input_stride_ = input_stride;
67 return *this;
68 }
69
70 inline size_t input_stride() const {
71 if (this->input_stride_ == 0) {
72 return input_channels();
73 } else {
74 assert(this->input_stride_ >= input_channels());
75 return this->input_stride_;
76 }
77 }
78
79 inline FullyConnectedOperatorTester& output_stride(size_t output_stride) {
80 assert(output_stride >= 1);
81 this->output_stride_ = output_stride;
82 return *this;
83 }
84
85 inline size_t output_stride() const {
86 if (this->output_stride_ == 0) {
87 return output_channels();
88 } else {
89 assert(this->output_stride_ >= output_channels());
90 return this->output_stride_;
91 }
92 }
93
94 inline FullyConnectedOperatorTester& qmin(uint8_t qmin) {
95 this->qmin_ = qmin;
96 return *this;
97 }
98
99 inline uint8_t qmin() const {
100 return this->qmin_;
101 }
102
103 inline FullyConnectedOperatorTester& qmax(uint8_t qmax) {
104 this->qmax_ = qmax;
105 return *this;
106 }
107
108 inline uint8_t qmax() const {
109 return this->qmax_;
110 }
111
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800112 inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
113 this->transpose_weights_ = transpose_weights;
114 return *this;
115 }
116
117 inline bool transpose_weights() const {
118 return this->transpose_weights_;
119 }
120
Marat Dukhanf568f082019-10-30 09:47:07 -0700121 inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
122 this->has_bias_ = has_bias;
123 return *this;
124 }
125
126 inline bool has_bias() const {
127 return this->has_bias_;
128 }
129
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800130 inline FullyConnectedOperatorTester& weights_type(WeightsType weights_type) {
131 this->weights_type_ = weights_type;
132 return *this;
133 }
134
135 inline WeightsType weights_type() const {
136 return this->weights_type_;
137 }
138
XNNPACK Teamb455b122019-09-27 18:10:33 -0700139 inline FullyConnectedOperatorTester& iterations(size_t iterations) {
140 this->iterations_ = iterations;
141 return *this;
142 }
143
144 inline size_t iterations() const {
145 return this->iterations_;
146 }
147
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700148 void TestQS8() const {
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800149 ASSERT_EQ(weights_type(), WeightsType::Default);
150
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700151 std::random_device random_device;
152 auto rng = std::mt19937(random_device());
Marat Dukhane7991e72021-08-10 22:30:03 -0700153 auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700154 auto i8rng = std::bind(std::uniform_int_distribution<int32_t>(
Marat Dukhane7991e72021-08-10 22:30:03 -0700155 std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()), std::ref(rng));
156 auto w8rng = std::bind(std::uniform_int_distribution<int32_t>(
157 -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()), std::ref(rng));
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700158
159 std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
160 (batch_size() - 1) * input_stride() + input_channels());
161 std::vector<int8_t> kernel(output_channels() * input_channels());
162 std::vector<int32_t> bias(output_channels());
163 std::vector<int8_t> output((batch_size() - 1) * output_stride() + output_channels());
164 std::vector<int32_t> accumulators(batch_size() * output_channels());
165 std::vector<double> output_ref(batch_size() * output_channels());
166
167 const int8_t input_zero_point = 127;
168
169 for (size_t iteration = 0; iteration < iterations(); iteration++) {
170 std::generate(input.begin(), input.end(), std::ref(i8rng));
Marat Dukhane7991e72021-08-10 22:30:03 -0700171 std::generate(kernel.begin(), kernel.end(), std::ref(w8rng));
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700172 std::generate(bias.begin(), bias.end(), std::ref(i32rng));
173 std::fill(output.begin(), output.end(), 0xA5);
174
175 // Compute reference results, without renormalization.
176 if (has_bias()) {
177 for (size_t i = 0; i < batch_size(); i++) {
178 for (size_t oc = 0; oc < output_channels(); oc++) {
179 accumulators[i * output_channels() + oc] = bias[oc];
180 }
181 }
182 } else {
183 std::fill(accumulators.begin(), accumulators.end(), 0);
184 }
185 if (transpose_weights()) {
186 for (size_t i = 0; i < batch_size(); i++) {
187 for (size_t oc = 0; oc < output_channels(); oc++) {
188 for (size_t ic = 0; ic < input_channels(); ic++) {
189 accumulators[i * output_channels() + oc] +=
190 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
191 int32_t(kernel[ic * output_channels() + oc]);
192 }
193 }
194 }
195 } else {
196 for (size_t i = 0; i < batch_size(); i++) {
197 for (size_t oc = 0; oc < output_channels(); oc++) {
198 for (size_t ic = 0; ic < input_channels(); ic++) {
199 accumulators[i * output_channels() + oc] +=
200 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
201 int32_t(kernel[oc * input_channels() + ic]);
202 }
203 }
204 }
205 }
206
207 // Compute renormalization parameters.
208 const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
209 const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
210
211 const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
212 const int8_t output_zero_point = int8_t(std::max(std::min(
213 lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
214 long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
215
216 // Renormalize reference results.
217 std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
218 [this, output_scale, output_zero_point](int32_t x) -> double {
219 return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
220 });
221
222 // Create, setup, run, and destroy Fully Connected operator.
223 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
224 xnn_operator_t fully_connected_op = nullptr;
225
226 const xnn_status status = xnn_create_fully_connected_nc_qs8(
227 input_channels(), output_channels(),
228 input_stride(), output_stride(),
229 input_zero_point, 1.0f /* input scale */,
230 1.0f /* kernel scale */,
231 kernel.data(), has_bias() ? bias.data() : nullptr,
232 output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
233 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
234 &fully_connected_op);
235 if (status == xnn_status_unsupported_hardware) {
236 GTEST_SKIP();
237 }
238 ASSERT_EQ(xnn_status_success, status);
239 ASSERT_NE(nullptr, fully_connected_op);
240
241 // Smart pointer to automatically delete fully_connected_op.
242 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
243
244 ASSERT_EQ(xnn_status_success,
245 xnn_setup_fully_connected_nc_qs8(
246 fully_connected_op,
247 batch_size(),
248 input.data(), output.data(),
249 nullptr /* thread pool */));
250
251 ASSERT_EQ(xnn_status_success,
252 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
253
254 // Verify results.
255 for (size_t i = 0; i < batch_size(); i++) {
256 for (size_t c = 0; c < output_channels(); c++) {
257 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80))
258 << "batch index = " << i << ", channel = " << c;
259 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80))
260 << "batch index = " << i << ", channel = " << c;
261 ASSERT_NEAR(
262 output_ref[i * output_channels() + c],
263 double(output[i * output_stride() + c]) - double(output_zero_point),
264 0.9)
265 << "batch index = " << i << ", channel = " << c;
266 }
267 }
268 }
269 }
270
Marat Dukhan08b7a972020-07-14 18:17:29 -0700271 void TestQU8() const {
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800272 ASSERT_EQ(weights_type(), WeightsType::Default);
273
XNNPACK Teamb455b122019-09-27 18:10:33 -0700274 std::random_device random_device;
275 auto rng = std::mt19937(random_device());
Marat Dukhane7991e72021-08-10 22:30:03 -0700276 auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
277 auto u8rng = std::bind(
278 std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700279
280 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
281 (batch_size() - 1) * input_stride() + input_channels());
282 std::vector<uint8_t> kernel(output_channels() * input_channels());
283 std::vector<int32_t> bias(output_channels());
284 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + output_channels());
285 std::vector<int32_t> accumulators(batch_size() * output_channels());
286 std::vector<double> output_ref(batch_size() * output_channels());
287
288 const uint8_t input_zero_point = 127;
289 const uint8_t kernel_zero_point = 127;
290
291 for (size_t iteration = 0; iteration < iterations(); iteration++) {
292 std::generate(input.begin(), input.end(), std::ref(u8rng));
293 std::generate(kernel.begin(), kernel.end(), std::ref(u8rng));
Marat Dukhanecd83112020-08-03 21:50:28 -0700294 std::generate(bias.begin(), bias.end(), std::ref(i32rng));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700295 std::fill(output.begin(), output.end(), 0xA5);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700296
297 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700298 if (has_bias()) {
299 for (size_t i = 0; i < batch_size(); i++) {
300 for (size_t oc = 0; oc < output_channels(); oc++) {
301 accumulators[i * output_channels() + oc] = bias[oc];
302 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700303 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700304 } else {
305 std::fill(accumulators.begin(), accumulators.end(), 0);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700306 }
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800307 if (transpose_weights()) {
308 for (size_t i = 0; i < batch_size(); i++) {
309 for (size_t oc = 0; oc < output_channels(); oc++) {
310 for (size_t ic = 0; ic < input_channels(); ic++) {
311 accumulators[i * output_channels() + oc] +=
312 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
313 (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
314 }
315 }
316 }
317 } else {
318 for (size_t i = 0; i < batch_size(); i++) {
319 for (size_t oc = 0; oc < output_channels(); oc++) {
320 for (size_t ic = 0; ic < input_channels(); ic++) {
321 accumulators[i * output_channels() + oc] +=
322 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
323 (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
324 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700325 }
326 }
327 }
328
329 // Compute renormalization parameters.
330 const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
331 const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
332
333 const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
334 const uint8_t output_zero_point = uint8_t(std::max(std::min(
335 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
336 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
337
338 // Renormalize reference results.
339 std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
340 [this, output_scale, output_zero_point](int32_t x) -> double {
341 return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
342 });
343
344 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800345 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700346 xnn_operator_t fully_connected_op = nullptr;
347
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700348 const xnn_status status = xnn_create_fully_connected_nc_qu8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700349 input_channels(), output_channels(),
350 input_stride(), output_stride(),
351 input_zero_point, 1.0f /* input scale */,
352 kernel_zero_point, 1.0f /* kernel scale */,
Marat Dukhanf568f082019-10-30 09:47:07 -0700353 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700354 output_zero_point, output_scale, qmin(), qmax(),
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800355 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700356 &fully_connected_op);
357 if (status == xnn_status_unsupported_hardware) {
358 GTEST_SKIP();
359 }
360 ASSERT_EQ(xnn_status_success, status);
361 ASSERT_NE(nullptr, fully_connected_op);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700362
363 // Smart pointer to automatically delete fully_connected_op.
364 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
365
366 ASSERT_EQ(xnn_status_success,
Marat Dukhan08b7a972020-07-14 18:17:29 -0700367 xnn_setup_fully_connected_nc_qu8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700368 fully_connected_op,
369 batch_size(),
370 input.data(), output.data(),
371 nullptr /* thread pool */));
372
373 ASSERT_EQ(xnn_status_success,
374 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
375
376 // Verify results.
377 for (size_t i = 0; i < batch_size(); i++) {
378 for (size_t c = 0; c < output_channels(); c++) {
379 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax()))
380 << "batch index = " << i << ", channel = " << c;
381 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin()))
382 << "batch index = " << i << ", channel = " << c;
383 ASSERT_NEAR(
384 output_ref[i * output_channels() + c],
385 double(output[i * output_stride() + c]) - double(output_zero_point),
386 0.9)
387 << "batch index = " << i << ", channel = " << c;
388 }
389 }
390 }
391 }
392
393 void TestF32() const {
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800394 ASSERT_EQ(weights_type(), WeightsType::Default);
395
XNNPACK Teamb455b122019-09-27 18:10:33 -0700396 std::random_device random_device;
397 auto rng = std::mt19937(random_device());
Marat Dukhane7991e72021-08-10 22:30:03 -0700398 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.1f, 1.0f), std::ref(rng));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700399
400 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
401 (batch_size() - 1) * input_stride() + input_channels());
402 std::vector<float> kernel(output_channels() * input_channels());
403 std::vector<float> bias(output_channels());
404 std::vector<float> output((batch_size() - 1) * output_stride() + output_channels());
405 std::vector<float> output_ref(batch_size() * output_channels());
406
407 for (size_t iteration = 0; iteration < iterations(); iteration++) {
408 std::generate(input.begin(), input.end(), std::ref(f32rng));
409 std::generate(kernel.begin(), kernel.end(), std::ref(f32rng));
410 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
411 std::fill(output.begin(), output.end(), nanf(""));
412
413 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700414 if (has_bias()) {
415 for (size_t i = 0; i < batch_size(); i++) {
416 for (size_t oc = 0; oc < output_channels(); oc++) {
417 output_ref[i * output_channels() + oc] = bias[oc];
418 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700419 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700420 } else {
421 std::fill(output_ref.begin(), output_ref.end(), 0.0f);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700422 }
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800423 if (transpose_weights()) {
424 for (size_t i = 0; i < batch_size(); i++) {
425 for (size_t oc = 0; oc < output_channels(); oc++) {
426 for (size_t ic = 0; ic < input_channels(); ic++) {
427 output_ref[i * output_channels() + oc] +=
428 input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
429 }
430 }
431 }
432 } else {
433 for (size_t i = 0; i < batch_size(); i++) {
434 for (size_t oc = 0; oc < output_channels(); oc++) {
435 for (size_t ic = 0; ic < input_channels(); ic++) {
436 output_ref[i * output_channels() + oc] +=
437 input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
438 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700439 }
440 }
441 }
442
443 // Compute clamping parameters.
Marat Dukhanc6edf922019-10-03 15:08:04 -0700444 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
445 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700446
Marat Dukhan869c62d2020-04-09 17:17:55 -0700447 const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
448 accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
449 const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
450 accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700451
452 // Clamp reference results.
453 for (float& value : output_ref) {
454 value = std::max(std::min(value, output_max), output_min);
455 }
456
457 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800458 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700459 xnn_operator_t fully_connected_op = nullptr;
460
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700461 const xnn_status status = xnn_create_fully_connected_nc_f32(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700462 input_channels(), output_channels(),
463 input_stride(), output_stride(),
Marat Dukhanf568f082019-10-30 09:47:07 -0700464 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700465 output_min, output_max,
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800466 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
Marat Dukhand23cb6e2021-04-01 01:18:58 -0700467 &fully_connected_op);
468 if (status == xnn_status_unsupported_hardware) {
469 GTEST_SKIP();
470 }
471 ASSERT_EQ(xnn_status_success, status);
472 ASSERT_NE(nullptr, fully_connected_op);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700473
474 // Smart pointer to automatically delete fully_connected_op.
475 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
476
477 ASSERT_EQ(xnn_status_success,
478 xnn_setup_fully_connected_nc_f32(
479 fully_connected_op,
480 batch_size(),
481 input.data(), output.data(),
482 nullptr /* thread pool */));
483
484 ASSERT_EQ(xnn_status_success,
485 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
486
487 // Verify results.
488 for (size_t i = 0; i < batch_size(); i++) {
489 for (size_t c = 0; c < output_channels(); c++) {
490 ASSERT_LE(output[i * output_stride() + c], output_max)
491 << "batch index = " << i << ", channel = " << c;
492 ASSERT_GE(output[i * output_stride() + c], output_min)
493 << "batch index = " << i << ", channel = " << c;
494 ASSERT_NEAR(
495 output_ref[i * output_channels() + c],
496 output[i * output_stride() + c],
497 1.0e-4 * std::abs(output_ref[i * output_channels() + c]))
498 << "batch index = " << i << ", channel = " << c;
499 }
500 }
501 }
502 }
503
Marat Dukhanddb3d162021-10-25 17:05:51 -0700504 void TestF16() const {
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800505 switch (weights_type()) {
506 case WeightsType::Default:
507 break;
508 case WeightsType::FP32:
509 break;
510 default:
511 GTEST_FAIL() << "unexpected weights type";
512 }
513
Marat Dukhanddb3d162021-10-25 17:05:51 -0700514 std::random_device random_device;
515 auto rng = std::mt19937(random_device());
516 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.1f, 1.0f), std::ref(rng));
517 auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
518
519 std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
520 (batch_size() - 1) * input_stride() + input_channels());
521 std::vector<uint16_t> kernel(output_channels() * input_channels());
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800522 std::vector<float> kernel_as_float(kernel.size());
Marat Dukhanddb3d162021-10-25 17:05:51 -0700523 std::vector<uint16_t> bias(output_channels());
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800524 std::vector<float> bias_as_float(bias.size());
Marat Dukhanddb3d162021-10-25 17:05:51 -0700525 std::vector<uint16_t> output((batch_size() - 1) * output_stride() + output_channels());
526 std::vector<float> output_ref(batch_size() * output_channels());
527
528 for (size_t iteration = 0; iteration < iterations(); iteration++) {
529 std::generate(input.begin(), input.end(), std::ref(f16rng));
530 std::generate(kernel.begin(), kernel.end(), std::ref(f16rng));
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800531 std::transform(kernel.cbegin(), kernel.cend(), kernel_as_float.begin(), fp16_ieee_to_fp32_value);
Marat Dukhanddb3d162021-10-25 17:05:51 -0700532 std::generate(bias.begin(), bias.end(), std::ref(f16rng));
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800533 std::transform(bias.cbegin(), bias.cend(), bias_as_float.begin(), fp16_ieee_to_fp32_value);
Marat Dukhanddb3d162021-10-25 17:05:51 -0700534 std::fill(output.begin(), output.end(), UINT16_C(0x7C00));
535
536 // Compute reference results, without renormalization.
537 if (has_bias()) {
538 for (size_t i = 0; i < batch_size(); i++) {
539 for (size_t oc = 0; oc < output_channels(); oc++) {
540 output_ref[i * output_channels() + oc] = fp16_ieee_to_fp32_value(bias[oc]);
541 }
542 }
543 } else {
544 std::fill(output_ref.begin(), output_ref.end(), 0.0f);
545 }
546 if (transpose_weights()) {
547 for (size_t i = 0; i < batch_size(); i++) {
548 for (size_t oc = 0; oc < output_channels(); oc++) {
549 for (size_t ic = 0; ic < input_channels(); ic++) {
550 output_ref[i * output_channels() + oc] +=
551 fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[ic * output_channels() + oc]);
552 }
553 }
554 }
555 } else {
556 for (size_t i = 0; i < batch_size(); i++) {
557 for (size_t oc = 0; oc < output_channels(); oc++) {
558 for (size_t ic = 0; ic < input_channels(); ic++) {
559 output_ref[i * output_channels() + oc] +=
560 fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[oc * input_channels() + ic]);
561 }
562 }
563 }
564 }
565
566 // Compute clamping parameters.
567 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
568 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
569 const float accumulated_range = accumulated_max - accumulated_min;
570 const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
571 const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
572 const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
573 const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
574
575 // Clamp reference results.
576 for (float& value : output_ref) {
577 value = std::max(std::min(value, output_max), output_min);
578 }
579
580 // Create, setup, run, and destroy Fully Connected operator.
581 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
582 xnn_operator_t fully_connected_op = nullptr;
583
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800584 const void* kernel_data = kernel.data();
585 const void* bias_data = bias.data();
586 if (weights_type() == WeightsType::FP32) {
587 kernel_data = kernel_as_float.data();
588 bias_data = bias_as_float.data();
589 }
590 uint32_t flags = 0;
591 if (transpose_weights()) {
592 flags |= XNN_FLAG_TRANSPOSE_WEIGHTS;
593 }
594 if (weights_type() == WeightsType::FP32) {
595 flags |= XNN_FLAG_FP32_STATIC_WEIGHTS;
596 }
Marat Dukhanddb3d162021-10-25 17:05:51 -0700597 const xnn_status status = xnn_create_fully_connected_nc_f16(
598 input_channels(), output_channels(),
599 input_stride(), output_stride(),
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800600 kernel_data, has_bias() ? bias_data : nullptr,
Marat Dukhanddb3d162021-10-25 17:05:51 -0700601 output_min, output_max,
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800602 flags,
Marat Dukhanddb3d162021-10-25 17:05:51 -0700603 &fully_connected_op);
604 if (status == xnn_status_unsupported_hardware) {
605 GTEST_SKIP();
606 }
607 ASSERT_EQ(xnn_status_success, status);
608 ASSERT_NE(nullptr, fully_connected_op);
609
610 // Smart pointer to automatically delete fully_connected_op.
611 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
612
613 ASSERT_EQ(xnn_status_success,
614 xnn_setup_fully_connected_nc_f16(
615 fully_connected_op,
616 batch_size(),
617 input.data(), output.data(),
618 nullptr /* thread pool */));
619
620 ASSERT_EQ(xnn_status_success,
621 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
622
623 // Verify results.
624 for (size_t i = 0; i < batch_size(); i++) {
625 for (size_t c = 0; c < output_channels(); c++) {
626 ASSERT_LE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_max)
627 << "batch index = " << i << ", channel = " << c;
628 ASSERT_GE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_min)
629 << "batch index = " << i << ", channel = " << c;
630 ASSERT_NEAR(
631 output_ref[i * output_channels() + c],
632 fp16_ieee_to_fp32_value(output[i * output_stride() + c]),
633 1.0e-2f * std::abs(output_ref[i * output_channels() + c]))
634 << "batch index = " << i << ", channel = " << c;
635 }
636 }
637 }
638 }
639
XNNPACK Teamb455b122019-09-27 18:10:33 -0700640 private:
641 size_t input_channels_{1};
642 size_t input_stride_{0};
643 size_t output_channels_{1};
644 size_t output_stride_{0};
645 size_t batch_size_{1};
646 uint8_t qmin_{0};
647 uint8_t qmax_{255};
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800648 bool transpose_weights_{false};
Marat Dukhanf568f082019-10-30 09:47:07 -0700649 bool has_bias_{true};
Marat Dukhan1d6b7c92022-01-14 21:18:44 -0800650 WeightsType weights_type_{WeightsType::Default};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700651 size_t iterations_{1};
652};