blob: 0c3548110ee265beb6c50c5543ca38dfb6d0107e [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#include <assert.h>
10#include <stddef.h>
11#include <stdint.h>
12#include <string.h>
13
14#include <xnnpack.h>
15#include <xnnpack/operator.h>
16#include <xnnpack/log.h>
17#include <xnnpack/common.h>
18#include <xnnpack/math.h>
19#include <xnnpack/params.h>
20#include <xnnpack/compute.h>
21
22
23void xnn_compute_ggemm(
24 const struct gemm_context context[restrict static 1],
25 size_t group_index,
26 size_t mr_block_start,
27 size_t nr_block_start,
28 size_t mr_block_size,
29 size_t nr_block_size)
30{
31 const size_t k_scaled = context->k_scaled;
32 const size_t a_stride = context->a_stride;
33 const size_t cm_stride = context->cm_stride;
34
35 context->ukernel(
36 mr_block_size,
37 nr_block_size,
38 k_scaled,
39 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
40 a_stride,
41 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
42 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
43 cm_stride,
44 context->cn_stride,
45 &context->params);
46}
47
48void xnn_compute_gemm(
49 const struct gemm_context context[restrict static 1],
50 size_t mr_block_start,
51 size_t nr_block_start,
52 size_t mr_block_size,
53 size_t nr_block_size)
54{
55 const size_t a_stride = context->a_stride;
56 const size_t cm_stride = context->cm_stride;
57
58 context->ukernel(
59 mr_block_size,
60 nr_block_size,
61 context->k_scaled,
62 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
63 a_stride,
64 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
65 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
66 cm_stride,
67 context->cn_stride,
68 &context->params);
69}
70
71void xnn_compute_spmm(
72 const struct spmm_context context[restrict static 1],
73 size_t batch_index,
74 size_t mr_block_start,
75 size_t mr_block_size)
76{
77 context->ukernel(
78 mr_block_size,
79 context->n,
80 (const void*) ((uintptr_t) context->a + batch_index * context->batched_a_stride + mr_block_start * sizeof(float)),
81 context->packed_weights,
82 context->input_increments,
83 context->output_channel_nonzeros,
84 (void*) ((uintptr_t) context->c + batch_index * context->batched_c_stride + mr_block_start * sizeof(float)),
85 &context->params);
86}
87
88void xnn_compute_gigemm(
89 const struct igemm_context context[restrict static 1],
90 size_t batch_index,
91 size_t group_index,
92 size_t mr_block_start,
93 size_t nr_block_start,
94 size_t mr_block_size,
95 size_t nr_block_size)
96{
97 const size_t ks = context->ks;
98 const size_t cm_stride = context->cm_stride;
99
100 context->ukernel(
101 mr_block_size,
102 nr_block_size,
103 context->kc,
104 context->ks_scaled,
105 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
106 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
107 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
108 cm_stride,
109 context->cn_stride,
110 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
111 context->zero,
112 &context->params);
113}
114
115void xnn_compute_igemm(
116 const struct igemm_context context[restrict static 1],
117 size_t batch_index,
118 size_t mr_block_start,
119 size_t nr_block_start,
120 size_t mr_block_size,
121 size_t nr_block_size)
122{
123 const size_t ks = context->ks;
124 const size_t cm_stride = context->cm_stride;
125
126 context->ukernel(
127 mr_block_size,
128 nr_block_size,
129 context->kc,
130 context->ks_scaled,
131 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
132 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
133 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
134 cm_stride,
135 context->cn_stride,
136 context->a_offset + batch_index * context->ba_stride,
137 context->zero,
138 &context->params);
139}
140
141void xnn_compute_gsubconv2d(
142 const struct subconv_context context[restrict static 1],
143 size_t batch_index,
144 size_t group_index,
145 size_t subkernel_index,
146 size_t slice_y,
147 size_t slice_x_start,
148 size_t nc_block_start,
149 size_t slice_x_max,
150 size_t nc_block_size)
151{
152 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
153
154 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
155 return;
156 }
157
158 const size_t slice_width = subconvolution_params->slice_width;
159 if XNN_UNLIKELY(slice_x_start >= slice_width) {
160 return;
161 }
162 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
163
164 const size_t cx_stride = context->cx_stride;
165 context->ukernel(
166 slice_x_size,
167 nc_block_size,
168 context->kc,
169 subconvolution_params->scaled_kernel_size,
170 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
171 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
172 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
173 cx_stride,
174 context->cn_stride,
175 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
176 context->zero,
177 &context->params);
178}
179
180void xnn_compute_subconv2d(
181 const struct subconv_context context[restrict static 1],
182 size_t batch_index,
183 size_t subkernel_index,
184 size_t slice_y,
185 size_t slice_x_start,
186 size_t nc_block_start,
187 size_t slice_x_max,
188 size_t nc_block_size)
189{
190 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
191
192 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
193 return;
194 }
195
196 const size_t slice_width = subconvolution_params->slice_width;
197 if XNN_UNLIKELY(slice_x_start >= slice_width) {
198 return;
199 }
200 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
201
202 const size_t cx_stride = context->cx_stride;
203 context->ukernel(
204 slice_x_size,
205 nc_block_size,
206 context->kc,
207 subconvolution_params->scaled_kernel_size,
208 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
209 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
210 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
211 cx_stride,
212 context->cn_stride,
213 context->a_offset + batch_index * context->ba_stride,
214 context->zero,
215 &context->params);
216}
217
218void xnn_compute_dconv2d_hwc2spchw(
219 const struct dconv2d_context context[restrict static 1],
220 size_t batch_index,
221 size_t output_y_start,
222 size_t output_y_slice)
223{
224 context->hwc2spchw_ukernel(
225 context->input_height,
226 context->input_width,
227 output_y_start,
228 output_y_start + output_y_slice,
229 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
230 context->zero,
231 context->packed_weights,
232 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
233 context->input_padding_top,
234 context->output_channels,
235 context->output_height_stride,
236 context->output_channel_stride,
237 &context->params);
238}
239
240void xnn_compute_dwconv_unipass(
241 const struct dwconv_context context[restrict static 1],
242 size_t output_y)
243{
244 context->unipass_ukernel(
245 context->groups,
246 context->output_width,
247 context->indirection_buffer + output_y * context->indirection_buffer_row_stride,
248 context->packed_weights,
249 context->output + output_y * context->output_row_stride,
250 context->indirection_buffer_col_stride,
251 context->output_col_increment,
252 &context->params);
253}
254
255void xnn_compute_dwconv2d_spchw(
256 const struct dwconv2d_context context[restrict static 1],
257 size_t batch_index,
258 size_t channel)
259{
260 context->spchw_ukernel(
261 context->output_height,
262 context->input_width,
263 (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
264 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
265 (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
266 context->input_tuple_stride,
267 context->output_tuple_stride,
268 context->input_pixel_stride,
269 context->output_pixel_stride,
270 &context->params);
271}
272
273void xnn_compute_argmax_pooling_unipass(
274 const struct argmax_pooling_context context[restrict static 1],
275 size_t batch_index,
276 size_t output_y)
277{
278 const void** indirect_input =
279 (const void**) ((uintptr_t) context->indirect_input +
280 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
281 void* output =
282 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
283 uint32_t* index =
284 (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
285
286 context->unipass_ukernel(
287 context->output_width, context->pooling_size, context->channels,
288 indirect_input, output, index,
289 context->input_increment, context->output_increment,
290 &context->params);
291}
292
293void xnn_compute_argmax_pooling_multipass(
294 const struct argmax_pooling_context context[restrict static 1],
295 size_t batch_index,
296 size_t output_y)
297{
298 const void** indirect_input =
299 (const void**) ((uintptr_t) context->indirect_input +
300 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
301 void* output =
302 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
303 uint32_t* index =
304 (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
305
306 XNN_ALIGN(16) float multipass_output_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(float)];
307 XNN_ALIGN(16) uint32_t multipass_index_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint32_t)];
308
309 context->multipass_ukernel(
310 context->output_width, context->pooling_size, context->channels,
311 indirect_input, multipass_output_buffer, multipass_index_buffer, output, index,
312 context->input_increment, context->output_increment,
313 &context->params);
314}
315
316void xnn_compute_max_pooling(
317 const struct max_pooling_context context[restrict static 1],
318 size_t batch_index,
319 size_t output_y)
320{
321 const void** indirect_input =
322 (const void**) ((uintptr_t) context->indirect_input +
323 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
324 void* output =
325 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
326
327 context->ukernel(
328 context->output_width, context->pooling_size, context->channels,
329 indirect_input, output,
330 context->input_increment, context->output_increment,
331 &context->params);
332}
333
334void xnn_compute_unpooling(
335 const struct unpooling_context context[restrict static 1],
336 size_t input_y,
337 size_t input_x)
338{
339 const void* input = (const void*) ((uintptr_t) context->input +
340 input_y * context->input_height_stride + input_x * context->input_width_stride);
341 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
342 input_y * context->index_height_stride + input_x * context->index_width_stride);
343 void** indirect_output =
344 (void**) ((uintptr_t) context->indirect_output +
345 input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
346
347 context->ukernel(
348 context->pooling_size,
349 context->channels,
350 context->fill_value,
351 input, index, indirect_output);
352}
353
354void xnn_compute_average_pooling_unipass(
355 const struct average_pooling_context context[restrict static 1],
356 size_t batch_index,
357 size_t output_y)
358{
359 const void** indirect_input =
360 (const void**) ((uintptr_t) context->indirect_input +
361 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
362 void* output =
363 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
364
365 context->unipass_ukernel(
366 context->output_width, context->pooling_size, context->channels,
367 indirect_input, context->zero, output,
368 context->input_increment, context->output_increment,
369 &context->params);
370}
371
372void xnn_compute_average_pooling_multipass(
373 const struct average_pooling_context context[restrict static 1],
374 size_t batch_index,
375 size_t output_y)
376{
377 const void** indirect_input =
378 (const void**) ((uintptr_t) context->indirect_input +
379 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
380 void* output =
381 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
382 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
383
384 context->multipass_ukernel(
385 context->output_width, context->pooling_size, context->channels,
386 indirect_input, context->zero, multipass_buffer, output,
387 context->input_increment, context->output_increment,
388 &context->params);
389}
390
391void xnn_compute_pixelwise_average_pooling_unipass(
392 const struct pixelwise_average_pooling_context context[restrict static 1],
393 size_t batch_index,
394 size_t output_y)
395{
396 const void** indirect_input =
397 (const void**) ((uintptr_t) context->indirect_input +
398 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
399 const void* pixelwise_buffer =
400 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
401 void* output =
402 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
403
404 context->unipass_ukernel(
405 context->output_width, context->pooling_size, context->channels,
406 indirect_input, context->zero, pixelwise_buffer, output,
407 context->input_increment, context->output_increment,
408 &context->params);
409}
410
411void xnn_compute_pixelwise_average_pooling_multipass(
412 const struct pixelwise_average_pooling_context context[restrict static 1],
413 size_t batch_index,
414 size_t output_y)
415{
416 const void** indirect_input =
417 (const void**) ((uintptr_t) context->indirect_input +
418 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
419 const void* pixelwise_buffer =
420 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
421 void* output =
422 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
423 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
424
425 context->multipass_ukernel(
426 context->output_width, context->pooling_size, context->channels,
427 indirect_input, context->zero, pixelwise_buffer, multipass_buffer, output,
428 context->input_increment, context->output_increment,
429 &context->params);
430}
431
432void xnn_compute_global_average_pooling_unipass(
433 const struct global_average_pooling_context context[restrict static 1],
434 size_t batch_index)
435{
436 const void* input =
437 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
438 void* output =
439 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
440
441 context->unipass_ukernel(
442 context->input_elements,
443 context->channels,
444 input,
445 context->input_pixel_stride,
446 context->zero,
447 output,
448 &context->params);
449}
450
451void xnn_compute_global_average_pooling_multipass(
452 const struct global_average_pooling_context context[restrict static 1],
453 size_t batch_index)
454{
455 const void* input =
456 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
457 void* output =
458 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
459 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
460
461 context->multipass_ukernel(
462 context->input_elements,
463 context->channels,
464 input,
465 context->input_pixel_stride,
466 context->zero,
467 multipass_buffer,
468 output,
469 &context->params);
470}
471
472void xnn_compute_global_average_pooling_spnchw(
473 const struct global_average_pooling_spnchw_context context[restrict static 1],
474 size_t batch_index,
475 size_t channels_start,
476 size_t channels_slice)
477{
478 const void* input =
479 (const void*) ((uintptr_t) context->input + channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
480 void* output =
481 (void*) ((uintptr_t) context->output + channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
482
483 context->ukernel(
484 context->input_elements,
485 channels_slice,
486 input,
487 output,
488 &context->params);
489}
490
491void xnn_compute_prelu(
492 const struct prelu_context context[restrict static 1],
493 size_t batch_start,
494 size_t batch_range)
495{
496 const size_t x_stride = context->x_stride;
497 const size_t y_stride = context->y_stride;
498 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
499 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
500
501 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride, &context->params);
502}
503
504void xnn_compute_channel_pad(
505 const struct channel_pad_context context[restrict static 1],
506 size_t batch_start,
507 size_t batch_range)
508{
509 const size_t x_stride = context->x_stride;
510 const size_t y_stride = context->y_stride;
511 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
512 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
513
514 context->ukernel(batch_range, context->n, context->l, context->r, context->c, x, x_stride, y, y_stride);
515}
516
517void xnn_compute_add_strided(
518 const struct add_strided_context context[restrict static 1],
519 size_t batch_index,
520 size_t batch_range /* always 1 */)
521{
522 assert(batch_range == 1);
523
524 const size_t n = context->n;
525 const size_t a_stride = context->a_stride;
526 const size_t b_stride = context->b_stride;
527 const size_t y_stride = context->y_stride;
528 const void* a = (const void*) ((uintptr_t) context->a + a_stride * batch_index);
529 const void* b = (const void*) ((uintptr_t) context->b + b_stride * batch_index);
530 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
531
532 context->ukernel(n, a, b, y, &context->params);
533}
534
535void xnn_compute_add_contiguous(
536 const struct add_contiguous_context context[restrict static 1],
537 size_t offset,
538 size_t size)
539{
540 const void* a = (const void*) ((uintptr_t) context->a + offset);
541 const void* b = (const void*) ((uintptr_t) context->b + offset);
542 void* y = (void*) ((uintptr_t) context->y + offset);
543 context->ukernel(size, a, b, y, &context->params);
544}
545
546void xnn_compute_channel_shuffle_fixed(
547 const struct channel_shuffle_context context[restrict static 1],
548 size_t index)
549{
550 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
551 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
552
553 context->fixed_ukernel(context->n, x, y);
554}
555
556void xnn_compute_channel_shuffle_variable(
557 const struct channel_shuffle_context context[restrict static 1],
558 size_t index)
559{
560 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
561 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
562
563 context->variable_ukernel(context->n, context->m, x, y);
564}
565
566void xnn_compute_lut_strided(
567 const struct lut_strided_context context[restrict static 1],
568 size_t batch_index)
569{
570 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
571 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
572
573 context->ukernel(context->n, x, context->t, y);
574}
575
576void xnn_compute_lut_contiguous(
577 const struct lut_contiguous_context context[restrict static 1],
578 size_t offset,
579 size_t size)
580{
581 const void* x = (const void*) ((uintptr_t) context->x + offset);
582 void* y = (void*) ((uintptr_t) context->y + offset);
583
584 context->ukernel(size, x, context->t, y);
585}
586
587void xnn_compute_univector_strided(
588 const struct univector_strided_context context[restrict static 1],
589 size_t batch_index,
590 size_t batch_range /* always 1 */)
591{
592 assert(batch_range == 1);
593
594 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
595 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
596 context->ukernel(context->n, x, y, &context->params);
597}
598
599void xnn_compute_univector_contiguous(
600 const struct univector_contiguous_context context[restrict static 1],
601 size_t offset,
602 size_t size)
603{
604 const void* x = (const void*) ((uintptr_t) context->x + offset);
605 void* y = (void*) ((uintptr_t) context->y + offset);
606 context->ukernel(size, x, y, &context->params);
607}
608
609void xnn_compute_u8_softargmax(
610 const struct u8_softargmax_context context[restrict static 1],
611 size_t batch_index)
612{
613 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
614 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
615 const size_t n = context->n;
616
617 uint8_t x_max = 0;
618 context->rmax_ukernel(n, x, &x_max);
619 const size_t adjustment = x_max ^ 255;
620 const uint32_t* t = (const uint32_t*) context->t + adjustment;
621 context->lut_norm_ukernel(n, x, t, y);
622}
623
624void xnn_compute_vmulcaddc(
625 const struct vmulcaddc_context context[restrict static 1],
626 size_t batch_start,
627 size_t batch_size)
628{
629 const size_t x_stride = context->x_stride;
630 const size_t y_stride = context->y_stride;
631
632 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
633 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
634
635 context->ukernel(
636 batch_size,
637 context->n,
638 x, x_stride,
639 context->w,
640 y, y_stride,
641 &context->params);
642}
643
644enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
645{
646 if (!xnn_params.initialized) {
647 xnn_log_error("failed to run operator: XNNPACK is not initialized");
648 return xnn_status_uninitialized;
649 }
650 switch (op->state) {
651 case xnn_run_state_invalid:
652 xnn_log_error("failed to run operator: operator was not successfully setup");
653 return xnn_status_invalid_state;
654 case xnn_run_state_ready:
655 break;
656 case xnn_run_state_skip:
657 return xnn_status_success;
658 }
659
660 switch (op->compute.type) {
661 case xnn_parallelization_type_invalid:
662 break;
663 case xnn_parallelization_type_1d:
664 assert(op->compute.range[0] != 0);
665 pthreadpool_parallelize_1d(
666 threadpool,
667 op->compute.task_1d,
668 &op->context,
669 op->compute.range[0],
670 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
671 break;
672 case xnn_parallelization_type_1d_tile_1d:
673 assert(op->compute.range[0] != 0);
674 assert(op->compute.tile[0] != 0);
675 pthreadpool_parallelize_1d_tile_1d(
676 threadpool,
677 op->compute.task_1d_tile_1d,
678 &op->context,
679 op->compute.range[0],
680 op->compute.tile[0],
681 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
682 break;
683 case xnn_parallelization_type_2d:
684 assert(op->compute.range[0] != 0);
685 assert(op->compute.range[1] != 0);
686 pthreadpool_parallelize_2d(
687 threadpool,
688 op->compute.task_2d,
689 &op->context,
690 op->compute.range[0], op->compute.range[1],
691 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
692 break;
693 case xnn_parallelization_type_2d_tile_1d:
694 assert(op->compute.range[0] != 0);
695 assert(op->compute.range[1] != 0);
696 assert(op->compute.tile[0] != 0);
697 pthreadpool_parallelize_2d_tile_1d(
698 threadpool,
699 op->compute.task_2d_tile_1d,
700 &op->context,
701 op->compute.range[0], op->compute.range[1],
702 op->compute.tile[0],
703 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
704 break;
705 case xnn_parallelization_type_2d_tile_2d:
706 assert(op->compute.range[0] != 0);
707 assert(op->compute.range[1] != 0);
708 assert(op->compute.tile[0] != 0);
709 assert(op->compute.tile[1] != 0);
710 pthreadpool_parallelize_2d_tile_2d(
711 threadpool,
712 op->compute.task_2d_tile_2d,
713 &op->context,
714 op->compute.range[0], op->compute.range[1],
715 op->compute.tile[0], op->compute.tile[1],
716 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
717 break;
718 case xnn_parallelization_type_3d_tile_2d:
719 assert(op->compute.range[0] != 0);
720 assert(op->compute.range[1] != 0);
721 assert(op->compute.range[2] != 0);
722 assert(op->compute.tile[0] != 0);
723 assert(op->compute.tile[1] != 0);
724 pthreadpool_parallelize_3d_tile_2d(
725 threadpool,
726 op->compute.task_3d_tile_2d,
727 &op->context,
728 op->compute.range[0], op->compute.range[1], op->compute.range[2],
729 op->compute.tile[0], op->compute.tile[1],
730 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
731 break;
732 case xnn_parallelization_type_4d_tile_2d:
733 assert(op->compute.range[0] != 0);
734 assert(op->compute.range[1] != 0);
735 assert(op->compute.range[2] != 0);
736 assert(op->compute.range[3] != 0);
737 assert(op->compute.tile[0] != 0);
738 assert(op->compute.tile[1] != 0);
739 pthreadpool_parallelize_4d_tile_2d(
740 threadpool,
741 op->compute.task_4d_tile_2d,
742 &op->context,
743 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
744 op->compute.tile[0], op->compute.tile[1],
745 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
746 break;
747 case xnn_parallelization_type_5d_tile_2d:
748 assert(op->compute.range[0] != 0);
749 assert(op->compute.range[1] != 0);
750 assert(op->compute.range[2] != 0);
751 assert(op->compute.range[3] != 0);
752 assert(op->compute.range[4] != 0);
753 assert(op->compute.tile[0] != 0);
754 assert(op->compute.tile[1] != 0);
755 pthreadpool_parallelize_5d_tile_2d(
756 threadpool,
757 op->compute.task_5d_tile_2d,
758 &op->context,
759 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
760 op->compute.tile[0], op->compute.tile[1],
761 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
762 break;
763 case xnn_parallelization_type_6d_tile_2d:
764 assert(op->compute.range[0] != 0);
765 assert(op->compute.range[1] != 0);
766 assert(op->compute.range[2] != 0);
767 assert(op->compute.range[3] != 0);
768 assert(op->compute.range[4] != 0);
769 assert(op->compute.range[5] != 0);
770 assert(op->compute.tile[0] != 0);
771 assert(op->compute.tile[1] != 0);
772 pthreadpool_parallelize_6d_tile_2d(
773 threadpool,
774 op->compute.task_6d_tile_2d,
775 &op->context,
776 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
777 op->compute.tile[0], op->compute.tile[1],
778 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
779 break;
780 default:
781 XNN_UNREACHABLE;
782 }
783 return xnn_status_success;
784}