blob: 6be05271d9cf0d4d790866398b8da64b462cb915 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#include <assert.h>
10#include <stddef.h>
11#include <stdint.h>
12#include <string.h>
13
14#include <xnnpack.h>
15#include <xnnpack/operator.h>
16#include <xnnpack/log.h>
17#include <xnnpack/common.h>
18#include <xnnpack/math.h>
19#include <xnnpack/params.h>
20#include <xnnpack/compute.h>
21
22
23void xnn_compute_ggemm(
24 const struct gemm_context context[restrict static 1],
25 size_t group_index,
26 size_t mr_block_start,
27 size_t nr_block_start,
28 size_t mr_block_size,
29 size_t nr_block_size)
30{
31 const size_t k_scaled = context->k_scaled;
32 const size_t a_stride = context->a_stride;
33 const size_t cm_stride = context->cm_stride;
34
35 context->ukernel(
36 mr_block_size,
37 nr_block_size,
38 k_scaled,
39 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
40 a_stride,
41 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
42 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
43 cm_stride,
44 context->cn_stride,
45 &context->params);
46}
47
48void xnn_compute_gemm(
49 const struct gemm_context context[restrict static 1],
50 size_t mr_block_start,
51 size_t nr_block_start,
52 size_t mr_block_size,
53 size_t nr_block_size)
54{
55 const size_t a_stride = context->a_stride;
56 const size_t cm_stride = context->cm_stride;
57
58 context->ukernel(
59 mr_block_size,
60 nr_block_size,
61 context->k_scaled,
62 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
63 a_stride,
64 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
65 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
66 cm_stride,
67 context->cn_stride,
68 &context->params);
69}
70
71void xnn_compute_spmm(
72 const struct spmm_context context[restrict static 1],
73 size_t batch_index,
74 size_t mr_block_start,
75 size_t mr_block_size)
76{
77 context->ukernel(
78 mr_block_size,
79 context->n,
80 (const void*) ((uintptr_t) context->a + batch_index * context->batched_a_stride + mr_block_start * sizeof(float)),
81 context->packed_weights,
82 context->input_increments,
83 context->output_channel_nonzeros,
84 (void*) ((uintptr_t) context->c + batch_index * context->batched_c_stride + mr_block_start * sizeof(float)),
85 &context->params);
86}
87
88void xnn_compute_gigemm(
89 const struct igemm_context context[restrict static 1],
90 size_t batch_index,
91 size_t group_index,
92 size_t mr_block_start,
93 size_t nr_block_start,
94 size_t mr_block_size,
95 size_t nr_block_size)
96{
97 const size_t ks = context->ks;
98 const size_t cm_stride = context->cm_stride;
99
100 context->ukernel(
101 mr_block_size,
102 nr_block_size,
103 context->kc,
104 context->ks_scaled,
105 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
106 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
107 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
108 cm_stride,
109 context->cn_stride,
110 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
111 context->zero,
112 &context->params);
113}
114
115void xnn_compute_igemm(
116 const struct igemm_context context[restrict static 1],
117 size_t batch_index,
118 size_t mr_block_start,
119 size_t nr_block_start,
120 size_t mr_block_size,
121 size_t nr_block_size)
122{
123 const size_t ks = context->ks;
124 const size_t cm_stride = context->cm_stride;
125
126 context->ukernel(
127 mr_block_size,
128 nr_block_size,
129 context->kc,
130 context->ks_scaled,
131 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
132 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
133 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
134 cm_stride,
135 context->cn_stride,
136 context->a_offset + batch_index * context->ba_stride,
137 context->zero,
138 &context->params);
139}
140
141void xnn_compute_gsubconv2d(
142 const struct subconv_context context[restrict static 1],
143 size_t batch_index,
144 size_t group_index,
145 size_t subkernel_index,
146 size_t slice_y,
147 size_t slice_x_start,
148 size_t nc_block_start,
149 size_t slice_x_max,
150 size_t nc_block_size)
151{
152 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
153
154 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
155 return;
156 }
157
158 const size_t slice_width = subconvolution_params->slice_width;
159 if XNN_UNLIKELY(slice_x_start >= slice_width) {
160 return;
161 }
162 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
163
164 const size_t cx_stride = context->cx_stride;
165 context->ukernel(
166 slice_x_size,
167 nc_block_size,
168 context->kc,
169 subconvolution_params->scaled_kernel_size,
170 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
171 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
172 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
173 cx_stride,
174 context->cn_stride,
175 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
176 context->zero,
177 &context->params);
178}
179
180void xnn_compute_subconv2d(
181 const struct subconv_context context[restrict static 1],
182 size_t batch_index,
183 size_t subkernel_index,
184 size_t slice_y,
185 size_t slice_x_start,
186 size_t nc_block_start,
187 size_t slice_x_max,
188 size_t nc_block_size)
189{
190 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
191
192 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
193 return;
194 }
195
196 const size_t slice_width = subconvolution_params->slice_width;
197 if XNN_UNLIKELY(slice_x_start >= slice_width) {
198 return;
199 }
200 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
201
202 const size_t cx_stride = context->cx_stride;
203 context->ukernel(
204 slice_x_size,
205 nc_block_size,
206 context->kc,
207 subconvolution_params->scaled_kernel_size,
208 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
209 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
210 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
211 cx_stride,
212 context->cn_stride,
213 context->a_offset + batch_index * context->ba_stride,
214 context->zero,
215 &context->params);
216}
217
218void xnn_compute_dconv2d_hwc2spchw(
219 const struct dconv2d_context context[restrict static 1],
220 size_t batch_index,
221 size_t output_y_start,
222 size_t output_y_slice)
223{
224 context->hwc2spchw_ukernel(
225 context->input_height,
226 context->input_width,
227 output_y_start,
228 output_y_start + output_y_slice,
229 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
230 context->zero,
231 context->packed_weights,
232 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
233 context->input_padding_top,
234 context->output_channels,
235 context->output_height_stride,
236 context->output_channel_stride,
237 &context->params);
238}
239
240void xnn_compute_dwconv_unipass(
241 const struct dwconv_context context[restrict static 1],
242 size_t output_y)
243{
244 context->unipass_ukernel(
245 context->groups,
246 context->output_width,
247 context->indirection_buffer + output_y * context->indirection_buffer_row_stride,
248 context->packed_weights,
249 context->output + output_y * context->output_row_stride,
250 context->indirection_buffer_col_stride,
251 context->output_col_increment,
252 &context->params);
253}
254
255void xnn_compute_dwconv2d_spchw(
256 const struct dwconv2d_context context[restrict static 1],
257 size_t batch_index,
258 size_t channel)
259{
260 context->spchw_ukernel(
261 context->output_height,
262 context->input_width,
263 (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
264 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
265 (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
266 context->input_tuple_stride,
267 context->output_tuple_stride,
268 context->input_pixel_stride,
269 context->output_pixel_stride,
270 &context->params);
271}
272
273void xnn_compute_argmax_pooling_unipass(
274 const struct argmax_pooling_context context[restrict static 1],
275 size_t batch_index,
276 size_t output_y)
277{
278 const void** indirect_input =
279 (const void**) ((uintptr_t) context->indirect_input +
280 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
281 void* output =
282 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
283 uint32_t* index =
284 (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
285
286 context->unipass_ukernel(
287 context->output_width, context->pooling_size, context->channels,
288 indirect_input, output, index,
289 context->input_increment, context->output_increment,
290 &context->params);
291}
292
293void xnn_compute_argmax_pooling_multipass(
294 const struct argmax_pooling_context context[restrict static 1],
295 size_t batch_index,
296 size_t output_y)
297{
298 const void** indirect_input =
299 (const void**) ((uintptr_t) context->indirect_input +
300 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
301 void* output =
302 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
303 uint32_t* index =
304 (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride);
305
306 XNN_ALIGN(16) float multipass_output_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(float)];
307 XNN_ALIGN(16) uint32_t multipass_index_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint32_t)];
308
309 context->multipass_ukernel(
310 context->output_width, context->pooling_size, context->channels,
311 indirect_input, multipass_output_buffer, multipass_index_buffer, output, index,
312 context->input_increment, context->output_increment,
313 &context->params);
314}
315
316void xnn_compute_max_pooling(
317 const struct max_pooling_context context[restrict static 1],
318 size_t batch_index,
319 size_t output_y)
320{
321 const void** indirect_input =
322 (const void**) ((uintptr_t) context->indirect_input +
323 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
324 void* output =
325 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
326
327 context->ukernel(
328 context->output_width, context->pooling_size, context->channels,
329 indirect_input, output,
330 context->input_increment, context->output_increment,
331 &context->params);
332}
333
334void xnn_compute_unpooling(
335 const struct unpooling_context context[restrict static 1],
336 size_t input_y,
337 size_t input_x)
338{
339 const void* input = (const void*) ((uintptr_t) context->input +
340 input_y * context->input_height_stride + input_x * context->input_width_stride);
341 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
342 input_y * context->index_height_stride + input_x * context->index_width_stride);
343 void** indirect_output =
344 (void**) ((uintptr_t) context->indirect_output +
345 input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
346
347 context->ukernel(
348 context->pooling_size,
349 context->channels,
350 context->fill_value,
351 input, index, indirect_output);
352}
353
354void xnn_compute_average_pooling_unipass(
355 const struct average_pooling_context context[restrict static 1],
356 size_t batch_index,
357 size_t output_y)
358{
359 const void** indirect_input =
360 (const void**) ((uintptr_t) context->indirect_input +
361 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
362 void* output =
363 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
364
365 context->unipass_ukernel(
366 context->output_width, context->pooling_size, context->channels,
367 indirect_input, context->zero, output,
368 context->input_increment, context->output_increment,
369 &context->params);
370}
371
372void xnn_compute_average_pooling_multipass(
373 const struct average_pooling_context context[restrict static 1],
374 size_t batch_index,
375 size_t output_y)
376{
377 const void** indirect_input =
378 (const void**) ((uintptr_t) context->indirect_input +
379 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
380 void* output =
381 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
382 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
383
384 context->multipass_ukernel(
385 context->output_width, context->pooling_size, context->channels,
386 indirect_input, context->zero, multipass_buffer, output,
387 context->input_increment, context->output_increment,
388 &context->params);
389}
390
391void xnn_compute_pixelwise_average_pooling_unipass(
392 const struct pixelwise_average_pooling_context context[restrict static 1],
393 size_t batch_index,
394 size_t output_y)
395{
396 const void** indirect_input =
397 (const void**) ((uintptr_t) context->indirect_input +
398 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
399 const void* pixelwise_buffer =
400 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
401 void* output =
402 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
403
404 context->unipass_ukernel(
405 context->output_width, context->pooling_size, context->channels,
406 indirect_input, context->zero, pixelwise_buffer, output,
407 context->input_increment, context->output_increment,
408 &context->params);
409}
410
411void xnn_compute_pixelwise_average_pooling_multipass(
412 const struct pixelwise_average_pooling_context context[restrict static 1],
413 size_t batch_index,
414 size_t output_y)
415{
416 const void** indirect_input =
417 (const void**) ((uintptr_t) context->indirect_input +
418 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
419 const void* pixelwise_buffer =
420 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
421 void* output =
422 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
423 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
424
425 context->multipass_ukernel(
426 context->output_width, context->pooling_size, context->channels,
427 indirect_input, context->zero, pixelwise_buffer, multipass_buffer, output,
428 context->input_increment, context->output_increment,
429 &context->params);
430}
431
432void xnn_compute_global_average_pooling_unipass(
433 const struct global_average_pooling_context context[restrict static 1],
434 size_t batch_index)
435{
436 const void* input =
437 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
438 void* output =
439 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
440
441 context->unipass_ukernel(
442 context->input_elements,
443 context->channels,
444 input,
445 context->input_pixel_stride,
446 context->zero,
447 output,
448 &context->params);
449}
450
451void xnn_compute_global_average_pooling_multipass(
452 const struct global_average_pooling_context context[restrict static 1],
453 size_t batch_index)
454{
455 const void* input =
456 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
457 void* output =
458 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
459 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
460
461 context->multipass_ukernel(
462 context->input_elements,
463 context->channels,
464 input,
465 context->input_pixel_stride,
466 context->zero,
467 multipass_buffer,
468 output,
469 &context->params);
470}
471
472void xnn_compute_global_average_pooling_spnchw(
473 const struct global_average_pooling_spnchw_context context[restrict static 1],
474 size_t batch_index,
475 size_t channels_start,
476 size_t channels_slice)
477{
478 const void* input =
479 (const void*) ((uintptr_t) context->input + channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
480 void* output =
481 (void*) ((uintptr_t) context->output + channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
482
483 context->ukernel(
484 context->input_elements,
485 channels_slice,
486 input,
487 output,
488 &context->params);
489}
490
Marat Dukhan69722492019-11-11 19:55:50 -0800491void xnn_compute_resize_bilinear(
492 const struct resize_bilinear_context context[restrict static 1],
493 size_t batch_index,
494 size_t pixel_start,
495 size_t pixel_range)
496{
497 void* output =
498 (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
499
500 context->ukernel(
501 pixel_range,
502 context->scaled_channels,
503 context->indirect_input + pixel_start * 4,
504 context->input_offset + batch_index * context->input_batch_stride,
505 context->packed_weights + (pixel_start << context->log2_wsize),
506 output,
507 context->output_pixel_stride - context->scaled_channels);
508}
509
XNNPACK Teamb455b122019-09-27 18:10:33 -0700510void xnn_compute_prelu(
511 const struct prelu_context context[restrict static 1],
512 size_t batch_start,
513 size_t batch_range)
514{
515 const size_t x_stride = context->x_stride;
516 const size_t y_stride = context->y_stride;
517 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
518 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
519
520 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride, &context->params);
521}
522
523void xnn_compute_channel_pad(
524 const struct channel_pad_context context[restrict static 1],
525 size_t batch_start,
526 size_t batch_range)
527{
528 const size_t x_stride = context->x_stride;
529 const size_t y_stride = context->y_stride;
530 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
531 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
532
533 context->ukernel(batch_range, context->n, context->l, context->r, context->c, x, x_stride, y, y_stride);
534}
535
536void xnn_compute_add_strided(
537 const struct add_strided_context context[restrict static 1],
538 size_t batch_index,
539 size_t batch_range /* always 1 */)
540{
541 assert(batch_range == 1);
542
543 const size_t n = context->n;
544 const size_t a_stride = context->a_stride;
545 const size_t b_stride = context->b_stride;
546 const size_t y_stride = context->y_stride;
547 const void* a = (const void*) ((uintptr_t) context->a + a_stride * batch_index);
548 const void* b = (const void*) ((uintptr_t) context->b + b_stride * batch_index);
549 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
550
551 context->ukernel(n, a, b, y, &context->params);
552}
553
554void xnn_compute_add_contiguous(
555 const struct add_contiguous_context context[restrict static 1],
556 size_t offset,
557 size_t size)
558{
559 const void* a = (const void*) ((uintptr_t) context->a + offset);
560 const void* b = (const void*) ((uintptr_t) context->b + offset);
561 void* y = (void*) ((uintptr_t) context->y + offset);
562 context->ukernel(size, a, b, y, &context->params);
563}
564
565void xnn_compute_channel_shuffle_fixed(
566 const struct channel_shuffle_context context[restrict static 1],
567 size_t index)
568{
569 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
570 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
571
572 context->fixed_ukernel(context->n, x, y);
573}
574
575void xnn_compute_channel_shuffle_variable(
576 const struct channel_shuffle_context context[restrict static 1],
577 size_t index)
578{
579 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
580 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
581
582 context->variable_ukernel(context->n, context->m, x, y);
583}
584
585void xnn_compute_lut_strided(
586 const struct lut_strided_context context[restrict static 1],
587 size_t batch_index)
588{
589 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
590 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
591
592 context->ukernel(context->n, x, context->t, y);
593}
594
595void xnn_compute_lut_contiguous(
596 const struct lut_contiguous_context context[restrict static 1],
597 size_t offset,
598 size_t size)
599{
600 const void* x = (const void*) ((uintptr_t) context->x + offset);
601 void* y = (void*) ((uintptr_t) context->y + offset);
602
603 context->ukernel(size, x, context->t, y);
604}
605
606void xnn_compute_univector_strided(
607 const struct univector_strided_context context[restrict static 1],
608 size_t batch_index,
609 size_t batch_range /* always 1 */)
610{
611 assert(batch_range == 1);
612
613 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
614 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
615 context->ukernel(context->n, x, y, &context->params);
616}
617
618void xnn_compute_univector_contiguous(
619 const struct univector_contiguous_context context[restrict static 1],
620 size_t offset,
621 size_t size)
622{
623 const void* x = (const void*) ((uintptr_t) context->x + offset);
624 void* y = (void*) ((uintptr_t) context->y + offset);
625 context->ukernel(size, x, y, &context->params);
626}
627
628void xnn_compute_u8_softargmax(
629 const struct u8_softargmax_context context[restrict static 1],
630 size_t batch_index)
631{
632 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
633 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
634 const size_t n = context->n;
635
636 uint8_t x_max = 0;
637 context->rmax_ukernel(n, x, &x_max);
638 const size_t adjustment = x_max ^ 255;
639 const uint32_t* t = (const uint32_t*) context->t + adjustment;
640 context->lut_norm_ukernel(n, x, t, y);
641}
642
643void xnn_compute_vmulcaddc(
644 const struct vmulcaddc_context context[restrict static 1],
645 size_t batch_start,
646 size_t batch_size)
647{
648 const size_t x_stride = context->x_stride;
649 const size_t y_stride = context->y_stride;
650
651 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
652 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
653
654 context->ukernel(
655 batch_size,
656 context->n,
657 x, x_stride,
658 context->w,
659 y, y_stride,
660 &context->params);
661}
662
663enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
664{
665 if (!xnn_params.initialized) {
666 xnn_log_error("failed to run operator: XNNPACK is not initialized");
667 return xnn_status_uninitialized;
668 }
669 switch (op->state) {
670 case xnn_run_state_invalid:
671 xnn_log_error("failed to run operator: operator was not successfully setup");
672 return xnn_status_invalid_state;
673 case xnn_run_state_ready:
674 break;
675 case xnn_run_state_skip:
676 return xnn_status_success;
677 }
678
679 switch (op->compute.type) {
680 case xnn_parallelization_type_invalid:
681 break;
682 case xnn_parallelization_type_1d:
683 assert(op->compute.range[0] != 0);
684 pthreadpool_parallelize_1d(
685 threadpool,
686 op->compute.task_1d,
687 &op->context,
688 op->compute.range[0],
689 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
690 break;
691 case xnn_parallelization_type_1d_tile_1d:
692 assert(op->compute.range[0] != 0);
693 assert(op->compute.tile[0] != 0);
694 pthreadpool_parallelize_1d_tile_1d(
695 threadpool,
696 op->compute.task_1d_tile_1d,
697 &op->context,
698 op->compute.range[0],
699 op->compute.tile[0],
700 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
701 break;
702 case xnn_parallelization_type_2d:
703 assert(op->compute.range[0] != 0);
704 assert(op->compute.range[1] != 0);
705 pthreadpool_parallelize_2d(
706 threadpool,
707 op->compute.task_2d,
708 &op->context,
709 op->compute.range[0], op->compute.range[1],
710 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
711 break;
712 case xnn_parallelization_type_2d_tile_1d:
713 assert(op->compute.range[0] != 0);
714 assert(op->compute.range[1] != 0);
715 assert(op->compute.tile[0] != 0);
716 pthreadpool_parallelize_2d_tile_1d(
717 threadpool,
718 op->compute.task_2d_tile_1d,
719 &op->context,
720 op->compute.range[0], op->compute.range[1],
721 op->compute.tile[0],
722 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
723 break;
724 case xnn_parallelization_type_2d_tile_2d:
725 assert(op->compute.range[0] != 0);
726 assert(op->compute.range[1] != 0);
727 assert(op->compute.tile[0] != 0);
728 assert(op->compute.tile[1] != 0);
729 pthreadpool_parallelize_2d_tile_2d(
730 threadpool,
731 op->compute.task_2d_tile_2d,
732 &op->context,
733 op->compute.range[0], op->compute.range[1],
734 op->compute.tile[0], op->compute.tile[1],
735 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
736 break;
737 case xnn_parallelization_type_3d_tile_2d:
738 assert(op->compute.range[0] != 0);
739 assert(op->compute.range[1] != 0);
740 assert(op->compute.range[2] != 0);
741 assert(op->compute.tile[0] != 0);
742 assert(op->compute.tile[1] != 0);
743 pthreadpool_parallelize_3d_tile_2d(
744 threadpool,
745 op->compute.task_3d_tile_2d,
746 &op->context,
747 op->compute.range[0], op->compute.range[1], op->compute.range[2],
748 op->compute.tile[0], op->compute.tile[1],
749 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
750 break;
751 case xnn_parallelization_type_4d_tile_2d:
752 assert(op->compute.range[0] != 0);
753 assert(op->compute.range[1] != 0);
754 assert(op->compute.range[2] != 0);
755 assert(op->compute.range[3] != 0);
756 assert(op->compute.tile[0] != 0);
757 assert(op->compute.tile[1] != 0);
758 pthreadpool_parallelize_4d_tile_2d(
759 threadpool,
760 op->compute.task_4d_tile_2d,
761 &op->context,
762 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
763 op->compute.tile[0], op->compute.tile[1],
764 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
765 break;
766 case xnn_parallelization_type_5d_tile_2d:
767 assert(op->compute.range[0] != 0);
768 assert(op->compute.range[1] != 0);
769 assert(op->compute.range[2] != 0);
770 assert(op->compute.range[3] != 0);
771 assert(op->compute.range[4] != 0);
772 assert(op->compute.tile[0] != 0);
773 assert(op->compute.tile[1] != 0);
774 pthreadpool_parallelize_5d_tile_2d(
775 threadpool,
776 op->compute.task_5d_tile_2d,
777 &op->context,
778 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
779 op->compute.tile[0], op->compute.tile[1],
780 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
781 break;
782 case xnn_parallelization_type_6d_tile_2d:
783 assert(op->compute.range[0] != 0);
784 assert(op->compute.range[1] != 0);
785 assert(op->compute.range[2] != 0);
786 assert(op->compute.range[3] != 0);
787 assert(op->compute.range[4] != 0);
788 assert(op->compute.range[5] != 0);
789 assert(op->compute.tile[0] != 0);
790 assert(op->compute.tile[1] != 0);
791 pthreadpool_parallelize_6d_tile_2d(
792 threadpool,
793 op->compute.task_6d_tile_2d,
794 &op->context,
795 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
796 op->compute.tile[0], op->compute.tile[1],
797 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
798 break;
799 default:
800 XNN_UNREACHABLE;
801 }
802 return xnn_status_success;
803}