blob: aa4a4656ad116cea14c823a630fd43b2e675557b [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <algorithm>
14#include <cassert>
15#include <cmath>
16#include <cstddef>
17#include <cstdlib>
18#include <functional>
19#include <random>
20#include <vector>
21
22#include <xnnpack.h>
23
24
25class SigmoidOperatorTester {
26 public:
27 inline SigmoidOperatorTester& channels(size_t channels) {
28 assert(channels != 0);
29 this->channels_ = channels;
30 return *this;
31 }
32
33 inline size_t channels() const {
34 return this->channels_;
35 }
36
37 inline SigmoidOperatorTester& input_stride(size_t input_stride) {
38 assert(input_stride != 0);
39 this->input_stride_ = input_stride;
40 return *this;
41 }
42
43 inline size_t input_stride() const {
44 if (this->input_stride_ == 0) {
45 return this->channels_;
46 } else {
47 assert(this->input_stride_ >= this->channels_);
48 return this->input_stride_;
49 }
50 }
51
52 inline SigmoidOperatorTester& output_stride(size_t output_stride) {
53 assert(output_stride != 0);
54 this->output_stride_ = output_stride;
55 return *this;
56 }
57
58 inline size_t output_stride() const {
59 if (this->output_stride_ == 0) {
60 return this->channels_;
61 } else {
62 assert(this->output_stride_ >= this->channels_);
63 return this->output_stride_;
64 }
65 }
66
67 inline SigmoidOperatorTester& batch_size(size_t batch_size) {
68 assert(batch_size != 0);
69 this->batch_size_ = batch_size;
70 return *this;
71 }
72
73 inline size_t batch_size() const {
74 return this->batch_size_;
75 }
76
77 inline SigmoidOperatorTester& input_scale(float input_scale) {
78 assert(input_scale > 0.0f);
79 assert(std::isnormal(input_scale));
80 this->input_scale_ = input_scale;
81 return *this;
82 }
83
84 inline float input_scale() const {
85 return this->input_scale_;
86 }
87
88 inline SigmoidOperatorTester& input_zero_point(uint8_t input_zero_point) {
89 this->input_zero_point_ = input_zero_point;
90 return *this;
91 }
92
93 inline uint8_t input_zero_point() const {
94 return this->input_zero_point_;
95 }
96
97 inline float output_scale() const {
98 return 1.0f / 256.0f;
99 }
100
101 inline uint8_t output_zero_point() const {
102 return 0;
103 }
104
105 inline SigmoidOperatorTester& qmin(uint8_t qmin) {
106 this->qmin_ = qmin;
107 return *this;
108 }
109
110 inline uint8_t qmin() const {
111 return this->qmin_;
112 }
113
114 inline SigmoidOperatorTester& qmax(uint8_t qmax) {
115 this->qmax_ = qmax;
116 return *this;
117 }
118
119 inline uint8_t qmax() const {
120 return this->qmax_;
121 }
122
123 inline SigmoidOperatorTester& iterations(size_t iterations) {
124 this->iterations_ = iterations;
125 return *this;
126 }
127
128 inline size_t iterations() const {
129 return this->iterations_;
130 }
131
132 void TestQ8() const {
133 std::random_device random_device;
134 auto rng = std::mt19937(random_device());
135 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
136
Marat Dukhan7bee7512019-11-18 15:15:48 -0800137 std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700138 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
139 std::vector<float> output_ref(batch_size() * channels());
140 for (size_t iteration = 0; iteration < iterations(); iteration++) {
141 std::generate(input.begin(), input.end(), std::ref(u8rng));
142 std::fill(output.begin(), output.end(), 0xA5);
143
144 // Compute reference results.
145 for (size_t i = 0; i < batch_size(); i++) {
146 for (size_t c = 0; c < channels(); c++) {
147 const float x = input_scale() *
148 (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point()));
149 const float sigmoid_x = 1.0f / (1.0f + std::exp(-x));
150 const float scaled_sigmoid_x = sigmoid_x / output_scale();
151 float y = scaled_sigmoid_x;
152 y = std::min<float>(y, int32_t(qmax()) - int32_t(output_zero_point()));
153 y = std::max<float>(y, int32_t(qmin()) - int32_t(output_zero_point()));
154 output_ref[i * channels() + c] = y + int32_t(output_zero_point());
155 }
156 }
157
158 // Create, setup, run, and destroy Sigmoid operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800159 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700160 xnn_operator_t sigmoid_op = nullptr;
161
162 ASSERT_EQ(xnn_status_success,
163 xnn_create_sigmoid_nc_q8(
164 channels(), input_stride(), output_stride(),
165 input_zero_point(), input_scale(),
166 output_zero_point(), output_scale(),
167 qmin(), qmax(),
168 0, &sigmoid_op));
169 ASSERT_NE(nullptr, sigmoid_op);
170
171 // Smart pointer to automatically delete sigmoid_op.
Marat Dukhan346a9e52019-11-15 09:06:30 -0800172 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sigmoid_op(sigmoid_op, xnn_delete_operator);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700173
174 ASSERT_EQ(xnn_status_success,
175 xnn_setup_sigmoid_nc_q8(
176 sigmoid_op,
177 batch_size(),
178 input.data(), output.data(),
179 nullptr /* thread pool */));
180
181 ASSERT_EQ(xnn_status_success,
182 xnn_run_operator(sigmoid_op, nullptr /* thread pool */));
183
184 // Verify results.
185 for (size_t i = 0; i < batch_size(); i++) {
186 for (size_t c = 0; c < channels(); c++) {
187 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
188 }
189 }
190 }
191 }
192
Marat Dukhan346a9e52019-11-15 09:06:30 -0800193 void TestF32() const {
194 std::random_device random_device;
195 auto rng = std::mt19937(random_device());
196 auto f32rng = std::bind(std::uniform_real_distribution<float>(-25.0f, 25.0f), rng);
197
Marat Dukhan7bee7512019-11-18 15:15:48 -0800198 std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
Marat Dukhan346a9e52019-11-15 09:06:30 -0800199 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
200 std::vector<double> output_ref(batch_size() * channels());
201 for (size_t iteration = 0; iteration < iterations(); iteration++) {
202 std::generate(input.begin(), input.end(), std::ref(f32rng));
203 std::fill(output.begin(), output.end(), 0xA5);
204
205 // Compute reference results.
206 for (size_t i = 0; i < batch_size(); i++) {
207 for (size_t c = 0; c < channels(); c++) {
208 const double x = input[i * input_stride() + c];
209 const double exp_x = std::exp(x);
210 const double sigmoid_x = exp_x / (1.0 + exp_x);
211 output_ref[i * channels() + c] = sigmoid_x;
212 }
213 }
214
215 // Create, setup, run, and destroy Sigmoid operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800216 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhan346a9e52019-11-15 09:06:30 -0800217 xnn_operator_t sigmoid_op = nullptr;
218
219 xnn_status status = xnn_create_sigmoid_nc_f32(
220 channels(), input_stride(), output_stride(),
221 0, &sigmoid_op);
Marat Dukhan346a9e52019-11-15 09:06:30 -0800222 ASSERT_EQ(xnn_status_success, status);
223 ASSERT_NE(nullptr, sigmoid_op);
224
225 // Smart pointer to automatically delete sigmoid_op.
226 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sigmoid_op(sigmoid_op, xnn_delete_operator);
227
228 ASSERT_EQ(xnn_status_success,
229 xnn_setup_sigmoid_nc_f32(
230 sigmoid_op,
231 batch_size(),
232 input.data(), output.data(),
233 nullptr /* thread pool */));
234
235 ASSERT_EQ(xnn_status_success,
236 xnn_run_operator(sigmoid_op, nullptr /* thread pool */));
237
238 // Verify results.
239 for (size_t i = 0; i < batch_size(); i++) {
240 for (size_t c = 0; c < channels(); c++) {
241 ASSERT_NEAR(
242 output[i * output_stride() + c],
243 output_ref[i * channels() + c],
Erich Elsen8fd7b5f2019-11-18 10:50:41 -0800244 5.0e-6);
Marat Dukhan346a9e52019-11-15 09:06:30 -0800245 }
246 }
247 }
248 }
249
XNNPACK Teamb455b122019-09-27 18:10:33 -0700250 private:
251 size_t batch_size_{1};
252 size_t channels_{1};
253 size_t input_stride_{0};
254 size_t output_stride_{0};
255 float input_scale_{0.75f};
256 uint8_t input_zero_point_{121};
257 uint8_t qmin_{0};
258 uint8_t qmax_{255};
259 size_t iterations_{15};
260};