blob: 968d8071315ffa323925c0d9b2c38a1f76c358a2 [file] [log] [blame]
Marat Dukhanca2733c2019-11-15 23:21:17 -08001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <array>
12#include <cmath>
13#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <initializer_list>
17#include <limits>
Marat Dukhanc5ee9ff2020-04-13 01:32:59 -070018#include <numeric>
Marat Dukhanca2733c2019-11-15 23:21:17 -080019#include <random>
20#include <vector>
21
Frank Barchard01898c02020-06-23 21:49:50 -070022#include <fp16.h>
23
Marat Dukhanca2733c2019-11-15 23:21:17 -080024#include <xnnpack.h>
25
26
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080027class BinaryElementwiseOperatorTester {
Marat Dukhanca2733c2019-11-15 23:21:17 -080028 public:
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080029 enum class OperationType {
30 Unknown,
31 Add,
Marat Dukhan69180502019-12-06 15:00:31 -080032 Divide,
Marat Dukhan79e7f842019-12-05 14:35:50 -080033 Maximum,
34 Minimum,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080035 Multiply,
Marat Dukhan05f3f6d2019-12-03 15:13:53 -080036 Subtract,
Marat Dukhanf7399262020-06-05 10:58:44 -070037 SquaredDifference,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080038 };
39
40 inline BinaryElementwiseOperatorTester& input1_shape(std::initializer_list<size_t> input1_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080041 assert(input1_shape.size() <= XNN_MAX_TENSOR_DIMS);
42 this->input1_shape_ = std::vector<size_t>(input1_shape);
43 return *this;
44 }
45
46 inline const std::vector<size_t>& input1_shape() const {
47 return this->input1_shape_;
48 }
49
50 inline size_t input1_dim(size_t i) const {
51 return i < num_input1_dims() ? this->input1_shape_[i] : 1;
52 }
53
54 inline size_t num_input1_dims() const {
55 return this->input1_shape_.size();
56 }
57
58 inline size_t num_input1_elements() const {
59 return std::accumulate(
60 this->input1_shape_.begin(), this->input1_shape_.end(), size_t(1), std::multiplies<size_t>());
61 }
62
Marat Dukhanff209482020-09-03 14:26:53 -070063 inline BinaryElementwiseOperatorTester& input1_zero_point(int16_t input1_zero_point) {
64 this->input1_zero_point_ = input1_zero_point;
65 return *this;
66 }
67
68 inline int16_t input1_zero_point() const {
69 return this->input1_zero_point_;
70 }
71
72 inline BinaryElementwiseOperatorTester& input1_scale(float input1_scale) {
Anush Elangovan6d490f72021-01-26 22:00:57 -080073 assert(std::isfinite(input1_scale));
Marat Dukhanff209482020-09-03 14:26:53 -070074 this->input1_scale_ = input1_scale;
75 return *this;
76 }
77
78 inline float input1_scale() const {
79 return this->input1_scale_;
80 }
81
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080082 inline BinaryElementwiseOperatorTester& input2_shape(std::initializer_list<size_t> input2_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080083 assert(input2_shape.size() <= XNN_MAX_TENSOR_DIMS);
84 this->input2_shape_ = std::vector<size_t>(input2_shape);
85 return *this;
86 }
87
88 inline const std::vector<size_t>& input2_shape() const {
89 return this->input2_shape_;
90 }
91
92 inline size_t input2_dim(size_t i) const {
93 return i < num_input2_dims() ? this->input2_shape_[i] : 1;
94 }
95
96 inline size_t num_input2_dims() const {
97 return this->input2_shape_.size();
98 }
99
100 inline size_t num_input2_elements() const {
101 return std::accumulate(
102 this->input2_shape_.begin(), this->input2_shape_.end(), size_t(1), std::multiplies<size_t>());
103 }
104
Marat Dukhanff209482020-09-03 14:26:53 -0700105 inline BinaryElementwiseOperatorTester& input2_zero_point(int16_t input2_zero_point) {
106 this->input2_zero_point_ = input2_zero_point;
107 return *this;
108 }
109
110 inline int16_t input2_zero_point() const {
111 return this->input2_zero_point_;
112 }
113
114 inline BinaryElementwiseOperatorTester& input2_scale(float input2_scale) {
Anush Elangovan6d490f72021-01-26 22:00:57 -0800115 assert(std::isfinite(input2_scale));
Marat Dukhanff209482020-09-03 14:26:53 -0700116 this->input2_scale_ = input2_scale;
117 return *this;
118 }
119
120 inline float input2_scale() const {
121 return this->input2_scale_;
122 }
123
124 inline BinaryElementwiseOperatorTester& output_zero_point(int16_t output_zero_point) {
125 this->output_zero_point_ = output_zero_point;
126 return *this;
127 }
128
129 inline int16_t output_zero_point() const {
130 return this->output_zero_point_;
131 }
132
133 inline BinaryElementwiseOperatorTester& output_scale(float output_scale) {
Anush Elangovan6d490f72021-01-26 22:00:57 -0800134 assert(std::isfinite(output_scale));
Marat Dukhanff209482020-09-03 14:26:53 -0700135 this->output_scale_ = output_scale;
136 return *this;
137 }
138
139 inline float output_scale() const {
140 return this->output_scale_;
141 }
142
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800143 inline BinaryElementwiseOperatorTester& qmin(uint8_t qmin) {
Marat Dukhanca2733c2019-11-15 23:21:17 -0800144 this->qmin_ = qmin;
145 return *this;
146 }
147
148 inline uint8_t qmin() const {
149 return this->qmin_;
150 }
151
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800152 inline BinaryElementwiseOperatorTester& qmax(uint8_t qmax) {
Marat Dukhanca2733c2019-11-15 23:21:17 -0800153 this->qmax_ = qmax;
154 return *this;
155 }
156
157 inline uint8_t qmax() const {
158 return this->qmax_;
159 }
160
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800161 inline BinaryElementwiseOperatorTester& operation_type(OperationType operation_type) {
162 this->operation_type_ = operation_type;
163 return *this;
164 }
165
166 inline OperationType operation_type() const {
167 return this->operation_type_;
168 }
169
170 inline BinaryElementwiseOperatorTester& iterations(size_t iterations) {
Marat Dukhanca2733c2019-11-15 23:21:17 -0800171 this->iterations_ = iterations;
172 return *this;
173 }
174
175 inline size_t iterations() const {
176 return this->iterations_;
177 }
178
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800179 float Compute(float a, float b) const {
180 switch (operation_type()) {
181 case OperationType::Add:
182 return a + b;
Marat Dukhan69180502019-12-06 15:00:31 -0800183 case OperationType::Divide:
184 return a / b;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800185 case OperationType::Maximum:
186 return std::max<float>(a, b);
187 case OperationType::Minimum:
188 return std::min<float>(a, b);
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800189 case OperationType::Multiply:
190 return a * b;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800191 case OperationType::Subtract:
192 return a - b;
Marat Dukhanf7399262020-06-05 10:58:44 -0700193 case OperationType::SquaredDifference:
194 return (a - b) * (a - b);
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800195 default:
196 return std::nanf("");
197 }
198 }
199
Marat Dukhanff209482020-09-03 14:26:53 -0700200 void TestQS8() const {
201 ASSERT_NE(operation_type(), OperationType::Unknown);
202 ASSERT_GE(input1_zero_point(), std::numeric_limits<int8_t>::min());
203 ASSERT_LE(input1_zero_point(), std::numeric_limits<int8_t>::max());
204 ASSERT_GE(input2_zero_point(), std::numeric_limits<int8_t>::min());
205 ASSERT_LE(input2_zero_point(), std::numeric_limits<int8_t>::max());
206 ASSERT_GE(output_zero_point(), std::numeric_limits<int8_t>::min());
207 ASSERT_LE(output_zero_point(), std::numeric_limits<int8_t>::max());
208
209 std::random_device random_device;
210 auto rng = std::mt19937(random_device());
211 auto i8rng = std::bind(
212 std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()), std::ref(rng));
213
214 // Compute generalized shapes.
215 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
216 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
217 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
218 std::fill(input1_dims.begin(), input1_dims.end(), 1);
219 std::fill(input2_dims.begin(), input2_dims.end(), 1);
220 std::fill(output_dims.begin(), output_dims.end(), 1);
221 std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
222 std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
223 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
224 if (input1_dims[i] != 1 && input2_dims[i] != 1) {
225 ASSERT_EQ(input1_dims[i], input2_dims[i]);
226 }
227 output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
228 }
229 const size_t num_output_elements =
230 std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
231
232 // Compute generalized strides.
233 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
234 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
235 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
236 size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
237 for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
238 input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
239 input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
240 output_strides[i - 1] = output_stride;
241 input1_stride *= input1_dims[i - 1];
242 input2_stride *= input2_dims[i - 1];
243 output_stride *= output_dims[i - 1];
244 }
245
246 std::vector<int8_t> input1(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input1_elements());
247 std::vector<int8_t> input2(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input2_elements());
248 std::vector<int8_t> output(num_output_elements);
249 std::vector<float> output_ref(num_output_elements);
250 for (size_t iteration = 0; iteration < iterations(); iteration++) {
251 std::generate(input1.begin(), input1.end(), std::ref(i8rng));
252 std::generate(input2.begin(), input2.end(), std::ref(i8rng));
253 std::fill(output.begin(), output.end(), 0xAA);
254
255 // Compute reference results.
256 for (size_t i = 0; i < output_dims[0]; i++) {
257 for (size_t j = 0; j < output_dims[1]; j++) {
258 for (size_t k = 0; k < output_dims[2]; k++) {
259 for (size_t l = 0; l < output_dims[3]; l++) {
260 for (size_t m = 0; m < output_dims[4]; m++) {
261 for (size_t n = 0; n < output_dims[5]; n++) {
262 output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
263 input1_scale() * (int32_t(input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]]) - input1_zero_point()),
264 input2_scale() * (int32_t(input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]) - input2_zero_point())) /
265 output_scale() + float(output_zero_point());
266 }
267 }
268 }
269 }
270 }
271 }
272
273 for (float& output_value : output_ref) {
274 output_value = std::min(std::max(output_value, float(int8_t(qmin() - 0x80))), float(int8_t(qmax() - 0x80)));
275 }
276
277 // Create, setup, run, and destroy a binary elementwise operator.
278 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
279 xnn_operator_t binary_elementwise_op = nullptr;
280 xnn_status status = xnn_status_unsupported_parameter;
281 switch (operation_type()) {
282 case OperationType::Add:
283 status = xnn_create_add_nd_qs8(
284 input1_zero_point(), input1_scale(),
285 input2_zero_point(), input2_scale(),
286 output_zero_point(), output_scale(),
287 int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
288 0, &binary_elementwise_op);
289 break;
290 default:
291 FAIL() << "Unsupported operation type";
292 }
293 if (status == xnn_status_unsupported_hardware) {
294 GTEST_SKIP();
295 }
296 ASSERT_EQ(xnn_status_success, status);
297 ASSERT_NE(nullptr, binary_elementwise_op);
298
299 // Smart pointer to automatically delete binary_elementwise_op.
300 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
301
302 switch (operation_type()) {
303 case OperationType::Add:
304 ASSERT_EQ(xnn_status_success,
305 xnn_setup_add_nd_qs8(
306 binary_elementwise_op,
307 num_input1_dims(),
308 input1_shape().data(),
309 num_input2_dims(),
310 input2_shape().data(),
311 input1.data(), input2.data(), output.data(),
312 nullptr /* thread pool */));
313 break;
314 default:
315 FAIL() << "Unsupported operation type";
316 }
317
318 ASSERT_EQ(xnn_status_success,
319 xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
320
321 // Verify results.
322 for (size_t i = 0; i < output_dims[0]; i++) {
323 for (size_t j = 0; j < output_dims[1]; j++) {
324 for (size_t k = 0; k < output_dims[2]; k++) {
325 for (size_t l = 0; l < output_dims[3]; l++) {
326 for (size_t m = 0; m < output_dims[4]; m++) {
327 for (size_t n = 0; n < output_dims[5]; n++) {
328 const size_t index =
329 i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
330 ASSERT_NEAR(float(output[index]), output_ref[index], 0.6f)
331 << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")"
332 << ", input1 zero point = " << input1_zero_point() << ", input1 scale = " << input1_scale()
333 << ", input2 zero point = " << input2_zero_point() << ", input2 scale = " << input2_scale()
334 << ", output zero point = " << output_zero_point() << ", output scale = " << output_scale();
335 }
336 }
337 }
338 }
339 }
340 }
341 }
342 }
Frank Barchard01898c02020-06-23 21:49:50 -0700343
344 void TestF16() const {
345 ASSERT_NE(operation_type(), OperationType::Unknown);
346
347 std::random_device random_device;
348 auto rng = std::mt19937(random_device());
Frank Barchard7d2c1f22020-09-14 16:43:53 -0700349 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 1.0f), rng);
Frank Barchard01898c02020-06-23 21:49:50 -0700350 auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
351
352 // Compute generalized shapes.
353 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
354 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
355 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
356 std::fill(input1_dims.begin(), input1_dims.end(), 1);
357 std::fill(input2_dims.begin(), input2_dims.end(), 1);
358 std::fill(output_dims.begin(), output_dims.end(), 1);
359 std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
360 std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
361 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
362 if (input1_dims[i] != 1 && input2_dims[i] != 1) {
363 ASSERT_EQ(input1_dims[i], input2_dims[i]);
364 }
365 output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
366 }
367 const size_t num_output_elements =
368 std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
369
370 // Compute generalized strides.
371 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
372 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
373 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
374 size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
375 for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
376 input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
377 input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
378 output_strides[i - 1] = output_stride;
379 input1_stride *= input1_dims[i - 1];
380 input2_stride *= input2_dims[i - 1];
381 output_stride *= output_dims[i - 1];
382 }
383
384 std::vector<uint16_t> input1(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input1_elements());
385 std::vector<uint16_t> input2(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input2_elements());
386 std::vector<uint16_t> output(num_output_elements);
387 std::vector<float> output_ref(num_output_elements);
388 for (size_t iteration = 0; iteration < iterations(); iteration++) {
389 std::generate(input1.begin(), input1.end(), std::ref(f16rng));
390 std::generate(input2.begin(), input2.end(), std::ref(f16rng));
391 std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
392
393 // Compute reference results.
394 for (size_t i = 0; i < output_dims[0]; i++) {
395 for (size_t j = 0; j < output_dims[1]; j++) {
396 for (size_t k = 0; k < output_dims[2]; k++) {
397 for (size_t l = 0; l < output_dims[3]; l++) {
398 for (size_t m = 0; m < output_dims[4]; m++) {
399 for (size_t n = 0; n < output_dims[5]; n++) {
400 output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
401 fp16_ieee_to_fp32_value(input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]]),
402 fp16_ieee_to_fp32_value(input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]));
403 }
404 }
405 }
406 }
407 }
408 }
409
410 // Compute clamping parameters.
411 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
412 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
413 const float accumulated_range = accumulated_max - accumulated_min;
414 const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
415 const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
416 const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
417 const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
418
419 for (float& output_value : output_ref) {
420 output_value = std::min(std::max(output_value, output_min), output_max);
421 }
422
423 // Create, setup, run, and destroy a binary elementwise operator.
424 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
425 xnn_operator_t binary_elementwise_op = nullptr;
426 xnn_status status = xnn_status_unsupported_parameter;
427 switch (operation_type()) {
428 case OperationType::Add:
429 status = xnn_create_add_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
430 break;
Frank Barchard0ea6a772020-09-09 15:26:31 -0700431 case OperationType::Multiply:
432 status = xnn_create_multiply_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
433 break;
Frank Barchard01898c02020-06-23 21:49:50 -0700434 default:
435 FAIL() << "Unsupported operation type";
436 }
437 if (status == xnn_status_unsupported_hardware) {
438 GTEST_SKIP();
439 }
440 ASSERT_EQ(xnn_status_success, status);
441 ASSERT_NE(nullptr, binary_elementwise_op);
442
443 // Smart pointer to automatically delete binary_elementwise_op.
444 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
445
446 switch (operation_type()) {
447 case OperationType::Add:
448 ASSERT_EQ(xnn_status_success,
449 xnn_setup_add_nd_f16(
450 binary_elementwise_op,
451 num_input1_dims(),
452 input1_shape().data(),
453 num_input2_dims(),
454 input2_shape().data(),
455 input1.data(), input2.data(), output.data(),
456 nullptr /* thread pool */));
457 break;
Frank Barchard0ea6a772020-09-09 15:26:31 -0700458 case OperationType::Multiply:
459 ASSERT_EQ(xnn_status_success,
460 xnn_setup_multiply_nd_f16(
461 binary_elementwise_op,
462 num_input1_dims(),
463 input1_shape().data(),
464 num_input2_dims(),
465 input2_shape().data(),
466 input1.data(), input2.data(), output.data(),
467 nullptr /* thread pool */));
468 break;
Frank Barchard01898c02020-06-23 21:49:50 -0700469 default:
470 FAIL() << "Unsupported operation type";
471 }
472
473 ASSERT_EQ(xnn_status_success,
474 xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
475
476 // Verify results.
477 for (size_t i = 0; i < output_dims[0]; i++) {
478 for (size_t j = 0; j < output_dims[1]; j++) {
479 for (size_t k = 0; k < output_dims[2]; k++) {
480 for (size_t l = 0; l < output_dims[3]; l++) {
481 for (size_t m = 0; m < output_dims[4]; m++) {
482 for (size_t n = 0; n < output_dims[5]; n++) {
483 const size_t index =
484 i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
Frank Barchard2b9d29b2020-09-17 12:03:39 -0700485 ASSERT_NEAR(fp16_ieee_to_fp32_value(output[index]), output_ref[index], std::max(1.0e-4f, std::abs(output_ref[index]) * 1.0e-2f))
Frank Barchard01898c02020-06-23 21:49:50 -0700486 << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
487 }
488 }
489 }
490 }
491 }
492 }
493 }
494 }
Marat Dukhanff209482020-09-03 14:26:53 -0700495
Marat Dukhanca2733c2019-11-15 23:21:17 -0800496 void TestF32() const {
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800497 ASSERT_NE(operation_type(), OperationType::Unknown);
498
Marat Dukhanca2733c2019-11-15 23:21:17 -0800499 std::random_device random_device;
500 auto rng = std::mt19937(random_device());
Marat Dukhan69180502019-12-06 15:00:31 -0800501 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.01f, 1.0f), rng);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800502
503 // Compute generalized shapes.
504 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
505 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
506 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
507 std::fill(input1_dims.begin(), input1_dims.end(), 1);
508 std::fill(input2_dims.begin(), input2_dims.end(), 1);
509 std::fill(output_dims.begin(), output_dims.end(), 1);
510 std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
511 std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
512 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
513 if (input1_dims[i] != 1 && input2_dims[i] != 1) {
514 ASSERT_EQ(input1_dims[i], input2_dims[i]);
515 }
516 output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
517 }
518 const size_t num_output_elements =
519 std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
520
521 // Compute generalized strides.
522 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
523 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
524 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
525 size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
526 for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
527 input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
528 input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
529 output_strides[i - 1] = output_stride;
530 input1_stride *= input1_dims[i - 1];
531 input2_stride *= input2_dims[i - 1];
532 output_stride *= output_dims[i - 1];
533 }
534
535 std::vector<float> input1(XNN_EXTRA_BYTES / sizeof(float) + num_input1_elements());
536 std::vector<float> input2(XNN_EXTRA_BYTES / sizeof(float) + num_input2_elements());
537 std::vector<float> output(num_output_elements);
538 std::vector<float> output_ref(num_output_elements);
539 for (size_t iteration = 0; iteration < iterations(); iteration++) {
540 std::generate(input1.begin(), input1.end(), std::ref(f32rng));
541 std::generate(input2.begin(), input2.end(), std::ref(f32rng));
542 std::fill(output.begin(), output.end(), nanf(""));
543
544 // Compute reference results.
545 for (size_t i = 0; i < output_dims[0]; i++) {
546 for (size_t j = 0; j < output_dims[1]; j++) {
547 for (size_t k = 0; k < output_dims[2]; k++) {
548 for (size_t l = 0; l < output_dims[3]; l++) {
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800549 for (size_t m = 0; m < output_dims[4]; m++) {
550 for (size_t n = 0; n < output_dims[5]; n++) {
551 output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
552 input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]],
553 input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]);
554 }
555 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800556 }
557 }
558 }
559 }
560 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
561 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
562 const float accumulated_range = accumulated_max - accumulated_min;
563 const float output_min = num_output_elements == 1 ?
564 -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin());
565 const float output_max = num_output_elements == 1 ?
566 +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
567 for (float& output_value : output_ref) {
568 output_value = std::min(std::max(output_value, output_min), output_max);
569 }
570
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800571 // Create, setup, run, and destroy a binary elementwise operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800572 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800573 xnn_operator_t binary_elementwise_op = nullptr;
574
575 switch (operation_type()) {
576 case OperationType::Add:
577 ASSERT_EQ(xnn_status_success,
578 xnn_create_add_nd_f32(
579 output_min, output_max,
580 0, &binary_elementwise_op));
581 break;
Marat Dukhan69180502019-12-06 15:00:31 -0800582 case OperationType::Divide:
583 ASSERT_EQ(xnn_status_success,
584 xnn_create_divide_nd_f32(
585 output_min, output_max,
586 0, &binary_elementwise_op));
587 break;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800588 case OperationType::Maximum:
589 ASSERT_EQ(xnn_status_success,
590 xnn_create_maximum_nd_f32(
591 0, &binary_elementwise_op));
592 break;
593 case OperationType::Minimum:
594 ASSERT_EQ(xnn_status_success,
595 xnn_create_minimum_nd_f32(
596 0, &binary_elementwise_op));
597 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800598 case OperationType::Multiply:
599 ASSERT_EQ(xnn_status_success,
600 xnn_create_multiply_nd_f32(
601 output_min, output_max,
602 0, &binary_elementwise_op));
603 break;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800604 case OperationType::Subtract:
605 ASSERT_EQ(xnn_status_success,
606 xnn_create_subtract_nd_f32(
607 output_min, output_max,
608 0, &binary_elementwise_op));
609 break;
Marat Dukhanf7399262020-06-05 10:58:44 -0700610 case OperationType::SquaredDifference:
611 ASSERT_EQ(xnn_status_success,
612 xnn_create_squared_difference_nd_f32(
613 0, &binary_elementwise_op));
614 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800615 default:
616 FAIL() << "Unsupported operation type";
617 }
618 ASSERT_NE(nullptr, binary_elementwise_op);
619
620 // Smart pointer to automatically delete binary_elementwise_op.
621 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
622
623 switch (operation_type()) {
624 case OperationType::Add:
625 ASSERT_EQ(xnn_status_success,
626 xnn_setup_add_nd_f32(
627 binary_elementwise_op,
628 num_input1_dims(),
629 input1_shape().data(),
630 num_input2_dims(),
631 input2_shape().data(),
632 input1.data(), input2.data(), output.data(),
633 nullptr /* thread pool */));
634 break;
Marat Dukhan69180502019-12-06 15:00:31 -0800635 case OperationType::Divide:
636 ASSERT_EQ(xnn_status_success,
637 xnn_setup_divide_nd_f32(
638 binary_elementwise_op,
639 num_input1_dims(),
640 input1_shape().data(),
641 num_input2_dims(),
642 input2_shape().data(),
643 input1.data(), input2.data(), output.data(),
644 nullptr /* thread pool */));
645 break;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800646 case OperationType::Maximum:
647 ASSERT_EQ(xnn_status_success,
648 xnn_setup_maximum_nd_f32(
649 binary_elementwise_op,
650 num_input1_dims(),
651 input1_shape().data(),
652 num_input2_dims(),
653 input2_shape().data(),
654 input1.data(), input2.data(), output.data(),
655 nullptr /* thread pool */));
656 break;
657 case OperationType::Minimum:
658 ASSERT_EQ(xnn_status_success,
659 xnn_setup_minimum_nd_f32(
660 binary_elementwise_op,
661 num_input1_dims(),
662 input1_shape().data(),
663 num_input2_dims(),
664 input2_shape().data(),
665 input1.data(), input2.data(), output.data(),
666 nullptr /* thread pool */));
667 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800668 case OperationType::Multiply:
669 ASSERT_EQ(xnn_status_success,
670 xnn_setup_multiply_nd_f32(
671 binary_elementwise_op,
672 num_input1_dims(),
673 input1_shape().data(),
674 num_input2_dims(),
675 input2_shape().data(),
676 input1.data(), input2.data(), output.data(),
677 nullptr /* thread pool */));
678 break;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800679 case OperationType::Subtract:
680 ASSERT_EQ(xnn_status_success,
681 xnn_setup_subtract_nd_f32(
682 binary_elementwise_op,
683 num_input1_dims(),
684 input1_shape().data(),
685 num_input2_dims(),
686 input2_shape().data(),
687 input1.data(), input2.data(), output.data(),
688 nullptr /* thread pool */));
689 break;
Marat Dukhanf7399262020-06-05 10:58:44 -0700690 case OperationType::SquaredDifference:
691 ASSERT_EQ(xnn_status_success,
692 xnn_setup_squared_difference_nd_f32(
693 binary_elementwise_op,
694 num_input1_dims(),
695 input1_shape().data(),
696 num_input2_dims(),
697 input2_shape().data(),
698 input1.data(), input2.data(), output.data(),
699 nullptr /* thread pool */));
700 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800701 default:
702 FAIL() << "Unsupported operation type";
703 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800704
705 ASSERT_EQ(xnn_status_success,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800706 xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
Marat Dukhanca2733c2019-11-15 23:21:17 -0800707
708 // Verify results.
709 for (size_t i = 0; i < output_dims[0]; i++) {
710 for (size_t j = 0; j < output_dims[1]; j++) {
711 for (size_t k = 0; k < output_dims[2]; k++) {
712 for (size_t l = 0; l < output_dims[3]; l++) {
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800713 for (size_t m = 0; m < output_dims[4]; m++) {
714 for (size_t n = 0; n < output_dims[5]; n++) {
715 const size_t index =
716 i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
717 ASSERT_NEAR(output[index], output_ref[index], 1.0e-6f * std::abs(output_ref[index]))
718 << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
719 }
720 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800721 }
722 }
723 }
724 }
725 }
726 }
727
728 private:
729 std::vector<size_t> input1_shape_;
730 std::vector<size_t> input2_shape_;
Marat Dukhanff209482020-09-03 14:26:53 -0700731 int16_t input1_zero_point_{0};
732 float input1_scale_{1.0f};
733 int16_t input2_zero_point_{0};
734 float input2_scale_{1.0f};
735 int16_t output_zero_point_{0};
736 float output_scale_{1.0f};
Marat Dukhanca2733c2019-11-15 23:21:17 -0800737 uint8_t qmin_{0};
738 uint8_t qmax_{255};
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800739 OperationType operation_type_{OperationType::Unknown};
Marat Dukhanab4af572019-12-03 11:11:18 -0800740 size_t iterations_{3};
Marat Dukhanca2733c2019-11-15 23:21:17 -0800741};