blob: 9a45854b49d5f200e0d4e006a291739f9748dc22 [file] [log] [blame]
Marat Dukhanca2733c2019-11-15 23:21:17 -08001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <array>
12#include <cmath>
13#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <initializer_list>
17#include <limits>
18#include <random>
19#include <vector>
20
21#include <xnnpack.h>
22
23
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080024class BinaryElementwiseOperatorTester {
Marat Dukhanca2733c2019-11-15 23:21:17 -080025 public:
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080026 enum class OperationType {
27 Unknown,
28 Add,
Marat Dukhan79e7f842019-12-05 14:35:50 -080029 Maximum,
30 Minimum,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080031 Multiply,
Marat Dukhan05f3f6d2019-12-03 15:13:53 -080032 Subtract,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080033 };
34
35 inline BinaryElementwiseOperatorTester& input1_shape(std::initializer_list<size_t> input1_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080036 assert(input1_shape.size() <= XNN_MAX_TENSOR_DIMS);
37 this->input1_shape_ = std::vector<size_t>(input1_shape);
38 return *this;
39 }
40
41 inline const std::vector<size_t>& input1_shape() const {
42 return this->input1_shape_;
43 }
44
45 inline size_t input1_dim(size_t i) const {
46 return i < num_input1_dims() ? this->input1_shape_[i] : 1;
47 }
48
49 inline size_t num_input1_dims() const {
50 return this->input1_shape_.size();
51 }
52
53 inline size_t num_input1_elements() const {
54 return std::accumulate(
55 this->input1_shape_.begin(), this->input1_shape_.end(), size_t(1), std::multiplies<size_t>());
56 }
57
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080058 inline BinaryElementwiseOperatorTester& input2_shape(std::initializer_list<size_t> input2_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080059 assert(input2_shape.size() <= XNN_MAX_TENSOR_DIMS);
60 this->input2_shape_ = std::vector<size_t>(input2_shape);
61 return *this;
62 }
63
64 inline const std::vector<size_t>& input2_shape() const {
65 return this->input2_shape_;
66 }
67
68 inline size_t input2_dim(size_t i) const {
69 return i < num_input2_dims() ? this->input2_shape_[i] : 1;
70 }
71
72 inline size_t num_input2_dims() const {
73 return this->input2_shape_.size();
74 }
75
76 inline size_t num_input2_elements() const {
77 return std::accumulate(
78 this->input2_shape_.begin(), this->input2_shape_.end(), size_t(1), std::multiplies<size_t>());
79 }
80
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080081 inline BinaryElementwiseOperatorTester& qmin(uint8_t qmin) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080082 this->qmin_ = qmin;
83 return *this;
84 }
85
86 inline uint8_t qmin() const {
87 return this->qmin_;
88 }
89
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080090 inline BinaryElementwiseOperatorTester& qmax(uint8_t qmax) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080091 this->qmax_ = qmax;
92 return *this;
93 }
94
95 inline uint8_t qmax() const {
96 return this->qmax_;
97 }
98
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080099 inline BinaryElementwiseOperatorTester& operation_type(OperationType operation_type) {
100 this->operation_type_ = operation_type;
101 return *this;
102 }
103
104 inline OperationType operation_type() const {
105 return this->operation_type_;
106 }
107
108 inline BinaryElementwiseOperatorTester& iterations(size_t iterations) {
Marat Dukhanca2733c2019-11-15 23:21:17 -0800109 this->iterations_ = iterations;
110 return *this;
111 }
112
113 inline size_t iterations() const {
114 return this->iterations_;
115 }
116
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800117 float Compute(float a, float b) const {
118 switch (operation_type()) {
119 case OperationType::Add:
120 return a + b;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800121 case OperationType::Maximum:
122 return std::max<float>(a, b);
123 case OperationType::Minimum:
124 return std::min<float>(a, b);
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800125 case OperationType::Multiply:
126 return a * b;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800127 case OperationType::Subtract:
128 return a - b;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800129 default:
130 return std::nanf("");
131 }
132 }
133
Marat Dukhanca2733c2019-11-15 23:21:17 -0800134 void TestF32() const {
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800135 ASSERT_NE(operation_type(), OperationType::Unknown);
136
Marat Dukhanca2733c2019-11-15 23:21:17 -0800137 std::random_device random_device;
138 auto rng = std::mt19937(random_device());
139 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 1.0f), rng);
140
141 // Compute generalized shapes.
142 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
143 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
144 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
145 std::fill(input1_dims.begin(), input1_dims.end(), 1);
146 std::fill(input2_dims.begin(), input2_dims.end(), 1);
147 std::fill(output_dims.begin(), output_dims.end(), 1);
148 std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
149 std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
150 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
151 if (input1_dims[i] != 1 && input2_dims[i] != 1) {
152 ASSERT_EQ(input1_dims[i], input2_dims[i]);
153 }
154 output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
155 }
156 const size_t num_output_elements =
157 std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
158
159 // Compute generalized strides.
160 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
161 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
162 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
163 size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
164 for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
165 input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
166 input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
167 output_strides[i - 1] = output_stride;
168 input1_stride *= input1_dims[i - 1];
169 input2_stride *= input2_dims[i - 1];
170 output_stride *= output_dims[i - 1];
171 }
172
173 std::vector<float> input1(XNN_EXTRA_BYTES / sizeof(float) + num_input1_elements());
174 std::vector<float> input2(XNN_EXTRA_BYTES / sizeof(float) + num_input2_elements());
175 std::vector<float> output(num_output_elements);
176 std::vector<float> output_ref(num_output_elements);
177 for (size_t iteration = 0; iteration < iterations(); iteration++) {
178 std::generate(input1.begin(), input1.end(), std::ref(f32rng));
179 std::generate(input2.begin(), input2.end(), std::ref(f32rng));
180 std::fill(output.begin(), output.end(), nanf(""));
181
182 // Compute reference results.
183 for (size_t i = 0; i < output_dims[0]; i++) {
184 for (size_t j = 0; j < output_dims[1]; j++) {
185 for (size_t k = 0; k < output_dims[2]; k++) {
186 for (size_t l = 0; l < output_dims[3]; l++) {
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800187 for (size_t m = 0; m < output_dims[4]; m++) {
188 for (size_t n = 0; n < output_dims[5]; n++) {
189 output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
190 input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]],
191 input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]);
192 }
193 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800194 }
195 }
196 }
197 }
198 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
199 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
200 const float accumulated_range = accumulated_max - accumulated_min;
201 const float output_min = num_output_elements == 1 ?
202 -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin());
203 const float output_max = num_output_elements == 1 ?
204 +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
205 for (float& output_value : output_ref) {
206 output_value = std::min(std::max(output_value, output_min), output_max);
207 }
208
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800209 // Create, setup, run, and destroy a binary elementwise operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800210 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800211 xnn_operator_t binary_elementwise_op = nullptr;
212
213 switch (operation_type()) {
214 case OperationType::Add:
215 ASSERT_EQ(xnn_status_success,
216 xnn_create_add_nd_f32(
217 output_min, output_max,
218 0, &binary_elementwise_op));
219 break;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800220 case OperationType::Maximum:
221 ASSERT_EQ(xnn_status_success,
222 xnn_create_maximum_nd_f32(
223 0, &binary_elementwise_op));
224 break;
225 case OperationType::Minimum:
226 ASSERT_EQ(xnn_status_success,
227 xnn_create_minimum_nd_f32(
228 0, &binary_elementwise_op));
229 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800230 case OperationType::Multiply:
231 ASSERT_EQ(xnn_status_success,
232 xnn_create_multiply_nd_f32(
233 output_min, output_max,
234 0, &binary_elementwise_op));
235 break;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800236 case OperationType::Subtract:
237 ASSERT_EQ(xnn_status_success,
238 xnn_create_subtract_nd_f32(
239 output_min, output_max,
240 0, &binary_elementwise_op));
241 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800242 default:
243 FAIL() << "Unsupported operation type";
244 }
245 ASSERT_NE(nullptr, binary_elementwise_op);
246
247 // Smart pointer to automatically delete binary_elementwise_op.
248 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
249
250 switch (operation_type()) {
251 case OperationType::Add:
252 ASSERT_EQ(xnn_status_success,
253 xnn_setup_add_nd_f32(
254 binary_elementwise_op,
255 num_input1_dims(),
256 input1_shape().data(),
257 num_input2_dims(),
258 input2_shape().data(),
259 input1.data(), input2.data(), output.data(),
260 nullptr /* thread pool */));
261 break;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800262 case OperationType::Maximum:
263 ASSERT_EQ(xnn_status_success,
264 xnn_setup_maximum_nd_f32(
265 binary_elementwise_op,
266 num_input1_dims(),
267 input1_shape().data(),
268 num_input2_dims(),
269 input2_shape().data(),
270 input1.data(), input2.data(), output.data(),
271 nullptr /* thread pool */));
272 break;
273 case OperationType::Minimum:
274 ASSERT_EQ(xnn_status_success,
275 xnn_setup_minimum_nd_f32(
276 binary_elementwise_op,
277 num_input1_dims(),
278 input1_shape().data(),
279 num_input2_dims(),
280 input2_shape().data(),
281 input1.data(), input2.data(), output.data(),
282 nullptr /* thread pool */));
283 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800284 case OperationType::Multiply:
285 ASSERT_EQ(xnn_status_success,
286 xnn_setup_multiply_nd_f32(
287 binary_elementwise_op,
288 num_input1_dims(),
289 input1_shape().data(),
290 num_input2_dims(),
291 input2_shape().data(),
292 input1.data(), input2.data(), output.data(),
293 nullptr /* thread pool */));
294 break;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800295 case OperationType::Subtract:
296 ASSERT_EQ(xnn_status_success,
297 xnn_setup_subtract_nd_f32(
298 binary_elementwise_op,
299 num_input1_dims(),
300 input1_shape().data(),
301 num_input2_dims(),
302 input2_shape().data(),
303 input1.data(), input2.data(), output.data(),
304 nullptr /* thread pool */));
305 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800306 default:
307 FAIL() << "Unsupported operation type";
308 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800309
310 ASSERT_EQ(xnn_status_success,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800311 xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
Marat Dukhanca2733c2019-11-15 23:21:17 -0800312
313 // Verify results.
314 for (size_t i = 0; i < output_dims[0]; i++) {
315 for (size_t j = 0; j < output_dims[1]; j++) {
316 for (size_t k = 0; k < output_dims[2]; k++) {
317 for (size_t l = 0; l < output_dims[3]; l++) {
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800318 for (size_t m = 0; m < output_dims[4]; m++) {
319 for (size_t n = 0; n < output_dims[5]; n++) {
320 const size_t index =
321 i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
322 ASSERT_NEAR(output[index], output_ref[index], 1.0e-6f * std::abs(output_ref[index]))
323 << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
324 }
325 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800326 }
327 }
328 }
329 }
330 }
331 }
332
333 private:
334 std::vector<size_t> input1_shape_;
335 std::vector<size_t> input2_shape_;
336 uint8_t qmin_{0};
337 uint8_t qmax_{255};
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800338 OperationType operation_type_{OperationType::Unknown};
Marat Dukhanab4af572019-12-03 11:11:18 -0800339 size_t iterations_{3};
Marat Dukhanca2733c2019-11-15 23:21:17 -0800340};