blob: eb79962135ab7681080ce64c5b5f4fc3372c7404 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <algorithm>
14#include <cassert>
15#include <cmath>
16#include <cstddef>
17#include <cstdlib>
18#include <functional>
19#include <random>
20#include <vector>
21
22#include <xnnpack.h>
23
24
Marat Dukhanfd8e6892020-01-27 15:25:25 -080025class SoftMaxOperatorTester {
XNNPACK Teamb455b122019-09-27 18:10:33 -070026 public:
Marat Dukhanfd8e6892020-01-27 15:25:25 -080027 inline SoftMaxOperatorTester& channels(size_t channels) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070028 assert(channels != 0);
29 this->channels_ = channels;
30 return *this;
31 }
32
33 inline size_t channels() const {
34 return this->channels_;
35 }
36
Marat Dukhanfd8e6892020-01-27 15:25:25 -080037 inline SoftMaxOperatorTester& input_stride(size_t input_stride) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070038 assert(input_stride != 0);
39 this->input_stride_ = input_stride;
40 return *this;
41 }
42
43 inline size_t input_stride() const {
44 if (this->input_stride_ == 0) {
45 return this->channels_;
46 } else {
47 assert(this->input_stride_ >= this->channels_);
48 return this->input_stride_;
49 }
50 }
51
Marat Dukhanfd8e6892020-01-27 15:25:25 -080052 inline SoftMaxOperatorTester& output_stride(size_t output_stride) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070053 assert(output_stride != 0);
54 this->output_stride_ = output_stride;
55 return *this;
56 }
57
58 inline size_t output_stride() const {
59 if (this->output_stride_ == 0) {
60 return this->channels_;
61 } else {
62 assert(this->output_stride_ >= this->channels_);
63 return this->output_stride_;
64 }
65 }
66
Marat Dukhanfd8e6892020-01-27 15:25:25 -080067 inline SoftMaxOperatorTester& batch_size(size_t batch_size) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070068 assert(batch_size != 0);
69 this->batch_size_ = batch_size;
70 return *this;
71 }
72
73 inline size_t batch_size() const {
74 return this->batch_size_;
75 }
76
Marat Dukhanfd8e6892020-01-27 15:25:25 -080077 inline SoftMaxOperatorTester& input_scale(float input_scale) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070078 assert(input_scale > 0.0f);
79 assert(std::isnormal(input_scale));
80 this->input_scale_ = input_scale;
81 return *this;
82 }
83
84 inline float input_scale() const {
85 return this->input_scale_;
86 }
87
Marat Dukhanfd8e6892020-01-27 15:25:25 -080088 inline SoftMaxOperatorTester& input_zero_point(uint8_t input_zero_point) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070089 this->input_zero_point_ = input_zero_point;
90 return *this;
91 }
92
93 inline uint8_t input_zero_point() const {
94 return this->input_zero_point_;
95 }
96
97 inline float output_scale() const {
98 return 1.0f / 256.0f;
99 }
100
101 inline uint8_t output_zero_point() const {
102 return 0;
103 }
104
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800105 inline SoftMaxOperatorTester& iterations(size_t iterations) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700106 this->iterations_ = iterations;
107 return *this;
108 }
109
110 inline size_t iterations() const {
111 return this->iterations_;
112 }
113
114 void TestQ8() const {
115 std::random_device random_device;
116 auto rng = std::mt19937(random_device());
117 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
118
119 std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels());
120 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
121 std::vector<float> output_ref(batch_size() * channels());
122 for (size_t iteration = 0; iteration < iterations(); iteration++) {
123 std::generate(input.begin(), input.end(), std::ref(u8rng));
124 std::fill(output.begin(), output.end(), 0xA5);
125
126 // Compute reference results.
127 for (size_t i = 0; i < batch_size(); i++) {
128 const int32_t max_input = *std::max_element(
129 input.data() + i * input_stride(),
130 input.data() + i * input_stride() + channels());
131 float sum_exp = 0.0f;
132 for (size_t c = 0; c < channels(); c++) {
133 sum_exp +=
134 std::exp((int32_t(input[i * input_stride() + c]) - max_input) *
135 input_scale());
136 }
137 for (size_t c = 0; c < channels(); c++) {
138 output_ref[i * channels() + c] =
139 std::exp((int32_t(input[i * input_stride() + c]) - max_input) *
140 input_scale()) /
141 (sum_exp * output_scale());
142 output_ref[i * channels() + c] = std::min(output_ref[i * channels() + c], 255.0f);
143 }
144 }
145
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800146 // Create, setup, run, and destroy SoftMax operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800147 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800148 xnn_operator_t softmax_op = nullptr;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700149
150 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800151 xnn_create_softmax_nc_q8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700152 channels(), input_stride(), output_stride(),
153 input_scale(),
154 output_zero_point(), output_scale(),
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800155 0, &softmax_op));
156 ASSERT_NE(nullptr, softmax_op);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700157
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800158 // Smart pointer to automatically delete softmax_op.
159 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_softmax_op(softmax_op, xnn_delete_operator);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700160
161 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800162 xnn_setup_softmax_nc_q8(
163 softmax_op,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700164 batch_size(),
165 input.data(), output.data(),
166 nullptr /* thread pool */));
167
168 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800169 xnn_run_operator(softmax_op, nullptr /* thread pool */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700170
171 // Verify results.
172 for (size_t i = 0; i < batch_size(); i++) {
173 for (size_t c = 0; c < channels(); c++) {
174 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
175 }
176 }
177 }
178 }
179
Marat Dukhan1edc4542020-01-27 12:40:13 -0800180 void TestF32() const {
181 std::random_device random_device;
182 auto rng = std::mt19937(random_device());
183 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
184
185 std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
186 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
187 std::vector<double> output_ref(batch_size() * channels());
188 for (size_t iteration = 0; iteration < iterations(); iteration++) {
189 std::generate(input.begin(), input.end(), std::ref(f32rng));
190 std::fill(output.begin(), output.end(), std::nanf(""));
191
192 // Compute reference results.
193 for (size_t i = 0; i < batch_size(); i++) {
194 const double max_input = *std::max_element(
195 input.data() + i * input_stride(),
196 input.data() + i * input_stride() + channels());
197 double sum_exp = 0.0;
198 for (size_t c = 0; c < channels(); c++) {
199 sum_exp += std::exp(double(input[i * input_stride() + c]) - max_input);
200 }
201 for (size_t c = 0; c < channels(); c++) {
202 output_ref[i * channels() + c] =
203 std::exp(double(input[i * input_stride() + c]) - max_input) / sum_exp;
204 }
205 }
206
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800207 // Create, setup, run, and destroy SoftMax operator.
Marat Dukhan1edc4542020-01-27 12:40:13 -0800208 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800209 xnn_operator_t softmax_op = nullptr;
Marat Dukhan1edc4542020-01-27 12:40:13 -0800210
211 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800212 xnn_create_softmax_nc_f32(
Marat Dukhan1edc4542020-01-27 12:40:13 -0800213 channels(), input_stride(), output_stride(),
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800214 0, &softmax_op));
215 ASSERT_NE(nullptr, softmax_op);
Marat Dukhan1edc4542020-01-27 12:40:13 -0800216
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800217 // Smart pointer to automatically delete softmax_op.
218 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_softmax_op(softmax_op, xnn_delete_operator);
Marat Dukhan1edc4542020-01-27 12:40:13 -0800219
220 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800221 xnn_setup_softmax_nc_f32(
222 softmax_op,
Marat Dukhan1edc4542020-01-27 12:40:13 -0800223 batch_size(),
224 input.data(), output.data(),
225 nullptr /* thread pool */));
226
227 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800228 xnn_run_operator(softmax_op, nullptr /* thread pool */));
Marat Dukhan1edc4542020-01-27 12:40:13 -0800229
230 // Verify results.
231 for (size_t i = 0; i < batch_size(); i++) {
232 for (size_t c = 0; c < channels(); c++) {
233 ASSERT_NEAR(
234 double(output[i * output_stride() + c]),
235 output_ref[i * channels() + c],
236 output_ref[i * channels() + c] * 1.0e-4);
237 }
238 }
239 }
240 }
241
XNNPACK Teamb455b122019-09-27 18:10:33 -0700242 private:
243 size_t batch_size_{1};
244 size_t channels_{1};
245 size_t input_stride_{0};
246 size_t output_stride_{0};
247 float input_scale_{0.176080093};
248 uint8_t input_zero_point_{121};
249 size_t iterations_{15};
250};