blob: 4121e314e9e9b538b759f1f0d3909969e6fa4fb6 [file] [log] [blame]
Marat Dukhanca2733c2019-11-15 23:21:17 -08001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <array>
12#include <cmath>
13#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <initializer_list>
17#include <limits>
Marat Dukhanc5ee9ff2020-04-13 01:32:59 -070018#include <numeric>
Marat Dukhanca2733c2019-11-15 23:21:17 -080019#include <random>
20#include <vector>
21
Frank Barchard01898c02020-06-23 21:49:50 -070022#include <fp16.h>
23
Marat Dukhanca2733c2019-11-15 23:21:17 -080024#include <xnnpack.h>
25
26
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080027class BinaryElementwiseOperatorTester {
Marat Dukhanca2733c2019-11-15 23:21:17 -080028 public:
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080029 enum class OperationType {
30 Unknown,
31 Add,
Marat Dukhan69180502019-12-06 15:00:31 -080032 Divide,
Marat Dukhan79e7f842019-12-05 14:35:50 -080033 Maximum,
34 Minimum,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080035 Multiply,
Marat Dukhan05f3f6d2019-12-03 15:13:53 -080036 Subtract,
Marat Dukhanf7399262020-06-05 10:58:44 -070037 SquaredDifference,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080038 };
39
40 inline BinaryElementwiseOperatorTester& input1_shape(std::initializer_list<size_t> input1_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080041 assert(input1_shape.size() <= XNN_MAX_TENSOR_DIMS);
42 this->input1_shape_ = std::vector<size_t>(input1_shape);
43 return *this;
44 }
45
46 inline const std::vector<size_t>& input1_shape() const {
47 return this->input1_shape_;
48 }
49
50 inline size_t input1_dim(size_t i) const {
51 return i < num_input1_dims() ? this->input1_shape_[i] : 1;
52 }
53
54 inline size_t num_input1_dims() const {
55 return this->input1_shape_.size();
56 }
57
58 inline size_t num_input1_elements() const {
59 return std::accumulate(
60 this->input1_shape_.begin(), this->input1_shape_.end(), size_t(1), std::multiplies<size_t>());
61 }
62
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080063 inline BinaryElementwiseOperatorTester& input2_shape(std::initializer_list<size_t> input2_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080064 assert(input2_shape.size() <= XNN_MAX_TENSOR_DIMS);
65 this->input2_shape_ = std::vector<size_t>(input2_shape);
66 return *this;
67 }
68
69 inline const std::vector<size_t>& input2_shape() const {
70 return this->input2_shape_;
71 }
72
73 inline size_t input2_dim(size_t i) const {
74 return i < num_input2_dims() ? this->input2_shape_[i] : 1;
75 }
76
77 inline size_t num_input2_dims() const {
78 return this->input2_shape_.size();
79 }
80
81 inline size_t num_input2_elements() const {
82 return std::accumulate(
83 this->input2_shape_.begin(), this->input2_shape_.end(), size_t(1), std::multiplies<size_t>());
84 }
85
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080086 inline BinaryElementwiseOperatorTester& qmin(uint8_t qmin) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080087 this->qmin_ = qmin;
88 return *this;
89 }
90
91 inline uint8_t qmin() const {
92 return this->qmin_;
93 }
94
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080095 inline BinaryElementwiseOperatorTester& qmax(uint8_t qmax) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080096 this->qmax_ = qmax;
97 return *this;
98 }
99
100 inline uint8_t qmax() const {
101 return this->qmax_;
102 }
103
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800104 inline BinaryElementwiseOperatorTester& operation_type(OperationType operation_type) {
105 this->operation_type_ = operation_type;
106 return *this;
107 }
108
109 inline OperationType operation_type() const {
110 return this->operation_type_;
111 }
112
113 inline BinaryElementwiseOperatorTester& iterations(size_t iterations) {
Marat Dukhanca2733c2019-11-15 23:21:17 -0800114 this->iterations_ = iterations;
115 return *this;
116 }
117
118 inline size_t iterations() const {
119 return this->iterations_;
120 }
121
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800122 float Compute(float a, float b) const {
123 switch (operation_type()) {
124 case OperationType::Add:
125 return a + b;
Marat Dukhan69180502019-12-06 15:00:31 -0800126 case OperationType::Divide:
127 return a / b;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800128 case OperationType::Maximum:
129 return std::max<float>(a, b);
130 case OperationType::Minimum:
131 return std::min<float>(a, b);
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800132 case OperationType::Multiply:
133 return a * b;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800134 case OperationType::Subtract:
135 return a - b;
Marat Dukhanf7399262020-06-05 10:58:44 -0700136 case OperationType::SquaredDifference:
137 return (a - b) * (a - b);
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800138 default:
139 return std::nanf("");
140 }
141 }
142
Frank Barchard01898c02020-06-23 21:49:50 -0700143
144 void TestF16() const {
145 ASSERT_NE(operation_type(), OperationType::Unknown);
146
147 std::random_device random_device;
148 auto rng = std::mt19937(random_device());
149 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.1f, 1.0f), rng);
150 auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
151
152 // Compute generalized shapes.
153 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
154 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
155 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
156 std::fill(input1_dims.begin(), input1_dims.end(), 1);
157 std::fill(input2_dims.begin(), input2_dims.end(), 1);
158 std::fill(output_dims.begin(), output_dims.end(), 1);
159 std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
160 std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
161 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
162 if (input1_dims[i] != 1 && input2_dims[i] != 1) {
163 ASSERT_EQ(input1_dims[i], input2_dims[i]);
164 }
165 output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
166 }
167 const size_t num_output_elements =
168 std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
169
170 // Compute generalized strides.
171 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
172 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
173 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
174 size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
175 for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
176 input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
177 input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
178 output_strides[i - 1] = output_stride;
179 input1_stride *= input1_dims[i - 1];
180 input2_stride *= input2_dims[i - 1];
181 output_stride *= output_dims[i - 1];
182 }
183
184 std::vector<uint16_t> input1(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input1_elements());
185 std::vector<uint16_t> input2(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input2_elements());
186 std::vector<uint16_t> output(num_output_elements);
187 std::vector<float> output_ref(num_output_elements);
188 for (size_t iteration = 0; iteration < iterations(); iteration++) {
189 std::generate(input1.begin(), input1.end(), std::ref(f16rng));
190 std::generate(input2.begin(), input2.end(), std::ref(f16rng));
191 std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
192
193 // Compute reference results.
194 for (size_t i = 0; i < output_dims[0]; i++) {
195 for (size_t j = 0; j < output_dims[1]; j++) {
196 for (size_t k = 0; k < output_dims[2]; k++) {
197 for (size_t l = 0; l < output_dims[3]; l++) {
198 for (size_t m = 0; m < output_dims[4]; m++) {
199 for (size_t n = 0; n < output_dims[5]; n++) {
200 output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
201 fp16_ieee_to_fp32_value(input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]]),
202 fp16_ieee_to_fp32_value(input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]));
203 }
204 }
205 }
206 }
207 }
208 }
209
210 // Compute clamping parameters.
211 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
212 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
213 const float accumulated_range = accumulated_max - accumulated_min;
214 const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
215 const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
216 const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
217 const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
218
219 for (float& output_value : output_ref) {
220 output_value = std::min(std::max(output_value, output_min), output_max);
221 }
222
223 // Create, setup, run, and destroy a binary elementwise operator.
224 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
225 xnn_operator_t binary_elementwise_op = nullptr;
226 xnn_status status = xnn_status_unsupported_parameter;
227 switch (operation_type()) {
228 case OperationType::Add:
229 status = xnn_create_add_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
230 break;
231 default:
232 FAIL() << "Unsupported operation type";
233 }
234 if (status == xnn_status_unsupported_hardware) {
235 GTEST_SKIP();
236 }
237 ASSERT_EQ(xnn_status_success, status);
238 ASSERT_NE(nullptr, binary_elementwise_op);
239
240 // Smart pointer to automatically delete binary_elementwise_op.
241 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
242
243 switch (operation_type()) {
244 case OperationType::Add:
245 ASSERT_EQ(xnn_status_success,
246 xnn_setup_add_nd_f16(
247 binary_elementwise_op,
248 num_input1_dims(),
249 input1_shape().data(),
250 num_input2_dims(),
251 input2_shape().data(),
252 input1.data(), input2.data(), output.data(),
253 nullptr /* thread pool */));
254 break;
255 default:
256 FAIL() << "Unsupported operation type";
257 }
258
259 ASSERT_EQ(xnn_status_success,
260 xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
261
262 // Verify results.
263 for (size_t i = 0; i < output_dims[0]; i++) {
264 for (size_t j = 0; j < output_dims[1]; j++) {
265 for (size_t k = 0; k < output_dims[2]; k++) {
266 for (size_t l = 0; l < output_dims[3]; l++) {
267 for (size_t m = 0; m < output_dims[4]; m++) {
268 for (size_t n = 0; n < output_dims[5]; n++) {
269 const size_t index =
270 i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
271 ASSERT_NEAR(fp16_ieee_to_fp32_value(output[index]), output_ref[index], 1.0e-2f * std::abs(output_ref[index]))
272 << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
273 }
274 }
275 }
276 }
277 }
278 }
279 }
280 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800281 void TestF32() const {
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800282 ASSERT_NE(operation_type(), OperationType::Unknown);
283
Marat Dukhanca2733c2019-11-15 23:21:17 -0800284 std::random_device random_device;
285 auto rng = std::mt19937(random_device());
Marat Dukhan69180502019-12-06 15:00:31 -0800286 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.01f, 1.0f), rng);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800287
288 // Compute generalized shapes.
289 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
290 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
291 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
292 std::fill(input1_dims.begin(), input1_dims.end(), 1);
293 std::fill(input2_dims.begin(), input2_dims.end(), 1);
294 std::fill(output_dims.begin(), output_dims.end(), 1);
295 std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
296 std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
297 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
298 if (input1_dims[i] != 1 && input2_dims[i] != 1) {
299 ASSERT_EQ(input1_dims[i], input2_dims[i]);
300 }
301 output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
302 }
303 const size_t num_output_elements =
304 std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
305
306 // Compute generalized strides.
307 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
308 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
309 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
310 size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
311 for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
312 input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
313 input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
314 output_strides[i - 1] = output_stride;
315 input1_stride *= input1_dims[i - 1];
316 input2_stride *= input2_dims[i - 1];
317 output_stride *= output_dims[i - 1];
318 }
319
320 std::vector<float> input1(XNN_EXTRA_BYTES / sizeof(float) + num_input1_elements());
321 std::vector<float> input2(XNN_EXTRA_BYTES / sizeof(float) + num_input2_elements());
322 std::vector<float> output(num_output_elements);
323 std::vector<float> output_ref(num_output_elements);
324 for (size_t iteration = 0; iteration < iterations(); iteration++) {
325 std::generate(input1.begin(), input1.end(), std::ref(f32rng));
326 std::generate(input2.begin(), input2.end(), std::ref(f32rng));
327 std::fill(output.begin(), output.end(), nanf(""));
328
329 // Compute reference results.
330 for (size_t i = 0; i < output_dims[0]; i++) {
331 for (size_t j = 0; j < output_dims[1]; j++) {
332 for (size_t k = 0; k < output_dims[2]; k++) {
333 for (size_t l = 0; l < output_dims[3]; l++) {
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800334 for (size_t m = 0; m < output_dims[4]; m++) {
335 for (size_t n = 0; n < output_dims[5]; n++) {
336 output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
337 input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]],
338 input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]);
339 }
340 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800341 }
342 }
343 }
344 }
345 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
346 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
347 const float accumulated_range = accumulated_max - accumulated_min;
348 const float output_min = num_output_elements == 1 ?
349 -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin());
350 const float output_max = num_output_elements == 1 ?
351 +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
352 for (float& output_value : output_ref) {
353 output_value = std::min(std::max(output_value, output_min), output_max);
354 }
355
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800356 // Create, setup, run, and destroy a binary elementwise operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800357 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800358 xnn_operator_t binary_elementwise_op = nullptr;
359
360 switch (operation_type()) {
361 case OperationType::Add:
362 ASSERT_EQ(xnn_status_success,
363 xnn_create_add_nd_f32(
364 output_min, output_max,
365 0, &binary_elementwise_op));
366 break;
Marat Dukhan69180502019-12-06 15:00:31 -0800367 case OperationType::Divide:
368 ASSERT_EQ(xnn_status_success,
369 xnn_create_divide_nd_f32(
370 output_min, output_max,
371 0, &binary_elementwise_op));
372 break;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800373 case OperationType::Maximum:
374 ASSERT_EQ(xnn_status_success,
375 xnn_create_maximum_nd_f32(
376 0, &binary_elementwise_op));
377 break;
378 case OperationType::Minimum:
379 ASSERT_EQ(xnn_status_success,
380 xnn_create_minimum_nd_f32(
381 0, &binary_elementwise_op));
382 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800383 case OperationType::Multiply:
384 ASSERT_EQ(xnn_status_success,
385 xnn_create_multiply_nd_f32(
386 output_min, output_max,
387 0, &binary_elementwise_op));
388 break;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800389 case OperationType::Subtract:
390 ASSERT_EQ(xnn_status_success,
391 xnn_create_subtract_nd_f32(
392 output_min, output_max,
393 0, &binary_elementwise_op));
394 break;
Marat Dukhanf7399262020-06-05 10:58:44 -0700395 case OperationType::SquaredDifference:
396 ASSERT_EQ(xnn_status_success,
397 xnn_create_squared_difference_nd_f32(
398 0, &binary_elementwise_op));
399 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800400 default:
401 FAIL() << "Unsupported operation type";
402 }
403 ASSERT_NE(nullptr, binary_elementwise_op);
404
405 // Smart pointer to automatically delete binary_elementwise_op.
406 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
407
408 switch (operation_type()) {
409 case OperationType::Add:
410 ASSERT_EQ(xnn_status_success,
411 xnn_setup_add_nd_f32(
412 binary_elementwise_op,
413 num_input1_dims(),
414 input1_shape().data(),
415 num_input2_dims(),
416 input2_shape().data(),
417 input1.data(), input2.data(), output.data(),
418 nullptr /* thread pool */));
419 break;
Marat Dukhan69180502019-12-06 15:00:31 -0800420 case OperationType::Divide:
421 ASSERT_EQ(xnn_status_success,
422 xnn_setup_divide_nd_f32(
423 binary_elementwise_op,
424 num_input1_dims(),
425 input1_shape().data(),
426 num_input2_dims(),
427 input2_shape().data(),
428 input1.data(), input2.data(), output.data(),
429 nullptr /* thread pool */));
430 break;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800431 case OperationType::Maximum:
432 ASSERT_EQ(xnn_status_success,
433 xnn_setup_maximum_nd_f32(
434 binary_elementwise_op,
435 num_input1_dims(),
436 input1_shape().data(),
437 num_input2_dims(),
438 input2_shape().data(),
439 input1.data(), input2.data(), output.data(),
440 nullptr /* thread pool */));
441 break;
442 case OperationType::Minimum:
443 ASSERT_EQ(xnn_status_success,
444 xnn_setup_minimum_nd_f32(
445 binary_elementwise_op,
446 num_input1_dims(),
447 input1_shape().data(),
448 num_input2_dims(),
449 input2_shape().data(),
450 input1.data(), input2.data(), output.data(),
451 nullptr /* thread pool */));
452 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800453 case OperationType::Multiply:
454 ASSERT_EQ(xnn_status_success,
455 xnn_setup_multiply_nd_f32(
456 binary_elementwise_op,
457 num_input1_dims(),
458 input1_shape().data(),
459 num_input2_dims(),
460 input2_shape().data(),
461 input1.data(), input2.data(), output.data(),
462 nullptr /* thread pool */));
463 break;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800464 case OperationType::Subtract:
465 ASSERT_EQ(xnn_status_success,
466 xnn_setup_subtract_nd_f32(
467 binary_elementwise_op,
468 num_input1_dims(),
469 input1_shape().data(),
470 num_input2_dims(),
471 input2_shape().data(),
472 input1.data(), input2.data(), output.data(),
473 nullptr /* thread pool */));
474 break;
Marat Dukhanf7399262020-06-05 10:58:44 -0700475 case OperationType::SquaredDifference:
476 ASSERT_EQ(xnn_status_success,
477 xnn_setup_squared_difference_nd_f32(
478 binary_elementwise_op,
479 num_input1_dims(),
480 input1_shape().data(),
481 num_input2_dims(),
482 input2_shape().data(),
483 input1.data(), input2.data(), output.data(),
484 nullptr /* thread pool */));
485 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800486 default:
487 FAIL() << "Unsupported operation type";
488 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800489
490 ASSERT_EQ(xnn_status_success,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800491 xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
Marat Dukhanca2733c2019-11-15 23:21:17 -0800492
493 // Verify results.
494 for (size_t i = 0; i < output_dims[0]; i++) {
495 for (size_t j = 0; j < output_dims[1]; j++) {
496 for (size_t k = 0; k < output_dims[2]; k++) {
497 for (size_t l = 0; l < output_dims[3]; l++) {
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800498 for (size_t m = 0; m < output_dims[4]; m++) {
499 for (size_t n = 0; n < output_dims[5]; n++) {
500 const size_t index =
501 i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
502 ASSERT_NEAR(output[index], output_ref[index], 1.0e-6f * std::abs(output_ref[index]))
503 << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
504 }
505 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800506 }
507 }
508 }
509 }
510 }
511 }
512
513 private:
514 std::vector<size_t> input1_shape_;
515 std::vector<size_t> input2_shape_;
516 uint8_t qmin_{0};
517 uint8_t qmax_{255};
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800518 OperationType operation_type_{OperationType::Unknown};
Marat Dukhanab4af572019-12-03 11:11:18 -0800519 size_t iterations_{3};
Marat Dukhanca2733c2019-11-15 23:21:17 -0800520};