blob: d28a18dece46ebd77948dafff9e9ae811ae155b4 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <cstddef>
14#include <cstdlib>
15#include <algorithm>
16#include <cmath>
17#include <functional>
18#include <random>
19#include <vector>
20
21#include <xnnpack.h>
22
23
24class FullyConnectedOperatorTester {
25 public:
26 inline FullyConnectedOperatorTester& input_channels(size_t input_channels) {
27 assert(input_channels >= 1);
28 this->input_channels_ = input_channels;
29 return *this;
30 }
31
32 inline size_t input_channels() const {
33 return this->input_channels_;
34 }
35
36 inline FullyConnectedOperatorTester& output_channels(size_t output_channels) {
37 assert(output_channels >= 1);
38 this->output_channels_ = output_channels;
39 return *this;
40 }
41
42 inline size_t output_channels() const {
43 return this->output_channels_;
44 }
45
46 inline FullyConnectedOperatorTester& batch_size(size_t batch_size) {
47 assert(batch_size >= 1);
48 this->batch_size_ = batch_size;
49 return *this;
50 }
51
52 inline size_t batch_size() const {
53 return this->batch_size_;
54 }
55
56 inline FullyConnectedOperatorTester& input_stride(size_t input_stride) {
57 assert(input_stride >= 1);
58 this->input_stride_ = input_stride;
59 return *this;
60 }
61
62 inline size_t input_stride() const {
63 if (this->input_stride_ == 0) {
64 return input_channels();
65 } else {
66 assert(this->input_stride_ >= input_channels());
67 return this->input_stride_;
68 }
69 }
70
71 inline FullyConnectedOperatorTester& output_stride(size_t output_stride) {
72 assert(output_stride >= 1);
73 this->output_stride_ = output_stride;
74 return *this;
75 }
76
77 inline size_t output_stride() const {
78 if (this->output_stride_ == 0) {
79 return output_channels();
80 } else {
81 assert(this->output_stride_ >= output_channels());
82 return this->output_stride_;
83 }
84 }
85
86 inline FullyConnectedOperatorTester& qmin(uint8_t qmin) {
87 this->qmin_ = qmin;
88 return *this;
89 }
90
91 inline uint8_t qmin() const {
92 return this->qmin_;
93 }
94
95 inline FullyConnectedOperatorTester& qmax(uint8_t qmax) {
96 this->qmax_ = qmax;
97 return *this;
98 }
99
100 inline uint8_t qmax() const {
101 return this->qmax_;
102 }
103
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800104 inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
105 this->transpose_weights_ = transpose_weights;
106 return *this;
107 }
108
109 inline bool transpose_weights() const {
110 return this->transpose_weights_;
111 }
112
Marat Dukhanf568f082019-10-30 09:47:07 -0700113 inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
114 this->has_bias_ = has_bias;
115 return *this;
116 }
117
118 inline bool has_bias() const {
119 return this->has_bias_;
120 }
121
XNNPACK Teamb455b122019-09-27 18:10:33 -0700122 inline FullyConnectedOperatorTester& iterations(size_t iterations) {
123 this->iterations_ = iterations;
124 return *this;
125 }
126
127 inline size_t iterations() const {
128 return this->iterations_;
129 }
130
131 void TestQ8() const {
132 std::random_device random_device;
133 auto rng = std::mt19937(random_device());
134 auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
135 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
136
137 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
138 (batch_size() - 1) * input_stride() + input_channels());
139 std::vector<uint8_t> kernel(output_channels() * input_channels());
140 std::vector<int32_t> bias(output_channels());
141 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + output_channels());
142 std::vector<int32_t> accumulators(batch_size() * output_channels());
143 std::vector<double> output_ref(batch_size() * output_channels());
144
145 const uint8_t input_zero_point = 127;
146 const uint8_t kernel_zero_point = 127;
147
148 for (size_t iteration = 0; iteration < iterations(); iteration++) {
149 std::generate(input.begin(), input.end(), std::ref(u8rng));
150 std::generate(kernel.begin(), kernel.end(), std::ref(u8rng));
151 std::generate(bias.begin(), bias.end(), std::ref(s32rng));
152 std::fill(output.begin(), output.end(), 0xA5);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700153
154 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700155 if (has_bias()) {
156 for (size_t i = 0; i < batch_size(); i++) {
157 for (size_t oc = 0; oc < output_channels(); oc++) {
158 accumulators[i * output_channels() + oc] = bias[oc];
159 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700160 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700161 } else {
162 std::fill(accumulators.begin(), accumulators.end(), 0);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700163 }
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800164 if (transpose_weights()) {
165 for (size_t i = 0; i < batch_size(); i++) {
166 for (size_t oc = 0; oc < output_channels(); oc++) {
167 for (size_t ic = 0; ic < input_channels(); ic++) {
168 accumulators[i * output_channels() + oc] +=
169 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
170 (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
171 }
172 }
173 }
174 } else {
175 for (size_t i = 0; i < batch_size(); i++) {
176 for (size_t oc = 0; oc < output_channels(); oc++) {
177 for (size_t ic = 0; ic < input_channels(); ic++) {
178 accumulators[i * output_channels() + oc] +=
179 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
180 (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
181 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700182 }
183 }
184 }
185
186 // Compute renormalization parameters.
187 const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
188 const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
189
190 const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
191 const uint8_t output_zero_point = uint8_t(std::max(std::min(
192 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
193 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
194
195 // Renormalize reference results.
196 std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
197 [this, output_scale, output_zero_point](int32_t x) -> double {
198 return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
199 });
200
201 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800202 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700203 xnn_operator_t fully_connected_op = nullptr;
204
205 ASSERT_EQ(xnn_status_success,
206 xnn_create_fully_connected_nc_q8(
207 input_channels(), output_channels(),
208 input_stride(), output_stride(),
209 input_zero_point, 1.0f /* input scale */,
210 kernel_zero_point, 1.0f /* kernel scale */,
Marat Dukhanf568f082019-10-30 09:47:07 -0700211 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700212 output_zero_point, output_scale, qmin(), qmax(),
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800213 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
214 &fully_connected_op));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700215
216 // Smart pointer to automatically delete fully_connected_op.
217 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
218
219 ASSERT_EQ(xnn_status_success,
220 xnn_setup_fully_connected_nc_q8(
221 fully_connected_op,
222 batch_size(),
223 input.data(), output.data(),
224 nullptr /* thread pool */));
225
226 ASSERT_EQ(xnn_status_success,
227 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
228
229 // Verify results.
230 for (size_t i = 0; i < batch_size(); i++) {
231 for (size_t c = 0; c < output_channels(); c++) {
232 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax()))
233 << "batch index = " << i << ", channel = " << c;
234 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin()))
235 << "batch index = " << i << ", channel = " << c;
236 ASSERT_NEAR(
237 output_ref[i * output_channels() + c],
238 double(output[i * output_stride() + c]) - double(output_zero_point),
239 0.9)
240 << "batch index = " << i << ", channel = " << c;
241 }
242 }
243 }
244 }
245
246 void TestF32() const {
247 std::random_device random_device;
248 auto rng = std::mt19937(random_device());
249 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.1f, 1.0f), rng);
250
251 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
252 (batch_size() - 1) * input_stride() + input_channels());
253 std::vector<float> kernel(output_channels() * input_channels());
254 std::vector<float> bias(output_channels());
255 std::vector<float> output((batch_size() - 1) * output_stride() + output_channels());
256 std::vector<float> output_ref(batch_size() * output_channels());
257
258 for (size_t iteration = 0; iteration < iterations(); iteration++) {
259 std::generate(input.begin(), input.end(), std::ref(f32rng));
260 std::generate(kernel.begin(), kernel.end(), std::ref(f32rng));
261 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
262 std::fill(output.begin(), output.end(), nanf(""));
263
264 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700265 if (has_bias()) {
266 for (size_t i = 0; i < batch_size(); i++) {
267 for (size_t oc = 0; oc < output_channels(); oc++) {
268 output_ref[i * output_channels() + oc] = bias[oc];
269 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700270 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700271 } else {
272 std::fill(output_ref.begin(), output_ref.end(), 0.0f);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700273 }
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800274 if (transpose_weights()) {
275 for (size_t i = 0; i < batch_size(); i++) {
276 for (size_t oc = 0; oc < output_channels(); oc++) {
277 for (size_t ic = 0; ic < input_channels(); ic++) {
278 output_ref[i * output_channels() + oc] +=
279 input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
280 }
281 }
282 }
283 } else {
284 for (size_t i = 0; i < batch_size(); i++) {
285 for (size_t oc = 0; oc < output_channels(); oc++) {
286 for (size_t ic = 0; ic < input_channels(); ic++) {
287 output_ref[i * output_channels() + oc] +=
288 input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
289 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700290 }
291 }
292 }
293
294 // Compute clamping parameters.
Marat Dukhanc6edf922019-10-03 15:08:04 -0700295 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
296 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700297
298 const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
299 const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
300
301 // Clamp reference results.
302 for (float& value : output_ref) {
303 value = std::max(std::min(value, output_max), output_min);
304 }
305
306 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800307 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700308 xnn_operator_t fully_connected_op = nullptr;
309
310 ASSERT_EQ(xnn_status_success,
311 xnn_create_fully_connected_nc_f32(
312 input_channels(), output_channels(),
313 input_stride(), output_stride(),
Marat Dukhanf568f082019-10-30 09:47:07 -0700314 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700315 output_min, output_max,
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800316 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
317 &fully_connected_op));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700318
319 // Smart pointer to automatically delete fully_connected_op.
320 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
321
322 ASSERT_EQ(xnn_status_success,
323 xnn_setup_fully_connected_nc_f32(
324 fully_connected_op,
325 batch_size(),
326 input.data(), output.data(),
327 nullptr /* thread pool */));
328
329 ASSERT_EQ(xnn_status_success,
330 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
331
332 // Verify results.
333 for (size_t i = 0; i < batch_size(); i++) {
334 for (size_t c = 0; c < output_channels(); c++) {
335 ASSERT_LE(output[i * output_stride() + c], output_max)
336 << "batch index = " << i << ", channel = " << c;
337 ASSERT_GE(output[i * output_stride() + c], output_min)
338 << "batch index = " << i << ", channel = " << c;
339 ASSERT_NEAR(
340 output_ref[i * output_channels() + c],
341 output[i * output_stride() + c],
342 1.0e-4 * std::abs(output_ref[i * output_channels() + c]))
343 << "batch index = " << i << ", channel = " << c;
344 }
345 }
346 }
347 }
348
349 private:
350 size_t input_channels_{1};
351 size_t input_stride_{0};
352 size_t output_channels_{1};
353 size_t output_stride_{0};
354 size_t batch_size_{1};
355 uint8_t qmin_{0};
356 uint8_t qmax_{255};
Marat Dukhanc4f0ff92019-12-03 14:59:08 -0800357 bool transpose_weights_{false};
Marat Dukhanf568f082019-10-30 09:47:07 -0700358 bool has_bias_{true};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700359 size_t iterations_{1};
360};