blob: 16b8a5245bac635d629c48c07b23de21292cc374 [file] [log] [blame]
Marat Dukhanca2733c2019-11-15 23:21:17 -08001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <array>
12#include <cmath>
13#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <initializer_list>
17#include <limits>
18#include <random>
19#include <vector>
20
21#include <xnnpack.h>
22
23
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080024class BinaryElementwiseOperatorTester {
Marat Dukhanca2733c2019-11-15 23:21:17 -080025 public:
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080026 enum class OperationType {
27 Unknown,
28 Add,
Marat Dukhan69180502019-12-06 15:00:31 -080029 Divide,
Marat Dukhan79e7f842019-12-05 14:35:50 -080030 Maximum,
31 Minimum,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080032 Multiply,
Marat Dukhan05f3f6d2019-12-03 15:13:53 -080033 Subtract,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080034 };
35
36 inline BinaryElementwiseOperatorTester& input1_shape(std::initializer_list<size_t> input1_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080037 assert(input1_shape.size() <= XNN_MAX_TENSOR_DIMS);
38 this->input1_shape_ = std::vector<size_t>(input1_shape);
39 return *this;
40 }
41
42 inline const std::vector<size_t>& input1_shape() const {
43 return this->input1_shape_;
44 }
45
46 inline size_t input1_dim(size_t i) const {
47 return i < num_input1_dims() ? this->input1_shape_[i] : 1;
48 }
49
50 inline size_t num_input1_dims() const {
51 return this->input1_shape_.size();
52 }
53
54 inline size_t num_input1_elements() const {
55 return std::accumulate(
56 this->input1_shape_.begin(), this->input1_shape_.end(), size_t(1), std::multiplies<size_t>());
57 }
58
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080059 inline BinaryElementwiseOperatorTester& input2_shape(std::initializer_list<size_t> input2_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080060 assert(input2_shape.size() <= XNN_MAX_TENSOR_DIMS);
61 this->input2_shape_ = std::vector<size_t>(input2_shape);
62 return *this;
63 }
64
65 inline const std::vector<size_t>& input2_shape() const {
66 return this->input2_shape_;
67 }
68
69 inline size_t input2_dim(size_t i) const {
70 return i < num_input2_dims() ? this->input2_shape_[i] : 1;
71 }
72
73 inline size_t num_input2_dims() const {
74 return this->input2_shape_.size();
75 }
76
77 inline size_t num_input2_elements() const {
78 return std::accumulate(
79 this->input2_shape_.begin(), this->input2_shape_.end(), size_t(1), std::multiplies<size_t>());
80 }
81
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080082 inline BinaryElementwiseOperatorTester& qmin(uint8_t qmin) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080083 this->qmin_ = qmin;
84 return *this;
85 }
86
87 inline uint8_t qmin() const {
88 return this->qmin_;
89 }
90
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080091 inline BinaryElementwiseOperatorTester& qmax(uint8_t qmax) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080092 this->qmax_ = qmax;
93 return *this;
94 }
95
96 inline uint8_t qmax() const {
97 return this->qmax_;
98 }
99
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800100 inline BinaryElementwiseOperatorTester& operation_type(OperationType operation_type) {
101 this->operation_type_ = operation_type;
102 return *this;
103 }
104
105 inline OperationType operation_type() const {
106 return this->operation_type_;
107 }
108
109 inline BinaryElementwiseOperatorTester& iterations(size_t iterations) {
Marat Dukhanca2733c2019-11-15 23:21:17 -0800110 this->iterations_ = iterations;
111 return *this;
112 }
113
114 inline size_t iterations() const {
115 return this->iterations_;
116 }
117
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800118 float Compute(float a, float b) const {
119 switch (operation_type()) {
120 case OperationType::Add:
121 return a + b;
Marat Dukhan69180502019-12-06 15:00:31 -0800122 case OperationType::Divide:
123 return a / b;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800124 case OperationType::Maximum:
125 return std::max<float>(a, b);
126 case OperationType::Minimum:
127 return std::min<float>(a, b);
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800128 case OperationType::Multiply:
129 return a * b;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800130 case OperationType::Subtract:
131 return a - b;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800132 default:
133 return std::nanf("");
134 }
135 }
136
Marat Dukhanca2733c2019-11-15 23:21:17 -0800137 void TestF32() const {
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800138 ASSERT_NE(operation_type(), OperationType::Unknown);
139
Marat Dukhanca2733c2019-11-15 23:21:17 -0800140 std::random_device random_device;
141 auto rng = std::mt19937(random_device());
Marat Dukhan69180502019-12-06 15:00:31 -0800142 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.01f, 1.0f), rng);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800143
144 // Compute generalized shapes.
145 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
146 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
147 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
148 std::fill(input1_dims.begin(), input1_dims.end(), 1);
149 std::fill(input2_dims.begin(), input2_dims.end(), 1);
150 std::fill(output_dims.begin(), output_dims.end(), 1);
151 std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
152 std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
153 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
154 if (input1_dims[i] != 1 && input2_dims[i] != 1) {
155 ASSERT_EQ(input1_dims[i], input2_dims[i]);
156 }
157 output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
158 }
159 const size_t num_output_elements =
160 std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
161
162 // Compute generalized strides.
163 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
164 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
165 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
166 size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
167 for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
168 input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
169 input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
170 output_strides[i - 1] = output_stride;
171 input1_stride *= input1_dims[i - 1];
172 input2_stride *= input2_dims[i - 1];
173 output_stride *= output_dims[i - 1];
174 }
175
176 std::vector<float> input1(XNN_EXTRA_BYTES / sizeof(float) + num_input1_elements());
177 std::vector<float> input2(XNN_EXTRA_BYTES / sizeof(float) + num_input2_elements());
178 std::vector<float> output(num_output_elements);
179 std::vector<float> output_ref(num_output_elements);
180 for (size_t iteration = 0; iteration < iterations(); iteration++) {
181 std::generate(input1.begin(), input1.end(), std::ref(f32rng));
182 std::generate(input2.begin(), input2.end(), std::ref(f32rng));
183 std::fill(output.begin(), output.end(), nanf(""));
184
185 // Compute reference results.
186 for (size_t i = 0; i < output_dims[0]; i++) {
187 for (size_t j = 0; j < output_dims[1]; j++) {
188 for (size_t k = 0; k < output_dims[2]; k++) {
189 for (size_t l = 0; l < output_dims[3]; l++) {
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800190 for (size_t m = 0; m < output_dims[4]; m++) {
191 for (size_t n = 0; n < output_dims[5]; n++) {
192 output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
193 input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]],
194 input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]);
195 }
196 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800197 }
198 }
199 }
200 }
201 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
202 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
203 const float accumulated_range = accumulated_max - accumulated_min;
204 const float output_min = num_output_elements == 1 ?
205 -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin());
206 const float output_max = num_output_elements == 1 ?
207 +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
208 for (float& output_value : output_ref) {
209 output_value = std::min(std::max(output_value, output_min), output_max);
210 }
211
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800212 // Create, setup, run, and destroy a binary elementwise operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800213 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800214 xnn_operator_t binary_elementwise_op = nullptr;
215
216 switch (operation_type()) {
217 case OperationType::Add:
218 ASSERT_EQ(xnn_status_success,
219 xnn_create_add_nd_f32(
220 output_min, output_max,
221 0, &binary_elementwise_op));
222 break;
Marat Dukhan69180502019-12-06 15:00:31 -0800223 case OperationType::Divide:
224 ASSERT_EQ(xnn_status_success,
225 xnn_create_divide_nd_f32(
226 output_min, output_max,
227 0, &binary_elementwise_op));
228 break;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800229 case OperationType::Maximum:
230 ASSERT_EQ(xnn_status_success,
231 xnn_create_maximum_nd_f32(
232 0, &binary_elementwise_op));
233 break;
234 case OperationType::Minimum:
235 ASSERT_EQ(xnn_status_success,
236 xnn_create_minimum_nd_f32(
237 0, &binary_elementwise_op));
238 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800239 case OperationType::Multiply:
240 ASSERT_EQ(xnn_status_success,
241 xnn_create_multiply_nd_f32(
242 output_min, output_max,
243 0, &binary_elementwise_op));
244 break;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800245 case OperationType::Subtract:
246 ASSERT_EQ(xnn_status_success,
247 xnn_create_subtract_nd_f32(
248 output_min, output_max,
249 0, &binary_elementwise_op));
250 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800251 default:
252 FAIL() << "Unsupported operation type";
253 }
254 ASSERT_NE(nullptr, binary_elementwise_op);
255
256 // Smart pointer to automatically delete binary_elementwise_op.
257 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
258
259 switch (operation_type()) {
260 case OperationType::Add:
261 ASSERT_EQ(xnn_status_success,
262 xnn_setup_add_nd_f32(
263 binary_elementwise_op,
264 num_input1_dims(),
265 input1_shape().data(),
266 num_input2_dims(),
267 input2_shape().data(),
268 input1.data(), input2.data(), output.data(),
269 nullptr /* thread pool */));
270 break;
Marat Dukhan69180502019-12-06 15:00:31 -0800271 case OperationType::Divide:
272 ASSERT_EQ(xnn_status_success,
273 xnn_setup_divide_nd_f32(
274 binary_elementwise_op,
275 num_input1_dims(),
276 input1_shape().data(),
277 num_input2_dims(),
278 input2_shape().data(),
279 input1.data(), input2.data(), output.data(),
280 nullptr /* thread pool */));
281 break;
Marat Dukhan79e7f842019-12-05 14:35:50 -0800282 case OperationType::Maximum:
283 ASSERT_EQ(xnn_status_success,
284 xnn_setup_maximum_nd_f32(
285 binary_elementwise_op,
286 num_input1_dims(),
287 input1_shape().data(),
288 num_input2_dims(),
289 input2_shape().data(),
290 input1.data(), input2.data(), output.data(),
291 nullptr /* thread pool */));
292 break;
293 case OperationType::Minimum:
294 ASSERT_EQ(xnn_status_success,
295 xnn_setup_minimum_nd_f32(
296 binary_elementwise_op,
297 num_input1_dims(),
298 input1_shape().data(),
299 num_input2_dims(),
300 input2_shape().data(),
301 input1.data(), input2.data(), output.data(),
302 nullptr /* thread pool */));
303 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800304 case OperationType::Multiply:
305 ASSERT_EQ(xnn_status_success,
306 xnn_setup_multiply_nd_f32(
307 binary_elementwise_op,
308 num_input1_dims(),
309 input1_shape().data(),
310 num_input2_dims(),
311 input2_shape().data(),
312 input1.data(), input2.data(), output.data(),
313 nullptr /* thread pool */));
314 break;
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800315 case OperationType::Subtract:
316 ASSERT_EQ(xnn_status_success,
317 xnn_setup_subtract_nd_f32(
318 binary_elementwise_op,
319 num_input1_dims(),
320 input1_shape().data(),
321 num_input2_dims(),
322 input2_shape().data(),
323 input1.data(), input2.data(), output.data(),
324 nullptr /* thread pool */));
325 break;
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800326 default:
327 FAIL() << "Unsupported operation type";
328 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800329
330 ASSERT_EQ(xnn_status_success,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800331 xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
Marat Dukhanca2733c2019-11-15 23:21:17 -0800332
333 // Verify results.
334 for (size_t i = 0; i < output_dims[0]; i++) {
335 for (size_t j = 0; j < output_dims[1]; j++) {
336 for (size_t k = 0; k < output_dims[2]; k++) {
337 for (size_t l = 0; l < output_dims[3]; l++) {
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800338 for (size_t m = 0; m < output_dims[4]; m++) {
339 for (size_t n = 0; n < output_dims[5]; n++) {
340 const size_t index =
341 i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
342 ASSERT_NEAR(output[index], output_ref[index], 1.0e-6f * std::abs(output_ref[index]))
343 << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
344 }
345 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800346 }
347 }
348 }
349 }
350 }
351 }
352
353 private:
354 std::vector<size_t> input1_shape_;
355 std::vector<size_t> input2_shape_;
356 uint8_t qmin_{0};
357 uint8_t qmax_{255};
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800358 OperationType operation_type_{OperationType::Unknown};
Marat Dukhanab4af572019-12-03 11:11:18 -0800359 size_t iterations_{3};
Marat Dukhanca2733c2019-11-15 23:21:17 -0800360};