blob: 66a69e51ceaa9def0569274f3d4d22944c18c74c [file] [log] [blame]
Marat Dukhanca2733c2019-11-15 23:21:17 -08001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <array>
12#include <cmath>
13#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <initializer_list>
17#include <limits>
18#include <random>
19#include <vector>
20
21#include <xnnpack.h>
22
23
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080024class BinaryElementwiseOperatorTester {
Marat Dukhanca2733c2019-11-15 23:21:17 -080025 public:
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080026 enum class OperationType {
27 Unknown,
28 Add,
29 Multiply,
30 };
31
32 inline BinaryElementwiseOperatorTester& input1_shape(std::initializer_list<size_t> input1_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080033 assert(input1_shape.size() <= XNN_MAX_TENSOR_DIMS);
34 this->input1_shape_ = std::vector<size_t>(input1_shape);
35 return *this;
36 }
37
38 inline const std::vector<size_t>& input1_shape() const {
39 return this->input1_shape_;
40 }
41
42 inline size_t input1_dim(size_t i) const {
43 return i < num_input1_dims() ? this->input1_shape_[i] : 1;
44 }
45
46 inline size_t num_input1_dims() const {
47 return this->input1_shape_.size();
48 }
49
50 inline size_t num_input1_elements() const {
51 return std::accumulate(
52 this->input1_shape_.begin(), this->input1_shape_.end(), size_t(1), std::multiplies<size_t>());
53 }
54
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080055 inline BinaryElementwiseOperatorTester& input2_shape(std::initializer_list<size_t> input2_shape) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080056 assert(input2_shape.size() <= XNN_MAX_TENSOR_DIMS);
57 this->input2_shape_ = std::vector<size_t>(input2_shape);
58 return *this;
59 }
60
61 inline const std::vector<size_t>& input2_shape() const {
62 return this->input2_shape_;
63 }
64
65 inline size_t input2_dim(size_t i) const {
66 return i < num_input2_dims() ? this->input2_shape_[i] : 1;
67 }
68
69 inline size_t num_input2_dims() const {
70 return this->input2_shape_.size();
71 }
72
73 inline size_t num_input2_elements() const {
74 return std::accumulate(
75 this->input2_shape_.begin(), this->input2_shape_.end(), size_t(1), std::multiplies<size_t>());
76 }
77
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080078 inline BinaryElementwiseOperatorTester& qmin(uint8_t qmin) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080079 this->qmin_ = qmin;
80 return *this;
81 }
82
83 inline uint8_t qmin() const {
84 return this->qmin_;
85 }
86
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080087 inline BinaryElementwiseOperatorTester& qmax(uint8_t qmax) {
Marat Dukhanca2733c2019-11-15 23:21:17 -080088 this->qmax_ = qmax;
89 return *this;
90 }
91
92 inline uint8_t qmax() const {
93 return this->qmax_;
94 }
95
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080096 inline BinaryElementwiseOperatorTester& operation_type(OperationType operation_type) {
97 this->operation_type_ = operation_type;
98 return *this;
99 }
100
101 inline OperationType operation_type() const {
102 return this->operation_type_;
103 }
104
105 inline BinaryElementwiseOperatorTester& iterations(size_t iterations) {
Marat Dukhanca2733c2019-11-15 23:21:17 -0800106 this->iterations_ = iterations;
107 return *this;
108 }
109
110 inline size_t iterations() const {
111 return this->iterations_;
112 }
113
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800114 float Compute(float a, float b) const {
115 switch (operation_type()) {
116 case OperationType::Add:
117 return a + b;
118 case OperationType::Multiply:
119 return a * b;
120 default:
121 return std::nanf("");
122 }
123 }
124
Marat Dukhanca2733c2019-11-15 23:21:17 -0800125 void TestF32() const {
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800126 ASSERT_NE(operation_type(), OperationType::Unknown);
127
Marat Dukhanca2733c2019-11-15 23:21:17 -0800128 std::random_device random_device;
129 auto rng = std::mt19937(random_device());
130 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 1.0f), rng);
131
132 // Compute generalized shapes.
133 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
134 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
135 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
136 std::fill(input1_dims.begin(), input1_dims.end(), 1);
137 std::fill(input2_dims.begin(), input2_dims.end(), 1);
138 std::fill(output_dims.begin(), output_dims.end(), 1);
139 std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
140 std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
141 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
142 if (input1_dims[i] != 1 && input2_dims[i] != 1) {
143 ASSERT_EQ(input1_dims[i], input2_dims[i]);
144 }
145 output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
146 }
147 const size_t num_output_elements =
148 std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
149
150 // Compute generalized strides.
151 std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
152 std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
153 std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
154 size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
155 for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
156 input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
157 input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
158 output_strides[i - 1] = output_stride;
159 input1_stride *= input1_dims[i - 1];
160 input2_stride *= input2_dims[i - 1];
161 output_stride *= output_dims[i - 1];
162 }
163
164 std::vector<float> input1(XNN_EXTRA_BYTES / sizeof(float) + num_input1_elements());
165 std::vector<float> input2(XNN_EXTRA_BYTES / sizeof(float) + num_input2_elements());
166 std::vector<float> output(num_output_elements);
167 std::vector<float> output_ref(num_output_elements);
168 for (size_t iteration = 0; iteration < iterations(); iteration++) {
169 std::generate(input1.begin(), input1.end(), std::ref(f32rng));
170 std::generate(input2.begin(), input2.end(), std::ref(f32rng));
171 std::fill(output.begin(), output.end(), nanf(""));
172
173 // Compute reference results.
174 for (size_t i = 0; i < output_dims[0]; i++) {
175 for (size_t j = 0; j < output_dims[1]; j++) {
176 for (size_t k = 0; k < output_dims[2]; k++) {
177 for (size_t l = 0; l < output_dims[3]; l++) {
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800178 output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3]] = Compute(
179 input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3]],
180 input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3]]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800181 }
182 }
183 }
184 }
185 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
186 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
187 const float accumulated_range = accumulated_max - accumulated_min;
188 const float output_min = num_output_elements == 1 ?
189 -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin());
190 const float output_max = num_output_elements == 1 ?
191 +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
192 for (float& output_value : output_ref) {
193 output_value = std::min(std::max(output_value, output_min), output_max);
194 }
195
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800196 // Create, setup, run, and destroy a binary elementwise operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800197 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800198 xnn_operator_t binary_elementwise_op = nullptr;
199
200 switch (operation_type()) {
201 case OperationType::Add:
202 ASSERT_EQ(xnn_status_success,
203 xnn_create_add_nd_f32(
204 output_min, output_max,
205 0, &binary_elementwise_op));
206 break;
207 case OperationType::Multiply:
208 ASSERT_EQ(xnn_status_success,
209 xnn_create_multiply_nd_f32(
210 output_min, output_max,
211 0, &binary_elementwise_op));
212 break;
213 default:
214 FAIL() << "Unsupported operation type";
215 }
216 ASSERT_NE(nullptr, binary_elementwise_op);
217
218 // Smart pointer to automatically delete binary_elementwise_op.
219 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
220
221 switch (operation_type()) {
222 case OperationType::Add:
223 ASSERT_EQ(xnn_status_success,
224 xnn_setup_add_nd_f32(
225 binary_elementwise_op,
226 num_input1_dims(),
227 input1_shape().data(),
228 num_input2_dims(),
229 input2_shape().data(),
230 input1.data(), input2.data(), output.data(),
231 nullptr /* thread pool */));
232 break;
233 case OperationType::Multiply:
234 ASSERT_EQ(xnn_status_success,
235 xnn_setup_multiply_nd_f32(
236 binary_elementwise_op,
237 num_input1_dims(),
238 input1_shape().data(),
239 num_input2_dims(),
240 input2_shape().data(),
241 input1.data(), input2.data(), output.data(),
242 nullptr /* thread pool */));
243 break;
244 default:
245 FAIL() << "Unsupported operation type";
246 }
Marat Dukhanca2733c2019-11-15 23:21:17 -0800247
248 ASSERT_EQ(xnn_status_success,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800249 xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
Marat Dukhanca2733c2019-11-15 23:21:17 -0800250
251 // Verify results.
252 for (size_t i = 0; i < output_dims[0]; i++) {
253 for (size_t j = 0; j < output_dims[1]; j++) {
254 for (size_t k = 0; k < output_dims[2]; k++) {
255 for (size_t l = 0; l < output_dims[3]; l++) {
256 const size_t index =
257 i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3];
258 ASSERT_NEAR(output[index], output_ref[index], 1.0e-6f * std::abs(output_ref[index]))
259 << "(i, j, k, l) = (" << i << ", " << j << ", " << k << ", " << l << ")";
260 }
261 }
262 }
263 }
264 }
265 }
266
267 private:
268 std::vector<size_t> input1_shape_;
269 std::vector<size_t> input2_shape_;
270 uint8_t qmin_{0};
271 uint8_t qmax_{255};
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800272 OperationType operation_type_{OperationType::Unknown};
Marat Dukhanab4af572019-12-03 11:11:18 -0800273 size_t iterations_{3};
Marat Dukhanca2733c2019-11-15 23:21:17 -0800274};