blob: 11065c43e80476aa325a1e564908c4ad54da8d81 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <stdbool.h>
12#include <stddef.h>
13#include <stdint.h>
14
15#include <cpuinfo.h>
16
17#include <xnnpack/common.h>
18
19#define XNN_INTERNAL_EXTRA_BYTES 32
20
21struct xnn_f16_output_params {
22 uint16_t scale;
23 uint16_t max;
24 uint16_t min;
25};
26
27union xnn_f32_output_params {
28 struct {
29 float max;
30 float min;
31 } scalar;
32#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
33 struct {
34 XNN_ALIGN(16) float max[4];
35 XNN_ALIGN(16) float min[4];
36 } sse;
Marat Dukhan80fc9322019-09-29 21:06:36 -070037#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -070038};
39
40union xnn_f32_spchw_params {
41 struct {
42 float max;
43 float min;
44 } scalar;
45#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
46 struct {
47 XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels
48 XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels
49 XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels
50 float min;
51 float max;
52 } neon;
Marat Dukhan80fc9322019-09-29 21:06:36 -070053#endif // CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
54#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -070055 struct {
56 XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels
57 XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels
58 XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels
59 XNN_ALIGN(16) float max[4];
60 XNN_ALIGN(16) float min[4];
61 } sse;
Marat Dukhan80fc9322019-09-29 21:06:36 -070062#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -070063};
64
65union xnn_u8_output_params {
66 struct {
67 int32_t max;
68 int32_t min;
69 } scalar;
70#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
71 struct {
72 uint8_t max;
73 uint8_t min;
74 } neon;
Marat Dukhan80fc9322019-09-29 21:06:36 -070075#endif // CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
XNNPACK Teamb455b122019-09-27 18:10:33 -070076#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
77 struct {
78 XNN_ALIGN(16) uint8_t max[16];
79 XNN_ALIGN(16) uint8_t min[16];
80 } sse2;
Marat Dukhan80fc9322019-09-29 21:06:36 -070081#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -070082};
83
84union xnn_f32_avgpool_params {
85 struct {
86 float multiplier;
87 float output_min;
88 float output_max;
89 } scalar;
90#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
91 struct {
92 XNN_ALIGN(16) float multiplier[4];
93 XNN_ALIGN(16) float output_max[4];
94 XNN_ALIGN(16) float output_min[4];
95 } sse2;
Marat Dukhan80fc9322019-09-29 21:06:36 -070096#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -070097#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
98 struct {
99 XNN_ALIGN(16) float multiplier;
100 XNN_ALIGN(16) float output_max;
101 XNN_ALIGN(16) float output_min;
102 } neon;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700103#endif // CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700104};
105
106union xnn_f32_gavgpool_params {
107 struct {
108 float multiplier;
109 float output_min;
110 float output_max;
111 } scalar;
112#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
113 struct {
114 XNN_ALIGN(16) float multiplier[4];
115 XNN_ALIGN(16) float output_max[4];
116 XNN_ALIGN(16) float output_min[4];
117 XNN_ALIGN(16) uint32_t mask[4];
118 } sse;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700119#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700120#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
121 struct {
122 XNN_ALIGN(16) float multiplier;
123 XNN_ALIGN(16) float output_max;
124 XNN_ALIGN(16) float output_min;
125 XNN_ALIGN(16) uint32_t mask[4];
126 } neon;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700127#endif // CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */
XNNPACK Teamb455b122019-09-27 18:10:33 -0700128};
129
130union xnn_f32_hswish_params {
131 struct {
132 float sixth;
133 float half;
134 float one;
135 } scalar;
136#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
137 struct {
138 XNN_ALIGN(16) float sixth[4];
139 XNN_ALIGN(16) float half[4];
140 XNN_ALIGN(16) float one[4];
141 } sse;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700142#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700143};
144
145union xnn_q8_gemm_params {
146 struct {
147 int32_t kernel_zero_point;
148 int32_t input_zero_point;
149 int32_t multiplier;
150 int32_t remainder_mask;
151 int32_t remainder_threshold;
152 uint32_t shift;
153 int32_t output_min_less_zero_point;
154 int32_t output_max_less_zero_point;
155 int32_t output_zero_point;
156 } scalar;
157#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
158 struct {
159 int16_t kernel_zero_point;
160 int16_t input_zero_point;
161 int32_t multiplier;
162 int32_t right_shift;
163 int16_t output_zero_point;
164 uint8_t output_max;
165 uint8_t output_min;
166 } neon;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700167#endif // CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700168#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
169 struct {
170 XNN_ALIGN(16) int16_t kernel_zero_point[8];
171 XNN_ALIGN(16) int16_t input_zero_point[8];
172 XNN_ALIGN(16) uint32_t multiplier[4];
173 XNN_ALIGN(16) uint64_t rounding[2];
174 XNN_ALIGN(16) int32_t remainder_mask[4];
175 XNN_ALIGN(16) int32_t remainder_threshold[4];
176 XNN_ALIGN(16) uint64_t shift[2];
177 XNN_ALIGN(16) int16_t output_zero_point[8];
178 XNN_ALIGN(16) uint8_t output_max[16];
179 XNN_ALIGN(16) uint8_t output_min[16];
180 } sse2;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700181#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700182};
183
184union xnn_q8_add_params {
185 struct {
186 int32_t zero_point_product;
187 uint32_t a_multiplier;
188 uint32_t b_multiplier;
189 uint32_t shift;
190 int32_t remainder_mask;
191 int32_t remainder_threshold;
192 int32_t y_zero_point;
193 int32_t y_max;
194 int32_t y_min;
195 } scalar;
196#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
197 struct {
198 uint8_t a_zero_point;
199 uint8_t b_zero_point;
200 int16_t y_zero_point;
201 int32_t a_multiplier;
202 int32_t b_multiplier;
203 int32_t right_shift;
204 uint8_t y_max;
205 uint8_t y_min;
206 } neon;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700207#endif // CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700208#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
209 struct {
210 XNN_ALIGN(16) int32_t zero_point_product[4];
211 XNN_ALIGN(16) uint16_t a_multiplier_lo[8];
212 XNN_ALIGN(16) uint16_t a_multiplier_hi[8];
213 XNN_ALIGN(16) uint16_t b_multiplier_lo[8];
214 XNN_ALIGN(16) uint16_t b_multiplier_hi[8];
215 XNN_ALIGN(16) int32_t remainder_mask[4];
216 XNN_ALIGN(16) int32_t remainder_threshold[4];
217 XNN_ALIGN(16) int16_t y_zero_point[8];
218 XNN_ALIGN(16) uint8_t y_max[16];
219 XNN_ALIGN(16) uint8_t y_min[16];
220 uint32_t shift;
221 uint32_t a_multiplier;
222 uint32_t b_multiplier;
223 } sse2;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700224#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700225};
226
227union xnn_q8_avgpool_params {
228 struct {
229 int32_t bias;
230 int32_t multiplier;
231 int64_t rounding;
232 uint32_t right_shift;
233 int32_t output_min_less_zero_point;
234 int32_t output_max_less_zero_point;
235 int32_t output_zero_point;
236 } scalar;
237#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
238 struct {
239 int32_t bias;
240 int32_t multiplier;
241 int64_t left_shift;
242 int16_t output_zero_point;
243 uint8_t output_max;
244 uint8_t output_min;
245 } neon;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700246#endif // CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700247#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
248 struct {
249 XNN_ALIGN(16) int32_t bias[4];
250 XNN_ALIGN(16) uint32_t multiplier[4];
251 XNN_ALIGN(16) uint64_t rounding[2];
252 XNN_ALIGN(16) uint64_t right_shift[2];
253 XNN_ALIGN(16) int16_t output_zero_point[8];
254 XNN_ALIGN(16) uint8_t output_max[16];
255 XNN_ALIGN(16) uint8_t output_min[16];
256 } sse2;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700257#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700258};
259
260union xnn_fp32_requantization_params {
261 struct {
262 float scale;
263 float min_less_zero_point;
264 float max_less_zero_point;
265 float magic;
266 int32_t magic_less_zero_point;
267 } scalar;
268 struct {
269 float scale;
270 float max;
271 float min;
272 float magic;
273 int32_t magic_less_zero_point;
274 } neon;
275 struct {
276 float scale;
277 int16_t zero_point;
278 uint8_t max;
279 uint8_t min;
280 } neonv8;
281 struct {
282 XNN_ALIGN(16) float scale[4];
283 XNN_ALIGN(16) int16_t zero_point[8];
284 XNN_ALIGN(16) uint8_t max[16];
285 XNN_ALIGN(16) uint8_t min[16];
286 } sse2;
287 struct {
288 XNN_ALIGN(16) float scale[4];
289 XNN_ALIGN(16) float min_less_zero_point[4];
290 XNN_ALIGN(16) float max_less_zero_point[4];
291 XNN_ALIGN(16) float magic[4];
292 XNN_ALIGN(16) int32_t magic_less_zero_point[4];
293 } psimd;
294};
295
296union xnn_precise_requantization_params {
297 struct {
298 uint32_t multiplier;
299 uint32_t rounding_lo;
300 uint32_t rounding_hi;
301 uint32_t shift_less_32;
302 int32_t min_less_zero_point;
303 int32_t max_less_zero_point;
304 int32_t zero_point;
305 } scalar;
306 struct {
307 int32_t multiplier;
308 int32_t right_shift;
309 int16_t zero_point;
310 uint8_t max;
311 uint8_t min;
312 } neon;
313 struct {
314 XNN_ALIGN(16) uint32_t multiplier[4];
315 XNN_ALIGN(16) uint64_t rounding[2];
316 XNN_ALIGN(16) uint32_t shift[4];
317 XNN_ALIGN(16) int16_t zero_point[8];
318 XNN_ALIGN(16) uint8_t max[16];
319 XNN_ALIGN(16) uint8_t min[16];
320 } sse2;
321};
322
323union xnn_q31_requantization_params {
324 struct {
325 int32_t multiplier;
326 int32_t remainder_mask;
327 int32_t remainder_threshold;
328 uint32_t shift;
329 int32_t min_less_zero_point;
330 int32_t max_less_zero_point;
331 int32_t zero_point;
332 } scalar;
333#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
334 struct {
335 int32_t multiplier;
336 int32_t right_shift;
337 int16_t zero_point;
338 uint8_t max;
339 uint8_t min;
340 } neon;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700341#endif // CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700342#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
343 struct {
344 XNN_ALIGN(16) uint32_t multiplier[4];
345 XNN_ALIGN(16) uint64_t rounding[2];
346 XNN_ALIGN(16) int32_t remainder_mask[4];
347 XNN_ALIGN(16) int32_t remainder_threshold[4];
348 XNN_ALIGN(16) uint64_t shift[2];
349 XNN_ALIGN(16) int16_t zero_point[8];
350 XNN_ALIGN(16) uint8_t max[16];
351 XNN_ALIGN(16) uint8_t min[16];
352 } sse2;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700353#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
XNNPACK Teamb455b122019-09-27 18:10:33 -0700354};
355
356union xnn_requantization_params {
357 union xnn_precise_requantization_params precise;
358 union xnn_fp32_requantization_params fp32;
359 union xnn_q31_requantization_params q31;
360};
361
362typedef void (*xnn_ppmm_ukernel_function)(
363 size_t mr,
364 size_t nc,
365 size_t kc,
366 const void* a,
367 const void* w,
368 void* c,
369 size_t cm_stride,
370 size_t cn_stride,
371 const void* params);
372
373typedef void (*xnn_f32_ppmm_ukernel_function)(
374 size_t mr,
375 size_t nc,
376 size_t kc,
377 const float* a,
378 const float* w,
379 float* c,
380 size_t cm_stride,
381 size_t cn_stride,
382 const union xnn_f32_output_params* params);
383
384typedef void (*xnn_f16_ppmm_ukernel_function)(
385 size_t mr,
386 size_t nc,
387 size_t kc,
388 const void* a,
389 const void* w,
390 void* c,
391 size_t cm_stride,
392 size_t cn_stride,
393 const struct xnn_f16_output_params* params);
394
395typedef void (*xnn_gemm_ukernel_function)(
396 size_t mr,
397 size_t nr,
398 size_t k,
399 const void* a,
400 size_t a_stride,
401 const void* w,
402 void* c,
403 size_t cm_stride,
404 size_t cn_stride,
405 const void* params);
406
407typedef void (*xnn_f32_gemm_ukernel_function)(
408 size_t mr,
409 size_t nr,
410 size_t k,
411 const float* a,
412 size_t a_stride,
413 const float* w,
414 float* c,
415 size_t cm_stride,
416 size_t cn_stride,
417 const union xnn_f32_output_params* params);
418
419typedef void (*xnn_f32_gemminc_ukernel_function)(
420 size_t mr,
421 size_t nr,
422 size_t k,
423 const float* a,
424 size_t a_stride,
425 const float* w,
426 float* c,
427 size_t cm_stride,
428 size_t cn_stride,
429 const float* acc,
430 const union xnn_f32_output_params* params);
431
432typedef void (*xnn_f16_gemm_ukernel_function)(
433 size_t mr,
434 size_t nr,
435 size_t k,
436 const void* a,
437 size_t a_stride,
438 const void* w,
439 void* c,
440 size_t cm_stride,
441 size_t cn_stride,
442 const struct xnn_f16_output_params* params);
443
444typedef void (*xnn_q8_gemm_ukernel_function)(
445 size_t mr,
446 size_t nr,
447 size_t k,
448 const uint8_t* a,
449 size_t a_stride,
450 const void* w,
451 uint8_t* c,
452 size_t cm_stride,
453 size_t cn_stride,
454 const union xnn_q8_gemm_params* params);
455
456typedef void (*xnn_igemm_ukernel_function)(
457 size_t mr,
458 size_t nr,
459 size_t kc,
460 size_t ks,
461 const void** a,
462 const void* w,
463 void* c,
464 size_t cm_stride,
465 size_t cn_stride,
466 size_t a_offset,
467 const void* zero,
468 const void* params);
469
470typedef void (*xnn_f32_igemm_ukernel_function)(
471 size_t mr,
472 size_t nr,
473 size_t kc,
474 size_t ks,
475 const float** a,
476 const float* w,
477 float* c,
478 size_t cm_stride,
479 size_t cn_stride,
480 size_t a_offset,
481 const float* zero,
482 const union xnn_f32_output_params* params);
483
484typedef void (*xnn_q8_igemm_ukernel_function)(
485 size_t mr,
486 size_t nr,
487 size_t kc,
488 size_t ks,
489 const uint8_t** a,
490 const void* w,
491 uint8_t* c,
492 size_t cm_stride,
493 size_t cn_stride,
494 size_t a_offset,
495 const uint8_t* zero,
496 const union xnn_q8_gemm_params* params);
497
498typedef void (*xnn_conv_hwc_ukernel_function)(
499 size_t input_height,
500 size_t input_width,
501 size_t output_y_start,
502 size_t output_y_end,
503 const void* input,
504 const void* zero,
505 const void* weights,
506 void* output,
507 size_t input_padding_top,
508 size_t output_channels,
509 size_t output_height_stride,
510 size_t output_width_stride,
511 const void* params);
512
513typedef void (*xnn_f32_conv_hwc_ukernel_function)(
514 size_t input_height,
515 size_t input_width,
516 size_t output_y_start,
517 size_t output_y_end,
518 const float* input,
519 const float* zero,
520 const float* weights,
521 float* output,
522 size_t input_padding_top,
523 size_t output_channels,
524 size_t output_height_stride,
525 size_t output_width_stride,
526 const union xnn_f32_output_params* params);
527
528typedef void (*xnn_conv_hwc2spchw_ukernel_function)(
529 size_t input_height,
530 size_t input_width,
531 size_t output_y_start,
532 size_t output_y_end,
533 const void* input,
534 const void* zero,
535 const void* weights,
536 void* output,
537 size_t input_padding_top,
538 size_t output_channels,
539 size_t output_height_stride,
540 size_t output_channel_stride,
541 const void* params);
542
543typedef void (*xnn_f32_conv_hwc2spchw_ukernel_function)(
544 size_t input_height,
545 size_t input_width,
546 size_t output_y_start,
547 size_t output_y_end,
548 const float* input,
549 const float* zero,
550 const float* weights,
551 float* output,
552 size_t input_padding_top,
553 size_t output_channels,
554 size_t output_height_stride,
555 size_t output_channel_stride,
556 const union xnn_f32_output_params* params);
557
558typedef void (*xnn_spmm_ukernel_function)(
559 uint32_t m,
560 uint32_t n,
561 const void* a,
562 const void* w,
563 const int32_t* dmap,
564 const uint32_t* nmap,
565 void* c,
566 const void* params);
567
568typedef void (*xnn_f32_spmm_ukernel_function)(
569 uint32_t m,
570 uint32_t n,
571 const float* a,
572 const float* w,
573 const int32_t* dmap,
574 const uint32_t* nmap,
575 float* c,
576 const union xnn_f32_output_params* params);
577
578typedef void (*xnn_packx_ukernel_function)(
579 size_t m,
580 size_t k,
581 const void* x,
582 size_t x_stride,
583 void* y);
584
585typedef void (*xnn_x32_packx_ukernel_function)(
586 size_t m,
587 size_t k,
588 const uint32_t* x,
589 size_t x_stride,
590 uint32_t* y);
591
592typedef void (*xnn_pad_ukernel_function)(
593 size_t m,
594 size_t n,
595 size_t l,
596 size_t r,
597 uint32_t c,
598 const void* x,
599 size_t x_stride,
600 void* y,
601 size_t y_stride);
602
603typedef void (*xnn_unpool_ukernel_function)(
604 size_t p,
605 size_t c,
606 uint32_t f,
607 const void* input,
608 const uint32_t* index,
609 void** output);
610
611typedef void (*xnn_x32_unpool_ukernel_function)(
612 size_t p,
613 size_t c,
614 uint32_t f,
615 const uint32_t* input,
616 const uint32_t* index,
617 uint32_t** output);
618
619typedef void (*xnn_zipc_ukernel_function)(
620 size_t n,
621 const void* x,
622 void* y);
623
624typedef void (*xnn_x8_zipc_ukernel_function)(
625 size_t n,
626 const uint8_t* x,
627 uint8_t* y);
628
629typedef void (*xnn_x32_zipc_ukernel_function)(
630 size_t n,
631 const uint32_t* x,
632 uint32_t* y);
633
634typedef void (*xnn_zipv_ukernel_function)(
635 size_t n,
636 size_t m,
637 const void* x,
638 void* y);
639
640typedef void (*xnn_x8_zipv_ukernel_function)(
641 size_t n,
642 size_t m,
643 const uint8_t* x,
644 uint8_t* y);
645
646typedef void (*xnn_x32_zipv_ukernel_function)(
647 size_t n,
648 size_t m,
649 const uint32_t* x,
650 uint32_t* y);
651
652typedef void (*xnn_x8_lut_ukernel_function)(
653 size_t n,
654 const uint8_t* x,
655 const uint8_t* t,
656 uint8_t* y);
657
658typedef void (*xnn_dwconv_spchw_ukernel_function)(
659 size_t output_height,
660 size_t input_width,
661 const void* input,
662 const void* weights,
663 void* output,
664 size_t input_tuple_stride,
665 size_t output_tuple_stride,
666 size_t input_height_stride,
667 size_t output_height_stride,
668 const void* params);
669
670typedef void (*xnn_f32_dwconv_spchw_ukernel_function)(
671 size_t output_height,
672 size_t input_width,
673 const float* input,
674 const float* weights,
675 float* output,
676 size_t input_tuple_stride,
677 size_t output_tuple_stride,
678 size_t input_height_stride,
679 size_t output_height_stride,
680 const union xnn_f32_spchw_params* params);
681
682typedef void (*xnn_dwconv_up_ukernel_function)(
683 size_t channels,
684 size_t output_width,
685 const void** input,
686 const void* weights,
687 void* output,
688 size_t input_stride,
689 size_t output_increment,
690 const void* params);
691
692typedef void (*xnn_f32_dwconv_up_ukernel_function)(
693 size_t channels,
694 size_t output_width,
695 const float** input,
696 const float* weights,
697 float* output,
698 size_t input_stride,
699 size_t output_increment,
700 const union xnn_f32_output_params* params);
701
702typedef void (*xnn_q8_dwconv_up_ukernel_function)(
703 size_t channels,
704 size_t output_width,
705 const uint8_t** input,
706 const void* weights,
707 uint8_t* output,
708 size_t input_stride,
709 size_t output_increment,
710 const union xnn_q8_gemm_params* params);
711
712typedef void (*xnn_dwconv_mp_ukernel_function)(
713 size_t channels,
714 size_t output_width,
715 const void** input,
716 const void* weights,
717 void* buffer,
718 void* output,
719 size_t input_stride,
720 size_t output_increment,
721 const void* params);
722
723typedef void (*xnn_gavgpool_up_ukernel_function)(
724 size_t m,
725 size_t n,
726 const void* x,
727 size_t x_stride,
728 const void* zero,
729 void* y,
730 const void* params);
731
732typedef void (*xnn_f32_gavgpool_up_ukernel_function)(
733 size_t m,
734 size_t n,
735 const float* x,
736 size_t x_stride,
737 const float* zero,
738 float* y,
739 const union xnn_f32_avgpool_params* params);
740
741typedef void (*xnn_gavgpool_spchw_ukernel_function)(
742 size_t elements,
743 size_t channels,
744 const float* input,
745 float* output,
746 const void* params);
747
748typedef void (*xnn_f32_gavgpool_spchw_ukernel_function)(
749 size_t elements,
750 size_t channels,
751 const float* input,
752 float* output,
753 const union xnn_f32_gavgpool_params* params);
754
755typedef void (*xnn_q8_gavgpool_up_ukernel_function)(
756 size_t m,
757 size_t n,
758 const uint8_t* x,
759 size_t x_stride,
760 const uint8_t* zero,
761 uint8_t* y,
762 const union xnn_q8_avgpool_params* params);
763
764typedef void (*xnn_gavgpool_mp_ukernel_function)(
765 size_t m,
766 size_t n,
767 const void* x,
768 size_t x_stride,
769 const void* zero,
770 void* buffer,
771 void* y,
772 const void* params);
773
774typedef void (*xnn_f32_gavgpool_mp_ukernel_function)(
775 size_t m,
776 size_t n,
777 const float* x,
778 size_t x_stride,
779 const float* zero,
780 float* buffer,
781 float* y,
782 const union xnn_f32_avgpool_params* params);
783
784typedef void (*xnn_q8_gavgpool_mp_ukernel_function)(
785 size_t m,
786 size_t n,
787 const uint8_t* x,
788 size_t x_stride,
789 const uint8_t* zero,
790 int32_t* buffer,
791 uint8_t* y,
792 const union xnn_q8_avgpool_params* params);
793
794typedef void (*xnn_avgpool_up_ukernel_function)(
795 size_t n,
796 size_t ks,
797 size_t kc,
798 const void** x,
799 const void* zero,
800 void* y,
801 size_t x_increment,
802 size_t y_increment,
803 const void* params);
804
805typedef void (*xnn_f32_avgpool_up_ukernel_function)(
806 size_t n,
807 size_t ks,
808 size_t kc,
809 const float** x,
810 const float* zero,
811 float* y,
812 size_t x_increment,
813 size_t y_increment,
814 const union xnn_f32_avgpool_params* params);
815
816typedef void (*xnn_q8_avgpool_up_ukernel_function)(
817 size_t n,
818 size_t ks,
819 size_t kc,
820 const uint8_t** x,
821 const uint8_t* zero,
822 uint8_t* y,
823 size_t x_increment,
824 size_t y_increment,
825 const union xnn_q8_avgpool_params* params);
826
827typedef void (*xnn_avgpool_mp_ukernel_function)(
828 size_t n,
829 size_t ks,
830 size_t kc,
831 const void** x,
832 const void* zero,
833 void* buffer,
834 void* y,
835 size_t x_increment,
836 size_t y_increment,
837 const void* params);
838
839typedef void (*xnn_f32_avgpool_mp_ukernel_function)(
840 size_t n,
841 size_t ks,
842 size_t kc,
843 const float** x,
844 const float* zero,
845 float* buffer,
846 float* y,
847 size_t x_increment,
848 size_t y_increment,
849 const union xnn_f32_avgpool_params* params);
850
851typedef void (*xnn_q8_avgpool_mp_ukernel_function)(
852 size_t n,
853 size_t ks,
854 size_t kc,
855 const uint8_t** x,
856 const uint8_t* zero,
857 int32_t* buffer,
858 uint8_t* y,
859 size_t x_increment,
860 size_t y_increment,
861 const union xnn_q8_avgpool_params* params);
862
863typedef void (*xnn_pavgpool_up_ukernel_function)(
864 size_t n,
865 size_t ks,
866 size_t kc,
867 const void** x,
868 const void* zero,
869 const void* multiplier,
870 void* y,
871 size_t x_increment,
872 size_t y_increment,
873 const void* params);
874
875typedef void (*xnn_f32_pavgpool_up_ukernel_function)(
876 size_t n,
877 size_t ks,
878 size_t kc,
879 const float** x,
880 const float* zero,
881 const float* multiplier,
882 float* y,
883 size_t x_increment,
884 size_t y_increment,
885 const union xnn_f32_output_params* params);
886
887typedef void (*xnn_pavgpool_mp_ukernel_function)(
888 size_t n,
889 size_t ks,
890 size_t kc,
891 const void** x,
892 const void* zero,
893 const void* multiplier,
894 void* buffer,
895 void* y,
896 size_t x_increment,
897 size_t y_increment,
898 const void* params);
899
900typedef void (*xnn_f32_pavgpool_mp_ukernel_function)(
901 size_t n,
902 size_t ks,
903 size_t kc,
904 const float** x,
905 const float* zero,
906 const float* multiplier,
907 float* buffer,
908 float* y,
909 size_t x_increment,
910 size_t y_increment,
911 const union xnn_f32_output_params* params);
912
913typedef void (*xnn_maxpool_ukernel_function)(
914 size_t n,
915 size_t ks,
916 size_t kc,
917 const void** x,
918 void* y,
919 size_t x_increment,
920 size_t y_increment,
921 const void* params);
922
923typedef void (*xnn_f32_maxpool_ukernel_function)(
924 size_t n,
925 size_t ks,
926 size_t kc,
927 const float** x,
928 float* y,
929 size_t x_increment,
930 size_t y_increment,
931 const union xnn_f32_output_params* params);
932
933typedef void (*xnn_u8_maxpool_ukernel_function)(
934 size_t n,
935 size_t ks,
936 size_t kc,
937 const uint8_t** x,
938 uint8_t* y,
939 size_t x_increment,
940 size_t y_increment,
941 const union xnn_u8_output_params* params);
942
943typedef void (*xnn_argmaxpool_up_ukernel_function)(
944 size_t n,
945 size_t ks,
946 size_t kc,
947 const void** x,
948 void* y,
949 uint32_t* i,
950 size_t x_increment,
951 size_t y_increment,
952 const void* params);
953
954typedef void (*xnn_f32_argmaxpool_up_ukernel_function)(
955 size_t n,
956 size_t ks,
957 size_t kc,
958 const float** x,
959 float* y,
960 uint32_t* i,
961 size_t x_increment,
962 size_t y_increment,
963 const union xnn_f32_output_params* params);
964
965typedef void (*xnn_argmaxpool_mp_ukernel_function)(
966 size_t n,
967 size_t ks,
968 size_t kc,
969 const void** x,
970 void* ab,
971 uint32_t* ib,
972 void* y,
973 uint32_t* i,
974 size_t x_increment,
975 size_t y_increment,
976 const void* params);
977
978typedef void (*xnn_f32_argmaxpool_mp_ukernel_function)(
979 size_t n,
980 size_t ks,
981 size_t kc,
982 const float** x,
983 float* ab,
984 uint32_t* ib,
985 float* y,
986 uint32_t* i,
987 size_t x_increment,
988 size_t y_increment,
989 const union xnn_f32_output_params* params);
990
991typedef void (*xnn_univector_ukernel_function)(
992 size_t n,
993 const void* x,
994 void* y,
995 const void* params);
996
997typedef void (*xnn_f32_clamp_ukernel_function)(
998 size_t n,
999 const float* x,
1000 float* y,
1001 const union xnn_f32_output_params* params);
1002
1003typedef void (*xnn_u8_clamp_ukernel_function)(
1004 size_t n,
1005 const uint8_t* x,
1006 uint8_t* y,
1007 const union xnn_u8_output_params* params);
1008
1009typedef void (*xnn_f32_hswish_ukernel_function)(
1010 size_t n,
1011 const float* x,
1012 float* y,
1013 const union xnn_f32_hswish_params* params);
1014
1015typedef void (*xnn_rmax_ukernel_function)(
1016 size_t n,
1017 const void* x,
1018 void* y);
1019
1020typedef void (*xnn_u8_rmax_ukernel_function)(
1021 size_t n,
1022 const uint8_t* x,
1023 uint8_t* y);
1024
1025typedef void (*xnn_f32_rmax_ukernel_function)(
1026 size_t n,
1027 const float* x,
1028 float* y);
1029
1030typedef void (*xnn_u8_lut32norm_ukernel_function)(
1031 size_t n,
1032 const uint8_t* x,
1033 const uint32_t* t,
1034 uint8_t* y);
1035
1036typedef void (*xnn_vadd_ukernel_function)(
1037 size_t n,
1038 const void* a,
1039 const void* b,
1040 void* y,
1041 const void* params);
1042
1043typedef void (*xnn_f32_vadd_ukernel_function)(
1044 size_t n,
1045 const float* a,
1046 const float* b,
1047 float* y,
1048 const union xnn_f32_output_params* params);
1049
1050typedef void (*xnn_q8_vadd_ukernel_function)(
1051 size_t n,
1052 const uint8_t* a,
1053 const uint8_t* b,
1054 uint8_t* y,
1055 const union xnn_q8_add_params* params);
1056
1057typedef void (*xnn_vmul_ukernel_function)(
1058 size_t n,
1059 const void* a,
1060 const void* b,
1061 void* y,
1062 const void* params);
1063
1064typedef void (*xnn_f32_vmul_ukernel_function)(
1065 size_t n,
1066 const float* a,
1067 const float* b,
1068 float* y,
1069 const union xnn_f32_output_params* params);
1070
1071typedef void (*xnn_vsub_ukernel_function)(
1072 size_t n,
1073 const void* a,
1074 const void* b,
1075 void* y,
1076 const void* params);
1077
1078typedef void (*xnn_f32_vsub_ukernel_function)(
1079 size_t n,
1080 const float* a,
1081 const float* b,
1082 float* y,
1083 const union xnn_f32_output_params* params);
1084
1085typedef void (*xnn_vmulcaddc_ukernel_function)(
1086 size_t m,
1087 size_t c,
1088 const void* x,
1089 size_t x_stride,
1090 const void* w,
1091 void* y,
1092 size_t y_stride,
1093 const void* params);
1094
1095typedef void (*xnn_f32_vmulcaddc_ukernel_function)(
1096 size_t m,
1097 size_t c,
1098 const float* x,
1099 size_t x_stride,
1100 const float* w,
1101 float* y,
1102 size_t y_stride,
1103 const union xnn_f32_output_params* params);
1104
1105typedef void (*xnn_prelu_ukernel_function)(
1106 size_t mr,
1107 size_t n,
1108 const void* x,
1109 size_t x_stride,
1110 const void* w,
1111 void* y,
1112 size_t y_stride,
1113 const void* params);
1114
1115typedef void (*xnn_f32_prelu_ukernel_function)(
1116 size_t mr,
1117 size_t n,
1118 const float* x,
1119 size_t x_stride,
1120 const float* w,
1121 float* y,
1122 size_t y_stride,
1123 const union xnn_f32_output_params* params);
1124
1125
1126struct gemm_parameters {
1127 xnn_gemm_ukernel_function gemm;
1128 xnn_igemm_ukernel_function igemm;
Marat Dukhan80fc9322019-09-29 21:06:36 -07001129 // Optional GEMM and IGEMM micro-kernels with MR=1 and the same NR and KR parameters.
XNNPACK Teamb455b122019-09-27 18:10:33 -07001130 xnn_gemm_ukernel_function gemm1;
1131 xnn_igemm_ukernel_function igemm1;
1132 uint8_t mr;
1133 uint8_t nr;
1134 uint8_t log2_kr;
1135 uint8_t log2_sr;
1136};
1137
1138struct spmm_parameters {
1139 xnn_spmm_ukernel_function ukernel;
1140 // Number of M-dimension elements in a tile.
1141 // Corresponds to a block of pixels in 1x1 Convolution and a block of batch size in Fully Connected operator.
1142 uint8_t mr;
1143 // Number of N-dimension elements in a tile.
1144 // Corresponds to a block of output channels/features in 1x1 Convolution and Fully Connected operator.
1145 uint8_t nr;
1146};
1147
1148struct hwc2spchw_dconv_parameters {
1149 xnn_conv_hwc2spchw_ukernel_function ukernel_with_symm_padding;
1150 // Number of output channels in a tile.
1151 // This parameter must be passed as is to weight packing function.
1152 uint8_t output_channel_tile;
1153 // Number of output height pixels in a tile.
1154 // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
1155 uint8_t output_height_tile;
1156 // Number of output width pixes in a tile.
1157 uint8_t output_width_tile;
1158};
1159
1160struct spchw_dwconv_parameters {
1161 xnn_dwconv_spchw_ukernel_function ukernel;
1162 // Number of input width pixels in a tile.
1163 uint8_t input_width_tile;
1164 // Number of output width pixels in a tile.
1165 uint8_t output_width_tile;
1166 // Number of output height pixels in a tile.
1167 // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
1168 uint8_t output_height_tile;
1169};
1170
1171struct spchw_gavgpool_parameters {
1172 xnn_gavgpool_spchw_ukernel_function ukernel;
1173 // Number of channels in a tile.
1174 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
1175 uint8_t channel_tile;
1176};
1177
1178struct dwconv_parameters {
1179 union {
1180 xnn_dwconv_up_ukernel_function up;
1181 xnn_dwconv_mp_ukernel_function mp;
1182 };
1183 uint8_t cr;
1184 uint8_t mr;
1185 uint8_t qr;
1186};
1187
1188struct gavgpool_parameters {
1189 xnn_gavgpool_up_ukernel_function up;
1190 xnn_gavgpool_mp_ukernel_function mp;
1191 uint8_t mr;
1192};
1193
1194struct avgpool_parameters {
1195 xnn_avgpool_up_ukernel_function up;
1196 xnn_avgpool_mp_ukernel_function mp;
1197 uint8_t mr;
1198 uint8_t qr;
1199};
1200
1201struct pavgpool_parameters {
1202 xnn_pavgpool_up_ukernel_function up;
1203 xnn_pavgpool_mp_ukernel_function mp;
1204 uint8_t mr;
1205 uint8_t qr;
1206};
1207
1208struct argmaxpool_parameters {
1209 union {
1210 xnn_argmaxpool_up_ukernel_function up;
1211 xnn_argmaxpool_mp_ukernel_function mp;
1212 };
1213 uint8_t mr;
1214 uint8_t qr;
1215};
1216
1217struct maxpool_parameters {
1218 xnn_maxpool_ukernel_function ukernel;
1219 uint8_t mr;
1220 uint8_t qr;
1221};
1222
1223struct zip_parameters {
1224 xnn_zipc_ukernel_function x2;
1225 xnn_zipc_ukernel_function x3;
1226 xnn_zipc_ukernel_function x4;
1227 xnn_zipv_ukernel_function xm;
1228};
1229
1230struct prelu_parameters {
1231 xnn_prelu_ukernel_function ukernel;
1232 uint8_t mr;
1233};
1234
1235struct pad_parameters {
1236 xnn_pad_ukernel_function ukernel;
1237 uint8_t mr;
1238};
1239
1240struct vmulcaddc_parameters {
1241 xnn_vmulcaddc_ukernel_function ukernel;
1242 uint8_t cr;
1243 uint8_t mr;
1244};
1245
1246#define XNN_MAX_Q8_DWCONV_UKERNELS 1
1247#define XNN_MAX_F32_DWCONV_UKERNELS 3
1248#define XNN_MAX_F32_ARGMAXPOOL_UKERNELS 3
1249
1250struct xnn_parameters {
1251 bool initialized;
1252 struct {
1253 struct gemm_parameters gemm;
1254 struct dwconv_parameters dwconv[XNN_MAX_Q8_DWCONV_UKERNELS];
1255 struct avgpool_parameters avgpool;
1256 struct gavgpool_parameters gavgpool;
1257 xnn_vadd_ukernel_function vadd;
1258 } q8;
1259 struct {
1260 struct maxpool_parameters maxpool;
1261 xnn_univector_ukernel_function clamp;
1262 xnn_u8_lut32norm_ukernel_function lut32norm;
1263 xnn_u8_rmax_ukernel_function rmax;
1264 } u8;
1265 struct {
1266 xnn_x8_lut_ukernel_function lut;
1267 struct zip_parameters zip;
1268 } x8;
1269 struct {
1270 struct gemm_parameters gemm;
1271 struct gemm_parameters gemm2;
1272 struct dwconv_parameters dwconv[XNN_MAX_F32_DWCONV_UKERNELS];
1273 struct avgpool_parameters avgpool;
1274 struct pavgpool_parameters pavgpool;
1275 struct gavgpool_parameters gavgpool;
1276 struct maxpool_parameters maxpool;
1277 struct argmaxpool_parameters argmaxpool[XNN_MAX_F32_ARGMAXPOOL_UKERNELS];
1278 xnn_univector_ukernel_function clamp;
1279 xnn_univector_ukernel_function hswish;
1280 struct prelu_parameters prelu;
1281 xnn_vadd_ukernel_function vadd;
1282 struct vmulcaddc_parameters vmulcaddc;
1283 // Sparse Matrix-Dense Matrix Multiplication (NR=1 block).
1284 struct spmm_parameters spmm;
1285 // Sparse Matrix-Dense Matrix Multiplication (NR=2 block).
1286 struct spmm_parameters spmm2;
1287 // Sparse Matrix-Dense Matrix Multiplication (NR=4 block).
1288 struct spmm_parameters spmm4;
1289 // Direct 3x3 stride-2 Convolution with 3 input channels and HWC->SpCHW layout conversion.
1290 struct hwc2spchw_dconv_parameters hwc2spchw_dconv3x3c3s2;
1291 // Direct 3x3 stride-1 Convolution with padding 1 on left and right in SpCHW layout.
1292 struct spchw_dwconv_parameters spchw_dwconv3x3;
1293 // Direct 3x3 stride-2 Convolution with padding 1 on left and right in SpCHW layout.
1294 struct spchw_dwconv_parameters spchw_dwconv3x3s2;
1295 // Global Average Pooling in SpCHW layout.
1296 struct spchw_gavgpool_parameters spchw_gavgpool;
1297 } f32;
1298 struct {
1299 struct pad_parameters pad;
1300 xnn_unpool_ukernel_function unpool;
1301 struct zip_parameters zip;
1302 } x32;
1303};
1304
1305extern XNN_INTERNAL struct xnn_parameters xnn_params;