blob: 81370f9c63393f6d4a39ffe01441675b43c93064 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <cstddef>
14#include <cstdlib>
15#include <algorithm>
16#include <cmath>
17#include <functional>
18#include <random>
19#include <vector>
20
21#include <xnnpack.h>
22
23
24class FullyConnectedOperatorTester {
25 public:
26 inline FullyConnectedOperatorTester& input_channels(size_t input_channels) {
27 assert(input_channels >= 1);
28 this->input_channels_ = input_channels;
29 return *this;
30 }
31
32 inline size_t input_channels() const {
33 return this->input_channels_;
34 }
35
36 inline FullyConnectedOperatorTester& output_channels(size_t output_channels) {
37 assert(output_channels >= 1);
38 this->output_channels_ = output_channels;
39 return *this;
40 }
41
42 inline size_t output_channels() const {
43 return this->output_channels_;
44 }
45
46 inline FullyConnectedOperatorTester& batch_size(size_t batch_size) {
47 assert(batch_size >= 1);
48 this->batch_size_ = batch_size;
49 return *this;
50 }
51
52 inline size_t batch_size() const {
53 return this->batch_size_;
54 }
55
56 inline FullyConnectedOperatorTester& input_stride(size_t input_stride) {
57 assert(input_stride >= 1);
58 this->input_stride_ = input_stride;
59 return *this;
60 }
61
62 inline size_t input_stride() const {
63 if (this->input_stride_ == 0) {
64 return input_channels();
65 } else {
66 assert(this->input_stride_ >= input_channels());
67 return this->input_stride_;
68 }
69 }
70
71 inline FullyConnectedOperatorTester& output_stride(size_t output_stride) {
72 assert(output_stride >= 1);
73 this->output_stride_ = output_stride;
74 return *this;
75 }
76
77 inline size_t output_stride() const {
78 if (this->output_stride_ == 0) {
79 return output_channels();
80 } else {
81 assert(this->output_stride_ >= output_channels());
82 return this->output_stride_;
83 }
84 }
85
86 inline FullyConnectedOperatorTester& qmin(uint8_t qmin) {
87 this->qmin_ = qmin;
88 return *this;
89 }
90
91 inline uint8_t qmin() const {
92 return this->qmin_;
93 }
94
95 inline FullyConnectedOperatorTester& qmax(uint8_t qmax) {
96 this->qmax_ = qmax;
97 return *this;
98 }
99
100 inline uint8_t qmax() const {
101 return this->qmax_;
102 }
103
Marat Dukhanf568f082019-10-30 09:47:07 -0700104 inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
105 this->has_bias_ = has_bias;
106 return *this;
107 }
108
109 inline bool has_bias() const {
110 return this->has_bias_;
111 }
112
XNNPACK Teamb455b122019-09-27 18:10:33 -0700113 inline FullyConnectedOperatorTester& iterations(size_t iterations) {
114 this->iterations_ = iterations;
115 return *this;
116 }
117
118 inline size_t iterations() const {
119 return this->iterations_;
120 }
121
122 void TestQ8() const {
123 std::random_device random_device;
124 auto rng = std::mt19937(random_device());
125 auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
126 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
127
128 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
129 (batch_size() - 1) * input_stride() + input_channels());
130 std::vector<uint8_t> kernel(output_channels() * input_channels());
131 std::vector<int32_t> bias(output_channels());
132 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + output_channels());
133 std::vector<int32_t> accumulators(batch_size() * output_channels());
134 std::vector<double> output_ref(batch_size() * output_channels());
135
136 const uint8_t input_zero_point = 127;
137 const uint8_t kernel_zero_point = 127;
138
139 for (size_t iteration = 0; iteration < iterations(); iteration++) {
140 std::generate(input.begin(), input.end(), std::ref(u8rng));
141 std::generate(kernel.begin(), kernel.end(), std::ref(u8rng));
142 std::generate(bias.begin(), bias.end(), std::ref(s32rng));
143 std::fill(output.begin(), output.end(), 0xA5);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700144
145 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700146 if (has_bias()) {
147 for (size_t i = 0; i < batch_size(); i++) {
148 for (size_t oc = 0; oc < output_channels(); oc++) {
149 accumulators[i * output_channels() + oc] = bias[oc];
150 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700151 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700152 } else {
153 std::fill(accumulators.begin(), accumulators.end(), 0);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700154 }
155 for (size_t i = 0; i < batch_size(); i++) {
156 for (size_t oc = 0; oc < output_channels(); oc++) {
157 for (size_t ic = 0; ic < input_channels(); ic++) {
158 accumulators[i * output_channels() + oc] +=
159 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
160 (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
161 }
162 }
163 }
164
165 // Compute renormalization parameters.
166 const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
167 const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
168
169 const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
170 const uint8_t output_zero_point = uint8_t(std::max(std::min(
171 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
172 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
173
174 // Renormalize reference results.
175 std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
176 [this, output_scale, output_zero_point](int32_t x) -> double {
177 return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
178 });
179
180 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800181 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700182 xnn_operator_t fully_connected_op = nullptr;
183
184 ASSERT_EQ(xnn_status_success,
185 xnn_create_fully_connected_nc_q8(
186 input_channels(), output_channels(),
187 input_stride(), output_stride(),
188 input_zero_point, 1.0f /* input scale */,
189 kernel_zero_point, 1.0f /* kernel scale */,
Marat Dukhanf568f082019-10-30 09:47:07 -0700190 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700191 output_zero_point, output_scale, qmin(), qmax(),
192 0, &fully_connected_op));
193
194 // Smart pointer to automatically delete fully_connected_op.
195 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
196
197 ASSERT_EQ(xnn_status_success,
198 xnn_setup_fully_connected_nc_q8(
199 fully_connected_op,
200 batch_size(),
201 input.data(), output.data(),
202 nullptr /* thread pool */));
203
204 ASSERT_EQ(xnn_status_success,
205 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
206
207 // Verify results.
208 for (size_t i = 0; i < batch_size(); i++) {
209 for (size_t c = 0; c < output_channels(); c++) {
210 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax()))
211 << "batch index = " << i << ", channel = " << c;
212 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin()))
213 << "batch index = " << i << ", channel = " << c;
214 ASSERT_NEAR(
215 output_ref[i * output_channels() + c],
216 double(output[i * output_stride() + c]) - double(output_zero_point),
217 0.9)
218 << "batch index = " << i << ", channel = " << c;
219 }
220 }
221 }
222 }
223
224 void TestF32() const {
225 std::random_device random_device;
226 auto rng = std::mt19937(random_device());
227 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.1f, 1.0f), rng);
228
229 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
230 (batch_size() - 1) * input_stride() + input_channels());
231 std::vector<float> kernel(output_channels() * input_channels());
232 std::vector<float> bias(output_channels());
233 std::vector<float> output((batch_size() - 1) * output_stride() + output_channels());
234 std::vector<float> output_ref(batch_size() * output_channels());
235
236 for (size_t iteration = 0; iteration < iterations(); iteration++) {
237 std::generate(input.begin(), input.end(), std::ref(f32rng));
238 std::generate(kernel.begin(), kernel.end(), std::ref(f32rng));
239 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
240 std::fill(output.begin(), output.end(), nanf(""));
241
242 // Compute reference results, without renormalization.
Marat Dukhanf568f082019-10-30 09:47:07 -0700243 if (has_bias()) {
244 for (size_t i = 0; i < batch_size(); i++) {
245 for (size_t oc = 0; oc < output_channels(); oc++) {
246 output_ref[i * output_channels() + oc] = bias[oc];
247 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700248 }
Marat Dukhanf568f082019-10-30 09:47:07 -0700249 } else {
250 std::fill(output_ref.begin(), output_ref.end(), 0.0f);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700251 }
252 for (size_t i = 0; i < batch_size(); i++) {
253 for (size_t oc = 0; oc < output_channels(); oc++) {
254 for (size_t ic = 0; ic < input_channels(); ic++) {
255 output_ref[i * output_channels() + oc] +=
256 input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
257 }
258 }
259 }
260
261 // Compute clamping parameters.
Marat Dukhanc6edf922019-10-03 15:08:04 -0700262 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
263 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700264
265 const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
266 const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
267
268 // Clamp reference results.
269 for (float& value : output_ref) {
270 value = std::max(std::min(value, output_max), output_min);
271 }
272
273 // Create, setup, run, and destroy Fully Connected operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800274 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700275 xnn_operator_t fully_connected_op = nullptr;
276
277 ASSERT_EQ(xnn_status_success,
278 xnn_create_fully_connected_nc_f32(
279 input_channels(), output_channels(),
280 input_stride(), output_stride(),
Marat Dukhanf568f082019-10-30 09:47:07 -0700281 kernel.data(), has_bias() ? bias.data() : nullptr,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700282 output_min, output_max,
283 0, &fully_connected_op));
284
285 // Smart pointer to automatically delete fully_connected_op.
286 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
287
288 ASSERT_EQ(xnn_status_success,
289 xnn_setup_fully_connected_nc_f32(
290 fully_connected_op,
291 batch_size(),
292 input.data(), output.data(),
293 nullptr /* thread pool */));
294
295 ASSERT_EQ(xnn_status_success,
296 xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
297
298 // Verify results.
299 for (size_t i = 0; i < batch_size(); i++) {
300 for (size_t c = 0; c < output_channels(); c++) {
301 ASSERT_LE(output[i * output_stride() + c], output_max)
302 << "batch index = " << i << ", channel = " << c;
303 ASSERT_GE(output[i * output_stride() + c], output_min)
304 << "batch index = " << i << ", channel = " << c;
305 ASSERT_NEAR(
306 output_ref[i * output_channels() + c],
307 output[i * output_stride() + c],
308 1.0e-4 * std::abs(output_ref[i * output_channels() + c]))
309 << "batch index = " << i << ", channel = " << c;
310 }
311 }
312 }
313 }
314
315 private:
316 size_t input_channels_{1};
317 size_t input_stride_{0};
318 size_t output_channels_{1};
319 size_t output_stride_{0};
320 size_t batch_size_{1};
321 uint8_t qmin_{0};
322 uint8_t qmax_{255};
Marat Dukhanf568f082019-10-30 09:47:07 -0700323 bool has_bias_{true};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700324 size_t iterations_{1};
325};