blob: d54cc15ddd54507086041dd01c95caaa31b5c281 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <algorithm>
14#include <cassert>
15#include <cmath>
16#include <cstddef>
17#include <cstdlib>
18#include <functional>
Marat Dukhan5ce30d92020-04-14 03:31:26 -070019#include <limits>
XNNPACK Teamb455b122019-09-27 18:10:33 -070020#include <random>
21#include <vector>
22
23#include <xnnpack.h>
24
25
Marat Dukhanfd8e6892020-01-27 15:25:25 -080026class SoftMaxOperatorTester {
XNNPACK Teamb455b122019-09-27 18:10:33 -070027 public:
Marat Dukhanfd8e6892020-01-27 15:25:25 -080028 inline SoftMaxOperatorTester& channels(size_t channels) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070029 assert(channels != 0);
30 this->channels_ = channels;
31 return *this;
32 }
33
34 inline size_t channels() const {
35 return this->channels_;
36 }
37
Marat Dukhanfd8e6892020-01-27 15:25:25 -080038 inline SoftMaxOperatorTester& input_stride(size_t input_stride) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070039 assert(input_stride != 0);
40 this->input_stride_ = input_stride;
41 return *this;
42 }
43
44 inline size_t input_stride() const {
45 if (this->input_stride_ == 0) {
46 return this->channels_;
47 } else {
48 assert(this->input_stride_ >= this->channels_);
49 return this->input_stride_;
50 }
51 }
52
Marat Dukhanfd8e6892020-01-27 15:25:25 -080053 inline SoftMaxOperatorTester& output_stride(size_t output_stride) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070054 assert(output_stride != 0);
55 this->output_stride_ = output_stride;
56 return *this;
57 }
58
59 inline size_t output_stride() const {
60 if (this->output_stride_ == 0) {
61 return this->channels_;
62 } else {
63 assert(this->output_stride_ >= this->channels_);
64 return this->output_stride_;
65 }
66 }
67
Marat Dukhanfd8e6892020-01-27 15:25:25 -080068 inline SoftMaxOperatorTester& batch_size(size_t batch_size) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070069 assert(batch_size != 0);
70 this->batch_size_ = batch_size;
71 return *this;
72 }
73
74 inline size_t batch_size() const {
75 return this->batch_size_;
76 }
77
Marat Dukhanfd8e6892020-01-27 15:25:25 -080078 inline SoftMaxOperatorTester& input_scale(float input_scale) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070079 assert(input_scale > 0.0f);
80 assert(std::isnormal(input_scale));
81 this->input_scale_ = input_scale;
82 return *this;
83 }
84
85 inline float input_scale() const {
86 return this->input_scale_;
87 }
88
Marat Dukhanfd8e6892020-01-27 15:25:25 -080089 inline SoftMaxOperatorTester& input_zero_point(uint8_t input_zero_point) {
XNNPACK Teamb455b122019-09-27 18:10:33 -070090 this->input_zero_point_ = input_zero_point;
91 return *this;
92 }
93
94 inline uint8_t input_zero_point() const {
95 return this->input_zero_point_;
96 }
97
98 inline float output_scale() const {
99 return 1.0f / 256.0f;
100 }
101
102 inline uint8_t output_zero_point() const {
103 return 0;
104 }
105
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800106 inline SoftMaxOperatorTester& iterations(size_t iterations) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700107 this->iterations_ = iterations;
108 return *this;
109 }
110
111 inline size_t iterations() const {
112 return this->iterations_;
113 }
114
Marat Dukhan08b7a972020-07-14 18:17:29 -0700115 void TestQU8() const {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700116 std::random_device random_device;
117 auto rng = std::mt19937(random_device());
Marat Dukhan5ce30d92020-04-14 03:31:26 -0700118 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700119
120 std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels());
121 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
122 std::vector<float> output_ref(batch_size() * channels());
123 for (size_t iteration = 0; iteration < iterations(); iteration++) {
124 std::generate(input.begin(), input.end(), std::ref(u8rng));
125 std::fill(output.begin(), output.end(), 0xA5);
126
127 // Compute reference results.
128 for (size_t i = 0; i < batch_size(); i++) {
129 const int32_t max_input = *std::max_element(
130 input.data() + i * input_stride(),
131 input.data() + i * input_stride() + channels());
132 float sum_exp = 0.0f;
133 for (size_t c = 0; c < channels(); c++) {
134 sum_exp +=
135 std::exp((int32_t(input[i * input_stride() + c]) - max_input) *
136 input_scale());
137 }
138 for (size_t c = 0; c < channels(); c++) {
139 output_ref[i * channels() + c] =
140 std::exp((int32_t(input[i * input_stride() + c]) - max_input) *
141 input_scale()) /
142 (sum_exp * output_scale());
143 output_ref[i * channels() + c] = std::min(output_ref[i * channels() + c], 255.0f);
144 }
145 }
146
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800147 // Create, setup, run, and destroy SoftMax operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800148 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800149 xnn_operator_t softmax_op = nullptr;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700150
151 ASSERT_EQ(xnn_status_success,
Marat Dukhan08b7a972020-07-14 18:17:29 -0700152 xnn_create_softmax_nc_qu8(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700153 channels(), input_stride(), output_stride(),
154 input_scale(),
155 output_zero_point(), output_scale(),
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800156 0, &softmax_op));
157 ASSERT_NE(nullptr, softmax_op);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700158
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800159 // Smart pointer to automatically delete softmax_op.
160 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_softmax_op(softmax_op, xnn_delete_operator);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700161
162 ASSERT_EQ(xnn_status_success,
Marat Dukhan08b7a972020-07-14 18:17:29 -0700163 xnn_setup_softmax_nc_qu8(
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800164 softmax_op,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700165 batch_size(),
166 input.data(), output.data(),
167 nullptr /* thread pool */));
168
169 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800170 xnn_run_operator(softmax_op, nullptr /* thread pool */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700171
172 // Verify results.
173 for (size_t i = 0; i < batch_size(); i++) {
174 for (size_t c = 0; c < channels(); c++) {
175 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
176 }
177 }
178 }
179 }
180
Marat Dukhan1edc4542020-01-27 12:40:13 -0800181 void TestF32() const {
182 std::random_device random_device;
183 auto rng = std::mt19937(random_device());
184 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
185
186 std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
187 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
188 std::vector<double> output_ref(batch_size() * channels());
189 for (size_t iteration = 0; iteration < iterations(); iteration++) {
190 std::generate(input.begin(), input.end(), std::ref(f32rng));
191 std::fill(output.begin(), output.end(), std::nanf(""));
192
193 // Compute reference results.
194 for (size_t i = 0; i < batch_size(); i++) {
195 const double max_input = *std::max_element(
196 input.data() + i * input_stride(),
197 input.data() + i * input_stride() + channels());
198 double sum_exp = 0.0;
199 for (size_t c = 0; c < channels(); c++) {
200 sum_exp += std::exp(double(input[i * input_stride() + c]) - max_input);
201 }
202 for (size_t c = 0; c < channels(); c++) {
203 output_ref[i * channels() + c] =
204 std::exp(double(input[i * input_stride() + c]) - max_input) / sum_exp;
205 }
206 }
207
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800208 // Create, setup, run, and destroy SoftMax operator.
Marat Dukhan1edc4542020-01-27 12:40:13 -0800209 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800210 xnn_operator_t softmax_op = nullptr;
Marat Dukhan1edc4542020-01-27 12:40:13 -0800211
212 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800213 xnn_create_softmax_nc_f32(
Marat Dukhan1edc4542020-01-27 12:40:13 -0800214 channels(), input_stride(), output_stride(),
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800215 0, &softmax_op));
216 ASSERT_NE(nullptr, softmax_op);
Marat Dukhan1edc4542020-01-27 12:40:13 -0800217
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800218 // Smart pointer to automatically delete softmax_op.
219 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_softmax_op(softmax_op, xnn_delete_operator);
Marat Dukhan1edc4542020-01-27 12:40:13 -0800220
221 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800222 xnn_setup_softmax_nc_f32(
223 softmax_op,
Marat Dukhan1edc4542020-01-27 12:40:13 -0800224 batch_size(),
225 input.data(), output.data(),
226 nullptr /* thread pool */));
227
228 ASSERT_EQ(xnn_status_success,
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800229 xnn_run_operator(softmax_op, nullptr /* thread pool */));
Marat Dukhan1edc4542020-01-27 12:40:13 -0800230
231 // Verify results.
232 for (size_t i = 0; i < batch_size(); i++) {
233 for (size_t c = 0; c < channels(); c++) {
234 ASSERT_NEAR(
235 double(output[i * output_stride() + c]),
236 output_ref[i * channels() + c],
237 output_ref[i * channels() + c] * 1.0e-4);
238 }
239 }
240 }
241 }
242
XNNPACK Teamb455b122019-09-27 18:10:33 -0700243 private:
244 size_t batch_size_{1};
245 size_t channels_{1};
246 size_t input_stride_{0};
247 size_t output_stride_{0};
248 float input_scale_{0.176080093};
249 uint8_t input_zero_point_{121};
250 size_t iterations_{15};
251};