blob: a21565126e2be6b6658e6f175cd405a4fa3d1c99 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#include <assert.h>
10#include <stddef.h>
11#include <stdint.h>
12#include <string.h>
13
14#include <xnnpack.h>
Marat Dukhand0cf9bd2020-04-13 20:19:19 -070015#include <xnnpack/allocator.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070016#include <xnnpack/operator.h>
17#include <xnnpack/log.h>
18#include <xnnpack/common.h>
19#include <xnnpack/math.h>
20#include <xnnpack/params.h>
21#include <xnnpack/compute.h>
22
23
Marat Dukhan49a59692020-03-06 16:58:33 -080024void xnn_compute_grouped_gemm(
Marat Dukhanf196d012020-04-15 11:50:03 -070025 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -070026 size_t group_index,
27 size_t mr_block_start,
28 size_t nr_block_start,
29 size_t mr_block_size,
30 size_t nr_block_size)
31{
32 const size_t k_scaled = context->k_scaled;
33 const size_t a_stride = context->a_stride;
34 const size_t cm_stride = context->cm_stride;
35
Marat Dukhan05702cf2020-03-26 15:41:33 -070036 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -070037 mr_block_size,
38 nr_block_size,
39 k_scaled,
40 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
41 a_stride,
42 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
43 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
44 cm_stride,
45 context->cn_stride,
46 &context->params);
47}
48
49void xnn_compute_gemm(
Marat Dukhanf196d012020-04-15 11:50:03 -070050 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -070051 size_t mr_block_start,
52 size_t nr_block_start,
53 size_t mr_block_size,
54 size_t nr_block_size)
55{
56 const size_t a_stride = context->a_stride;
57 const size_t cm_stride = context->cm_stride;
58
Marat Dukhan05702cf2020-03-26 15:41:33 -070059 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -070060 mr_block_size,
61 nr_block_size,
62 context->k_scaled,
63 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
64 a_stride,
65 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
66 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
67 cm_stride,
68 context->cn_stride,
69 &context->params);
70}
71
72void xnn_compute_spmm(
Marat Dukhanf196d012020-04-15 11:50:03 -070073 const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -070074 size_t batch_index,
75 size_t mr_block_start,
76 size_t mr_block_size)
77{
78 context->ukernel(
79 mr_block_size,
80 context->n,
Marat Dukhane8bfcc82020-11-16 12:28:13 -080081 (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start),
82 context->nonzero_weights,
XNNPACK Teamb455b122019-09-27 18:10:33 -070083 context->input_increments,
84 context->output_channel_nonzeros,
Marat Dukhane8bfcc82020-11-16 12:28:13 -080085 (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start),
86 context->scaled_m,
XNNPACK Teamb455b122019-09-27 18:10:33 -070087 &context->params);
88}
89
Fabio Riccardic2146cc2020-08-19 11:15:24 -070090void xnn_compute_grouped_batch_igemm(
Marat Dukhanf196d012020-04-15 11:50:03 -070091 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -070092 size_t batch_index,
93 size_t group_index,
94 size_t mr_block_start,
95 size_t nr_block_start,
96 size_t mr_block_size,
97 size_t nr_block_size)
98{
99 const size_t ks = context->ks;
100 const size_t cm_stride = context->cm_stride;
101
Marat Dukhan05702cf2020-03-26 15:41:33 -0700102 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -0700103 mr_block_size,
104 nr_block_size,
105 context->kc,
106 context->ks_scaled,
107 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
108 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
109 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
110 cm_stride,
111 context->cn_stride,
112 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
113 context->zero,
114 &context->params);
115}
116
Fabio Riccardic2146cc2020-08-19 11:15:24 -0700117void xnn_compute_grouped_igemm(
118 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
119 size_t group_index,
120 size_t mr_block_start,
121 size_t nr_block_start,
122 size_t mr_block_size,
123 size_t nr_block_size)
124{
125 const size_t ks = context->ks;
126 const size_t cm_stride = context->cm_stride;
127
128 context->ukernel.function[XNN_UARCH_DEFAULT](
129 mr_block_size,
130 nr_block_size,
131 context->kc,
132 context->ks_scaled,
133 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
134 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
135 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
136 cm_stride,
137 context->cn_stride,
138 context->a_offset + group_index * context->ga_stride,
139 context->zero,
140 &context->params);
141}
142
143void xnn_compute_batch_igemm(
Marat Dukhanf196d012020-04-15 11:50:03 -0700144 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700145 size_t batch_index,
146 size_t mr_block_start,
147 size_t nr_block_start,
148 size_t mr_block_size,
149 size_t nr_block_size)
150{
151 const size_t ks = context->ks;
152 const size_t cm_stride = context->cm_stride;
153
Marat Dukhan05702cf2020-03-26 15:41:33 -0700154 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -0700155 mr_block_size,
156 nr_block_size,
157 context->kc,
158 context->ks_scaled,
159 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
160 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
161 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
162 cm_stride,
163 context->cn_stride,
164 context->a_offset + batch_index * context->ba_stride,
165 context->zero,
166 &context->params);
167}
168
Fabio Riccardic2146cc2020-08-19 11:15:24 -0700169void xnn_compute_igemm(
170 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
171 size_t mr_block_start,
172 size_t nr_block_start,
173 size_t mr_block_size,
174 size_t nr_block_size)
175{
176 const size_t ks = context->ks;
177 const size_t cm_stride = context->cm_stride;
178
179 context->ukernel.function[XNN_UARCH_DEFAULT](
180 mr_block_size,
181 nr_block_size,
182 context->kc,
183 context->ks_scaled,
184 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
185 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
186 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
187 cm_stride,
188 context->cn_stride,
189 context->a_offset,
190 context->zero,
191 &context->params);
192}
193
Marat Dukhan49a59692020-03-06 16:58:33 -0800194void xnn_compute_grouped_subgemm2d(
Marat Dukhanf196d012020-04-15 11:50:03 -0700195 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhan29954272020-02-13 17:56:11 -0800196 size_t batch_index,
197 size_t group_index,
198 size_t subkernel_index,
199 size_t slice_y,
200 size_t slice_x_start,
201 size_t nc_block_start,
202 size_t slice_x_max,
203 size_t nc_block_size)
204{
205 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
206
207 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
208 return;
209 }
210
211 const size_t slice_width = subconvolution_params->slice_width;
212 if XNN_UNLIKELY(slice_x_start >= slice_width) {
213 return;
214 }
215 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
216
217 const size_t ax_stride = context->ax_stride;
218 const size_t cx_stride = context->cx_stride;
Marat Dukhan05702cf2020-03-26 15:41:33 -0700219 context->ukernel.function[XNN_UARCH_DEFAULT](
Marat Dukhan29954272020-02-13 17:56:11 -0800220 slice_x_size,
221 nc_block_size,
222 context->kc,
223 (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
224 ax_stride,
225 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
226 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
227 cx_stride,
228 context->cn_stride,
229 &context->params);
230}
231
232void xnn_compute_subgemm2d(
Marat Dukhanf196d012020-04-15 11:50:03 -0700233 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhan29954272020-02-13 17:56:11 -0800234 size_t batch_index,
235 size_t subkernel_index,
236 size_t slice_y,
237 size_t slice_x_start,
238 size_t nc_block_start,
239 size_t slice_x_max,
240 size_t nc_block_size)
241{
242 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
243
244 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
245 return;
246 }
247
248 const size_t slice_width = subconvolution_params->slice_width;
249 if XNN_UNLIKELY(slice_x_start >= slice_width) {
250 return;
251 }
252 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
253
254 const size_t ax_stride = context->ax_stride;
255 const size_t cx_stride = context->cx_stride;
Marat Dukhan05702cf2020-03-26 15:41:33 -0700256 context->ukernel.function[XNN_UARCH_DEFAULT](
Marat Dukhan29954272020-02-13 17:56:11 -0800257 slice_x_size,
258 nc_block_size,
259 context->kc,
260 (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
261 ax_stride,
262 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
263 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
264 cx_stride,
265 context->cn_stride,
266 &context->params);
267}
268
Marat Dukhan49a59692020-03-06 16:58:33 -0800269void xnn_compute_grouped_subconv2d(
Marat Dukhanf196d012020-04-15 11:50:03 -0700270 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700271 size_t batch_index,
272 size_t group_index,
273 size_t subkernel_index,
274 size_t slice_y,
275 size_t slice_x_start,
276 size_t nc_block_start,
277 size_t slice_x_max,
278 size_t nc_block_size)
279{
280 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
281
282 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
283 return;
284 }
285
286 const size_t slice_width = subconvolution_params->slice_width;
287 if XNN_UNLIKELY(slice_x_start >= slice_width) {
288 return;
289 }
290 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
291
292 const size_t cx_stride = context->cx_stride;
Marat Dukhan05702cf2020-03-26 15:41:33 -0700293 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -0700294 slice_x_size,
295 nc_block_size,
296 context->kc,
297 subconvolution_params->scaled_kernel_size,
298 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
299 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
300 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
301 cx_stride,
302 context->cn_stride,
303 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
304 context->zero,
305 &context->params);
306}
307
308void xnn_compute_subconv2d(
Marat Dukhanf196d012020-04-15 11:50:03 -0700309 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700310 size_t batch_index,
311 size_t subkernel_index,
312 size_t slice_y,
313 size_t slice_x_start,
314 size_t nc_block_start,
315 size_t slice_x_max,
316 size_t nc_block_size)
317{
318 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
319
320 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
321 return;
322 }
323
324 const size_t slice_width = subconvolution_params->slice_width;
325 if XNN_UNLIKELY(slice_x_start >= slice_width) {
326 return;
327 }
328 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
329
330 const size_t cx_stride = context->cx_stride;
Marat Dukhan05702cf2020-03-26 15:41:33 -0700331 context->ukernel.function[XNN_UARCH_DEFAULT](
XNNPACK Teamb455b122019-09-27 18:10:33 -0700332 slice_x_size,
333 nc_block_size,
334 context->kc,
335 subconvolution_params->scaled_kernel_size,
336 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
337 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
338 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
339 cx_stride,
340 context->cn_stride,
341 context->a_offset + batch_index * context->ba_stride,
342 context->zero,
343 &context->params);
344}
345
Marat Dukhan1f29b802020-05-15 23:46:39 -0700346void xnn_compute_conv2d_hwc2chw(
347 const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700348 size_t batch_index,
349 size_t output_y_start,
350 size_t output_y_slice)
351{
Marat Dukhan1f29b802020-05-15 23:46:39 -0700352 context->hwc2chw_ukernel(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700353 context->input_height,
354 context->input_width,
355 output_y_start,
356 output_y_start + output_y_slice,
357 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
358 context->zero,
359 context->packed_weights,
360 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
361 context->input_padding_top,
362 context->output_channels,
363 context->output_height_stride,
364 context->output_channel_stride,
365 &context->params);
366}
367
368void xnn_compute_dwconv_unipass(
Marat Dukhanf196d012020-04-15 11:50:03 -0700369 const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhanc79427c2020-10-15 09:04:21 -0700370 size_t batch_index,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700371 size_t output_y)
372{
Marat Dukhanc79427c2020-10-15 09:04:21 -0700373 const void** indirect_input =
374 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
375 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
376 void* output = (void*) ((uintptr_t) context->output +
377 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
378
XNNPACK Teamb455b122019-09-27 18:10:33 -0700379 context->unipass_ukernel(
Marat Dukhanc79427c2020-10-15 09:04:21 -0700380 context->groups, context->output_width,
381 indirect_input, context->packed_weights, output,
382 context->indirect_input_width_stride, context->output_increment,
383 input_offset, context->zero,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700384 &context->params);
385}
386
Marat Dukhan1f29b802020-05-15 23:46:39 -0700387void xnn_compute_dwconv2d_chw(
Marat Dukhanf196d012020-04-15 11:50:03 -0700388 const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700389 size_t batch_index,
390 size_t channel)
391{
Marat Dukhan1f29b802020-05-15 23:46:39 -0700392 context->chw_ukernel(
Erich Elseneda9c112020-05-11 04:40:25 -0700393 context->input_height,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700394 context->input_width,
395 (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
396 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
Erich Elsen4e5db3d2020-05-07 08:57:47 -0700397 context->zero,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700398 (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
Erich Elsen4e5db3d2020-05-07 08:57:47 -0700399 context->input_padding_top,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700400 &context->params);
401}
402
Marat Dukhanad71b9a2020-11-20 00:01:51 -0800403void xnn_compute_depthtospace2d_chw2hwc(
404 const struct depthtospace2d_chw2hwc_context* context,
Artsiom Ablavatskibbe85062020-11-05 14:07:37 -0800405 size_t batch_index)
406{
407 context->ukernel(
408 context->output_channels,
409 context->input_height,
410 context->input_width,
411 context->block_size,
412 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
413 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
414 context->input_channel_stride,
415 context->input_height_stride,
416 context->output_height_stride,
417 context->output_width_stride);
418}
419
XNNPACK Teamb455b122019-09-27 18:10:33 -0700420void xnn_compute_argmax_pooling_unipass(
Marat Dukhanf196d012020-04-15 11:50:03 -0700421 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700422 size_t batch_index,
423 size_t output_y)
424{
Marat Dukhan329da642019-11-19 21:44:39 -0800425 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
426 output_y * context->indirect_input_height_stride);
427 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
428 void* output = (void*) ((uintptr_t) context->output +
429 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
430 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
431 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700432
433 context->unipass_ukernel(
434 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800435 indirect_input, input_offset, output, index,
Marat Dukhan447c4f52020-07-17 01:07:28 -0700436 context->input_increment, context->output_increment);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700437}
438
439void xnn_compute_argmax_pooling_multipass(
Marat Dukhanf196d012020-04-15 11:50:03 -0700440 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700441 size_t batch_index,
442 size_t output_y)
443{
Marat Dukhan329da642019-11-19 21:44:39 -0800444 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
445 output_y * context->indirect_input_height_stride);
Frank Barchardd5360722020-05-17 16:10:36 -0700446 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
Marat Dukhan329da642019-11-19 21:44:39 -0800447 void* output = (void*) ((uintptr_t) context->output +
448 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
449 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
450 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700451
Marat Dukhand0cf9bd2020-04-13 20:19:19 -0700452 void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES);
453 void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700454
455 context->multipass_ukernel(
456 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800457 indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
Marat Dukhan447c4f52020-07-17 01:07:28 -0700458 context->input_increment, context->output_increment);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700459}
460
461void xnn_compute_max_pooling(
Marat Dukhanf196d012020-04-15 11:50:03 -0700462 const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700463 size_t batch_index,
464 size_t output_y)
465{
Marat Dukhan329da642019-11-19 21:44:39 -0800466 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
467 output_y * context->indirect_input_height_stride);
468 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
469 void* output = (void*) ((uintptr_t) context->output +
470 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700471
472 context->ukernel(
473 context->output_width, context->pooling_size, context->channels,
Marat Dukhan329da642019-11-19 21:44:39 -0800474 indirect_input, input_offset, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700475 context->input_increment, context->output_increment,
476 &context->params);
477}
478
479void xnn_compute_unpooling(
Marat Dukhanf196d012020-04-15 11:50:03 -0700480 const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700481 size_t input_y,
482 size_t input_x)
483{
484 const void* input = (const void*) ((uintptr_t) context->input +
485 input_y * context->input_height_stride + input_x * context->input_width_stride);
486 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
487 input_y * context->index_height_stride + input_x * context->index_width_stride);
488 void** indirect_output =
489 (void**) ((uintptr_t) context->indirect_output +
490 input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
491
492 context->ukernel(
493 context->pooling_size,
494 context->channels,
495 context->fill_value,
496 input, index, indirect_output);
497}
498
499void xnn_compute_average_pooling_unipass(
Marat Dukhanf196d012020-04-15 11:50:03 -0700500 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700501 size_t batch_index,
502 size_t output_y)
503{
504 const void** indirect_input =
Marat Dukhan96171aa2020-02-27 18:26:48 -0800505 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
506 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
507 void* output = (void*) ((uintptr_t) context->output +
508 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700509
510 context->unipass_ukernel(
511 context->output_width, context->pooling_size, context->channels,
Marat Dukhan96171aa2020-02-27 18:26:48 -0800512 indirect_input, input_offset, context->zero, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700513 context->input_increment, context->output_increment,
514 &context->params);
515}
516
517void xnn_compute_average_pooling_multipass(
Marat Dukhanf196d012020-04-15 11:50:03 -0700518 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700519 size_t batch_index,
520 size_t output_y)
521{
522 const void** indirect_input =
Marat Dukhan96171aa2020-02-27 18:26:48 -0800523 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
524 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
525 void* output = (void*) ((uintptr_t) context->output +
526 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
Marat Dukhand0cf9bd2020-04-13 20:19:19 -0700527
528 void* multipass_buffer =
529 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700530
531 context->multipass_ukernel(
532 context->output_width, context->pooling_size, context->channels,
Marat Dukhan96171aa2020-02-27 18:26:48 -0800533 indirect_input, input_offset, context->zero, multipass_buffer, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700534 context->input_increment, context->output_increment,
535 &context->params);
536}
537
538void xnn_compute_pixelwise_average_pooling_unipass(
Marat Dukhanf196d012020-04-15 11:50:03 -0700539 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700540 size_t batch_index,
541 size_t output_y)
542{
543 const void** indirect_input =
Marat Dukhan96171aa2020-02-27 18:26:48 -0800544 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
545 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700546 const void* pixelwise_buffer =
547 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
Marat Dukhan96171aa2020-02-27 18:26:48 -0800548 void* output = (void*) ((uintptr_t) context->output +
549 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700550
551 context->unipass_ukernel(
552 context->output_width, context->pooling_size, context->channels,
Marat Dukhan96171aa2020-02-27 18:26:48 -0800553 indirect_input, input_offset, context->zero, pixelwise_buffer, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700554 context->input_increment, context->output_increment,
555 &context->params);
556}
557
558void xnn_compute_pixelwise_average_pooling_multipass(
Marat Dukhanf196d012020-04-15 11:50:03 -0700559 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700560 size_t batch_index,
561 size_t output_y)
562{
563 const void** indirect_input =
Marat Dukhan96171aa2020-02-27 18:26:48 -0800564 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
565 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700566 const void* pixelwise_buffer =
567 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
Marat Dukhan96171aa2020-02-27 18:26:48 -0800568 void* output = (void*) ((uintptr_t) context->output +
569 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
Marat Dukhand0cf9bd2020-04-13 20:19:19 -0700570
571 void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700572
573 context->multipass_ukernel(
574 context->output_width, context->pooling_size, context->channels,
Marat Dukhan96171aa2020-02-27 18:26:48 -0800575 indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output,
XNNPACK Teamb455b122019-09-27 18:10:33 -0700576 context->input_increment, context->output_increment,
577 &context->params);
578}
579
Marat Dukhanefc47b82019-11-18 09:25:38 -0800580void xnn_compute_global_average_pooling_nwc_unipass(
Marat Dukhanf196d012020-04-15 11:50:03 -0700581 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700582 size_t batch_index)
583{
584 const void* input =
585 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
586 void* output =
587 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
588
589 context->unipass_ukernel(
590 context->input_elements,
591 context->channels,
592 input,
593 context->input_pixel_stride,
594 context->zero,
595 output,
596 &context->params);
597}
598
Marat Dukhanefc47b82019-11-18 09:25:38 -0800599void xnn_compute_global_average_pooling_nwc_multipass(
Marat Dukhanf196d012020-04-15 11:50:03 -0700600 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700601 size_t batch_index)
602{
603 const void* input =
604 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
605 void* output =
606 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
Marat Dukhand0cf9bd2020-04-13 20:19:19 -0700607
608 void* multipass_buffer =
609 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700610
611 context->multipass_ukernel(
612 context->input_elements,
613 context->channels,
614 input,
615 context->input_pixel_stride,
616 context->zero,
617 multipass_buffer,
618 output,
619 &context->params);
620}
621
Marat Dukhanefc47b82019-11-18 09:25:38 -0800622void xnn_compute_global_average_pooling_ncw(
Marat Dukhanf196d012020-04-15 11:50:03 -0700623 const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700624 size_t batch_index,
625 size_t channels_start,
626 size_t channels_slice)
627{
Marat Dukhanefc47b82019-11-18 09:25:38 -0800628 const void* input = (const void*) ((uintptr_t) context->input +
629 channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
630 void* output = (void*) ((uintptr_t) context->output +
631 channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700632
633 context->ukernel(
634 context->input_elements,
635 channels_slice,
636 input,
637 output,
638 &context->params);
639}
640
Marat Dukhan69722492019-11-11 19:55:50 -0800641void xnn_compute_resize_bilinear(
Marat Dukhanf196d012020-04-15 11:50:03 -0700642 const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhan69722492019-11-11 19:55:50 -0800643 size_t batch_index,
644 size_t pixel_start,
645 size_t pixel_range)
646{
647 void* output =
648 (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
649
650 context->ukernel(
651 pixel_range,
652 context->scaled_channels,
653 context->indirect_input + pixel_start * 4,
654 context->input_offset + batch_index * context->input_batch_stride,
Marat Dukhan02072e62020-04-14 02:59:11 -0700655 (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)),
Marat Dukhan69722492019-11-11 19:55:50 -0800656 output,
657 context->output_pixel_stride - context->scaled_channels);
658}
659
Artsiom Ablavatski97918102020-10-27 15:52:59 -0700660void xnn_compute_resize_bilinear_chw(
661 const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
662 size_t batch_index,
663 size_t channel_start,
664 size_t channel_range)
665{
666 void* output =
667 (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride);
668 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride;
669
670 context->ukernel(
671 context->output_pixels,
672 channel_range,
673 context->indirect_input,
674 input_offset,
675 context->packed_weights,
676 output,
677 context->input_channel_stride);
678}
679
XNNPACK Teamb455b122019-09-27 18:10:33 -0700680void xnn_compute_prelu(
Marat Dukhanf196d012020-04-15 11:50:03 -0700681 const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700682 size_t batch_start,
683 size_t batch_range)
684{
685 const size_t x_stride = context->x_stride;
686 const size_t y_stride = context->y_stride;
687 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
688 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
689
Marat Dukhanc8230a42020-02-24 00:00:35 -0800690 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700691}
692
Marat Dukhan4662b192020-05-21 15:52:03 -0700693void xnn_compute_pad_5d(
694 const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
695 size_t i, size_t j, size_t k, size_t l, size_t m,
696 size_t l_range, size_t m_range)
697{
698 assert(l_range == 1);
699 assert(m_range == 1);
700
701 const void* input = (const void*) ((uintptr_t) context->input +
702 i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]);
703 void* output = (void*) ((uintptr_t) context->output +
704 i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]);
705
706 const size_t i_padding = context->pre_paddings[5];
707 const size_t j_padding = context->pre_paddings[4];
708 const size_t k_padding = context->pre_paddings[3];
709 const size_t l_padding = context->pre_paddings[2];
710 const size_t m_padding = context->pre_paddings[1];
711
712 const size_t i_size = context->input_size[5];
713 const size_t j_size = context->input_size[4];
714 const size_t k_size = context->input_size[3];
715 const size_t l_size = context->input_size[2];
716 const size_t m_size = context->input_size[1];
717
718 if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size &&
719 l - l_padding < l_size && m - m_padding < m_size)
720 {
721 context->pad_ukernel(
722 1 /* rows */,
723 context->input_size[0], context->pre_paddings[0], context->post_paddings[0],
Marat Dukhan63523d42020-05-22 17:07:33 -0700724 &context->padding_value,
Marat Dukhan4662b192020-05-21 15:52:03 -0700725 input, 0 /* input stride */, output, 0 /* output stride */);
726 } else {
727 context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, &context->padding_value);
728 }
729}
730
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800731void xnn_compute_elementwise_binary_5d(
Marat Dukhanf196d012020-04-15 11:50:03 -0700732 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800733 size_t i, size_t j, size_t k, size_t l, size_t m,
734 size_t l_range, size_t m_range)
Marat Dukhanca2733c2019-11-15 23:21:17 -0800735{
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800736 assert(l_range == 1);
737 assert(m_range == 1);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800738
739 const void* a = (const void*) ((uintptr_t) context->a +
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800740 i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800741 const void* b = (const void*) ((uintptr_t) context->b +
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800742 i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800743 void* y = (void*) ((uintptr_t) context->y +
Marat Dukhanfc2b96e2019-12-03 12:04:04 -0800744 i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
Marat Dukhanca2733c2019-11-15 23:21:17 -0800745 context->ukernel(context->elements, a, b, y, &context->params);
746}
747
XNNPACK Teamb455b122019-09-27 18:10:33 -0700748void xnn_compute_channel_shuffle_fixed(
Marat Dukhanf196d012020-04-15 11:50:03 -0700749 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700750 size_t index)
751{
752 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
753 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
754
755 context->fixed_ukernel(context->n, x, y);
756}
757
758void xnn_compute_channel_shuffle_variable(
Marat Dukhanf196d012020-04-15 11:50:03 -0700759 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700760 size_t index)
761{
762 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
763 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
764
765 context->variable_ukernel(context->n, context->m, x, y);
766}
767
768void xnn_compute_lut_strided(
Marat Dukhanf196d012020-04-15 11:50:03 -0700769 const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700770 size_t batch_index)
771{
772 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
773 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
774
775 context->ukernel(context->n, x, context->t, y);
776}
777
778void xnn_compute_lut_contiguous(
Marat Dukhanf196d012020-04-15 11:50:03 -0700779 const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700780 size_t offset,
781 size_t size)
782{
783 const void* x = (const void*) ((uintptr_t) context->x + offset);
784 void* y = (void*) ((uintptr_t) context->y + offset);
785
786 context->ukernel(size, x, context->t, y);
787}
788
789void xnn_compute_univector_strided(
Marat Dukhanf196d012020-04-15 11:50:03 -0700790 const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700791 size_t batch_index,
792 size_t batch_range /* always 1 */)
793{
794 assert(batch_range == 1);
795
796 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
797 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
798 context->ukernel(context->n, x, y, &context->params);
799}
800
801void xnn_compute_univector_contiguous(
Marat Dukhanf196d012020-04-15 11:50:03 -0700802 const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700803 size_t offset,
804 size_t size)
805{
806 const void* x = (const void*) ((uintptr_t) context->x + offset);
807 void* y = (void*) ((uintptr_t) context->y + offset);
808 context->ukernel(size, x, y, &context->params);
809}
810
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800811void xnn_compute_u8_softmax(
Marat Dukhanf196d012020-04-15 11:50:03 -0700812 const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700813 size_t batch_index)
814{
815 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
816 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
817 const size_t n = context->n;
818
819 uint8_t x_max = 0;
820 context->rmax_ukernel(n, x, &x_max);
821 const size_t adjustment = x_max ^ 255;
822 const uint32_t* t = (const uint32_t*) context->t + adjustment;
823 context->lut_norm_ukernel(n, x, t, y);
824}
825
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800826void xnn_compute_f32_three_pass_softmax(
Marat Dukhanf196d012020-04-15 11:50:03 -0700827 const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhan1edc4542020-01-27 12:40:13 -0800828 size_t batch_index)
829{
830 const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index);
831 float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index);
832 const size_t n = context->n;
833
834 // First pass: reduce-max
835 float x_max;
836 context->rmax_ukernel(n, x, &x_max);
837
838 // Second pass: reduce-add & store exp(x-x_max)
839 float y_sum;
840 context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max);
841
842 // Third pass: scale y
843 const float y_scale = 1.0f / y_sum;
844 context->vmulc_ukernel(n, y, &y_scale, y, &context->params);
845}
846
XNNPACK Teamb455b122019-09-27 18:10:33 -0700847void xnn_compute_vmulcaddc(
Marat Dukhanf196d012020-04-15 11:50:03 -0700848 const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
XNNPACK Teamb455b122019-09-27 18:10:33 -0700849 size_t batch_start,
850 size_t batch_size)
851{
852 const size_t x_stride = context->x_stride;
853 const size_t y_stride = context->y_stride;
854
855 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
856 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
857
858 context->ukernel(
859 batch_size,
860 context->n,
861 x, x_stride,
862 context->w,
863 y, y_stride,
864 &context->params);
865}
866
Marat Dukhan05702cf2020-03-26 15:41:33 -0700867#if XNN_MAX_UARCH_TYPES > 1
868 void xnn_compute_hmp_grouped_gemm(
Marat Dukhanf196d012020-04-15 11:50:03 -0700869 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhan05702cf2020-03-26 15:41:33 -0700870 uint32_t uarch_index,
871 size_t group_index,
872 size_t mr_block_start,
873 size_t nr_block_start,
874 size_t mr_block_size,
875 size_t nr_block_size)
876 {
877 const size_t k_scaled = context->k_scaled;
878 const size_t a_stride = context->a_stride;
879 const size_t cm_stride = context->cm_stride;
880
881 context->ukernel.function[uarch_index](
882 mr_block_size,
883 nr_block_size,
884 k_scaled,
885 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
886 a_stride,
887 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
888 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
889 cm_stride,
890 context->cn_stride,
891 &context->params);
892 }
893
894 void xnn_compute_hmp_gemm(
Marat Dukhanf196d012020-04-15 11:50:03 -0700895 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhan05702cf2020-03-26 15:41:33 -0700896 uint32_t uarch_index,
897 size_t mr_block_start,
898 size_t nr_block_start,
899 size_t mr_block_size,
900 size_t nr_block_size)
901 {
902 const size_t a_stride = context->a_stride;
903 const size_t cm_stride = context->cm_stride;
904
905 context->ukernel.function[uarch_index](
906 mr_block_size,
907 nr_block_size,
908 context->k_scaled,
909 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
910 a_stride,
911 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
912 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
913 cm_stride,
914 context->cn_stride,
915 &context->params);
916 }
917
Fabio Riccardic2146cc2020-08-19 11:15:24 -0700918 void xnn_compute_hmp_grouped_batch_igemm(
Marat Dukhanf196d012020-04-15 11:50:03 -0700919 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhan05702cf2020-03-26 15:41:33 -0700920 uint32_t uarch_index,
921 size_t batch_index,
922 size_t group_index,
923 size_t mr_block_start,
924 size_t nr_block_start,
925 size_t mr_block_size,
926 size_t nr_block_size)
927 {
928 const size_t ks = context->ks;
929 const size_t cm_stride = context->cm_stride;
930
931 context->ukernel.function[uarch_index](
932 mr_block_size,
933 nr_block_size,
934 context->kc,
935 context->ks_scaled,
936 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
937 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
938 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
939 cm_stride,
940 context->cn_stride,
941 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
942 context->zero,
943 &context->params);
944 }
945
Fabio Riccardic2146cc2020-08-19 11:15:24 -0700946 void xnn_compute_hmp_grouped_igemm(
947 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
948 uint32_t uarch_index,
949 size_t group_index,
950 size_t mr_block_start,
951 size_t nr_block_start,
952 size_t mr_block_size,
953 size_t nr_block_size)
954 {
955 const size_t ks = context->ks;
956 const size_t cm_stride = context->cm_stride;
957
958 context->ukernel.function[uarch_index](
959 mr_block_size,
960 nr_block_size,
961 context->kc,
962 context->ks_scaled,
963 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
964 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
965 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
966 cm_stride,
967 context->cn_stride,
968 context->a_offset + group_index * context->ga_stride,
969 context->zero,
970 &context->params);
971 }
972
973 void xnn_compute_batch_hmp_igemm(
Marat Dukhanf196d012020-04-15 11:50:03 -0700974 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
Marat Dukhan05702cf2020-03-26 15:41:33 -0700975 uint32_t uarch_index,
976 size_t batch_index,
977 size_t mr_block_start,
978 size_t nr_block_start,
979 size_t mr_block_size,
980 size_t nr_block_size)
981 {
982 const size_t ks = context->ks;
983 const size_t cm_stride = context->cm_stride;
984
985 context->ukernel.function[uarch_index](
986 mr_block_size,
987 nr_block_size,
988 context->kc,
989 context->ks_scaled,
990 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
991 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
992 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
993 cm_stride,
994 context->cn_stride,
995 context->a_offset + batch_index * context->ba_stride,
996 context->zero,
997 &context->params);
998 }
Fabio Riccardic2146cc2020-08-19 11:15:24 -0700999
1000 void xnn_compute_hmp_igemm(
1001 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1002 uint32_t uarch_index,
1003 size_t mr_block_start,
1004 size_t nr_block_start,
1005 size_t mr_block_size,
1006 size_t nr_block_size)
1007 {
1008 const size_t ks = context->ks;
1009 const size_t cm_stride = context->cm_stride;
1010
1011 context->ukernel.function[uarch_index](
1012 mr_block_size,
1013 nr_block_size,
1014 context->kc,
1015 context->ks_scaled,
1016 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1017 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1018 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1019 cm_stride,
1020 context->cn_stride,
1021 context->a_offset,
1022 context->zero,
1023 &context->params);
1024 }
Marat Dukhan05702cf2020-03-26 15:41:33 -07001025#endif // XNN_MAX_UARCH_TYPES > 1
1026
XNNPACK Teamb455b122019-09-27 18:10:33 -07001027enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
1028{
Marat Dukhan854fb6b2020-06-19 12:33:44 -07001029 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
XNNPACK Teamb455b122019-09-27 18:10:33 -07001030 xnn_log_error("failed to run operator: XNNPACK is not initialized");
1031 return xnn_status_uninitialized;
1032 }
1033 switch (op->state) {
1034 case xnn_run_state_invalid:
1035 xnn_log_error("failed to run operator: operator was not successfully setup");
1036 return xnn_status_invalid_state;
1037 case xnn_run_state_ready:
1038 break;
1039 case xnn_run_state_skip:
1040 return xnn_status_success;
1041 }
1042
1043 switch (op->compute.type) {
1044 case xnn_parallelization_type_invalid:
1045 break;
1046 case xnn_parallelization_type_1d:
1047 assert(op->compute.range[0] != 0);
1048 pthreadpool_parallelize_1d(
1049 threadpool,
1050 op->compute.task_1d,
1051 &op->context,
1052 op->compute.range[0],
1053 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1054 break;
1055 case xnn_parallelization_type_1d_tile_1d:
1056 assert(op->compute.range[0] != 0);
1057 assert(op->compute.tile[0] != 0);
1058 pthreadpool_parallelize_1d_tile_1d(
1059 threadpool,
1060 op->compute.task_1d_tile_1d,
1061 &op->context,
1062 op->compute.range[0],
1063 op->compute.tile[0],
1064 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1065 break;
1066 case xnn_parallelization_type_2d:
1067 assert(op->compute.range[0] != 0);
1068 assert(op->compute.range[1] != 0);
1069 pthreadpool_parallelize_2d(
1070 threadpool,
1071 op->compute.task_2d,
1072 &op->context,
1073 op->compute.range[0], op->compute.range[1],
1074 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1075 break;
1076 case xnn_parallelization_type_2d_tile_1d:
1077 assert(op->compute.range[0] != 0);
1078 assert(op->compute.range[1] != 0);
1079 assert(op->compute.tile[0] != 0);
1080 pthreadpool_parallelize_2d_tile_1d(
1081 threadpool,
1082 op->compute.task_2d_tile_1d,
1083 &op->context,
1084 op->compute.range[0], op->compute.range[1],
1085 op->compute.tile[0],
1086 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1087 break;
1088 case xnn_parallelization_type_2d_tile_2d:
1089 assert(op->compute.range[0] != 0);
1090 assert(op->compute.range[1] != 0);
1091 assert(op->compute.tile[0] != 0);
1092 assert(op->compute.tile[1] != 0);
1093 pthreadpool_parallelize_2d_tile_2d(
1094 threadpool,
1095 op->compute.task_2d_tile_2d,
1096 &op->context,
1097 op->compute.range[0], op->compute.range[1],
1098 op->compute.tile[0], op->compute.tile[1],
1099 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1100 break;
1101 case xnn_parallelization_type_3d_tile_2d:
1102 assert(op->compute.range[0] != 0);
1103 assert(op->compute.range[1] != 0);
1104 assert(op->compute.range[2] != 0);
1105 assert(op->compute.tile[0] != 0);
1106 assert(op->compute.tile[1] != 0);
1107 pthreadpool_parallelize_3d_tile_2d(
1108 threadpool,
1109 op->compute.task_3d_tile_2d,
1110 &op->context,
1111 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1112 op->compute.tile[0], op->compute.tile[1],
1113 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1114 break;
1115 case xnn_parallelization_type_4d_tile_2d:
1116 assert(op->compute.range[0] != 0);
1117 assert(op->compute.range[1] != 0);
1118 assert(op->compute.range[2] != 0);
1119 assert(op->compute.range[3] != 0);
1120 assert(op->compute.tile[0] != 0);
1121 assert(op->compute.tile[1] != 0);
1122 pthreadpool_parallelize_4d_tile_2d(
1123 threadpool,
1124 op->compute.task_4d_tile_2d,
1125 &op->context,
1126 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1127 op->compute.tile[0], op->compute.tile[1],
1128 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1129 break;
1130 case xnn_parallelization_type_5d_tile_2d:
1131 assert(op->compute.range[0] != 0);
1132 assert(op->compute.range[1] != 0);
1133 assert(op->compute.range[2] != 0);
1134 assert(op->compute.range[3] != 0);
1135 assert(op->compute.range[4] != 0);
1136 assert(op->compute.tile[0] != 0);
1137 assert(op->compute.tile[1] != 0);
1138 pthreadpool_parallelize_5d_tile_2d(
1139 threadpool,
1140 op->compute.task_5d_tile_2d,
1141 &op->context,
1142 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1143 op->compute.tile[0], op->compute.tile[1],
1144 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1145 break;
1146 case xnn_parallelization_type_6d_tile_2d:
1147 assert(op->compute.range[0] != 0);
1148 assert(op->compute.range[1] != 0);
1149 assert(op->compute.range[2] != 0);
1150 assert(op->compute.range[3] != 0);
1151 assert(op->compute.range[4] != 0);
1152 assert(op->compute.range[5] != 0);
1153 assert(op->compute.tile[0] != 0);
1154 assert(op->compute.tile[1] != 0);
1155 pthreadpool_parallelize_6d_tile_2d(
1156 threadpool,
1157 op->compute.task_6d_tile_2d,
1158 &op->context,
1159 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
1160 op->compute.tile[0], op->compute.tile[1],
1161 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1162 break;
Marat Dukhan05702cf2020-03-26 15:41:33 -07001163#if XNN_MAX_UARCH_TYPES > 1
1164 case xnn_parallelization_type_2d_tile_2d_with_uarch:
1165 assert(op->compute.range[0] != 0);
1166 assert(op->compute.range[1] != 0);
1167 assert(op->compute.tile[0] != 0);
1168 assert(op->compute.tile[1] != 0);
1169 pthreadpool_parallelize_2d_tile_2d_with_uarch(
1170 threadpool,
1171 op->compute.task_2d_tile_2d_with_id,
1172 &op->context,
1173 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1174 op->compute.range[0], op->compute.range[1],
1175 op->compute.tile[0], op->compute.tile[1],
1176 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1177 break;
1178 case xnn_parallelization_type_3d_tile_2d_with_uarch:
1179 assert(op->compute.range[0] != 0);
1180 assert(op->compute.range[1] != 0);
1181 assert(op->compute.range[2] != 0);
1182 assert(op->compute.tile[0] != 0);
1183 assert(op->compute.tile[1] != 0);
1184 pthreadpool_parallelize_3d_tile_2d_with_uarch(
1185 threadpool,
1186 op->compute.task_3d_tile_2d_with_id,
1187 &op->context,
1188 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1189 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1190 op->compute.tile[0], op->compute.tile[1],
1191 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1192 break;
1193 case xnn_parallelization_type_4d_tile_2d_with_uarch:
1194 assert(op->compute.range[0] != 0);
1195 assert(op->compute.range[1] != 0);
1196 assert(op->compute.range[2] != 0);
1197 assert(op->compute.range[3] != 0);
1198 assert(op->compute.tile[0] != 0);
1199 assert(op->compute.tile[1] != 0);
1200 pthreadpool_parallelize_4d_tile_2d_with_uarch(
1201 threadpool,
1202 op->compute.task_4d_tile_2d_with_id,
1203 &op->context,
1204 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1205 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1206 op->compute.tile[0], op->compute.tile[1],
1207 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1208 break;
1209#endif // XNN_MAX_UARCH_TYPES > 1
XNNPACK Teamb455b122019-09-27 18:10:33 -07001210 default:
1211 XNN_UNREACHABLE;
1212 }
1213 return xnn_status_success;
1214}