blob: 89062737f528d1b73f90c5690aa7eb199643e324 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#include <assert.h>
10#include <stddef.h>
11#include <stdint.h>
12#include <string.h>
13
14#include <xnnpack.h>
Marat Dukhand0cf9bd2020-04-13 20:19:19 -070015#include <xnnpack/allocator.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070016#include <xnnpack/operator.h>
17#include <xnnpack/log.h>
18#include <xnnpack/common.h>
19#include <xnnpack/math.h>
20#include <xnnpack/params.h>
21#include <xnnpack/compute.h>
22
23
Marat Dukhan49a59692020-03-06 16:58:33 -080024void xnn_compute_grouped_gemm(
XNNPACK Teamb455b122019-09-27 18:10:33 -070025 const struct gemm_context context[restrict static 1],
26 size_t group_index,
27 size_t mr_block_start,
28 size_t nr_block_start,
29 size_t mr_block_size,
30 size_t nr_block_size)
31{
32 const size_t k_scaled = context->k_scaled;
33 const size_t a_stride = context->a_stride;
34 const size_t cm_stride = context->cm_stride;
35
Marat Dukhan05702cf2020-03-26 15:41:33 -070036 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -070037 mr_block_size,
38 nr_block_size,
39 k_scaled,
40 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
41 a_stride,
42 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
43 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
44 cm_stride,
45 context->cn_stride,
46 &context->params);
47}
48
49void xnn_compute_gemm(
50 const struct gemm_context context[restrict static 1],
51 size_t mr_block_start,
52 size_t nr_block_start,
53 size_t mr_block_size,
54 size_t nr_block_size)
55{
56 const size_t a_stride = context->a_stride;
57 const size_t cm_stride = context->cm_stride;
58
Marat Dukhan05702cf2020-03-26 15:41:33 -070059 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -070060 mr_block_size,
61 nr_block_size,
62 context->k_scaled,
63 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
64 a_stride,
65 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
66 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
67 cm_stride,
68 context->cn_stride,
69 &context->params);
70}
71
72void xnn_compute_spmm(
73 const struct spmm_context context[restrict static 1],
74 size_t batch_index,
75 size_t mr_block_start,
76 size_t mr_block_size)
77{
78 context->ukernel(
79 mr_block_size,
80 context->n,
81 (const void*) ((uintptr_t) context->a + batch_index * context->batched_a_stride + mr_block_start * sizeof(float)),
82 context->packed_weights,
83 context->input_increments,
84 context->output_channel_nonzeros,
85 (void*) ((uintptr_t) context->c + batch_index * context->batched_c_stride + mr_block_start * sizeof(float)),
86 &context->params);
87}
88
Marat Dukhan49a59692020-03-06 16:58:33 -080089void xnn_compute_grouped_igemm(
XNNPACK Teamb455b122019-09-27 18:10:33 -070090 const struct igemm_context context[restrict static 1],
91 size_t batch_index,
92 size_t group_index,
93 size_t mr_block_start,
94 size_t nr_block_start,
95 size_t mr_block_size,
96 size_t nr_block_size)
97{
98 const size_t ks = context->ks;
99 const size_t cm_stride = context->cm_stride;
100
Marat Dukhan05702cf2020-03-26 15:41:33 -0700101 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -0700102 mr_block_size,
103 nr_block_size,
104 context->kc,
105 context->ks_scaled,
106 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
107 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
108 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
109 cm_stride,
110 context->cn_stride,
111 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
112 context->zero,
113 &context->params);
114}
115
116void xnn_compute_igemm(
117 const struct igemm_context context[restrict static 1],
118 size_t batch_index,
119 size_t mr_block_start,
120 size_t nr_block_start,
121 size_t mr_block_size,
122 size_t nr_block_size)
123{
124 const size_t ks = context->ks;
125 const size_t cm_stride = context->cm_stride;
126
Marat Dukhan05702cf2020-03-26 15:41:33 -0700127 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -0700128 mr_block_size,
129 nr_block_size,
130 context->kc,
131 context->ks_scaled,
132 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
133 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
134 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
135 cm_stride,
136 context->cn_stride,
137 context->a_offset + batch_index * context->ba_stride,
138 context->zero,
139 &context->params);
140}
141
Marat Dukhan49a59692020-03-06 16:58:33 -0800142void xnn_compute_grouped_subgemm2d(
Marat Dukhan29954272020-02-13 17:56:11 -0800143 const struct subgemm_context context[restrict static 1],
144 size_t batch_index,
145 size_t group_index,
146 size_t subkernel_index,
147 size_t slice_y,
148 size_t slice_x_start,
149 size_t nc_block_start,
150 size_t slice_x_max,
151 size_t nc_block_size)
152{
153 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
154
155 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
156 return;
157 }
158
159 const size_t slice_width = subconvolution_params->slice_width;
160 if XNN_UNLIKELY(slice_x_start >= slice_width) {
161 return;
162 }
163 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
164
165 const size_t ax_stride = context->ax_stride;
166 const size_t cx_stride = context->cx_stride;
Marat Dukhan05702cf2020-03-26 15:41:33 -0700167 context->ukernel.function[XNN_UARCH_DEFAULT](
Marat Dukhan29954272020-02-13 17:56:11 -0800168 slice_x_size,
169 nc_block_size,
170 context->kc,
171 (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
172 ax_stride,
173 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
174 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
175 cx_stride,
176 context->cn_stride,
177 &context->params);
178}
179
180void xnn_compute_subgemm2d(
181 const struct subgemm_context context[restrict static 1],
182 size_t batch_index,
183 size_t subkernel_index,
184 size_t slice_y,
185 size_t slice_x_start,
186 size_t nc_block_start,
187 size_t slice_x_max,
188 size_t nc_block_size)
189{
190 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
191
192 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
193 return;
194 }
195
196 const size_t slice_width = subconvolution_params->slice_width;
197 if XNN_UNLIKELY(slice_x_start >= slice_width) {
198 return;
199 }
200 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
201
202 const size_t ax_stride = context->ax_stride;
203 const size_t cx_stride = context->cx_stride;
Marat Dukhan05702cf2020-03-26 15:41:33 -0700204 context->ukernel.function[XNN_UARCH_DEFAULT](
Marat Dukhan29954272020-02-13 17:56:11 -0800205 slice_x_size,
206 nc_block_size,
207 context->kc,
208 (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
209 ax_stride,
210 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
211 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
212 cx_stride,
213 context->cn_stride,
214 &context->params);
215}
216
Marat Dukhan49a59692020-03-06 16:58:33 -0800217void xnn_compute_grouped_subconv2d(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700218 const struct subconv_context context[restrict static 1],
219 size_t batch_index,
220 size_t group_index,
221 size_t subkernel_index,
222 size_t slice_y,
223 size_t slice_x_start,
224 size_t nc_block_start,
225 size_t slice_x_max,
226 size_t nc_block_size)
227{
228 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
229
230 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
231 return;
232 }
233
234 const size_t slice_width = subconvolution_params->slice_width;
235 if XNN_UNLIKELY(slice_x_start >= slice_width) {
236 return;
237 }
238 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
239
240 const size_t cx_stride = context->cx_stride;
Marat Dukhan05702cf2020-03-26 15:41:33 -0700241 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -0700242 slice_x_size,
243 nc_block_size,
244 context->kc,
245 subconvolution_params->scaled_kernel_size,
246 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
247 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
248 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
249 cx_stride,
250 context->cn_stride,
251 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
252 context->zero,
253 &context->params);
254}
255
256void xnn_compute_subconv2d(
257 const struct subconv_context context[restrict static 1],
258 size_t batch_index,
259 size_t subkernel_index,
260 size_t slice_y,
261 size_t slice_x_start,
262 size_t nc_block_start,
263 size_t slice_x_max,
264 size_t nc_block_size)
265{
266 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
267
268 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
269 return;
270 }
271
272 const size_t slice_width = subconvolution_params->slice_width;
273 if XNN_UNLIKELY(slice_x_start >= slice_width) {
274 return;
275 }
276 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
277
278 const size_t cx_stride = context->cx_stride;
Marat Dukhan05702cf2020-03-26 15:41:33 -0700279 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -0700280 slice_x_size,
281 nc_block_size,
282 context->kc,
283 subconvolution_params->scaled_kernel_size,
284 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
285 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
286 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
287 cx_stride,
288 context->cn_stride,
289 context->a_offset + batch_index * context->ba_stride,
290 context->zero,
291 &context->params);
292}
293
294void xnn_compute_dconv2d_hwc2spchw(
295 const struct dconv2d_context context[restrict static 1],
296 size_t batch_index,
297 size_t output_y_start,
298 size_t output_y_slice)
299{
300 context->hwc2spchw_ukernel(
301 context->input_height,
302 context->input_width,
303 output_y_start,
304 output_y_start + output_y_slice,
305 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
306 context->zero,
307 context->packed_weights,
308 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
309 context->input_padding_top,
310 context->output_channels,
311 context->output_height_stride,
312 context->output_channel_stride,
313 &context->params);
314}
315
316void xnn_compute_dwconv_unipass(
317 const struct dwconv_context context[restrict static 1],
318 size_t output_y)
319{
320 context->unipass_ukernel(
321 context->groups,
322 context->output_width,
323 context->indirection_buffer + output_y * context->indirection_buffer_row_stride,
324 context->packed_weights,
Marat Dukhan02072e62020-04-14 02:59:11 -0700325 (void*) ((uintptr_t) context->output + output_y * context->output_row_stride),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700326 context->indirection_buffer_col_stride,
327 context->output_col_increment,
328 &context->params);
329}
330
331void xnn_compute_dwconv2d_spchw(
332 const struct dwconv2d_context context[restrict static 1],
333 size_t batch_index,
334 size_t channel)
335{
336 context->spchw_ukernel(
337 context->output_height,
338 context->input_width,
339 (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
340 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
341 (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
342 context->input_tuple_stride,
343 context->output_tuple_stride,
344 context->input_pixel_stride,
345 context->output_pixel_stride,
346 &context->params);
347}
348
349void xnn_compute_argmax_pooling_unipass(
350 const struct argmax_pooling_context context[restrict static 1],
351 size_t batch_index,
352 size_t output_y)
353{
Marat Dukhan329da642019-11-19 21:44:39 -0800354 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
355 output_y * context->indirect_input_height_stride);
356 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
357 void* output = (void*) ((uintptr_t) context->output +
358 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
359 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
360 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700361
362 context->unipass_ukernel(
363 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800364 indirect_input, input_offset, output, index,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700365 context->input_increment, context->output_increment,
366 &context->params);
367}
368
369void xnn_compute_argmax_pooling_multipass(
370 const struct argmax_pooling_context context[restrict static 1],
371 size_t batch_index,
372 size_t output_y)
373{
Marat Dukhan329da642019-11-19 21:44:39 -0800374 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
375 output_y * context->indirect_input_height_stride);
376 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
377 void* output = (void*) ((uintptr_t) context->output +
378 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
379 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
380 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700381
Marat Dukhand0cf9bd2020-04-13 20:19:19 -0700382 void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES);
383 void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700384
385 context->multipass_ukernel(
386 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800387 indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700388 context->input_increment, context->output_increment,
389 &context->params);
390}
391
392void xnn_compute_max_pooling(
393 const struct max_pooling_context context[restrict static 1],
394 size_t batch_index,
395 size_t output_y)
396{
Marat Dukhan329da642019-11-19 21:44:39 -0800397 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
398 output_y * context->indirect_input_height_stride);
399 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
400 void* output = (void*) ((uintptr_t) context->output +
401 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700402
403 context->ukernel(
404 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800405 indirect_input, input_offset, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700406 context->input_increment, context->output_increment,
407 &context->params);
408}
409
410void xnn_compute_unpooling(
411 const struct unpooling_context context[restrict static 1],
412 size_t input_y,
413 size_t input_x)
414{
415 const void* input = (const void*) ((uintptr_t) context->input +
416 input_y * context->input_height_stride + input_x * context->input_width_stride);
417 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
418 input_y * context->index_height_stride + input_x * context->index_width_stride);
419 void** indirect_output =
420 (void**) ((uintptr_t) context->indirect_output +
421 input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
422
423 context->ukernel(
424 context->pooling_size,
425 context->channels,
426 context->fill_value,
427 input, index, indirect_output);
428}
429
430void xnn_compute_average_pooling_unipass(
431 const struct average_pooling_context context[restrict static 1],
432 size_t batch_index,
433 size_t output_y)
434{
435 const void** indirect_input =
Marat Dukhan96171aa2020-02-27 18:26:48 -0800436 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
437 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
438 void* output = (void*) ((uintptr_t) context->output +
439 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700440
441 context->unipass_ukernel(
442 context->output_width, context->pooling_size, context->channels,
Marat Dukhan96171aa2020-02-27 18:26:48 -0800443 indirect_input, input_offset, context->zero, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700444 context->input_increment, context->output_increment,
445 &context->params);
446}
447
448void xnn_compute_average_pooling_multipass(
449 const struct average_pooling_context context[restrict static 1],
450 size_t batch_index,
451 size_t output_y)
452{
453 const void** indirect_input =
Marat Dukhan96171aa2020-02-27 18:26:48 -0800454 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
455 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
456 void* output = (void*) ((uintptr_t) context->output +
457 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
Marat Dukhand0cf9bd2020-04-13 20:19:19 -0700458
459 void* multipass_buffer =
460 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700461
462 context->multipass_ukernel(
463 context->output_width, context->pooling_size, context->channels,
Marat Dukhan96171aa2020-02-27 18:26:48 -0800464 indirect_input, input_offset, context->zero, multipass_buffer, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700465 context->input_increment, context->output_increment,
466 &context->params);
467}
468
469void xnn_compute_pixelwise_average_pooling_unipass(
470 const struct pixelwise_average_pooling_context context[restrict static 1],
471 size_t batch_index,
472 size_t output_y)
473{
474 const void** indirect_input =
Marat Dukhan96171aa2020-02-27 18:26:48 -0800475 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
476 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700477 const void* pixelwise_buffer =
478 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
Marat Dukhan96171aa2020-02-27 18:26:48 -0800479 void* output = (void*) ((uintptr_t) context->output +
480 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700481
482 context->unipass_ukernel(
483 context->output_width, context->pooling_size, context->channels,
Marat Dukhan96171aa2020-02-27 18:26:48 -0800484 indirect_input, input_offset, context->zero, pixelwise_buffer, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700485 context->input_increment, context->output_increment,
486 &context->params);
487}
488
489void xnn_compute_pixelwise_average_pooling_multipass(
490 const struct pixelwise_average_pooling_context context[restrict static 1],
491 size_t batch_index,
492 size_t output_y)
493{
494 const void** indirect_input =
Marat Dukhan96171aa2020-02-27 18:26:48 -0800495 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
496 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700497 const void* pixelwise_buffer =
498 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
Marat Dukhan96171aa2020-02-27 18:26:48 -0800499 void* output = (void*) ((uintptr_t) context->output +
500 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
Marat Dukhand0cf9bd2020-04-13 20:19:19 -0700501
502 void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700503
504 context->multipass_ukernel(
505 context->output_width, context->pooling_size, context->channels,
Marat Dukhan96171aa2020-02-27 18:26:48 -0800506 indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700507 context->input_increment, context->output_increment,
508 &context->params);
509}
510
Marat Dukhanefc47b82019-11-18 09:25:38 -0800511void xnn_compute_global_average_pooling_nwc_unipass(
512 const struct global_average_pooling_nwc_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700513 size_t batch_index)
514{
515 const void* input =
516 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
517 void* output =
518 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
519
520 context->unipass_ukernel(
521 context->input_elements,
522 context->channels,
523 input,
524 context->input_pixel_stride,
525 context->zero,
526 output,
527 &context->params);
528}
529
Marat Dukhanefc47b82019-11-18 09:25:38 -0800530void xnn_compute_global_average_pooling_nwc_multipass(
531 const struct global_average_pooling_nwc_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700532 size_t batch_index)
533{
534 const void* input =
535 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
536 void* output =
537 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
Marat Dukhand0cf9bd2020-04-13 20:19:19 -0700538
539 void* multipass_buffer =
540 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700541
542 context->multipass_ukernel(
543 context->input_elements,
544 context->channels,
545 input,
546 context->input_pixel_stride,
547 context->zero,
548 multipass_buffer,
549 output,
550 &context->params);
551}
552
Marat Dukhanefc47b82019-11-18 09:25:38 -0800553void xnn_compute_global_average_pooling_ncw(
554 const struct global_average_pooling_ncw_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700555 size_t batch_index,
556 size_t channels_start,
557 size_t channels_slice)
558{
Marat Dukhanefc47b82019-11-18 09:25:38 -0800559 const void* input = (const void*) ((uintptr_t) context->input +
560 channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
561 void* output = (void*) ((uintptr_t) context->output +
562 channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700563
564 context->ukernel(
565 context->input_elements,
566 channels_slice,
567 input,
568 output,
569 &context->params);
570}
571
Marat Dukhan69722492019-11-11 19:55:50 -0800572void xnn_compute_resize_bilinear(
573 const struct resize_bilinear_context context[restrict static 1],
574 size_t batch_index,
575 size_t pixel_start,
576 size_t pixel_range)
577{
578 void* output =
579 (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
580
581 context->ukernel(
582 pixel_range,
583 context->scaled_channels,
584 context->indirect_input + pixel_start * 4,
585 context->input_offset + batch_index * context->input_batch_stride,
Marat Dukhan02072e62020-04-14 02:59:11 -0700586 (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)),
Marat Dukhan69722492019-11-11 19:55:50 -0800587 output,
588 context->output_pixel_stride - context->scaled_channels);
589}
590
XNNPACK Teamb455b122019-09-27 18:10:33 -0700591void xnn_compute_prelu(
592 const struct prelu_context context[restrict static 1],
593 size_t batch_start,
594 size_t batch_range)
595{
596 const size_t x_stride = context->x_stride;
597 const size_t y_stride = context->y_stride;
598 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
599 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
600
Marat Dukhanc8230a42020-02-24 00:00:35 -0800601 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700602}
603
604void xnn_compute_channel_pad(
605 const struct channel_pad_context context[restrict static 1],
606 size_t batch_start,
607 size_t batch_range)
608{
609 const size_t x_stride = context->x_stride;
610 const size_t y_stride = context->y_stride;
611 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
612 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
613
614 context->ukernel(batch_range, context->n, context->l, context->r, context->c, x, x_stride, y, y_stride);
615}
616
617void xnn_compute_add_strided(
618 const struct add_strided_context context[restrict static 1],
619 size_t batch_index,
620 size_t batch_range /* always 1 */)
621{
622 assert(batch_range == 1);
623
624 const size_t n = context->n;
625 const size_t a_stride = context->a_stride;
626 const size_t b_stride = context->b_stride;
627 const size_t y_stride = context->y_stride;
628 const void* a = (const void*) ((uintptr_t) context->a + a_stride * batch_index);
629 const void* b = (const void*) ((uintptr_t) context->b + b_stride * batch_index);
630 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
631
632 context->ukernel(n, a, b, y, &context->params);
633}
634
635void xnn_compute_add_contiguous(
636 const struct add_contiguous_context context[restrict static 1],
637 size_t offset,
638 size_t size)
639{
640 const void* a = (const void*) ((uintptr_t) context->a + offset);
641 const void* b = (const void*) ((uintptr_t) context->b + offset);
642 void* y = (void*) ((uintptr_t) context->y + offset);
643 context->ukernel(size, a, b, y, &context->params);
644}
645
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800646void xnn_compute_elementwise_binary_5d(
Marat Dukhanca2733c2019-11-15 23:21:17 -0800647 const struct elementwise_binary_context context[restrict static 1],
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800648 size_t i, size_t j, size_t k, size_t l, size_t m,
649 size_t l_range, size_t m_range)
Marat Dukhanca2733c2019-11-15 23:21:17 -0800650{
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800651 assert(l_range == 1);
652 assert(m_range == 1);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800653
654 const void* a = (const void*) ((uintptr_t) context->a +
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800655 i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800656 const void* b = (const void*) ((uintptr_t) context->b +
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800657 i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800658 void* y = (void*) ((uintptr_t) context->y +
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800659 i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800660 context->ukernel(context->elements, a, b, y, &context->params);
661}
662
XNNPACK Teamb455b122019-09-27 18:10:33 -0700663void xnn_compute_channel_shuffle_fixed(
664 const struct channel_shuffle_context context[restrict static 1],
665 size_t index)
666{
667 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
668 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
669
670 context->fixed_ukernel(context->n, x, y);
671}
672
673void xnn_compute_channel_shuffle_variable(
674 const struct channel_shuffle_context context[restrict static 1],
675 size_t index)
676{
677 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
678 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
679
680 context->variable_ukernel(context->n, context->m, x, y);
681}
682
683void xnn_compute_lut_strided(
684 const struct lut_strided_context context[restrict static 1],
685 size_t batch_index)
686{
687 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
688 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
689
690 context->ukernel(context->n, x, context->t, y);
691}
692
693void xnn_compute_lut_contiguous(
694 const struct lut_contiguous_context context[restrict static 1],
695 size_t offset,
696 size_t size)
697{
698 const void* x = (const void*) ((uintptr_t) context->x + offset);
699 void* y = (void*) ((uintptr_t) context->y + offset);
700
701 context->ukernel(size, x, context->t, y);
702}
703
704void xnn_compute_univector_strided(
705 const struct univector_strided_context context[restrict static 1],
706 size_t batch_index,
707 size_t batch_range /* always 1 */)
708{
709 assert(batch_range == 1);
710
711 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
712 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
713 context->ukernel(context->n, x, y, &context->params);
714}
715
716void xnn_compute_univector_contiguous(
717 const struct univector_contiguous_context context[restrict static 1],
718 size_t offset,
719 size_t size)
720{
721 const void* x = (const void*) ((uintptr_t) context->x + offset);
722 void* y = (void*) ((uintptr_t) context->y + offset);
723 context->ukernel(size, x, y, &context->params);
724}
725
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800726void xnn_compute_u8_softmax(
727 const struct u8_softmax_context context[restrict static 1],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700728 size_t batch_index)
729{
730 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
731 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
732 const size_t n = context->n;
733
734 uint8_t x_max = 0;
735 context->rmax_ukernel(n, x, &x_max);
736 const size_t adjustment = x_max ^ 255;
737 const uint32_t* t = (const uint32_t*) context->t + adjustment;
738 context->lut_norm_ukernel(n, x, t, y);
739}
740
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800741void xnn_compute_f32_three_pass_softmax(
742 const struct f32_three_pass_softmax_context context[restrict static 1],
Marat Dukhan1edc4542020-01-27 12:40:13 -0800743 size_t batch_index)
744{
745 const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index);
746 float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index);
747 const size_t n = context->n;
748
749 // First pass: reduce-max
750 float x_max;
751 context->rmax_ukernel(n, x, &x_max);
752
753 // Second pass: reduce-add & store exp(x-x_max)
754 float y_sum;
755 context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max);
756
757 // Third pass: scale y
758 const float y_scale = 1.0f / y_sum;
759 context->vmulc_ukernel(n, y, &y_scale, y, &context->params);
760}
761
XNNPACK Teamb455b122019-09-27 18:10:33 -0700762void xnn_compute_vmulcaddc(
763 const struct vmulcaddc_context context[restrict static 1],
764 size_t batch_start,
765 size_t batch_size)
766{
767 const size_t x_stride = context->x_stride;
768 const size_t y_stride = context->y_stride;
769
770 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
771 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
772
773 context->ukernel(
774 batch_size,
775 context->n,
776 x, x_stride,
777 context->w,
778 y, y_stride,
779 &context->params);
780}
781
Marat Dukhan05702cf2020-03-26 15:41:33 -0700782#if XNN_MAX_UARCH_TYPES > 1
783 void xnn_compute_hmp_grouped_gemm(
784 const struct gemm_context context[restrict static 1],
785 uint32_t uarch_index,
786 size_t group_index,
787 size_t mr_block_start,
788 size_t nr_block_start,
789 size_t mr_block_size,
790 size_t nr_block_size)
791 {
792 const size_t k_scaled = context->k_scaled;
793 const size_t a_stride = context->a_stride;
794 const size_t cm_stride = context->cm_stride;
795
796 context->ukernel.function[uarch_index](
797 mr_block_size,
798 nr_block_size,
799 k_scaled,
800 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
801 a_stride,
802 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
803 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
804 cm_stride,
805 context->cn_stride,
806 &context->params);
807 }
808
809 void xnn_compute_hmp_gemm(
810 const struct gemm_context context[restrict static 1],
811 uint32_t uarch_index,
812 size_t mr_block_start,
813 size_t nr_block_start,
814 size_t mr_block_size,
815 size_t nr_block_size)
816 {
817 const size_t a_stride = context->a_stride;
818 const size_t cm_stride = context->cm_stride;
819
820 context->ukernel.function[uarch_index](
821 mr_block_size,
822 nr_block_size,
823 context->k_scaled,
824 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
825 a_stride,
826 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
827 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
828 cm_stride,
829 context->cn_stride,
830 &context->params);
831 }
832
833 void xnn_compute_hmp_grouped_igemm(
834 const struct igemm_context context[restrict static 1],
835 uint32_t uarch_index,
836 size_t batch_index,
837 size_t group_index,
838 size_t mr_block_start,
839 size_t nr_block_start,
840 size_t mr_block_size,
841 size_t nr_block_size)
842 {
843 const size_t ks = context->ks;
844 const size_t cm_stride = context->cm_stride;
845
846 context->ukernel.function[uarch_index](
847 mr_block_size,
848 nr_block_size,
849 context->kc,
850 context->ks_scaled,
851 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
852 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
853 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
854 cm_stride,
855 context->cn_stride,
856 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
857 context->zero,
858 &context->params);
859 }
860
861 void xnn_compute_hmp_igemm(
862 const struct igemm_context context[restrict static 1],
863 uint32_t uarch_index,
864 size_t batch_index,
865 size_t mr_block_start,
866 size_t nr_block_start,
867 size_t mr_block_size,
868 size_t nr_block_size)
869 {
870 const size_t ks = context->ks;
871 const size_t cm_stride = context->cm_stride;
872
873 context->ukernel.function[uarch_index](
874 mr_block_size,
875 nr_block_size,
876 context->kc,
877 context->ks_scaled,
878 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
879 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
880 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
881 cm_stride,
882 context->cn_stride,
883 context->a_offset + batch_index * context->ba_stride,
884 context->zero,
885 &context->params);
886 }
887#endif // XNN_MAX_UARCH_TYPES > 1
888
XNNPACK Teamb455b122019-09-27 18:10:33 -0700889enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
890{
891 if (!xnn_params.initialized) {
892 xnn_log_error("failed to run operator: XNNPACK is not initialized");
893 return xnn_status_uninitialized;
894 }
895 switch (op->state) {
896 case xnn_run_state_invalid:
897 xnn_log_error("failed to run operator: operator was not successfully setup");
898 return xnn_status_invalid_state;
899 case xnn_run_state_ready:
900 break;
901 case xnn_run_state_skip:
902 return xnn_status_success;
903 }
904
905 switch (op->compute.type) {
906 case xnn_parallelization_type_invalid:
907 break;
908 case xnn_parallelization_type_1d:
909 assert(op->compute.range[0] != 0);
910 pthreadpool_parallelize_1d(
911 threadpool,
912 op->compute.task_1d,
913 &op->context,
914 op->compute.range[0],
915 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
916 break;
917 case xnn_parallelization_type_1d_tile_1d:
918 assert(op->compute.range[0] != 0);
919 assert(op->compute.tile[0] != 0);
920 pthreadpool_parallelize_1d_tile_1d(
921 threadpool,
922 op->compute.task_1d_tile_1d,
923 &op->context,
924 op->compute.range[0],
925 op->compute.tile[0],
926 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
927 break;
928 case xnn_parallelization_type_2d:
929 assert(op->compute.range[0] != 0);
930 assert(op->compute.range[1] != 0);
931 pthreadpool_parallelize_2d(
932 threadpool,
933 op->compute.task_2d,
934 &op->context,
935 op->compute.range[0], op->compute.range[1],
936 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
937 break;
938 case xnn_parallelization_type_2d_tile_1d:
939 assert(op->compute.range[0] != 0);
940 assert(op->compute.range[1] != 0);
941 assert(op->compute.tile[0] != 0);
942 pthreadpool_parallelize_2d_tile_1d(
943 threadpool,
944 op->compute.task_2d_tile_1d,
945 &op->context,
946 op->compute.range[0], op->compute.range[1],
947 op->compute.tile[0],
948 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
949 break;
950 case xnn_parallelization_type_2d_tile_2d:
951 assert(op->compute.range[0] != 0);
952 assert(op->compute.range[1] != 0);
953 assert(op->compute.tile[0] != 0);
954 assert(op->compute.tile[1] != 0);
955 pthreadpool_parallelize_2d_tile_2d(
956 threadpool,
957 op->compute.task_2d_tile_2d,
958 &op->context,
959 op->compute.range[0], op->compute.range[1],
960 op->compute.tile[0], op->compute.tile[1],
961 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
962 break;
963 case xnn_parallelization_type_3d_tile_2d:
964 assert(op->compute.range[0] != 0);
965 assert(op->compute.range[1] != 0);
966 assert(op->compute.range[2] != 0);
967 assert(op->compute.tile[0] != 0);
968 assert(op->compute.tile[1] != 0);
969 pthreadpool_parallelize_3d_tile_2d(
970 threadpool,
971 op->compute.task_3d_tile_2d,
972 &op->context,
973 op->compute.range[0], op->compute.range[1], op->compute.range[2],
974 op->compute.tile[0], op->compute.tile[1],
975 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
976 break;
977 case xnn_parallelization_type_4d_tile_2d:
978 assert(op->compute.range[0] != 0);
979 assert(op->compute.range[1] != 0);
980 assert(op->compute.range[2] != 0);
981 assert(op->compute.range[3] != 0);
982 assert(op->compute.tile[0] != 0);
983 assert(op->compute.tile[1] != 0);
984 pthreadpool_parallelize_4d_tile_2d(
985 threadpool,
986 op->compute.task_4d_tile_2d,
987 &op->context,
988 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
989 op->compute.tile[0], op->compute.tile[1],
990 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
991 break;
992 case xnn_parallelization_type_5d_tile_2d:
993 assert(op->compute.range[0] != 0);
994 assert(op->compute.range[1] != 0);
995 assert(op->compute.range[2] != 0);
996 assert(op->compute.range[3] != 0);
997 assert(op->compute.range[4] != 0);
998 assert(op->compute.tile[0] != 0);
999 assert(op->compute.tile[1] != 0);
1000 pthreadpool_parallelize_5d_tile_2d(
1001 threadpool,
1002 op->compute.task_5d_tile_2d,
1003 &op->context,
1004 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1005 op->compute.tile[0], op->compute.tile[1],
1006 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1007 break;
1008 case xnn_parallelization_type_6d_tile_2d:
1009 assert(op->compute.range[0] != 0);
1010 assert(op->compute.range[1] != 0);
1011 assert(op->compute.range[2] != 0);
1012 assert(op->compute.range[3] != 0);
1013 assert(op->compute.range[4] != 0);
1014 assert(op->compute.range[5] != 0);
1015 assert(op->compute.tile[0] != 0);
1016 assert(op->compute.tile[1] != 0);
1017 pthreadpool_parallelize_6d_tile_2d(
1018 threadpool,
1019 op->compute.task_6d_tile_2d,
1020 &op->context,
1021 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
1022 op->compute.tile[0], op->compute.tile[1],
1023 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1024 break;
Marat Dukhan05702cf2020-03-26 15:41:33 -07001025#if XNN_MAX_UARCH_TYPES > 1
1026 case xnn_parallelization_type_2d_tile_2d_with_uarch:
1027 assert(op->compute.range[0] != 0);
1028 assert(op->compute.range[1] != 0);
1029 assert(op->compute.tile[0] != 0);
1030 assert(op->compute.tile[1] != 0);
1031 pthreadpool_parallelize_2d_tile_2d_with_uarch(
1032 threadpool,
1033 op->compute.task_2d_tile_2d_with_id,
1034 &op->context,
1035 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1036 op->compute.range[0], op->compute.range[1],
1037 op->compute.tile[0], op->compute.tile[1],
1038 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1039 break;
1040 case xnn_parallelization_type_3d_tile_2d_with_uarch:
1041 assert(op->compute.range[0] != 0);
1042 assert(op->compute.range[1] != 0);
1043 assert(op->compute.range[2] != 0);
1044 assert(op->compute.tile[0] != 0);
1045 assert(op->compute.tile[1] != 0);
1046 pthreadpool_parallelize_3d_tile_2d_with_uarch(
1047 threadpool,
1048 op->compute.task_3d_tile_2d_with_id,
1049 &op->context,
1050 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1051 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1052 op->compute.tile[0], op->compute.tile[1],
1053 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1054 break;
1055 case xnn_parallelization_type_4d_tile_2d_with_uarch:
1056 assert(op->compute.range[0] != 0);
1057 assert(op->compute.range[1] != 0);
1058 assert(op->compute.range[2] != 0);
1059 assert(op->compute.range[3] != 0);
1060 assert(op->compute.tile[0] != 0);
1061 assert(op->compute.tile[1] != 0);
1062 pthreadpool_parallelize_4d_tile_2d_with_uarch(
1063 threadpool,
1064 op->compute.task_4d_tile_2d_with_id,
1065 &op->context,
1066 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1067 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1068 op->compute.tile[0], op->compute.tile[1],
1069 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1070 break;
1071#endif // XNN_MAX_UARCH_TYPES > 1
XNNPACK Teamb455b122019-09-27 18:10:33 -07001072 default:
1073 XNN_UNREACHABLE;
1074 }
1075 return xnn_status_success;
1076}