blob: c3161b6f5c95ce73b4fe85d44a0ce703f7b41d01 [file] [log] [blame]
Marat Dukhanb1a0fc32019-12-02 19:32:02 -08001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <assert.h>
7#include <math.h>
8#include <stddef.h>
9#include <stdint.h>
10#include <stdlib.h>
11
12#include <xnnpack.h>
13#include <xnnpack/allocator.h>
14#include <xnnpack/log.h>
15#include <xnnpack/operator.h>
16#include <xnnpack/params-init.h>
17#include <xnnpack/params.h>
18
19
20static enum xnn_status create_binary_elementwise_nd_f32(
21 float output_min,
22 float output_max,
23 uint32_t flags,
24 enum xnn_operator_type operator_type,
25 xnn_operator_t* binary_elementwise_op_out)
26{
27 xnn_operator_t binary_elementwise_op = NULL;
28 enum xnn_status status = xnn_status_uninitialized;
29
30 if (!xnn_params.initialized) {
Marat Dukhan69180502019-12-06 15:00:31 -080031 xnn_log_error("failed to create Add/Subtract/Multiply/Divide/Minimum/Maximum operator: XNNPACK is not initialized");
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080032 goto error;
33 }
34
35 status = xnn_status_invalid_parameter;
36
37 if (isnan(output_min)) {
38 xnn_log_error(
Marat Dukhan69180502019-12-06 15:00:31 -080039 "failed to create Add/Subtract/Multiply/Divide/Minimum/Maximum operator with NaN output lower bound: lower bound must be non-NaN");
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080040 goto error;
41 }
42
43 if (isnan(output_max)) {
44 xnn_log_error(
Marat Dukhan69180502019-12-06 15:00:31 -080045 "failed to create Add/Subtract/Multiply/Divide/Minimum/Maximum operator with NaN output upper bound: upper bound must be non-NaN");
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080046 goto error;
47 }
48
49 if (output_min >= output_max) {
50 xnn_log_error(
Marat Dukhan69180502019-12-06 15:00:31 -080051 "failed to create Add/Subtract/Multiply/Divide/Minimum/Maximum operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080052 output_min, output_max);
53 goto error;
54 }
55
56 status = xnn_status_out_of_memory;
57
58 binary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
59 if (binary_elementwise_op == NULL) {
Marat Dukhan69180502019-12-06 15:00:31 -080060 xnn_log_error("failed to allocate %zu bytes for Add/Subtract/Multiply/Divide/Minimum/Maximum operator descriptor", sizeof(struct xnn_operator));
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080061 goto error;
62 }
63
64 binary_elementwise_op->f32_output_params = xnn_init_f32_output_params(output_min, output_max);
65
66 binary_elementwise_op->type = operator_type;
67 binary_elementwise_op->ukernel.type = xnn_ukernel_type_binary_elementwise;
68
69 binary_elementwise_op->state = xnn_run_state_invalid;
70
71 *binary_elementwise_op_out = binary_elementwise_op;
72 return xnn_status_success;
73
74error:
75 xnn_delete_operator(binary_elementwise_op);
76 return status;
77}
78
79enum xnn_status xnn_create_add_nd_f32(
80 float output_min,
81 float output_max,
82 uint32_t flags,
83 xnn_operator_t* add_op_out)
84{
85 return create_binary_elementwise_nd_f32(
86 output_min, output_max, flags, xnn_operator_type_add_nd_f32, add_op_out);
87}
88
Marat Dukhan69180502019-12-06 15:00:31 -080089enum xnn_status xnn_create_divide_nd_f32(
90 float output_min,
91 float output_max,
92 uint32_t flags,
93 xnn_operator_t* divide_op_out)
94{
95 return create_binary_elementwise_nd_f32(
96 output_min, output_max, flags, xnn_operator_type_divide_nd_f32, divide_op_out);
97}
98
Marat Dukhan79e7f842019-12-05 14:35:50 -080099enum xnn_status xnn_create_maximum_nd_f32(
100 uint32_t flags,
101 xnn_operator_t* maximum_op_out)
102{
103 return create_binary_elementwise_nd_f32(
104 -INFINITY /* output_min */, INFINITY /* output_max */,
105 flags, xnn_operator_type_maximum_nd_f32, maximum_op_out);
106}
107
108enum xnn_status xnn_create_minimum_nd_f32(
109 uint32_t flags,
110 xnn_operator_t* minimum_op_out)
111{
112 return create_binary_elementwise_nd_f32(
113 -INFINITY /* output_min */, INFINITY /* output_max */,
114 flags, xnn_operator_type_minimum_nd_f32, minimum_op_out);
115}
116
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800117enum xnn_status xnn_create_multiply_nd_f32(
118 float output_min,
119 float output_max,
120 uint32_t flags,
121 xnn_operator_t* multiply_op_out)
122{
123 return create_binary_elementwise_nd_f32(
124 output_min, output_max, flags, xnn_operator_type_multiply_nd_f32, multiply_op_out);
125}
126
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800127enum xnn_status xnn_create_subtract_nd_f32(
128 float output_min,
129 float output_max,
130 uint32_t flags,
131 xnn_operator_t* subtract_op_out)
132{
133 return create_binary_elementwise_nd_f32(
134 output_min, output_max, flags, xnn_operator_type_subtract_nd_f32, subtract_op_out);
135}
136
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800137static enum xnn_status setup_binary_elementwise_nd_f32(
138 xnn_operator_t binary_elementwise_op,
139 enum xnn_operator_type expected_operator_type,
140 size_t num_input1_dims,
141 const size_t* input1_shape,
142 size_t num_input2_dims,
143 const size_t* input2_shape,
144 const float* input1,
145 const float* input2,
146 float* output,
147 const struct vbinary_parameters vbinary[restrict static 1],
148 size_t num_threads)
149{
150 if (binary_elementwise_op->type != expected_operator_type) {
Marat Dukhan69180502019-12-06 15:00:31 -0800151 xnn_log_error("failed to setup Add/Subtract/Multiply/Divide/Minimum/Maximum (ND, F32) operator: operator type mismatch");
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800152 return xnn_status_invalid_parameter;
153 }
154 binary_elementwise_op->state = xnn_run_state_invalid;
155
156 if (!xnn_params.initialized) {
Marat Dukhan69180502019-12-06 15:00:31 -0800157 xnn_log_error("failed to setup Add/Subtract/Multiply/Divide/Minimum/Maximum operator: XNNPACK is not initialized");
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800158 return xnn_status_uninitialized;
159 }
160
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800161 if (max(num_input1_dims, num_input2_dims) > XNN_MAX_TENSOR_DIMS) {
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800162 xnn_log_error(
Marat Dukhan69180502019-12-06 15:00:31 -0800163 "failed to setup Add/Subtract/Multiply/Divide/Minimum/Maximum operator with %zu and %zu dimensions in input shapes: "
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800164 "the number of input dimensions must not exceed %d",
165 num_input1_dims, num_input2_dims, XNN_MAX_TENSOR_DIMS);
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800166 return xnn_status_unsupported_parameter;
167 }
168
169 for (size_t i = 0; i < num_input1_dims; i++) {
170 if (input1_shape[i] == 0) {
Marat Dukhan69180502019-12-06 15:00:31 -0800171 xnn_log_error("failed to setup Add/Subtract/Multiply/Divide/Minimum/Maximum operator: shape dimension #%zu of input #1 is zero", i);
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800172 return xnn_status_invalid_parameter;
173 }
174 }
175
176 for (size_t i = 0; i < num_input2_dims; i++) {
177 if (input2_shape[i] == 0) {
Marat Dukhan69180502019-12-06 15:00:31 -0800178 xnn_log_error("failed to setup Add/Subtract/Multiply/Divide/Minimum/Maximum operator: shape dimension #%zu of input #2 is zero", i);
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800179 return xnn_status_invalid_parameter;
180 }
181 }
182
183 size_t num_compressed_dims = 0;
184 size_t compressed_input1_shape[XNN_MAX_TENSOR_DIMS];
185 size_t compressed_input2_shape[XNN_MAX_TENSOR_DIMS];
186 size_t compressed_output_shape[XNN_MAX_TENSOR_DIMS];
187 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
188 compressed_input1_shape[i] = 1;
189 compressed_input2_shape[i] = 1;
190 compressed_output_shape[i] = 1;
191 }
192 bool broadcast_input1 = false;
193 bool broadcast_input2 = false;
194 bool first_nonunit = true;
195 const size_t num_common_dims = min(num_input1_dims, num_input2_dims);
196 for (size_t i = 1; i <= num_common_dims; i++) {
197 const size_t input1_dim = input1_shape[num_input1_dims - i];
198 const size_t input2_dim = input2_shape[num_input2_dims - i];
199 if (input1_dim == 1 && input2_dim == 1) {
200 continue;
201 }
202 assert(!broadcast_input1 || !broadcast_input2);
203
204 if (input1_dim == 1) {
205 if (!broadcast_input1) {
206 broadcast_input1 = true;
207 broadcast_input2 = false;
208 num_compressed_dims++;
209 }
210 compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
211 compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
212 } else if (input2_dim == 1) {
213 if (!broadcast_input2) {
214 broadcast_input1 = false;
215 broadcast_input2 = true;
216 num_compressed_dims++;
217 }
218 compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
219 compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
220 } else if (input1_dim == input2_dim) {
221 if (broadcast_input1 || broadcast_input2 || first_nonunit) {
222 broadcast_input1 = false;
223 broadcast_input2 = false;
224 num_compressed_dims++;
225 }
226 compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
227 compressed_input2_shape[num_compressed_dims - 1] *= input1_dim;
228 compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
229 } else {
Marat Dukhan69180502019-12-06 15:00:31 -0800230 xnn_log_error("failed to setup Add/Subtract/Multiply/Divide/Minimum/Maximum operator: "
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800231 "shape dimension #%zu of input1 (%zu) does not match shape dimension #%zu of input2 (%zu)",
232 num_input1_dims - i, input1_dim, num_input2_dims - i, input2_dim);
233 return xnn_status_invalid_parameter;
234 }
235 first_nonunit = false;
236 }
237 if (num_input1_dims > num_input2_dims) {
238 if (!broadcast_input2) {
239 num_compressed_dims++;
240 }
241 for (size_t i = 0; i < num_input1_dims - num_input2_dims; i++) {
242 const size_t input1_dim = input1_shape[i];
243 compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
244 compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
245 }
246 } else if (num_input2_dims > num_input1_dims) {
247 if (!broadcast_input1) {
248 num_compressed_dims++;
249 }
250 for (size_t i = 0; i < num_input2_dims - num_input1_dims; i++) {
251 const size_t input2_dim = input2_shape[i];
252 compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
253 compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
254 }
255 }
256 num_compressed_dims = max(num_compressed_dims, 1);
257
258 binary_elementwise_op->context.elementwise_binary = (struct elementwise_binary_context) {
259 .a = input1,
260 .b = input2,
261 .y = output,
262 .elements = compressed_output_shape[0] * sizeof(float),
263 .params.f32 = binary_elementwise_op->f32_output_params,
264 };
265 const size_t* compressed_a_shape = compressed_input1_shape;
266 const size_t* compressed_b_shape = compressed_input2_shape;
267 if (compressed_input1_shape[0] == 1) {
268 binary_elementwise_op->context.elementwise_binary.ukernel = vbinary->ropc_ukernel;
269 binary_elementwise_op->context.elementwise_binary.a = input2;
270 binary_elementwise_op->context.elementwise_binary.b = input1;
271 compressed_a_shape = compressed_input2_shape;
272 compressed_b_shape = compressed_input1_shape;
273 } else if (compressed_input2_shape[0] == 1) {
274 binary_elementwise_op->context.elementwise_binary.ukernel = vbinary->opc_ukernel;
275 } else if (compressed_input1_shape[0] == compressed_input2_shape[0]) {
276 binary_elementwise_op->context.elementwise_binary.ukernel = vbinary->op_ukernel;
277 }
278 size_t a_stride = compressed_a_shape[0], b_stride = compressed_b_shape[0], y_stride = compressed_output_shape[0];
279 for (size_t i = 1; i < num_compressed_dims; i++) {
280 if (compressed_a_shape[i] != 1) {
281 binary_elementwise_op->context.elementwise_binary.a_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = a_stride * sizeof(float);
282 }
283 if (compressed_b_shape[i] != 1) {
284 binary_elementwise_op->context.elementwise_binary.b_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = b_stride * sizeof(float);
285 }
286 binary_elementwise_op->context.elementwise_binary.y_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = y_stride * sizeof(float);
287 a_stride *= compressed_a_shape[i];
288 b_stride *= compressed_b_shape[i];
289 y_stride *= compressed_output_shape[i];
290 }
291
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800292 binary_elementwise_op->compute.type = xnn_parallelization_type_5d_tile_2d;
293 binary_elementwise_op->compute.task_5d_tile_2d = (pthreadpool_task_5d_tile_2d_t) xnn_compute_elementwise_binary_5d;
294 binary_elementwise_op->compute.range[0] = compressed_output_shape[5];
295 binary_elementwise_op->compute.range[1] = compressed_output_shape[4];
296 binary_elementwise_op->compute.range[2] = compressed_output_shape[3];
297 binary_elementwise_op->compute.range[3] = compressed_output_shape[2];
298 binary_elementwise_op->compute.range[4] = compressed_output_shape[1];
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800299 binary_elementwise_op->compute.tile[0] = 1;
300 binary_elementwise_op->compute.tile[1] = 1;
301 binary_elementwise_op->state = xnn_run_state_ready;
302
303 return xnn_status_success;
304}
305
306enum xnn_status xnn_setup_add_nd_f32(
307 xnn_operator_t add_op,
308 size_t num_input1_dims,
309 const size_t* input1_shape,
310 size_t num_input2_dims,
311 const size_t* input2_shape,
312 const float* input1,
313 const float* input2,
314 float* output,
315 pthreadpool_t threadpool)
316{
317 return setup_binary_elementwise_nd_f32(
318 add_op, xnn_operator_type_add_nd_f32,
319 num_input1_dims, input1_shape,
320 num_input2_dims, input2_shape,
321 input1, input2, output,
322 &xnn_params.f32.vadd,
323 pthreadpool_get_threads_count(threadpool));
324}
325
Marat Dukhan69180502019-12-06 15:00:31 -0800326enum xnn_status xnn_setup_divide_nd_f32(
327 xnn_operator_t divide_op,
328 size_t num_input1_dims,
329 const size_t* input1_shape,
330 size_t num_input2_dims,
331 const size_t* input2_shape,
332 const float* input1,
333 const float* input2,
334 float* output,
335 pthreadpool_t threadpool)
336{
337 return setup_binary_elementwise_nd_f32(
338 divide_op, xnn_operator_type_divide_nd_f32,
339 num_input1_dims, input1_shape,
340 num_input2_dims, input2_shape,
341 input1, input2, output,
342 &xnn_params.f32.vdiv,
343 pthreadpool_get_threads_count(threadpool));
344}
345
Marat Dukhan79e7f842019-12-05 14:35:50 -0800346enum xnn_status xnn_setup_maximum_nd_f32(
347 xnn_operator_t maximum_op,
348 size_t num_input1_dims,
349 const size_t* input1_shape,
350 size_t num_input2_dims,
351 const size_t* input2_shape,
352 const float* input1,
353 const float* input2,
354 float* output,
355 pthreadpool_t threadpool)
356{
357 return setup_binary_elementwise_nd_f32(
358 maximum_op, xnn_operator_type_maximum_nd_f32,
359 num_input1_dims, input1_shape,
360 num_input2_dims, input2_shape,
361 input1, input2, output,
362 &xnn_params.f32.vmax,
363 pthreadpool_get_threads_count(threadpool));
364}
365
366enum xnn_status xnn_setup_minimum_nd_f32(
367 xnn_operator_t minimum_op,
368 size_t num_input1_dims,
369 const size_t* input1_shape,
370 size_t num_input2_dims,
371 const size_t* input2_shape,
372 const float* input1,
373 const float* input2,
374 float* output,
375 pthreadpool_t threadpool)
376{
377 return setup_binary_elementwise_nd_f32(
378 minimum_op, xnn_operator_type_minimum_nd_f32,
379 num_input1_dims, input1_shape,
380 num_input2_dims, input2_shape,
381 input1, input2, output,
382 &xnn_params.f32.vmin,
383 pthreadpool_get_threads_count(threadpool));
384}
385
Marat Dukhanb1a0fc32019-12-02 19:32:02 -0800386enum xnn_status xnn_setup_multiply_nd_f32(
387 xnn_operator_t multiply_op,
388 size_t num_input1_dims,
389 const size_t* input1_shape,
390 size_t num_input2_dims,
391 const size_t* input2_shape,
392 const float* input1,
393 const float* input2,
394 float* output,
395 pthreadpool_t threadpool)
396{
397 return setup_binary_elementwise_nd_f32(
398 multiply_op, xnn_operator_type_multiply_nd_f32,
399 num_input1_dims, input1_shape,
400 num_input2_dims, input2_shape,
401 input1, input2, output,
402 &xnn_params.f32.vmul,
403 pthreadpool_get_threads_count(threadpool));
404}
Marat Dukhan05f3f6d2019-12-03 15:13:53 -0800405
406enum xnn_status xnn_setup_subtract_nd_f32(
407 xnn_operator_t subtract_op,
408 size_t num_input1_dims,
409 const size_t* input1_shape,
410 size_t num_input2_dims,
411 const size_t* input2_shape,
412 const float* input1,
413 const float* input2,
414 float* output,
415 pthreadpool_t threadpool)
416{
417 return setup_binary_elementwise_nd_f32(
418 subtract_op, xnn_operator_type_subtract_nd_f32,
419 num_input1_dims, input1_shape,
420 num_input2_dims, input2_shape,
421 input1, input2, output,
422 &xnn_params.f32.vsub,
423 pthreadpool_get_threads_count(threadpool));
424}