blob: 9105e07dda2b91ec7b9966554536730f982bff26 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <algorithm>
14#include <cassert>
15#include <cmath>
16#include <cstddef>
17#include <cstdlib>
18#include <functional>
Marat Dukhan5ce30d92020-04-14 03:31:26 -070019#include <limits>
XNNPACK Teamb455b122019-09-27 18:10:33 -070020#include <random>
21#include <vector>
22
23#include <xnnpack.h>
24
25
26class SigmoidOperatorTester {
27 public:
28 inline SigmoidOperatorTester& channels(size_t channels) {
29 assert(channels != 0);
30 this->channels_ = channels;
31 return *this;
32 }
33
34 inline size_t channels() const {
35 return this->channels_;
36 }
37
38 inline SigmoidOperatorTester& input_stride(size_t input_stride) {
39 assert(input_stride != 0);
40 this->input_stride_ = input_stride;
41 return *this;
42 }
43
44 inline size_t input_stride() const {
45 if (this->input_stride_ == 0) {
46 return this->channels_;
47 } else {
48 assert(this->input_stride_ >= this->channels_);
49 return this->input_stride_;
50 }
51 }
52
53 inline SigmoidOperatorTester& output_stride(size_t output_stride) {
54 assert(output_stride != 0);
55 this->output_stride_ = output_stride;
56 return *this;
57 }
58
59 inline size_t output_stride() const {
60 if (this->output_stride_ == 0) {
61 return this->channels_;
62 } else {
63 assert(this->output_stride_ >= this->channels_);
64 return this->output_stride_;
65 }
66 }
67
68 inline SigmoidOperatorTester& batch_size(size_t batch_size) {
69 assert(batch_size != 0);
70 this->batch_size_ = batch_size;
71 return *this;
72 }
73
74 inline size_t batch_size() const {
75 return this->batch_size_;
76 }
77
78 inline SigmoidOperatorTester& input_scale(float input_scale) {
79 assert(input_scale > 0.0f);
80 assert(std::isnormal(input_scale));
81 this->input_scale_ = input_scale;
82 return *this;
83 }
84
85 inline float input_scale() const {
86 return this->input_scale_;
87 }
88
89 inline SigmoidOperatorTester& input_zero_point(uint8_t input_zero_point) {
90 this->input_zero_point_ = input_zero_point;
91 return *this;
92 }
93
94 inline uint8_t input_zero_point() const {
95 return this->input_zero_point_;
96 }
97
98 inline float output_scale() const {
99 return 1.0f / 256.0f;
100 }
101
102 inline uint8_t output_zero_point() const {
103 return 0;
104 }
105
106 inline SigmoidOperatorTester& qmin(uint8_t qmin) {
107 this->qmin_ = qmin;
108 return *this;
109 }
110
111 inline uint8_t qmin() const {
112 return this->qmin_;
113 }
114
115 inline SigmoidOperatorTester& qmax(uint8_t qmax) {
116 this->qmax_ = qmax;
117 return *this;
118 }
119
120 inline uint8_t qmax() const {
121 return this->qmax_;
122 }
123
124 inline SigmoidOperatorTester& iterations(size_t iterations) {
125 this->iterations_ = iterations;
126 return *this;
127 }
128
129 inline size_t iterations() const {
130 return this->iterations_;
131 }
132
Marat Dukhan08b7a972020-07-14 18:17:29 -0700133 void TestQU8() const {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700134 std::random_device random_device;
135 auto rng = std::mt19937(random_device());
Marat Dukhan5ce30d92020-04-14 03:31:26 -0700136 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700137
Marat Dukhan7bee7512019-11-18 15:15:48 -0800138 std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700139 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
140 std::vector<float> output_ref(batch_size() * channels());
141 for (size_t iteration = 0; iteration < iterations(); iteration++) {
142 std::generate(input.begin(), input.end(), std::ref(u8rng));
143 std::fill(output.begin(), output.end(), 0xA5);
144
145 // Compute reference results.
146 for (size_t i = 0; i < batch_size(); i++) {
147 for (size_t c = 0; c < channels(); c++) {
148 const float x = input_scale() *
149 (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point()));
150 const float sigmoid_x = 1.0f / (1.0f + std::exp(-x));
151 const float scaled_sigmoid_x = sigmoid_x / output_scale();
152 float y = scaled_sigmoid_x;
153 y = std::min<float>(y, int32_t(qmax()) - int32_t(output_zero_point()));
154 y = std::max<float>(y, int32_t(qmin()) - int32_t(output_zero_point()));
155 output_ref[i * channels() + c] = y + int32_t(output_zero_point());
156 }
157 }
158
159 // Create, setup, run, and destroy Sigmoid operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800160 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700161 xnn_operator_t sigmoid_op = nullptr;
162
163 ASSERT_EQ(xnn_status_success,
Marat Dukhan08b7a972020-07-14 18:17:29 -0700164 xnn_create_sigmoid_nc_qu8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700165 channels(), input_stride(), output_stride(),
166 input_zero_point(), input_scale(),
167 output_zero_point(), output_scale(),
168 qmin(), qmax(),
169 0, &sigmoid_op));
170 ASSERT_NE(nullptr, sigmoid_op);
171
172 // Smart pointer to automatically delete sigmoid_op.
Marat Dukhan346a9e52019-11-15 09:06:30 -0800173 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sigmoid_op(sigmoid_op, xnn_delete_operator);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700174
175 ASSERT_EQ(xnn_status_success,
Marat Dukhan08b7a972020-07-14 18:17:29 -0700176 xnn_setup_sigmoid_nc_qu8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700177 sigmoid_op,
178 batch_size(),
179 input.data(), output.data(),
180 nullptr /* thread pool */));
181
182 ASSERT_EQ(xnn_status_success,
183 xnn_run_operator(sigmoid_op, nullptr /* thread pool */));
184
185 // Verify results.
186 for (size_t i = 0; i < batch_size(); i++) {
187 for (size_t c = 0; c < channels(); c++) {
188 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
189 }
190 }
191 }
192 }
193
Marat Dukhan346a9e52019-11-15 09:06:30 -0800194 void TestF32() const {
195 std::random_device random_device;
196 auto rng = std::mt19937(random_device());
197 auto f32rng = std::bind(std::uniform_real_distribution<float>(-25.0f, 25.0f), rng);
198
Marat Dukhan7bee7512019-11-18 15:15:48 -0800199 std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
Marat Dukhan346a9e52019-11-15 09:06:30 -0800200 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
201 std::vector<double> output_ref(batch_size() * channels());
202 for (size_t iteration = 0; iteration < iterations(); iteration++) {
203 std::generate(input.begin(), input.end(), std::ref(f32rng));
204 std::fill(output.begin(), output.end(), 0xA5);
205
206 // Compute reference results.
207 for (size_t i = 0; i < batch_size(); i++) {
208 for (size_t c = 0; c < channels(); c++) {
209 const double x = input[i * input_stride() + c];
210 const double exp_x = std::exp(x);
211 const double sigmoid_x = exp_x / (1.0 + exp_x);
212 output_ref[i * channels() + c] = sigmoid_x;
213 }
214 }
215
216 // Create, setup, run, and destroy Sigmoid operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800217 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhan346a9e52019-11-15 09:06:30 -0800218 xnn_operator_t sigmoid_op = nullptr;
219
220 xnn_status status = xnn_create_sigmoid_nc_f32(
221 channels(), input_stride(), output_stride(),
222 0, &sigmoid_op);
Marat Dukhan346a9e52019-11-15 09:06:30 -0800223 ASSERT_EQ(xnn_status_success, status);
224 ASSERT_NE(nullptr, sigmoid_op);
225
226 // Smart pointer to automatically delete sigmoid_op.
227 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sigmoid_op(sigmoid_op, xnn_delete_operator);
228
229 ASSERT_EQ(xnn_status_success,
230 xnn_setup_sigmoid_nc_f32(
231 sigmoid_op,
232 batch_size(),
233 input.data(), output.data(),
234 nullptr /* thread pool */));
235
236 ASSERT_EQ(xnn_status_success,
237 xnn_run_operator(sigmoid_op, nullptr /* thread pool */));
238
239 // Verify results.
240 for (size_t i = 0; i < batch_size(); i++) {
241 for (size_t c = 0; c < channels(); c++) {
242 ASSERT_NEAR(
243 output[i * output_stride() + c],
244 output_ref[i * channels() + c],
Erich Elsen8fd7b5f2019-11-18 10:50:41 -0800245 5.0e-6);
Marat Dukhan346a9e52019-11-15 09:06:30 -0800246 }
247 }
248 }
249 }
250
XNNPACK Teamb455b122019-09-27 18:10:33 -0700251 private:
252 size_t batch_size_{1};
253 size_t channels_{1};
254 size_t input_stride_{0};
255 size_t output_stride_{0};
256 float input_scale_{0.75f};
257 uint8_t input_zero_point_{121};
258 uint8_t qmin_{0};
259 uint8_t qmax_{255};
260 size_t iterations_{15};
261};