blob: b4f8dbee037add7c7f0902d20c19a14c54536655 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#include <assert.h>
10#include <stddef.h>
11#include <stdint.h>
12#include <string.h>
13
14#include <xnnpack.h>
15#include <xnnpack/operator.h>
16#include <xnnpack/log.h>
17#include <xnnpack/common.h>
18#include <xnnpack/math.h>
19#include <xnnpack/params.h>
20#include <xnnpack/compute.h>
21
22
23void xnn_compute_ggemm(
24 const struct gemm_context context[restrict static 1],
25 size_t group_index,
26 size_t mr_block_start,
27 size_t nr_block_start,
28 size_t mr_block_size,
29 size_t nr_block_size)
30{
31 const size_t k_scaled = context->k_scaled;
32 const size_t a_stride = context->a_stride;
33 const size_t cm_stride = context->cm_stride;
34
35 context->ukernel(
36 mr_block_size,
37 nr_block_size,
38 k_scaled,
39 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
40 a_stride,
41 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
42 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
43 cm_stride,
44 context->cn_stride,
45 &context->params);
46}
47
48void xnn_compute_gemm(
49 const struct gemm_context context[restrict static 1],
50 size_t mr_block_start,
51 size_t nr_block_start,
52 size_t mr_block_size,
53 size_t nr_block_size)
54{
55 const size_t a_stride = context->a_stride;
56 const size_t cm_stride = context->cm_stride;
57
58 context->ukernel(
59 mr_block_size,
60 nr_block_size,
61 context->k_scaled,
62 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
63 a_stride,
64 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
65 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
66 cm_stride,
67 context->cn_stride,
68 &context->params);
69}
70
71void xnn_compute_spmm(
72 const struct spmm_context context[restrict static 1],
73 size_t batch_index,
74 size_t mr_block_start,
75 size_t mr_block_size)
76{
77 context->ukernel(
78 mr_block_size,
79 context->n,
80 (const void*) ((uintptr_t) context->a + batch_index * context->batched_a_stride + mr_block_start * sizeof(float)),
81 context->packed_weights,
82 context->input_increments,
83 context->output_channel_nonzeros,
84 (void*) ((uintptr_t) context->c + batch_index * context->batched_c_stride + mr_block_start * sizeof(float)),
85 &context->params);
86}
87
88void xnn_compute_gigemm(
89 const struct igemm_context context[restrict static 1],
90 size_t batch_index,
91 size_t group_index,
92 size_t mr_block_start,
93 size_t nr_block_start,
94 size_t mr_block_size,
95 size_t nr_block_size)
96{
97 const size_t ks = context->ks;
98 const size_t cm_stride = context->cm_stride;
99
100 context->ukernel(
101 mr_block_size,
102 nr_block_size,
103 context->kc,
104 context->ks_scaled,
105 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
106 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
107 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
108 cm_stride,
109 context->cn_stride,
110 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
111 context->zero,
112 &context->params);
113}
114
115void xnn_compute_igemm(
116 const struct igemm_context context[restrict static 1],
117 size_t batch_index,
118 size_t mr_block_start,
119 size_t nr_block_start,
120 size_t mr_block_size,
121 size_t nr_block_size)
122{
123 const size_t ks = context->ks;
124 const size_t cm_stride = context->cm_stride;
125
126 context->ukernel(
127 mr_block_size,
128 nr_block_size,
129 context->kc,
130 context->ks_scaled,
131 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
132 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
133 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
134 cm_stride,
135 context->cn_stride,
136 context->a_offset + batch_index * context->ba_stride,
137 context->zero,
138 &context->params);
139}
140
141void xnn_compute_gsubconv2d(
142 const struct subconv_context context[restrict static 1],
143 size_t batch_index,
144 size_t group_index,
145 size_t subkernel_index,
146 size_t slice_y,
147 size_t slice_x_start,
148 size_t nc_block_start,
149 size_t slice_x_max,
150 size_t nc_block_size)
151{
152 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
153
154 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
155 return;
156 }
157
158 const size_t slice_width = subconvolution_params->slice_width;
159 if XNN_UNLIKELY(slice_x_start >= slice_width) {
160 return;
161 }
162 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
163
164 const size_t cx_stride = context->cx_stride;
165 context->ukernel(
166 slice_x_size,
167 nc_block_size,
168 context->kc,
169 subconvolution_params->scaled_kernel_size,
170 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
171 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
172 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
173 cx_stride,
174 context->cn_stride,
175 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
176 context->zero,
177 &context->params);
178}
179
180void xnn_compute_subconv2d(
181 const struct subconv_context context[restrict static 1],
182 size_t batch_index,
183 size_t subkernel_index,
184 size_t slice_y,
185 size_t slice_x_start,
186 size_t nc_block_start,
187 size_t slice_x_max,
188 size_t nc_block_size)
189{
190 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
191
192 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
193 return;
194 }
195
196 const size_t slice_width = subconvolution_params->slice_width;
197 if XNN_UNLIKELY(slice_x_start >= slice_width) {
198 return;
199 }
200 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
201
202 const size_t cx_stride = context->cx_stride;
203 context->ukernel(
204 slice_x_size,
205 nc_block_size,
206 context->kc,
207 subconvolution_params->scaled_kernel_size,
208 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
209 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
210 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
211 cx_stride,
212 context->cn_stride,
213 context->a_offset + batch_index * context->ba_stride,
214 context->zero,
215 &context->params);
216}
217
218void xnn_compute_dconv2d_hwc2spchw(
219 const struct dconv2d_context context[restrict static 1],
220 size_t batch_index,
221 size_t output_y_start,
222 size_t output_y_slice)
223{
224 context->hwc2spchw_ukernel(
225 context->input_height,
226 context->input_width,
227 output_y_start,
228 output_y_start + output_y_slice,
229 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
230 context->zero,
231 context->packed_weights,
232 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
233 context->input_padding_top,
234 context->output_channels,
235 context->output_height_stride,
236 context->output_channel_stride,
237 &context->params);
238}
239
240void xnn_compute_dwconv_unipass(
241 const struct dwconv_context context[restrict static 1],
242 size_t output_y)
243{
244 context->unipass_ukernel(
245 context->groups,
246 context->output_width,
247 context->indirection_buffer + output_y * context->indirection_buffer_row_stride,
248 context->packed_weights,
249 context->output + output_y * context->output_row_stride,
250 context->indirection_buffer_col_stride,
251 context->output_col_increment,
252 &context->params);
253}
254
255void xnn_compute_dwconv2d_spchw(
256 const struct dwconv2d_context context[restrict static 1],
257 size_t batch_index,
258 size_t channel)
259{
260 context->spchw_ukernel(
261 context->output_height,
262 context->input_width,
263 (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
264 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
265 (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
266 context->input_tuple_stride,
267 context->output_tuple_stride,
268 context->input_pixel_stride,
269 context->output_pixel_stride,
270 &context->params);
271}
272
273void xnn_compute_argmax_pooling_unipass(
274 const struct argmax_pooling_context context[restrict static 1],
275 size_t batch_index,
276 size_t output_y)
277{
Marat Dukhan329da642019-11-19 21:44:39 -0800278 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
279 output_y * context->indirect_input_height_stride);
280 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
281 void* output = (void*) ((uintptr_t) context->output +
282 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
283 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
284 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700285
286 context->unipass_ukernel(
287 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800288 indirect_input, input_offset, output, index,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700289 context->input_increment, context->output_increment,
290 &context->params);
291}
292
293void xnn_compute_argmax_pooling_multipass(
294 const struct argmax_pooling_context context[restrict static 1],
295 size_t batch_index,
296 size_t output_y)
297{
Marat Dukhan329da642019-11-19 21:44:39 -0800298 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
299 output_y * context->indirect_input_height_stride);
300 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
301 void* output = (void*) ((uintptr_t) context->output +
302 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
303 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
304 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700305
Marat Dukhan329da642019-11-19 21:44:39 -0800306 XNN_ALIGN(16) float multipass_accumulation_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(float)];
XNNPACK Teamb455b122019-09-27 18:10:33 -0700307 XNN_ALIGN(16) uint32_t multipass_index_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint32_t)];
308
309 context->multipass_ukernel(
310 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800311 indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700312 context->input_increment, context->output_increment,
313 &context->params);
314}
315
316void xnn_compute_max_pooling(
317 const struct max_pooling_context context[restrict static 1],
318 size_t batch_index,
319 size_t output_y)
320{
Marat Dukhan329da642019-11-19 21:44:39 -0800321 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
322 output_y * context->indirect_input_height_stride);
323 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
324 void* output = (void*) ((uintptr_t) context->output +
325 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700326
327 context->ukernel(
328 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800329 indirect_input, input_offset, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700330 context->input_increment, context->output_increment,
331 &context->params);
332}
333
334void xnn_compute_unpooling(
335 const struct unpooling_context context[restrict static 1],
336 size_t input_y,
337 size_t input_x)
338{
339 const void* input = (const void*) ((uintptr_t) context->input +
340 input_y * context->input_height_stride + input_x * context->input_width_stride);
341 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
342 input_y * context->index_height_stride + input_x * context->index_width_stride);
343 void** indirect_output =
344 (void**) ((uintptr_t) context->indirect_output +
345 input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
346
347 context->ukernel(
348 context->pooling_size,
349 context->channels,
350 context->fill_value,
351 input, index, indirect_output);
352}
353
354void xnn_compute_average_pooling_unipass(
355 const struct average_pooling_context context[restrict static 1],
356 size_t batch_index,
357 size_t output_y)
358{
359 const void** indirect_input =
360 (const void**) ((uintptr_t) context->indirect_input +
361 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
362 void* output =
363 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
364
365 context->unipass_ukernel(
366 context->output_width, context->pooling_size, context->channels,
367 indirect_input, context->zero, output,
368 context->input_increment, context->output_increment,
369 &context->params);
370}
371
372void xnn_compute_average_pooling_multipass(
373 const struct average_pooling_context context[restrict static 1],
374 size_t batch_index,
375 size_t output_y)
376{
377 const void** indirect_input =
378 (const void**) ((uintptr_t) context->indirect_input +
379 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
380 void* output =
381 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
382 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
383
384 context->multipass_ukernel(
385 context->output_width, context->pooling_size, context->channels,
386 indirect_input, context->zero, multipass_buffer, output,
387 context->input_increment, context->output_increment,
388 &context->params);
389}
390
391void xnn_compute_pixelwise_average_pooling_unipass(
392 const struct pixelwise_average_pooling_context context[restrict static 1],
393 size_t batch_index,
394 size_t output_y)
395{
396 const void** indirect_input =
397 (const void**) ((uintptr_t) context->indirect_input +
398 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
399 const void* pixelwise_buffer =
400 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
401 void* output =
402 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
403
404 context->unipass_ukernel(
405 context->output_width, context->pooling_size, context->channels,
406 indirect_input, context->zero, pixelwise_buffer, output,
407 context->input_increment, context->output_increment,
408 &context->params);
409}
410
411void xnn_compute_pixelwise_average_pooling_multipass(
412 const struct pixelwise_average_pooling_context context[restrict static 1],
413 size_t batch_index,
414 size_t output_y)
415{
416 const void** indirect_input =
417 (const void**) ((uintptr_t) context->indirect_input +
418 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
419 const void* pixelwise_buffer =
420 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
421 void* output =
422 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
423 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
424
425 context->multipass_ukernel(
426 context->output_width, context->pooling_size, context->channels,
427 indirect_input, context->zero, pixelwise_buffer, multipass_buffer, output,
428 context->input_increment, context->output_increment,
429 &context->params);
430}
431
Marat Dukhanefc47b82019-11-18 09:25:38 -0800432void xnn_compute_global_average_pooling_nwc_unipass(
433 const struct global_average_pooling_nwc_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700434 size_t batch_index)
435{
436 const void* input =
437 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
438 void* output =
439 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
440
441 context->unipass_ukernel(
442 context->input_elements,
443 context->channels,
444 input,
445 context->input_pixel_stride,
446 context->zero,
447 output,
448 &context->params);
449}
450
Marat Dukhanefc47b82019-11-18 09:25:38 -0800451void xnn_compute_global_average_pooling_nwc_multipass(
452 const struct global_average_pooling_nwc_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700453 size_t batch_index)
454{
455 const void* input =
456 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
457 void* output =
458 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
459 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
460
461 context->multipass_ukernel(
462 context->input_elements,
463 context->channels,
464 input,
465 context->input_pixel_stride,
466 context->zero,
467 multipass_buffer,
468 output,
469 &context->params);
470}
471
Marat Dukhanefc47b82019-11-18 09:25:38 -0800472void xnn_compute_global_average_pooling_ncw(
473 const struct global_average_pooling_ncw_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700474 size_t batch_index,
475 size_t channels_start,
476 size_t channels_slice)
477{
Marat Dukhanefc47b82019-11-18 09:25:38 -0800478 const void* input = (const void*) ((uintptr_t) context->input +
479 channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
480 void* output = (void*) ((uintptr_t) context->output +
481 channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700482
483 context->ukernel(
484 context->input_elements,
485 channels_slice,
486 input,
487 output,
488 &context->params);
489}
490
Marat Dukhan69722492019-11-11 19:55:50 -0800491void xnn_compute_resize_bilinear(
492 const struct resize_bilinear_context context[restrict static 1],
493 size_t batch_index,
494 size_t pixel_start,
495 size_t pixel_range)
496{
497 void* output =
498 (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
499
500 context->ukernel(
501 pixel_range,
502 context->scaled_channels,
503 context->indirect_input + pixel_start * 4,
504 context->input_offset + batch_index * context->input_batch_stride,
505 context->packed_weights + (pixel_start << context->log2_wsize),
506 output,
507 context->output_pixel_stride - context->scaled_channels);
508}
509
XNNPACK Teamb455b122019-09-27 18:10:33 -0700510void xnn_compute_prelu(
511 const struct prelu_context context[restrict static 1],
512 size_t batch_start,
513 size_t batch_range)
514{
515 const size_t x_stride = context->x_stride;
516 const size_t y_stride = context->y_stride;
517 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
518 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
519
520 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride, &context->params);
521}
522
523void xnn_compute_channel_pad(
524 const struct channel_pad_context context[restrict static 1],
525 size_t batch_start,
526 size_t batch_range)
527{
528 const size_t x_stride = context->x_stride;
529 const size_t y_stride = context->y_stride;
530 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
531 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
532
533 context->ukernel(batch_range, context->n, context->l, context->r, context->c, x, x_stride, y, y_stride);
534}
535
536void xnn_compute_add_strided(
537 const struct add_strided_context context[restrict static 1],
538 size_t batch_index,
539 size_t batch_range /* always 1 */)
540{
541 assert(batch_range == 1);
542
543 const size_t n = context->n;
544 const size_t a_stride = context->a_stride;
545 const size_t b_stride = context->b_stride;
546 const size_t y_stride = context->y_stride;
547 const void* a = (const void*) ((uintptr_t) context->a + a_stride * batch_index);
548 const void* b = (const void*) ((uintptr_t) context->b + b_stride * batch_index);
549 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
550
551 context->ukernel(n, a, b, y, &context->params);
552}
553
554void xnn_compute_add_contiguous(
555 const struct add_contiguous_context context[restrict static 1],
556 size_t offset,
557 size_t size)
558{
559 const void* a = (const void*) ((uintptr_t) context->a + offset);
560 const void* b = (const void*) ((uintptr_t) context->b + offset);
561 void* y = (void*) ((uintptr_t) context->y + offset);
562 context->ukernel(size, a, b, y, &context->params);
563}
564
Marat Dukhanca2733c2019-11-15 23:21:17 -0800565void xnn_compute_elementwise_binary_3d(
566 const struct elementwise_binary_context context[restrict static 1],
567 size_t i, size_t j, size_t k,
568 size_t j_range, size_t k_range)
569{
570 assert(j_range == 1);
571 assert(k_range == 1);
572
573 const void* a = (const void*) ((uintptr_t) context->a +
574 i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2]);
575 const void* b = (const void*) ((uintptr_t) context->b +
576 i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2]);
577 void* y = (void*) ((uintptr_t) context->y +
578 i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2]);
579 context->ukernel(context->elements, a, b, y, &context->params);
580}
581
XNNPACK Teamb455b122019-09-27 18:10:33 -0700582void xnn_compute_channel_shuffle_fixed(
583 const struct channel_shuffle_context context[restrict static 1],
584 size_t index)
585{
586 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
587 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
588
589 context->fixed_ukernel(context->n, x, y);
590}
591
592void xnn_compute_channel_shuffle_variable(
593 const struct channel_shuffle_context context[restrict static 1],
594 size_t index)
595{
596 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
597 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
598
599 context->variable_ukernel(context->n, context->m, x, y);
600}
601
602void xnn_compute_lut_strided(
603 const struct lut_strided_context context[restrict static 1],
604 size_t batch_index)
605{
606 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
607 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
608
609 context->ukernel(context->n, x, context->t, y);
610}
611
612void xnn_compute_lut_contiguous(
613 const struct lut_contiguous_context context[restrict static 1],
614 size_t offset,
615 size_t size)
616{
617 const void* x = (const void*) ((uintptr_t) context->x + offset);
618 void* y = (void*) ((uintptr_t) context->y + offset);
619
620 context->ukernel(size, x, context->t, y);
621}
622
623void xnn_compute_univector_strided(
624 const struct univector_strided_context context[restrict static 1],
625 size_t batch_index,
626 size_t batch_range /* always 1 */)
627{
628 assert(batch_range == 1);
629
630 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
631 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
632 context->ukernel(context->n, x, y, &context->params);
633}
634
635void xnn_compute_univector_contiguous(
636 const struct univector_contiguous_context context[restrict static 1],
637 size_t offset,
638 size_t size)
639{
640 const void* x = (const void*) ((uintptr_t) context->x + offset);
641 void* y = (void*) ((uintptr_t) context->y + offset);
642 context->ukernel(size, x, y, &context->params);
643}
644
645void xnn_compute_u8_softargmax(
646 const struct u8_softargmax_context context[restrict static 1],
647 size_t batch_index)
648{
649 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
650 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
651 const size_t n = context->n;
652
653 uint8_t x_max = 0;
654 context->rmax_ukernel(n, x, &x_max);
655 const size_t adjustment = x_max ^ 255;
656 const uint32_t* t = (const uint32_t*) context->t + adjustment;
657 context->lut_norm_ukernel(n, x, t, y);
658}
659
660void xnn_compute_vmulcaddc(
661 const struct vmulcaddc_context context[restrict static 1],
662 size_t batch_start,
663 size_t batch_size)
664{
665 const size_t x_stride = context->x_stride;
666 const size_t y_stride = context->y_stride;
667
668 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
669 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
670
671 context->ukernel(
672 batch_size,
673 context->n,
674 x, x_stride,
675 context->w,
676 y, y_stride,
677 &context->params);
678}
679
680enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
681{
682 if (!xnn_params.initialized) {
683 xnn_log_error("failed to run operator: XNNPACK is not initialized");
684 return xnn_status_uninitialized;
685 }
686 switch (op->state) {
687 case xnn_run_state_invalid:
688 xnn_log_error("failed to run operator: operator was not successfully setup");
689 return xnn_status_invalid_state;
690 case xnn_run_state_ready:
691 break;
692 case xnn_run_state_skip:
693 return xnn_status_success;
694 }
695
696 switch (op->compute.type) {
697 case xnn_parallelization_type_invalid:
698 break;
699 case xnn_parallelization_type_1d:
700 assert(op->compute.range[0] != 0);
701 pthreadpool_parallelize_1d(
702 threadpool,
703 op->compute.task_1d,
704 &op->context,
705 op->compute.range[0],
706 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
707 break;
708 case xnn_parallelization_type_1d_tile_1d:
709 assert(op->compute.range[0] != 0);
710 assert(op->compute.tile[0] != 0);
711 pthreadpool_parallelize_1d_tile_1d(
712 threadpool,
713 op->compute.task_1d_tile_1d,
714 &op->context,
715 op->compute.range[0],
716 op->compute.tile[0],
717 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
718 break;
719 case xnn_parallelization_type_2d:
720 assert(op->compute.range[0] != 0);
721 assert(op->compute.range[1] != 0);
722 pthreadpool_parallelize_2d(
723 threadpool,
724 op->compute.task_2d,
725 &op->context,
726 op->compute.range[0], op->compute.range[1],
727 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
728 break;
729 case xnn_parallelization_type_2d_tile_1d:
730 assert(op->compute.range[0] != 0);
731 assert(op->compute.range[1] != 0);
732 assert(op->compute.tile[0] != 0);
733 pthreadpool_parallelize_2d_tile_1d(
734 threadpool,
735 op->compute.task_2d_tile_1d,
736 &op->context,
737 op->compute.range[0], op->compute.range[1],
738 op->compute.tile[0],
739 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
740 break;
741 case xnn_parallelization_type_2d_tile_2d:
742 assert(op->compute.range[0] != 0);
743 assert(op->compute.range[1] != 0);
744 assert(op->compute.tile[0] != 0);
745 assert(op->compute.tile[1] != 0);
746 pthreadpool_parallelize_2d_tile_2d(
747 threadpool,
748 op->compute.task_2d_tile_2d,
749 &op->context,
750 op->compute.range[0], op->compute.range[1],
751 op->compute.tile[0], op->compute.tile[1],
752 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
753 break;
754 case xnn_parallelization_type_3d_tile_2d:
755 assert(op->compute.range[0] != 0);
756 assert(op->compute.range[1] != 0);
757 assert(op->compute.range[2] != 0);
758 assert(op->compute.tile[0] != 0);
759 assert(op->compute.tile[1] != 0);
760 pthreadpool_parallelize_3d_tile_2d(
761 threadpool,
762 op->compute.task_3d_tile_2d,
763 &op->context,
764 op->compute.range[0], op->compute.range[1], op->compute.range[2],
765 op->compute.tile[0], op->compute.tile[1],
766 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
767 break;
768 case xnn_parallelization_type_4d_tile_2d:
769 assert(op->compute.range[0] != 0);
770 assert(op->compute.range[1] != 0);
771 assert(op->compute.range[2] != 0);
772 assert(op->compute.range[3] != 0);
773 assert(op->compute.tile[0] != 0);
774 assert(op->compute.tile[1] != 0);
775 pthreadpool_parallelize_4d_tile_2d(
776 threadpool,
777 op->compute.task_4d_tile_2d,
778 &op->context,
779 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
780 op->compute.tile[0], op->compute.tile[1],
781 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
782 break;
783 case xnn_parallelization_type_5d_tile_2d:
784 assert(op->compute.range[0] != 0);
785 assert(op->compute.range[1] != 0);
786 assert(op->compute.range[2] != 0);
787 assert(op->compute.range[3] != 0);
788 assert(op->compute.range[4] != 0);
789 assert(op->compute.tile[0] != 0);
790 assert(op->compute.tile[1] != 0);
791 pthreadpool_parallelize_5d_tile_2d(
792 threadpool,
793 op->compute.task_5d_tile_2d,
794 &op->context,
795 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
796 op->compute.tile[0], op->compute.tile[1],
797 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
798 break;
799 case xnn_parallelization_type_6d_tile_2d:
800 assert(op->compute.range[0] != 0);
801 assert(op->compute.range[1] != 0);
802 assert(op->compute.range[2] != 0);
803 assert(op->compute.range[3] != 0);
804 assert(op->compute.range[4] != 0);
805 assert(op->compute.range[5] != 0);
806 assert(op->compute.tile[0] != 0);
807 assert(op->compute.tile[1] != 0);
808 pthreadpool_parallelize_6d_tile_2d(
809 threadpool,
810 op->compute.task_6d_tile_2d,
811 &op->context,
812 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
813 op->compute.tile[0], op->compute.tile[1],
814 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
815 break;
816 default:
817 XNN_UNREACHABLE;
818 }
819 return xnn_status_success;
820}