blob: 1e93282c85408750a06e9b1ffa19fd7ce833562a [file] [log] [blame]
Marat Dukhan346a9e52019-11-15 09:06:30 -08001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
18#include <xnnpack.h>
19#include <xnnpack/params-init.h>
20#include <xnnpack/params.h>
21
22
23class VUnOpMicrokernelTester {
24 public:
25 enum class OpType {
Marat Dukhan2b9efd82020-06-08 01:09:31 -070026 Abs,
Marat Dukhaned6baaf2020-12-01 15:07:08 -080027 ELU,
Marat Dukhan8cc7efe2020-06-10 16:24:27 -070028 LeakyReLU,
Marat Dukhaneecf8fd2020-06-09 08:59:37 -070029 Negate,
Frank Barchardfb158e22020-07-15 16:10:10 -070030 ReLU,
Marat Dukhaneecf8fd2020-06-09 08:59:37 -070031 RoundToNearestEven,
32 RoundTowardsZero,
33 RoundUp,
34 RoundDown,
35 Square,
Marat Dukhanf4db2f32020-06-30 10:55:30 -070036 SquareRoot,
Marat Dukhan346a9e52019-11-15 09:06:30 -080037 Sigmoid,
38 };
39
40 enum class Variant {
41 Native,
42 Scalar,
43 };
44
45 inline VUnOpMicrokernelTester& batch_size(size_t batch_size) {
46 assert(batch_size != 0);
47 this->batch_size_ = batch_size;
48 return *this;
49 }
50
51 inline size_t batch_size() const {
52 return this->batch_size_;
53 }
54
55 inline VUnOpMicrokernelTester& inplace(bool inplace) {
56 this->inplace_ = inplace;
57 return *this;
58 }
59
60 inline bool inplace() const {
61 return this->inplace_;
62 }
63
Marat Dukhan8cc7efe2020-06-10 16:24:27 -070064 inline VUnOpMicrokernelTester& slope(float slope) {
65 this->slope_ = slope;
66 return *this;
67 }
68
69 inline float slope() const {
70 return this->slope_;
71 }
72
Marat Dukhaned6baaf2020-12-01 15:07:08 -080073 inline VUnOpMicrokernelTester& prescale(float prescale) {
74 this->prescale_ = prescale;
75 return *this;
76 }
77
78 inline float prescale() const {
79 return this->prescale_;
80 }
81
82 inline VUnOpMicrokernelTester& alpha(float alpha) {
83 this->alpha_ = alpha;
84 return *this;
85 }
86
87 inline float alpha() const {
88 return this->alpha_;
89 }
90
91 inline VUnOpMicrokernelTester& beta(float beta) {
92 this->beta_ = beta;
93 return *this;
94 }
95
96 inline float beta() const {
97 return this->beta_;
98 }
99
Marat Dukhan346a9e52019-11-15 09:06:30 -0800100 inline VUnOpMicrokernelTester& qmin(uint8_t qmin) {
101 this->qmin_ = qmin;
102 return *this;
103 }
104
105 inline uint8_t qmin() const {
106 return this->qmin_;
107 }
108
109 inline VUnOpMicrokernelTester& qmax(uint8_t qmax) {
110 this->qmax_ = qmax;
111 return *this;
112 }
113
114 inline uint8_t qmax() const {
115 return this->qmax_;
116 }
117
118 inline VUnOpMicrokernelTester& iterations(size_t iterations) {
119 this->iterations_ = iterations;
120 return *this;
121 }
122
123 inline size_t iterations() const {
124 return this->iterations_;
125 }
126
Marat Dukhan1e782c42019-11-21 17:02:40 -0800127 void Test(xnn_f32_vunary_ukernel_function vunary, OpType op_type, Variant variant = Variant::Native) const {
Marat Dukhan346a9e52019-11-15 09:06:30 -0800128 std::random_device random_device;
129 auto rng = std::mt19937(random_device());
Marat Dukhanf4db2f32020-06-30 10:55:30 -0700130 auto distribution = std::uniform_real_distribution<float>(-125.0f, 125.0f);
131 switch (op_type) {
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800132 case OpType::ELU:
133 distribution = std::uniform_real_distribution<float>(-20.0f, 20.0f);
134 break;
Marat Dukhanf4db2f32020-06-30 10:55:30 -0700135 case OpType::SquareRoot:
136 distribution = std::uniform_real_distribution<float>(0.0f, 10.0f);
137 break;
138 default:
139 break;
140 }
141 auto f32rng = std::bind(distribution, std::ref(rng));
Marat Dukhan346a9e52019-11-15 09:06:30 -0800142
143 std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
144 std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
145 std::vector<double> y_ref(batch_size());
146 for (size_t iteration = 0; iteration < iterations(); iteration++) {
147 if (inplace()) {
148 std::generate(y.begin(), y.end(), std::ref(f32rng));
149 } else {
150 std::generate(x.begin(), x.end(), std::ref(f32rng));
151 std::fill(y.begin(), y.end(), nanf(""));
152 }
153 const float* x_data = inplace() ? y.data() : x.data();
154
155 // Compute reference results.
156 for (size_t i = 0; i < batch_size(); i++) {
157 switch (op_type) {
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700158 case OpType::Abs:
Marat Dukhan34ccfba2020-06-09 08:46:51 -0700159 y_ref[i] = std::abs(x_data[i]);
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700160 break;
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800161 case OpType::ELU:
162 {
163 y_ref[i] = std::signbit(x_data[i]) ? alpha() * std::expm1(double(x_data[i]) * prescale()) : double(x_data[i]) * beta();
164 break;
165 }
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700166 case OpType::LeakyReLU:
167 y_ref[i] = std::signbit(x_data[i]) ? x_data[i] * slope() : x_data[i];
168 break;
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700169 case OpType::Negate:
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700170 y_ref[i] = -x_data[i];
171 break;
Frank Barchardfb158e22020-07-15 16:10:10 -0700172 case OpType::ReLU:
173 y_ref[i] = std::max(x_data[i], 0.0f);
174 break;
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700175 case OpType::RoundToNearestEven:
176 y_ref[i] = std::nearbyint(double(x_data[i]));
177 break;
178 case OpType::RoundTowardsZero:
179 y_ref[i] = std::trunc(double(x_data[i]));
180 break;
181 case OpType::RoundUp:
182 y_ref[i] = std::ceil(double(x_data[i]));
183 break;
184 case OpType::RoundDown:
185 y_ref[i] = std::floor(double(x_data[i]));
186 break;
187 case OpType::Square:
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700188 y_ref[i] = double(x_data[i]) * double(x_data[i]);
189 break;
Marat Dukhanf4db2f32020-06-30 10:55:30 -0700190 case OpType::SquareRoot:
191 y_ref[i] = std::sqrt(double(x_data[i]));
192 break;
Marat Dukhan346a9e52019-11-15 09:06:30 -0800193 case OpType::Sigmoid:
194 {
195 const double e = std::exp(double(x_data[i]));
196 y_ref[i] = e / (1.0 + e);
197 break;
198 }
199 }
200 }
Marat Dukhan346a9e52019-11-15 09:06:30 -0800201
Frank Barchard9f3a8432020-06-02 13:59:35 -0700202 // Prepare parameters.
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700203 union {
204 union xnn_f32_abs_params abs;
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800205 union xnn_f32_elu_params elu;
Frank Barchardfb158e22020-07-15 16:10:10 -0700206 union xnn_f32_relu_params relu;
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700207 union xnn_f32_lrelu_params lrelu;
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700208 union xnn_f32_neg_params neg;
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700209 union xnn_f32_rnd_params rnd;
Marat Dukhanf4db2f32020-06-30 10:55:30 -0700210 union xnn_f32_sqrt_params sqrt;
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700211 } params;
212 switch (op_type) {
213 case OpType::Abs:
214 switch (variant) {
215 case Variant::Native:
216 params.abs = xnn_init_f32_abs_params();
217 break;
218 case Variant::Scalar:
219 params.abs = xnn_init_scalar_f32_abs_params();
220 break;
221 }
Marat Dukhan346a9e52019-11-15 09:06:30 -0800222 break;
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800223 case OpType::ELU:
224 switch (variant) {
225 case Variant::Native:
226 params.elu = xnn_init_f32_elu_params(prescale(), alpha(), beta());
227 break;
228 case Variant::Scalar:
229 params.elu = xnn_init_scalar_f32_elu_params(prescale(), alpha(), beta());
230 break;
231 }
232 break;
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700233 case OpType::LeakyReLU:
234 switch (variant) {
235 case Variant::Native:
236 params.lrelu = xnn_init_f32_lrelu_params(slope());
237 break;
238 case Variant::Scalar:
239 params.lrelu = xnn_init_scalar_f32_lrelu_params(slope());
240 break;
241 }
242 break;
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700243 case OpType::Negate:
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700244 switch (variant) {
245 case Variant::Native:
246 params.neg = xnn_init_f32_neg_params();
247 break;
248 case Variant::Scalar:
249 params.neg = xnn_init_scalar_f32_neg_params();
250 break;
251 }
252 break;
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700253 case OpType::RoundToNearestEven:
254 case OpType::RoundTowardsZero:
255 case OpType::RoundUp:
256 case OpType::RoundDown:
257 switch (variant) {
258 case Variant::Native:
259 params.rnd = xnn_init_f32_rnd_params();
260 break;
261 case Variant::Scalar:
262 params.rnd = xnn_init_scalar_f32_rnd_params();
263 break;
264 }
265 break;
Frank Barchardfb158e22020-07-15 16:10:10 -0700266 case OpType::ReLU:
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700267 case OpType::Sigmoid:
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700268 case OpType::Square:
Marat Dukhan346a9e52019-11-15 09:06:30 -0800269 break;
Marat Dukhanf4db2f32020-06-30 10:55:30 -0700270 case OpType::SquareRoot:
271 switch (variant) {
272 case Variant::Native:
273 params.sqrt = xnn_init_f32_sqrt_params();
274 break;
275 case Variant::Scalar:
276 params.sqrt = xnn_init_scalar_f32_sqrt_params();
277 break;
278 }
279 break;
Marat Dukhan346a9e52019-11-15 09:06:30 -0800280 }
281
282 // Call optimized micro-kernel.
Frank Barcharde70dbeb2020-05-01 15:46:41 -0700283 vunary(batch_size() * sizeof(float), x_data, y.data(), &params);
Marat Dukhan346a9e52019-11-15 09:06:30 -0800284
285 // Verify results.
286 for (size_t i = 0; i < batch_size(); i++) {
Frank Barchard2b9d29b2020-09-17 12:03:39 -0700287 ASSERT_NEAR(y[i], y_ref[i], std::max(5.0e-6, std::abs(y_ref[i]) * 1.0e-5))
Marat Dukhan8d3c07e2020-01-02 01:20:59 -0800288 << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
Marat Dukhan346a9e52019-11-15 09:06:30 -0800289 }
290 }
291 }
292
293 private:
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800294 size_t batch_size_ = 1;
295 bool inplace_ = false;
296 float slope_ = 0.5f;
297 float prescale_ = 1.0f;
298 float alpha_ = 1.0f;
299 float beta_ = 1.0f;
300 uint8_t qmin_ = 0;
301 uint8_t qmax_ = 255;
302 size_t iterations_ = 15;
Marat Dukhan346a9e52019-11-15 09:06:30 -0800303};