blob: c8f33826440d55d4e4d942b9cae6aef6a3a3b748 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#include <assert.h>
10#include <stddef.h>
11#include <stdint.h>
12#include <string.h>
13
14#include <xnnpack.h>
15#include <xnnpack/operator.h>
16#include <xnnpack/log.h>
17#include <xnnpack/common.h>
18#include <xnnpack/math.h>
19#include <xnnpack/params.h>
20#include <xnnpack/compute.h>
21
22
23void xnn_compute_ggemm(
24 const struct gemm_context context[restrict static 1],
25 size_t group_index,
26 size_t mr_block_start,
27 size_t nr_block_start,
28 size_t mr_block_size,
29 size_t nr_block_size)
30{
31 const size_t k_scaled = context->k_scaled;
32 const size_t a_stride = context->a_stride;
33 const size_t cm_stride = context->cm_stride;
34
35 context->ukernel(
36 mr_block_size,
37 nr_block_size,
38 k_scaled,
39 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
40 a_stride,
41 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
42 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
43 cm_stride,
44 context->cn_stride,
45 &context->params);
46}
47
48void xnn_compute_gemm(
49 const struct gemm_context context[restrict static 1],
50 size_t mr_block_start,
51 size_t nr_block_start,
52 size_t mr_block_size,
53 size_t nr_block_size)
54{
55 const size_t a_stride = context->a_stride;
56 const size_t cm_stride = context->cm_stride;
57
58 context->ukernel(
59 mr_block_size,
60 nr_block_size,
61 context->k_scaled,
62 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
63 a_stride,
64 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
65 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
66 cm_stride,
67 context->cn_stride,
68 &context->params);
69}
70
71void xnn_compute_spmm(
72 const struct spmm_context context[restrict static 1],
73 size_t batch_index,
74 size_t mr_block_start,
75 size_t mr_block_size)
76{
77 context->ukernel(
78 mr_block_size,
79 context->n,
80 (const void*) ((uintptr_t) context->a + batch_index * context->batched_a_stride + mr_block_start * sizeof(float)),
81 context->packed_weights,
82 context->input_increments,
83 context->output_channel_nonzeros,
84 (void*) ((uintptr_t) context->c + batch_index * context->batched_c_stride + mr_block_start * sizeof(float)),
85 &context->params);
86}
87
88void xnn_compute_gigemm(
89 const struct igemm_context context[restrict static 1],
90 size_t batch_index,
91 size_t group_index,
92 size_t mr_block_start,
93 size_t nr_block_start,
94 size_t mr_block_size,
95 size_t nr_block_size)
96{
97 const size_t ks = context->ks;
98 const size_t cm_stride = context->cm_stride;
99
100 context->ukernel(
101 mr_block_size,
102 nr_block_size,
103 context->kc,
104 context->ks_scaled,
105 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
106 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
107 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
108 cm_stride,
109 context->cn_stride,
110 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
111 context->zero,
112 &context->params);
113}
114
115void xnn_compute_igemm(
116 const struct igemm_context context[restrict static 1],
117 size_t batch_index,
118 size_t mr_block_start,
119 size_t nr_block_start,
120 size_t mr_block_size,
121 size_t nr_block_size)
122{
123 const size_t ks = context->ks;
124 const size_t cm_stride = context->cm_stride;
125
126 context->ukernel(
127 mr_block_size,
128 nr_block_size,
129 context->kc,
130 context->ks_scaled,
131 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
132 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
133 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
134 cm_stride,
135 context->cn_stride,
136 context->a_offset + batch_index * context->ba_stride,
137 context->zero,
138 &context->params);
139}
140
141void xnn_compute_gsubconv2d(
142 const struct subconv_context context[restrict static 1],
143 size_t batch_index,
144 size_t group_index,
145 size_t subkernel_index,
146 size_t slice_y,
147 size_t slice_x_start,
148 size_t nc_block_start,
149 size_t slice_x_max,
150 size_t nc_block_size)
151{
152 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
153
154 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
155 return;
156 }
157
158 const size_t slice_width = subconvolution_params->slice_width;
159 if XNN_UNLIKELY(slice_x_start >= slice_width) {
160 return;
161 }
162 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
163
164 const size_t cx_stride = context->cx_stride;
165 context->ukernel(
166 slice_x_size,
167 nc_block_size,
168 context->kc,
169 subconvolution_params->scaled_kernel_size,
170 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
171 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
172 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
173 cx_stride,
174 context->cn_stride,
175 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
176 context->zero,
177 &context->params);
178}
179
180void xnn_compute_subconv2d(
181 const struct subconv_context context[restrict static 1],
182 size_t batch_index,
183 size_t subkernel_index,
184 size_t slice_y,
185 size_t slice_x_start,
186 size_t nc_block_start,
187 size_t slice_x_max,
188 size_t nc_block_size)
189{
190 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
191
192 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
193 return;
194 }
195
196 const size_t slice_width = subconvolution_params->slice_width;
197 if XNN_UNLIKELY(slice_x_start >= slice_width) {
198 return;
199 }
200 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
201
202 const size_t cx_stride = context->cx_stride;
203 context->ukernel(
204 slice_x_size,
205 nc_block_size,
206 context->kc,
207 subconvolution_params->scaled_kernel_size,
208 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
209 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
210 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
211 cx_stride,
212 context->cn_stride,
213 context->a_offset + batch_index * context->ba_stride,
214 context->zero,
215 &context->params);
216}
217
218void xnn_compute_dconv2d_hwc2spchw(
219 const struct dconv2d_context context[restrict static 1],
220 size_t batch_index,
221 size_t output_y_start,
222 size_t output_y_slice)
223{
224 context->hwc2spchw_ukernel(
225 context->input_height,
226 context->input_width,
227 output_y_start,
228 output_y_start + output_y_slice,
229 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
230 context->zero,
231 context->packed_weights,
232 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
233 context->input_padding_top,
234 context->output_channels,
235 context->output_height_stride,
236 context->output_channel_stride,
237 &context->params);
238}
239
240void xnn_compute_dwconv_unipass(
241 const struct dwconv_context context[restrict static 1],
242 size_t output_y)
243{
244 context->unipass_ukernel(
245 context->groups,
246 context->output_width,
247 context->indirection_buffer + output_y * context->indirection_buffer_row_stride,
248 context->packed_weights,
249 context->output + output_y * context->output_row_stride,
250 context->indirection_buffer_col_stride,
251 context->output_col_increment,
252 &context->params);
253}
254
255void xnn_compute_dwconv2d_spchw(
256 const struct dwconv2d_context context[restrict static 1],
257 size_t batch_index,
258 size_t channel)
259{
260 context->spchw_ukernel(
261 context->output_height,
262 context->input_width,
263 (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
264 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
265 (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
266 context->input_tuple_stride,
267 context->output_tuple_stride,
268 context->input_pixel_stride,
269 context->output_pixel_stride,
270 &context->params);
271}
272
273void xnn_compute_argmax_pooling_unipass(
274 const struct argmax_pooling_context context[restrict static 1],
275 size_t batch_index,
276 size_t output_y)
277{
Marat Dukhan329da642019-11-19 21:44:39 -0800278 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
279 output_y * context->indirect_input_height_stride);
280 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
281 void* output = (void*) ((uintptr_t) context->output +
282 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
283 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
284 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700285
286 context->unipass_ukernel(
287 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800288 indirect_input, input_offset, output, index,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700289 context->input_increment, context->output_increment,
290 &context->params);
291}
292
293void xnn_compute_argmax_pooling_multipass(
294 const struct argmax_pooling_context context[restrict static 1],
295 size_t batch_index,
296 size_t output_y)
297{
Marat Dukhan329da642019-11-19 21:44:39 -0800298 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
299 output_y * context->indirect_input_height_stride);
300 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
301 void* output = (void*) ((uintptr_t) context->output +
302 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
303 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
304 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700305
Marat Dukhan329da642019-11-19 21:44:39 -0800306 XNN_ALIGN(16) float multipass_accumulation_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(float)];
XNNPACK Teamb455b122019-09-27 18:10:33 -0700307 XNN_ALIGN(16) uint32_t multipass_index_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint32_t)];
308
309 context->multipass_ukernel(
310 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800311 indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700312 context->input_increment, context->output_increment,
313 &context->params);
314}
315
316void xnn_compute_max_pooling(
317 const struct max_pooling_context context[restrict static 1],
318 size_t batch_index,
319 size_t output_y)
320{
Marat Dukhan329da642019-11-19 21:44:39 -0800321 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
322 output_y * context->indirect_input_height_stride);
323 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
324 void* output = (void*) ((uintptr_t) context->output +
325 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700326
327 context->ukernel(
328 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800329 indirect_input, input_offset, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700330 context->input_increment, context->output_increment,
331 &context->params);
332}
333
334void xnn_compute_unpooling(
335 const struct unpooling_context context[restrict static 1],
336 size_t input_y,
337 size_t input_x)
338{
339 const void* input = (const void*) ((uintptr_t) context->input +
340 input_y * context->input_height_stride + input_x * context->input_width_stride);
341 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
342 input_y * context->index_height_stride + input_x * context->index_width_stride);
343 void** indirect_output =
344 (void**) ((uintptr_t) context->indirect_output +
345 input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
346
347 context->ukernel(
348 context->pooling_size,
349 context->channels,
350 context->fill_value,
351 input, index, indirect_output);
352}
353
354void xnn_compute_average_pooling_unipass(
355 const struct average_pooling_context context[restrict static 1],
356 size_t batch_index,
357 size_t output_y)
358{
359 const void** indirect_input =
360 (const void**) ((uintptr_t) context->indirect_input +
361 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
362 void* output =
363 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
364
365 context->unipass_ukernel(
366 context->output_width, context->pooling_size, context->channels,
367 indirect_input, context->zero, output,
368 context->input_increment, context->output_increment,
369 &context->params);
370}
371
372void xnn_compute_average_pooling_multipass(
373 const struct average_pooling_context context[restrict static 1],
374 size_t batch_index,
375 size_t output_y)
376{
377 const void** indirect_input =
378 (const void**) ((uintptr_t) context->indirect_input +
379 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
380 void* output =
381 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
382 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
383
384 context->multipass_ukernel(
385 context->output_width, context->pooling_size, context->channels,
386 indirect_input, context->zero, multipass_buffer, output,
387 context->input_increment, context->output_increment,
388 &context->params);
389}
390
391void xnn_compute_pixelwise_average_pooling_unipass(
392 const struct pixelwise_average_pooling_context context[restrict static 1],
393 size_t batch_index,
394 size_t output_y)
395{
396 const void** indirect_input =
397 (const void**) ((uintptr_t) context->indirect_input +
398 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
399 const void* pixelwise_buffer =
400 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
401 void* output =
402 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
403
404 context->unipass_ukernel(
405 context->output_width, context->pooling_size, context->channels,
406 indirect_input, context->zero, pixelwise_buffer, output,
407 context->input_increment, context->output_increment,
408 &context->params);
409}
410
411void xnn_compute_pixelwise_average_pooling_multipass(
412 const struct pixelwise_average_pooling_context context[restrict static 1],
413 size_t batch_index,
414 size_t output_y)
415{
416 const void** indirect_input =
417 (const void**) ((uintptr_t) context->indirect_input +
418 batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
419 const void* pixelwise_buffer =
420 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
421 void* output =
422 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
423 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
424
425 context->multipass_ukernel(
426 context->output_width, context->pooling_size, context->channels,
427 indirect_input, context->zero, pixelwise_buffer, multipass_buffer, output,
428 context->input_increment, context->output_increment,
429 &context->params);
430}
431
Marat Dukhanefc47b82019-11-18 09:25:38 -0800432void xnn_compute_global_average_pooling_nwc_unipass(
433 const struct global_average_pooling_nwc_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700434 size_t batch_index)
435{
436 const void* input =
437 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
438 void* output =
439 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
440
441 context->unipass_ukernel(
442 context->input_elements,
443 context->channels,
444 input,
445 context->input_pixel_stride,
446 context->zero,
447 output,
448 &context->params);
449}
450
Marat Dukhanefc47b82019-11-18 09:25:38 -0800451void xnn_compute_global_average_pooling_nwc_multipass(
452 const struct global_average_pooling_nwc_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700453 size_t batch_index)
454{
455 const void* input =
456 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
457 void* output =
458 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
459 XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
460
461 context->multipass_ukernel(
462 context->input_elements,
463 context->channels,
464 input,
465 context->input_pixel_stride,
466 context->zero,
467 multipass_buffer,
468 output,
469 &context->params);
470}
471
Marat Dukhanefc47b82019-11-18 09:25:38 -0800472void xnn_compute_global_average_pooling_ncw(
473 const struct global_average_pooling_ncw_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700474 size_t batch_index,
475 size_t channels_start,
476 size_t channels_slice)
477{
Marat Dukhanefc47b82019-11-18 09:25:38 -0800478 const void* input = (const void*) ((uintptr_t) context->input +
479 channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
480 void* output = (void*) ((uintptr_t) context->output +
481 channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700482
483 context->ukernel(
484 context->input_elements,
485 channels_slice,
486 input,
487 output,
488 &context->params);
489}
490
Marat Dukhan69722492019-11-11 19:55:50 -0800491void xnn_compute_resize_bilinear(
492 const struct resize_bilinear_context context[restrict static 1],
493 size_t batch_index,
494 size_t pixel_start,
495 size_t pixel_range)
496{
497 void* output =
498 (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
499
500 context->ukernel(
501 pixel_range,
502 context->scaled_channels,
503 context->indirect_input + pixel_start * 4,
504 context->input_offset + batch_index * context->input_batch_stride,
505 context->packed_weights + (pixel_start << context->log2_wsize),
506 output,
507 context->output_pixel_stride - context->scaled_channels);
508}
509
XNNPACK Teamb455b122019-09-27 18:10:33 -0700510void xnn_compute_prelu(
511 const struct prelu_context context[restrict static 1],
512 size_t batch_start,
513 size_t batch_range)
514{
515 const size_t x_stride = context->x_stride;
516 const size_t y_stride = context->y_stride;
517 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
518 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
519
520 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride, &context->params);
521}
522
523void xnn_compute_channel_pad(
524 const struct channel_pad_context context[restrict static 1],
525 size_t batch_start,
526 size_t batch_range)
527{
528 const size_t x_stride = context->x_stride;
529 const size_t y_stride = context->y_stride;
530 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
531 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
532
533 context->ukernel(batch_range, context->n, context->l, context->r, context->c, x, x_stride, y, y_stride);
534}
535
536void xnn_compute_add_strided(
537 const struct add_strided_context context[restrict static 1],
538 size_t batch_index,
539 size_t batch_range /* always 1 */)
540{
541 assert(batch_range == 1);
542
543 const size_t n = context->n;
544 const size_t a_stride = context->a_stride;
545 const size_t b_stride = context->b_stride;
546 const size_t y_stride = context->y_stride;
547 const void* a = (const void*) ((uintptr_t) context->a + a_stride * batch_index);
548 const void* b = (const void*) ((uintptr_t) context->b + b_stride * batch_index);
549 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
550
551 context->ukernel(n, a, b, y, &context->params);
552}
553
554void xnn_compute_add_contiguous(
555 const struct add_contiguous_context context[restrict static 1],
556 size_t offset,
557 size_t size)
558{
559 const void* a = (const void*) ((uintptr_t) context->a + offset);
560 const void* b = (const void*) ((uintptr_t) context->b + offset);
561 void* y = (void*) ((uintptr_t) context->y + offset);
562 context->ukernel(size, a, b, y, &context->params);
563}
564
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800565void xnn_compute_elementwise_binary_5d(
Marat Dukhanca2733c2019-11-15 23:21:17 -0800566 const struct elementwise_binary_context context[restrict static 1],
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800567 size_t i, size_t j, size_t k, size_t l, size_t m,
568 size_t l_range, size_t m_range)
Marat Dukhanca2733c2019-11-15 23:21:17 -0800569{
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800570 assert(l_range == 1);
571 assert(m_range == 1);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800572
573 const void* a = (const void*) ((uintptr_t) context->a +
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800574 i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800575 const void* b = (const void*) ((uintptr_t) context->b +
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800576 i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800577 void* y = (void*) ((uintptr_t) context->y +
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800578 i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800579 context->ukernel(context->elements, a, b, y, &context->params);
580}
581
XNNPACK Teamb455b122019-09-27 18:10:33 -0700582void xnn_compute_channel_shuffle_fixed(
583 const struct channel_shuffle_context context[restrict static 1],
584 size_t index)
585{
586 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
587 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
588
589 context->fixed_ukernel(context->n, x, y);
590}
591
592void xnn_compute_channel_shuffle_variable(
593 const struct channel_shuffle_context context[restrict static 1],
594 size_t index)
595{
596 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
597 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
598
599 context->variable_ukernel(context->n, context->m, x, y);
600}
601
602void xnn_compute_lut_strided(
603 const struct lut_strided_context context[restrict static 1],
604 size_t batch_index)
605{
606 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
607 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
608
609 context->ukernel(context->n, x, context->t, y);
610}
611
612void xnn_compute_lut_contiguous(
613 const struct lut_contiguous_context context[restrict static 1],
614 size_t offset,
615 size_t size)
616{
617 const void* x = (const void*) ((uintptr_t) context->x + offset);
618 void* y = (void*) ((uintptr_t) context->y + offset);
619
620 context->ukernel(size, x, context->t, y);
621}
622
623void xnn_compute_univector_strided(
624 const struct univector_strided_context context[restrict static 1],
625 size_t batch_index,
626 size_t batch_range /* always 1 */)
627{
628 assert(batch_range == 1);
629
630 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
631 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
632 context->ukernel(context->n, x, y, &context->params);
633}
634
635void xnn_compute_univector_contiguous(
636 const struct univector_contiguous_context context[restrict static 1],
637 size_t offset,
638 size_t size)
639{
640 const void* x = (const void*) ((uintptr_t) context->x + offset);
641 void* y = (void*) ((uintptr_t) context->y + offset);
642 context->ukernel(size, x, y, &context->params);
643}
644
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800645void xnn_compute_u8_softmax(
646 const struct u8_softmax_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700647 size_t batch_index)
648{
649 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
650 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
651 const size_t n = context->n;
652
653 uint8_t x_max = 0;
654 context->rmax_ukernel(n, x, &x_max);
655 const size_t adjustment = x_max ^ 255;
656 const uint32_t* t = (const uint32_t*) context->t + adjustment;
657 context->lut_norm_ukernel(n, x, t, y);
658}
659
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800660void xnn_compute_f32_three_pass_softmax(
661 const struct f32_three_pass_softmax_context context[restrict static 1],
Marat Dukhan1edc4542020-01-27 12:40:13 -0800662 size_t batch_index)
663{
664 const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index);
665 float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index);
666 const size_t n = context->n;
667
668 // First pass: reduce-max
669 float x_max;
670 context->rmax_ukernel(n, x, &x_max);
671
672 // Second pass: reduce-add & store exp(x-x_max)
673 float y_sum;
674 context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max);
675
676 // Third pass: scale y
677 const float y_scale = 1.0f / y_sum;
678 context->vmulc_ukernel(n, y, &y_scale, y, &context->params);
679}
680
XNNPACK Teamb455b122019-09-27 18:10:33 -0700681void xnn_compute_vmulcaddc(
682 const struct vmulcaddc_context context[restrict static 1],
683 size_t batch_start,
684 size_t batch_size)
685{
686 const size_t x_stride = context->x_stride;
687 const size_t y_stride = context->y_stride;
688
689 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
690 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
691
692 context->ukernel(
693 batch_size,
694 context->n,
695 x, x_stride,
696 context->w,
697 y, y_stride,
698 &context->params);
699}
700
701enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
702{
703 if (!xnn_params.initialized) {
704 xnn_log_error("failed to run operator: XNNPACK is not initialized");
705 return xnn_status_uninitialized;
706 }
707 switch (op->state) {
708 case xnn_run_state_invalid:
709 xnn_log_error("failed to run operator: operator was not successfully setup");
710 return xnn_status_invalid_state;
711 case xnn_run_state_ready:
712 break;
713 case xnn_run_state_skip:
714 return xnn_status_success;
715 }
716
717 switch (op->compute.type) {
718 case xnn_parallelization_type_invalid:
719 break;
720 case xnn_parallelization_type_1d:
721 assert(op->compute.range[0] != 0);
722 pthreadpool_parallelize_1d(
723 threadpool,
724 op->compute.task_1d,
725 &op->context,
726 op->compute.range[0],
727 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
728 break;
729 case xnn_parallelization_type_1d_tile_1d:
730 assert(op->compute.range[0] != 0);
731 assert(op->compute.tile[0] != 0);
732 pthreadpool_parallelize_1d_tile_1d(
733 threadpool,
734 op->compute.task_1d_tile_1d,
735 &op->context,
736 op->compute.range[0],
737 op->compute.tile[0],
738 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
739 break;
740 case xnn_parallelization_type_2d:
741 assert(op->compute.range[0] != 0);
742 assert(op->compute.range[1] != 0);
743 pthreadpool_parallelize_2d(
744 threadpool,
745 op->compute.task_2d,
746 &op->context,
747 op->compute.range[0], op->compute.range[1],
748 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
749 break;
750 case xnn_parallelization_type_2d_tile_1d:
751 assert(op->compute.range[0] != 0);
752 assert(op->compute.range[1] != 0);
753 assert(op->compute.tile[0] != 0);
754 pthreadpool_parallelize_2d_tile_1d(
755 threadpool,
756 op->compute.task_2d_tile_1d,
757 &op->context,
758 op->compute.range[0], op->compute.range[1],
759 op->compute.tile[0],
760 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
761 break;
762 case xnn_parallelization_type_2d_tile_2d:
763 assert(op->compute.range[0] != 0);
764 assert(op->compute.range[1] != 0);
765 assert(op->compute.tile[0] != 0);
766 assert(op->compute.tile[1] != 0);
767 pthreadpool_parallelize_2d_tile_2d(
768 threadpool,
769 op->compute.task_2d_tile_2d,
770 &op->context,
771 op->compute.range[0], op->compute.range[1],
772 op->compute.tile[0], op->compute.tile[1],
773 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
774 break;
775 case xnn_parallelization_type_3d_tile_2d:
776 assert(op->compute.range[0] != 0);
777 assert(op->compute.range[1] != 0);
778 assert(op->compute.range[2] != 0);
779 assert(op->compute.tile[0] != 0);
780 assert(op->compute.tile[1] != 0);
781 pthreadpool_parallelize_3d_tile_2d(
782 threadpool,
783 op->compute.task_3d_tile_2d,
784 &op->context,
785 op->compute.range[0], op->compute.range[1], op->compute.range[2],
786 op->compute.tile[0], op->compute.tile[1],
787 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
788 break;
789 case xnn_parallelization_type_4d_tile_2d:
790 assert(op->compute.range[0] != 0);
791 assert(op->compute.range[1] != 0);
792 assert(op->compute.range[2] != 0);
793 assert(op->compute.range[3] != 0);
794 assert(op->compute.tile[0] != 0);
795 assert(op->compute.tile[1] != 0);
796 pthreadpool_parallelize_4d_tile_2d(
797 threadpool,
798 op->compute.task_4d_tile_2d,
799 &op->context,
800 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
801 op->compute.tile[0], op->compute.tile[1],
802 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
803 break;
804 case xnn_parallelization_type_5d_tile_2d:
805 assert(op->compute.range[0] != 0);
806 assert(op->compute.range[1] != 0);
807 assert(op->compute.range[2] != 0);
808 assert(op->compute.range[3] != 0);
809 assert(op->compute.range[4] != 0);
810 assert(op->compute.tile[0] != 0);
811 assert(op->compute.tile[1] != 0);
812 pthreadpool_parallelize_5d_tile_2d(
813 threadpool,
814 op->compute.task_5d_tile_2d,
815 &op->context,
816 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
817 op->compute.tile[0], op->compute.tile[1],
818 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
819 break;
820 case xnn_parallelization_type_6d_tile_2d:
821 assert(op->compute.range[0] != 0);
822 assert(op->compute.range[1] != 0);
823 assert(op->compute.range[2] != 0);
824 assert(op->compute.range[3] != 0);
825 assert(op->compute.range[4] != 0);
826 assert(op->compute.range[5] != 0);
827 assert(op->compute.tile[0] != 0);
828 assert(op->compute.tile[1] != 0);
829 pthreadpool_parallelize_6d_tile_2d(
830 threadpool,
831 op->compute.task_6d_tile_2d,
832 &op->context,
833 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
834 op->compute.tile[0], op->compute.tile[1],
835 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
836 break;
837 default:
838 XNN_UNREACHABLE;
839 }
840 return xnn_status_success;
841}