blob: 8940f9a0f0dfe5952b20902bed40578f08cfaff1 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <stdbool.h>
12#include <stddef.h>
13#include <stdint.h>
14
15#include <pthreadpool.h>
16
17#ifdef __cplusplus
18extern "C" {
19#endif
20
21// The number of bytes XNNPACK may read beyond array bounds.
22// The caller must allocate at this this many extra bytes after the tensor data passed to XNNPACK.
23//
24// Note: XNNPACK reads, but never writes beyond array bounds.
25#define XNN_EXTRA_BYTES 16
26
27// The convolution operator represents a depthwise convolution, and use HWGo layout for filters.
28#define XNN_CONVOLUTION_FLAG_DEPTHWISE 0x00000001
29
30// The operator assumes NHWC layout for the input, regardless of the output layout.
31#define XNN_FLAG_INPUT_NHWC 0x00000002
32
33// Status code for any XNNPACK function call.
34enum xnn_status {
35 // The call succeeded, and all output arguments now contain valid data.
36 xnn_status_success = 0,
37 xnn_status_uninitialized = 1,
38 xnn_status_invalid_parameter = 2,
39 xnn_status_invalid_state = 3,
40 xnn_status_unsupported_parameter = 4,
41 xnn_status_unsupported_hardware = 5,
42 xnn_status_out_of_memory = 6,
43};
44
45enum xnn_status xnn_initialize(void);
46
47enum xnn_status xnn_deinitialize(void);
48
49typedef struct xnn_operator* xnn_operator_t;
50
51enum xnn_status xnn_create_convolution2d_nhwc_q8(
52 uint32_t input_padding_top,
53 uint32_t input_padding_right,
54 uint32_t input_padding_bottom,
55 uint32_t input_padding_left,
56 uint32_t kernel_height,
57 uint32_t kernel_width,
58 uint32_t subsampling_height,
59 uint32_t subsampling_width,
60 uint32_t dilation_height,
61 uint32_t dilation_width,
62 uint32_t groups,
63 size_t group_input_channels,
64 size_t group_output_channels,
65 size_t input_pixel_stride,
66 size_t output_pixel_stride,
67 uint8_t input_zero_point,
68 float input_scale,
69 uint8_t kernel_zero_point,
70 float kernel_scale,
71 const uint8_t* kernel,
72 const int32_t* bias,
73 uint8_t output_zero_point,
74 float output_scale,
75 uint8_t output_min,
76 uint8_t output_max,
77 uint32_t flags,
78 xnn_operator_t* convolution_op_out);
79
80enum xnn_status xnn_setup_convolution2d_nhwc_q8(
81 xnn_operator_t convolution_op,
82 size_t batch_size,
83 size_t input_height,
84 size_t input_width,
85 const uint8_t* input,
86 uint8_t* output,
87 pthreadpool_t threadpool);
88
89enum xnn_status xnn_create_convolution2d_nhwc_f32(
90 uint32_t input_padding_top,
91 uint32_t input_padding_right,
92 uint32_t input_padding_bottom,
93 uint32_t input_padding_left,
94 uint32_t kernel_height,
95 uint32_t kernel_width,
96 uint32_t subsampling_height,
97 uint32_t subsampling_width,
98 uint32_t dilation_height,
99 uint32_t dilation_width,
100 uint32_t groups,
101 size_t group_input_channels,
102 size_t group_output_channels,
103 size_t input_pixel_stride,
104 size_t output_pixel_stride,
105 const float* kernel,
106 const float* bias,
107 float output_min,
108 float output_max,
109 uint32_t flags,
110 xnn_operator_t* convolution_op_out);
111
112enum xnn_status xnn_setup_convolution2d_nhwc_f32(
113 xnn_operator_t convolution_op,
114 size_t batch_size,
115 size_t input_height,
116 size_t input_width,
117 const float* input,
118 float* output,
119 pthreadpool_t threadpool);
120
121enum xnn_status xnn_create_convolution2d_spnchw_f32(
122 uint32_t input_padding_top,
123 uint32_t input_padding_right,
124 uint32_t input_padding_bottom,
125 uint32_t input_padding_left,
126 uint32_t kernel_height,
127 uint32_t kernel_width,
128 uint32_t subsampling_height,
129 uint32_t subsampling_width,
130 uint32_t dilation_height,
131 uint32_t dilation_width,
132 uint32_t groups,
133 size_t group_input_channels,
134 size_t group_output_channels,
135 const float* kernel,
136 const float* bias,
137 float output_min,
138 float output_max,
139 uint32_t flags,
140 xnn_operator_t* convolution_op_out);
141
142enum xnn_status xnn_setup_convolution2d_spnchw_f32(
143 xnn_operator_t convolution_op,
144 size_t batch_size,
145 size_t input_batch_stride,
146 size_t output_batch_stride,
147 size_t input_height,
148 size_t input_width,
149 const float* input,
150 float* output,
151 pthreadpool_t threadpool);
152
153enum xnn_status xnn_create_deconvolution2d_nhwc_q8(
154 uint32_t output_padding_top,
155 uint32_t output_padding_right,
156 uint32_t output_padding_bottom,
157 uint32_t output_padding_left,
158 uint32_t adjustment_height,
159 uint32_t adjustment_width,
160 uint32_t kernel_height,
161 uint32_t kernel_width,
162 uint32_t stride_height,
163 uint32_t stride_width,
164 uint32_t dilation_height,
165 uint32_t dilation_width,
166 uint32_t groups,
167 size_t group_input_channels,
168 size_t group_output_channels,
169 size_t input_pixel_stride,
170 size_t output_pixel_stride,
171 uint8_t input_zero_point,
172 float input_scale,
173 uint8_t kernel_zero_point,
174 float kernel_scale,
175 const uint8_t* kernel,
176 const int32_t* bias,
177 uint8_t output_zero_point,
178 float output_scale,
179 uint8_t output_min,
180 uint8_t output_max,
181 uint32_t flags,
182 xnn_operator_t* deconvolution_op_out);
183
184enum xnn_status xnn_setup_deconvolution2d_nhwc_q8(
185 xnn_operator_t deconvolution_op,
186 size_t batch_size,
187 size_t input_height,
188 size_t input_width,
189 const uint8_t* input,
190 uint8_t* output,
191 pthreadpool_t threadpool);
192
193enum xnn_status xnn_create_deconvolution2d_nhwc_f32(
194 uint32_t output_padding_top,
195 uint32_t output_padding_right,
196 uint32_t output_padding_bottom,
197 uint32_t output_padding_left,
198 uint32_t adjustment_height,
199 uint32_t adjustment_width,
200 uint32_t kernel_height,
201 uint32_t kernel_width,
202 uint32_t stride_height,
203 uint32_t stride_width,
204 uint32_t dilation_height,
205 uint32_t dilation_width,
206 uint32_t groups,
207 size_t group_input_channels,
208 size_t group_output_channels,
209 size_t input_pixel_stride,
210 size_t output_pixel_stride,
211 const float* kernel,
212 const float* bias,
213 float output_min,
214 float output_max,
215 uint32_t flags,
216 xnn_operator_t* deconvolution_op_out);
217
218enum xnn_status xnn_setup_deconvolution2d_nhwc_f32(
219 xnn_operator_t deconvolution_op,
220 size_t batch_size,
221 size_t input_height,
222 size_t input_width,
223 const float* input,
224 float* output,
225 pthreadpool_t threadpool);
226
227enum xnn_status xnn_create_fully_connected_nc_q8(
228 size_t input_channels,
229 size_t output_channels,
230 size_t input_stride,
231 size_t output_stride,
232 uint8_t input_zero_point,
233 float input_scale,
234 uint8_t kernel_zero_point,
235 float kernel_scale,
236 const uint8_t* kernel,
237 const int32_t* bias,
238 uint8_t output_zero_point,
239 float output_scale,
240 uint8_t output_min,
241 uint8_t output_max,
242 uint32_t flags,
243 xnn_operator_t* fully_connected_op_out);
244
245enum xnn_status xnn_setup_fully_connected_nc_q8(
246 xnn_operator_t fully_connected_op,
247 size_t batch_size,
248 const uint8_t* input,
249 uint8_t* output,
250 pthreadpool_t threadpool);
251
252enum xnn_status xnn_create_fully_connected_nc_f32(
253 size_t input_channels,
254 size_t output_channels,
255 size_t input_stride,
256 size_t output_stride,
257 const float* kernel,
258 const float* bias,
259 float output_min,
260 float output_max,
261 uint32_t flags,
262 xnn_operator_t* fully_connected_op_out);
263
264enum xnn_status xnn_setup_fully_connected_nc_f32(
265 xnn_operator_t fully_connected_op,
266 size_t batch_size,
267 const float* input,
268 float* output,
269 pthreadpool_t threadpool);
270
271enum xnn_status xnn_create_global_average_pooling_nwc_q8(
272 size_t channels,
273 size_t input_stride,
274 size_t output_stride,
275 uint8_t input_zero_point,
276 float input_scale,
277 uint8_t output_zero_point,
278 float output_scale,
279 uint8_t output_min,
280 uint8_t output_max,
281 uint32_t flags,
282 xnn_operator_t* global_average_pooling_op_out);
283
284enum xnn_status xnn_setup_global_average_pooling_nwc_q8(
285 xnn_operator_t global_average_pooling_op,
286 size_t batch_size,
287 size_t width,
288 const uint8_t* input,
289 uint8_t* output,
290 pthreadpool_t threadpool);
291
292enum xnn_status xnn_create_global_average_pooling_nwc_f32(
293 size_t channels,
294 size_t input_stride,
295 size_t output_stride,
296 float output_min,
297 float output_max,
298 uint32_t flags,
299 xnn_operator_t* global_average_pooling_op_out);
300
301enum xnn_status xnn_setup_global_average_pooling_nwc_f32(
302 xnn_operator_t global_average_pooling_op,
303 size_t batch_size,
304 size_t width,
305 const float* input,
306 float* output,
307 pthreadpool_t threadpool);
308
309enum xnn_status xnn_create_global_average_pooling_spnchw_f32(
310 size_t channels,
311 float output_min,
312 float output_max,
313 uint32_t flags,
314 xnn_operator_t* global_average_pooling_op_out);
315
316enum xnn_status xnn_setup_global_average_pooling_spnchw_f32(
317 xnn_operator_t global_average_pooling_op,
318 size_t batch_size,
319 size_t height,
320 size_t width,
321 const float* input,
322 float* output,
323 pthreadpool_t threadpool);
324
325enum xnn_status xnn_create_average_pooling2d_nhwc_q8(
326 uint32_t input_padding_top,
327 uint32_t input_padding_right,
328 uint32_t input_padding_bottom,
329 uint32_t input_padding_left,
330 uint32_t pooling_height,
331 uint32_t pooling_width,
332 uint32_t stride_height,
333 uint32_t stride_width,
334 size_t channels,
335 size_t input_pixel_stride,
336 size_t output_pixel_stride,
337 uint8_t input_zero_point,
338 float input_scale,
339 uint8_t output_zero_point,
340 float output_scale,
341 uint8_t output_min,
342 uint8_t output_max,
343 uint32_t flags,
344 xnn_operator_t* average_pooling_op_out);
345
346enum xnn_status xnn_setup_average_pooling2d_nhwc_q8(
347 xnn_operator_t average_pooling_op,
348 size_t batch_size,
349 size_t input_height,
350 size_t input_width,
351 const uint8_t* input,
352 uint8_t* output,
353 pthreadpool_t threadpool);
354
355enum xnn_status xnn_create_average_pooling2d_nhwc_f32(
356 uint32_t input_padding_top,
357 uint32_t input_padding_right,
358 uint32_t input_padding_bottom,
359 uint32_t input_padding_left,
360 uint32_t pooling_height,
361 uint32_t pooling_width,
362 uint32_t stride_height,
363 uint32_t stride_width,
364 size_t channels,
365 size_t input_pixel_stride,
366 size_t output_pixel_stride,
367 float output_min,
368 float output_max,
369 uint32_t flags,
370 xnn_operator_t* average_pooling_op_out);
371
372enum xnn_status xnn_setup_average_pooling2d_nhwc_f32(
373 xnn_operator_t average_pooling_op,
374 size_t batch_size,
375 size_t input_height,
376 size_t input_width,
377 const float* input,
378 float* output,
379 pthreadpool_t threadpool);
380
381enum xnn_status xnn_create_max_pooling2d_nhwc_u8(
382 uint32_t input_padding_top,
383 uint32_t input_padding_right,
384 uint32_t input_padding_bottom,
385 uint32_t input_padding_left,
386 uint32_t pooling_height,
387 uint32_t pooling_width,
388 uint32_t stride_height,
389 uint32_t stride_width,
390 uint32_t dilation_height,
391 uint32_t dilation_width,
392 size_t channels,
393 size_t input_pixel_stride,
394 size_t output_pixel_stride,
395 uint8_t output_min,
396 uint8_t output_max,
397 uint32_t flags,
398 xnn_operator_t* max_pooling_op_out);
399
400enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
401 xnn_operator_t max_pooling_op,
402 size_t batch_size,
403 size_t input_height,
404 size_t input_width,
405 const uint8_t* input,
406 uint8_t* output,
407 pthreadpool_t threadpool);
408
409enum xnn_status xnn_create_max_pooling2d_nhwc_f32(
410 uint32_t input_padding_top,
411 uint32_t input_padding_right,
412 uint32_t input_padding_bottom,
413 uint32_t input_padding_left,
414 uint32_t pooling_height,
415 uint32_t pooling_width,
416 uint32_t stride_height,
417 uint32_t stride_width,
418 uint32_t dilation_height,
419 uint32_t dilation_width,
420 size_t channels,
421 size_t input_pixel_stride,
422 size_t output_pixel_stride,
423 float output_min,
424 float output_max,
425 uint32_t flags,
426 xnn_operator_t* max_pooling_op_out);
427
428enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
429 xnn_operator_t max_pooling_op,
430 size_t batch_size,
431 size_t input_height,
432 size_t input_width,
433 const float* input,
434 float* output,
435 pthreadpool_t threadpool);
436
437enum xnn_status xnn_create_argmax_pooling2d_nhwc_f32(
438 uint32_t input_padding_top,
439 uint32_t input_padding_right,
440 uint32_t input_padding_bottom,
441 uint32_t input_padding_left,
442 uint32_t pooling_height,
443 uint32_t pooling_width,
444 size_t channels,
445 size_t input_pixel_stride,
446 size_t output_pixel_stride,
447 float output_min,
448 float output_max,
449 uint32_t flags,
450 xnn_operator_t* argmax_pooling_op_out);
451
452enum xnn_status xnn_setup_argmax_pooling2d_nhwc_f32(
453 xnn_operator_t argmax_pooling_op,
454 size_t batch_size,
455 size_t input_height,
456 size_t input_width,
457 const float* input,
458 float* output,
459 uint32_t* index,
460 pthreadpool_t threadpool);
461
462enum xnn_status xnn_create_unpooling2d_nhwc_x32(
463 uint32_t input_padding_top,
464 uint32_t input_padding_right,
465 uint32_t input_padding_bottom,
466 uint32_t input_padding_left,
467 uint32_t pooling_height,
468 uint32_t pooling_width,
469 size_t channels,
470 size_t input_pixel_stride,
471 size_t output_pixel_stride,
472 uint32_t flags,
473 xnn_operator_t* unpooling_op_out);
474
475enum xnn_status xnn_setup_unpooling2d_nhwc_x32(
476 xnn_operator_t unpooling_op,
477 size_t batch_size,
478 size_t input_height,
479 size_t input_width,
480 const void* input,
481 const uint32_t* index,
482 void* output,
483 pthreadpool_t threadpool);
484
485enum xnn_status xnn_create_channel_shuffle_nc_x8(
486 size_t groups,
487 size_t group_channels,
488 size_t input_stride,
489 size_t output_stride,
490 uint32_t flags,
491 xnn_operator_t* channel_shuffle_op_out);
492
493enum xnn_status xnn_setup_channel_shuffle_nc_x8(
494 xnn_operator_t channel_shuffle_op,
495 size_t batch_size,
496 const void* input,
497 void* output,
498 pthreadpool_t threadpool);
499
500enum xnn_status xnn_create_channel_shuffle_nc_x32(
501 size_t groups,
502 size_t group_channels,
503 size_t input_stride,
504 size_t output_stride,
505 uint32_t flags,
506 xnn_operator_t* channel_shuffle_op_out);
507
508enum xnn_status xnn_setup_channel_shuffle_nc_x32(
509 xnn_operator_t channel_shuffle_op,
510 size_t batch_size,
511 const void* input,
512 void* output,
513 pthreadpool_t threadpool);
514
515enum xnn_status xnn_create_add_nc_q8(
516 size_t channels,
517 size_t a_stride,
518 size_t b_stride,
519 size_t sum_stride,
520 uint8_t a_zero_point,
521 float a_scale,
522 uint8_t b_zero_point,
523 float b_scale,
524 uint8_t sum_zero_point,
525 float sum_scale,
526 uint8_t sum_min,
527 uint8_t sum_max,
528 uint32_t flags,
529 xnn_operator_t* add_op_out);
530
531enum xnn_status xnn_setup_add_nc_q8(
532 xnn_operator_t add_op,
533 size_t batch_size,
534 const uint8_t* a,
535 const uint8_t* b,
536 uint8_t* sum,
537 pthreadpool_t threadpool);
538
539enum xnn_status xnn_create_add_nc_f32(
540 size_t channels,
541 size_t a_stride,
542 size_t b_stride,
543 size_t sum_stride,
544 float sum_min,
545 float sum_max,
546 uint32_t flags,
547 xnn_operator_t* add_op_out);
548
549enum xnn_status xnn_setup_add_nc_f32(
550 xnn_operator_t add_op,
551 size_t batch_size,
552 const float* a,
553 const float* b,
554 float* sum,
555 pthreadpool_t threadpool);
556
557enum xnn_status xnn_create_channel_pad_nc_x32(
558 size_t input_channels,
559 size_t pad_before_channels,
560 size_t pad_after_channels,
561 size_t input_stride,
562 size_t output_stride,
563 const void* pad_value,
564 uint32_t flags,
565 xnn_operator_t* channel_pad_op_out);
566
567enum xnn_status xnn_setup_channel_pad_nc_x32(
568 xnn_operator_t channel_pad_op,
569 size_t batch_size,
570 const void* input,
571 void* output,
572 pthreadpool_t threadpool);
573
574enum xnn_status xnn_create_clamp_nc_u8(
575 size_t channels,
576 size_t input_stride,
577 size_t output_stride,
578 uint8_t output_min,
579 uint8_t output_max,
580 uint32_t flags,
581 xnn_operator_t* clamp_op_out);
582
583enum xnn_status xnn_setup_clamp_nc_u8(
584 xnn_operator_t clamp_op,
585 size_t batch_size,
586 const uint8_t* input,
587 uint8_t* output,
588 pthreadpool_t threadpool);
589
590enum xnn_status xnn_create_clamp_nc_f32(
591 size_t channels,
592 size_t input_stride,
593 size_t output_stride,
594 float output_min,
595 float output_max,
596 uint32_t flags,
597 xnn_operator_t* clamp_op_out);
598
599enum xnn_status xnn_setup_clamp_nc_f32(
600 xnn_operator_t clamp_op,
601 size_t batch_size,
602 const float* input,
603 float* output,
604 pthreadpool_t threadpool);
605
606enum xnn_status xnn_create_hardswish_nc_f32(
607 size_t channels,
608 size_t input_stride,
609 size_t output_stride,
610 uint32_t flags,
611 xnn_operator_t* hardswish_op_out);
612
613enum xnn_status xnn_setup_hardswish_nc_f32(
614 xnn_operator_t hardswish_op,
615 size_t batch_size,
616 const float* input,
617 float* output,
618 pthreadpool_t threadpool);
619
620enum xnn_status xnn_create_sigmoid_nc_q8(
621 size_t channels,
622 size_t input_stride,
623 size_t output_stride,
624 uint8_t input_zero_point,
625 float input_scale,
626 uint8_t output_zero_point,
627 float output_scale,
628 uint8_t output_min,
629 uint8_t output_max,
630 uint32_t flags,
631 xnn_operator_t* sigmoid_op_out);
632
633enum xnn_status xnn_setup_sigmoid_nc_q8(
634 xnn_operator_t sigmoid_op,
635 size_t batch_size,
636 const uint8_t* input,
637 uint8_t* output,
638 pthreadpool_t threadpool);
639
640enum xnn_status xnn_create_leaky_relu_nc_q8(
641 size_t channels,
642 size_t input_stride,
643 size_t output_stride,
644 float negative_slope,
645 uint8_t input_zero_point,
646 float input_scale,
647 uint8_t output_zero_point,
648 float output_scale,
649 uint8_t output_min,
650 uint8_t output_max,
651 uint32_t flags,
652 xnn_operator_t* leaky_relu_op_out);
653
654enum xnn_status xnn_setup_leaky_relu_nc_q8(
655 xnn_operator_t leaky_relu_op,
656 size_t batch_size,
657 const uint8_t* input,
658 uint8_t* output,
659 pthreadpool_t threadpool);
660
661enum xnn_status xnn_create_prelu_nc_f32(
662 size_t channels,
663 size_t input_stride,
664 size_t output_stride,
665 const float* negative_slope,
666 float output_min,
667 float output_max,
668 uint32_t flags,
669 xnn_operator_t* prelu_op_out);
670
671enum xnn_status xnn_setup_prelu_nc_f32(
672 xnn_operator_t prelu_op,
673 size_t batch_size,
674 const float* input,
675 float* output,
676 pthreadpool_t threadpool);
677
678enum xnn_status xnn_create_softargmax_nc_q8(
679 size_t channels,
680 size_t input_stride,
681 size_t output_stride,
682 float input_scale,
683 uint8_t output_zero_point,
684 float output_scale,
685 uint32_t flags,
686 xnn_operator_t* softargmax_op_out);
687
688enum xnn_status xnn_setup_softargmax_nc_q8(
689 xnn_operator_t softargmax_op,
690 size_t batch_size,
691 const uint8_t* input,
692 uint8_t* output,
693 pthreadpool_t threadpool);
694
695enum xnn_status xnn_run_operator(
696 xnn_operator_t op,
697 pthreadpool_t threadpool);
698
699enum xnn_status xnn_delete_operator(
700 xnn_operator_t op);
701
702#ifdef __cplusplus
703} // extern "C"
704#endif