blob: 4a23391ec8b2692d9f21562ae7a483110530a359 [file] [log] [blame]
Marat Dukhan5de7bc02021-09-09 19:04:01 -07001// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cmath>
13#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <limits>
17#include <random>
18#include <vector>
19
20#include <xnnpack.h>
21
22
23class TanhOperatorTester {
24 public:
25 inline TanhOperatorTester& channels(size_t channels) {
26 assert(channels != 0);
27 this->channels_ = channels;
28 return *this;
29 }
30
31 inline size_t channels() const {
32 return this->channels_;
33 }
34
35 inline TanhOperatorTester& input_stride(size_t input_stride) {
36 assert(input_stride != 0);
37 this->input_stride_ = input_stride;
38 return *this;
39 }
40
41 inline size_t input_stride() const {
42 if (this->input_stride_ == 0) {
43 return this->channels_;
44 } else {
45 assert(this->input_stride_ >= this->channels_);
46 return this->input_stride_;
47 }
48 }
49
50 inline TanhOperatorTester& output_stride(size_t output_stride) {
51 assert(output_stride != 0);
52 this->output_stride_ = output_stride;
53 return *this;
54 }
55
56 inline size_t output_stride() const {
57 if (this->output_stride_ == 0) {
58 return this->channels_;
59 } else {
60 assert(this->output_stride_ >= this->channels_);
61 return this->output_stride_;
62 }
63 }
64
65 inline TanhOperatorTester& batch_size(size_t batch_size) {
66 assert(batch_size != 0);
67 this->batch_size_ = batch_size;
68 return *this;
69 }
70
71 inline size_t batch_size() const {
72 return this->batch_size_;
73 }
74
75 inline TanhOperatorTester& input_scale(float input_scale) {
76 assert(input_scale > 0.0f);
77 assert(std::isnormal(input_scale));
78 this->input_scale_ = input_scale;
79 return *this;
80 }
81
82 inline float input_scale() const {
83 return this->input_scale_;
84 }
85
86 inline TanhOperatorTester& input_zero_point(uint8_t input_zero_point) {
87 this->input_zero_point_ = input_zero_point;
88 return *this;
89 }
90
91 inline uint8_t input_zero_point() const {
92 return this->input_zero_point_;
93 }
94
95 inline float output_scale() const {
96 return 1.0f / 128.0f;
97 }
98
99 inline uint8_t output_zero_point() const {
100 return 128;
101 }
102
103 inline TanhOperatorTester& qmin(uint8_t qmin) {
104 this->qmin_ = qmin;
105 return *this;
106 }
107
108 inline uint8_t qmin() const {
109 return this->qmin_;
110 }
111
112 inline TanhOperatorTester& qmax(uint8_t qmax) {
113 this->qmax_ = qmax;
114 return *this;
115 }
116
117 inline uint8_t qmax() const {
118 return this->qmax_;
119 }
120
121 inline TanhOperatorTester& iterations(size_t iterations) {
122 this->iterations_ = iterations;
123 return *this;
124 }
125
126 inline size_t iterations() const {
127 return this->iterations_;
128 }
129
130 void TestQS8() const {
131 std::random_device random_device;
132 auto rng = std::mt19937(random_device());
133 auto i8rng = std::bind(
134 std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
135 std::ref(rng));
136
137 std::vector<int8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
138 std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels());
139 std::vector<float> output_ref(batch_size() * channels());
140 for (size_t iteration = 0; iteration < iterations(); iteration++) {
141 std::generate(input.begin(), input.end(), std::ref(i8rng));
142 std::fill(output.begin(), output.end(), 0xA5);
143
144 // Compute reference results.
145 for (size_t i = 0; i < batch_size(); i++) {
146 for (size_t c = 0; c < channels(); c++) {
147 const float x = input_scale() *
148 (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point() - 0x80));
149 const float tanh_x = std::tanh(x);
150 const float scaled_tanh_x = tanh_x / output_scale();
151 float y = scaled_tanh_x;
152 y = std::min<float>(y, int32_t(qmax() - 0x80) - int32_t(output_zero_point() - 0x80));
153 y = std::max<float>(y, int32_t(qmin() - 0x80) - int32_t(output_zero_point() - 0x80));
154 output_ref[i * channels() + c] = y + int32_t(output_zero_point() - 0x80);
155 }
156 }
157
158 // Create, setup, run, and destroy Sigmoid operator.
159 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
160 xnn_operator_t tanh_op = nullptr;
161
162 ASSERT_EQ(xnn_status_success,
163 xnn_create_tanh_nc_qs8(
164 channels(), input_stride(), output_stride(),
165 int8_t(input_zero_point() - 0x80), input_scale(),
166 int8_t(output_zero_point() - 0x80), output_scale(),
167 int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
168 0, &tanh_op));
169 ASSERT_NE(nullptr, tanh_op);
170
171 // Smart pointer to automatically delete tanh_op.
172 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_tanh_op(tanh_op, xnn_delete_operator);
173
174 ASSERT_EQ(xnn_status_success,
175 xnn_setup_tanh_nc_qs8(
176 tanh_op,
177 batch_size(),
178 input.data(), output.data(),
179 nullptr /* thread pool */));
180
181 ASSERT_EQ(xnn_status_success,
182 xnn_run_operator(tanh_op, nullptr /* thread pool */));
183
184 // Verify results.
185 for (size_t i = 0; i < batch_size(); i++) {
186 for (size_t c = 0; c < channels(); c++) {
187 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
188 }
189 }
190 }
191 }
192
193 void TestQU8() const {
194 std::random_device random_device;
195 auto rng = std::mt19937(random_device());
196 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
197
198 std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
199 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
200 std::vector<float> output_ref(batch_size() * channels());
201 for (size_t iteration = 0; iteration < iterations(); iteration++) {
202 std::generate(input.begin(), input.end(), std::ref(u8rng));
203 std::fill(output.begin(), output.end(), 0xA5);
204
205 // Compute reference results.
206 for (size_t i = 0; i < batch_size(); i++) {
207 for (size_t c = 0; c < channels(); c++) {
208 const float x = input_scale() *
209 (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point()));
210 const float tanh_x = std::tanh(x);
211 const float scaled_tanh_x = tanh_x / output_scale();
212 float y = scaled_tanh_x;
213 y = std::min<float>(y, int32_t(qmax()) - int32_t(output_zero_point()));
214 y = std::max<float>(y, int32_t(qmin()) - int32_t(output_zero_point()));
215 output_ref[i * channels() + c] = y + int32_t(output_zero_point());
216 }
217 }
218
219 // Create, setup, run, and destroy Sigmoid operator.
220 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
221 xnn_operator_t tanh_op = nullptr;
222
223 ASSERT_EQ(xnn_status_success,
224 xnn_create_tanh_nc_qu8(
225 channels(), input_stride(), output_stride(),
226 input_zero_point(), input_scale(),
227 output_zero_point(), output_scale(),
228 qmin(), qmax(),
229 0, &tanh_op));
230 ASSERT_NE(nullptr, tanh_op);
231
232 // Smart pointer to automatically delete tanh_op.
233 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_tanh_op(tanh_op, xnn_delete_operator);
234
235 ASSERT_EQ(xnn_status_success,
236 xnn_setup_tanh_nc_qu8(
237 tanh_op,
238 batch_size(),
239 input.data(), output.data(),
240 nullptr /* thread pool */));
241
242 ASSERT_EQ(xnn_status_success,
243 xnn_run_operator(tanh_op, nullptr /* thread pool */));
244
245 // Verify results.
246 for (size_t i = 0; i < batch_size(); i++) {
247 for (size_t c = 0; c < channels(); c++) {
248 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
249 }
250 }
251 }
252 }
253
254 private:
255 size_t batch_size_{1};
256 size_t channels_{1};
257 size_t input_stride_{0};
258 size_t output_stride_{0};
259 float input_scale_{0.75f};
260 uint8_t input_zero_point_{121};
261 uint8_t qmin_{0};
262 uint8_t qmax_{255};
263 size_t iterations_{15};
264};