blob: b1d4f7647b5590238624ae2e10c934c5abdd80dc [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <cstddef>
14#include <cstdlib>
15#include <algorithm>
16#include <cmath>
17#include <functional>
Marat Dukhan5ce30d92020-04-14 03:31:26 -070018#include <limits>
XNNPACK Teamb455b122019-09-27 18:10:33 -070019#include <random>
20#include <vector>
21
22#include <xnnpack.h>
23
24
25class FullyConnectedOperatorTester {
26 public:
27 inline FullyConnectedOperatorTester& input_channels(size_t input_channels) {
28 assert(input_channels >= 1);
29 this->input_channels_ = input_channels;
30 return *this;
31 }
32
33 inline size_t input_channels() const {
34 return this->input_channels_;
35 }
36
37 inline FullyConnectedOperatorTester& output_channels(size_t output_channels) {
38 assert(output_channels >= 1);
39 this->output_channels_ = output_channels;
40 return *this;
41 }
42
43 inline size_t output_channels() const {
44 return this->output_channels_;
45 }
46
47 inline FullyConnectedOperatorTester& batch_size(size_t batch_size) {
48 assert(batch_size >= 1);
49 this->batch_size_ = batch_size;
50 return *this;
51 }
52
53 inline size_t batch_size() const {
54 return this->batch_size_;
55 }
56
57 inline FullyConnectedOperatorTester& input_stride(size_t input_stride) {
58 assert(input_stride >= 1);
59 this->input_stride_ = input_stride;
60 return *this;
61 }
62
63 inline size_t input_stride() const {
64 if (this->input_stride_ == 0) {
65 return input_channels();
66 } else {
67 assert(this->input_stride_ >= input_channels());
68 return this->input_stride_;
69 }
70 }
71
72 inline FullyConnectedOperatorTester& output_stride(size_t output_stride) {
73 assert(output_stride >= 1);
74 this->output_stride_ = output_stride;
75 return *this;
76 }
77
78 inline size_t output_stride() const {
79 if (this->output_stride_ == 0) {
80 return output_channels();
81 } else {
82 assert(this->output_stride_ >= output_channels());
83 return this->output_stride_;
84 }
85 }
86
87 inline FullyConnectedOperatorTester& qmin(uint8_t qmin) {
88 this->qmin_ = qmin;
89 return *this;
90 }
91
92 inline uint8_t qmin() const {
93 return this->qmin_;
94 }
95
96 inline FullyConnectedOperatorTester& qmax(uint8_t qmax) {
97 this->qmax_ = qmax;
98 return *this;
99 }
100
101 inline uint8_t qmax() const {
102 return this->qmax_;
103 }
104
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800105 inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
106 this->transpose_weights_ = transpose_weights;
107 return *this;
108 }
109
110 inline bool transpose_weights() const {
111 return this->transpose_weights_;
112 }
113
Marat Dukhanf568f082019-10-30 09:47:07 -0700114 inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
115 this->has_bias_ = has_bias;
116 return *this;
117 }
118
119 inline bool has_bias() const {
120 return this->has_bias_;
121 }
122
XNNPACK Teamb455b122019-09-27 18:10:33 -0700123 inline FullyConnectedOperatorTester& iterations(size_t iterations) {
124 this->iterations_ = iterations;
125 return *this;
126 }
127
128 inline size_t iterations() const {
129 return this->iterations_;
130 }
131
132 void TestQ8() const {
133 std::random_device random_device;
134 auto rng = std::mt19937(random_device());
135 auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
Marat Dukhan5ce30d92020-04-14 03:31:26 -0700136 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700137
138 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
139 (batch_size() - 1) * input_stride() + input_channels());
140 std::vector<uint8_t> kernel(output_channels() * input_channels());
141 std::vector<int32_t> bias(output_channels());
142 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + output_channels());
143 std::vector<int32_t> accumulators(batch_size() * output_channels());
144 std::vector<double> output_ref(batch_size() * output_channels());
145
146 const uint8_t input_zero_point = 127;
147 const uint8_t kernel_zero_point = 127;
148
149 for (size_t iteration = 0; iteration < iterations(); iteration++) {
150 std::generate(input.begin(), input.end(), std::ref(u8rng));
151 std::generate(kernel.begin(), kernel.end(), std::ref(u8rng));
152 std::generate(bias.begin(), bias.end(), std::ref(s32rng));
153 std::fill(output.begin(), output.end(), 0xA5);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700154
155 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700156 if (has_bias()) {
157 for (size_t i = 0; i < batch_size(); i++) {
158 for (size_t oc = 0; oc < output_channels(); oc++) {
159 accumulators[i * output_channels() + oc] = bias[oc];
160 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700161 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700162 } else {
163 std::fill(accumulators.begin(), accumulators.end(), 0);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700164 }
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800165 if (transpose_weights()) {
166 for (size_t i = 0; i < batch_size(); i++) {
167 for (size_t oc = 0; oc < output_channels(); oc++) {
168 for (size_t ic = 0; ic < input_channels(); ic++) {
169 accumulators[i * output_channels() + oc] +=
170 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
171 (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
172 }
173 }
174 }
175 } else {
176 for (size_t i = 0; i < batch_size(); i++) {
177 for (size_t oc = 0; oc < output_channels(); oc++) {
178 for (size_t ic = 0; ic < input_channels(); ic++) {
179 accumulators[i * output_channels() + oc] +=
180 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
181 (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
182 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700183 }
184 }
185 }
186
187 // Compute renormalization parameters.
188 const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
189 const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
190
191 const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
192 const uint8_t output_zero_point = uint8_t(std::max(std::min(
193 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
194 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
195
196 // Renormalize reference results.
197 std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
198 [this, output_scale, output_zero_point](int32_t x) -> double {
199 return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
200 });
201
202 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800203 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700204 xnn_operator_t fully_connected_op = nullptr;
205
206 ASSERT_EQ(xnn_status_success,
207 xnn_create_fully_connected_nc_q8(
208 input_channels(), output_channels(),
209 input_stride(), output_stride(),
210 input_zero_point, 1.0f /* input scale */,
211 kernel_zero_point, 1.0f /* kernel scale */,
Marat Dukhanf568f082019-10-30 09:47:07 -0700212 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700213 output_zero_point, output_scale, qmin(), qmax(),
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800214 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
215 &fully_connected_op));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700216
217 // Smart pointer to automatically delete fully_connected_op.
218 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
219
220 ASSERT_EQ(xnn_status_success,
221 xnn_setup_fully_connected_nc_q8(
222 fully_connected_op,
223 batch_size(),
224 input.data(), output.data(),
225 nullptr /* thread pool */));
226
227 ASSERT_EQ(xnn_status_success,
228 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
229
230 // Verify results.
231 for (size_t i = 0; i < batch_size(); i++) {
232 for (size_t c = 0; c < output_channels(); c++) {
233 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax()))
234 << "batch index = " << i << ", channel = " << c;
235 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin()))
236 << "batch index = " << i << ", channel = " << c;
237 ASSERT_NEAR(
238 output_ref[i * output_channels() + c],
239 double(output[i * output_stride() + c]) - double(output_zero_point),
240 0.9)
241 << "batch index = " << i << ", channel = " << c;
242 }
243 }
244 }
245 }
246
247 void TestF32() const {
248 std::random_device random_device;
249 auto rng = std::mt19937(random_device());
250 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.1f, 1.0f), rng);
251
252 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
253 (batch_size() - 1) * input_stride() + input_channels());
254 std::vector<float> kernel(output_channels() * input_channels());
255 std::vector<float> bias(output_channels());
256 std::vector<float> output((batch_size() - 1) * output_stride() + output_channels());
257 std::vector<float> output_ref(batch_size() * output_channels());
258
259 for (size_t iteration = 0; iteration < iterations(); iteration++) {
260 std::generate(input.begin(), input.end(), std::ref(f32rng));
261 std::generate(kernel.begin(), kernel.end(), std::ref(f32rng));
262 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
263 std::fill(output.begin(), output.end(), nanf(""));
264
265 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700266 if (has_bias()) {
267 for (size_t i = 0; i < batch_size(); i++) {
268 for (size_t oc = 0; oc < output_channels(); oc++) {
269 output_ref[i * output_channels() + oc] = bias[oc];
270 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700271 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700272 } else {
273 std::fill(output_ref.begin(), output_ref.end(), 0.0f);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700274 }
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800275 if (transpose_weights()) {
276 for (size_t i = 0; i < batch_size(); i++) {
277 for (size_t oc = 0; oc < output_channels(); oc++) {
278 for (size_t ic = 0; ic < input_channels(); ic++) {
279 output_ref[i * output_channels() + oc] +=
280 input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
281 }
282 }
283 }
284 } else {
285 for (size_t i = 0; i < batch_size(); i++) {
286 for (size_t oc = 0; oc < output_channels(); oc++) {
287 for (size_t ic = 0; ic < input_channels(); ic++) {
288 output_ref[i * output_channels() + oc] +=
289 input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
290 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700291 }
292 }
293 }
294
295 // Compute clamping parameters.
Marat Dukhanc6edf922019-10-03 15:08:04 -0700296 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
297 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700298
Marat Dukhan869c62d2020-04-09 17:17:55 -0700299 const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
300 accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
301 const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
302 accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700303
304 // Clamp reference results.
305 for (float& value : output_ref) {
306 value = std::max(std::min(value, output_max), output_min);
307 }
308
309 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800310 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700311 xnn_operator_t fully_connected_op = nullptr;
312
313 ASSERT_EQ(xnn_status_success,
314 xnn_create_fully_connected_nc_f32(
315 input_channels(), output_channels(),
316 input_stride(), output_stride(),
Marat Dukhanf568f082019-10-30 09:47:07 -0700317 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700318 output_min, output_max,
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800319 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
320 &fully_connected_op));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700321
322 // Smart pointer to automatically delete fully_connected_op.
323 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
324
325 ASSERT_EQ(xnn_status_success,
326 xnn_setup_fully_connected_nc_f32(
327 fully_connected_op,
328 batch_size(),
329 input.data(), output.data(),
330 nullptr /* thread pool */));
331
332 ASSERT_EQ(xnn_status_success,
333 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
334
335 // Verify results.
336 for (size_t i = 0; i < batch_size(); i++) {
337 for (size_t c = 0; c < output_channels(); c++) {
338 ASSERT_LE(output[i * output_stride() + c], output_max)
339 << "batch index = " << i << ", channel = " << c;
340 ASSERT_GE(output[i * output_stride() + c], output_min)
341 << "batch index = " << i << ", channel = " << c;
342 ASSERT_NEAR(
343 output_ref[i * output_channels() + c],
344 output[i * output_stride() + c],
345 1.0e-4 * std::abs(output_ref[i * output_channels() + c]))
346 << "batch index = " << i << ", channel = " << c;
347 }
348 }
349 }
350 }
351
352 private:
353 size_t input_channels_{1};
354 size_t input_stride_{0};
355 size_t output_channels_{1};
356 size_t output_stride_{0};
357 size_t batch_size_{1};
358 uint8_t qmin_{0};
359 uint8_t qmax_{255};
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800360 bool transpose_weights_{false};
Marat Dukhanf568f082019-10-30 09:47:07 -0700361 bool has_bias_{true};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700362 size_t iterations_{1};
363};