blob: 6ecf25c0e19fbc549f3c7afabf2b33c72e896d3b [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <stddef.h>
12#include <stdint.h>
13
14#include <pthreadpool.h>
15
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070016#include <xnnpack/params.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070017#include <xnnpack/compute.h>
18
19
20enum xnn_ukernel_type {
21 xnn_ukernel_type_none = 0,
22 xnn_ukernel_type_add,
23 xnn_ukernel_type_argmax_pooling,
24 xnn_ukernel_type_average_pooling,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080025 xnn_ukernel_type_binary_elementwise,
XNNPACK Teamb455b122019-09-27 18:10:33 -070026 xnn_ukernel_type_channel_shuffle,
27 xnn_ukernel_type_clamp,
XNNPACK Teamb455b122019-09-27 18:10:33 -070028 xnn_ukernel_type_dconv2d_hwc2spchw,
29 xnn_ukernel_type_dwconv,
30 xnn_ukernel_type_gemm,
31 xnn_ukernel_type_global_average_pooling,
32 xnn_ukernel_type_hswish,
Marat Dukhan346a9e52019-11-15 09:06:30 -080033 xnn_ukernel_type_igemm,
XNNPACK Teamb455b122019-09-27 18:10:33 -070034 xnn_ukernel_type_lut,
35 xnn_ukernel_type_max_pooling,
36 xnn_ukernel_type_pad,
37 xnn_ukernel_type_pixelwise_average_pooling,
38 xnn_ukernel_type_prelu,
Marat Dukhan346a9e52019-11-15 09:06:30 -080039 xnn_ukernel_type_sigmoid,
Marat Dukhanfd8e6892020-01-27 15:25:25 -080040 xnn_ukernel_type_softmax,
XNNPACK Teamb455b122019-09-27 18:10:33 -070041 xnn_ukernel_type_spmm,
42 xnn_ukernel_type_subconv2d,
43 xnn_ukernel_type_unpooling,
44 xnn_ukernel_type_vmulcaddc,
45};
46
47enum xnn_operator_type {
48 xnn_operator_type_none = 0,
Marat Dukhanefc47b82019-11-18 09:25:38 -080049 xnn_operator_type_add_nc_f32,
Marat Dukhanb1a0fc32019-12-02 19:32:02 -080050 xnn_operator_type_add_nd_f32,
Marat Dukhanefc47b82019-11-18 09:25:38 -080051 xnn_operator_type_add_nc_q8,
52 xnn_operator_type_argmax_pooling_nhwc_f32,
53 xnn_operator_type_average_pooling_nhwc_f32,
54 xnn_operator_type_average_pooling_nhwc_q8,
55 xnn_operator_type_channel_pad_nc_x32,
56 xnn_operator_type_channel_shuffle_nc_x32,
57 xnn_operator_type_channel_shuffle_nc_x8,
58 xnn_operator_type_clamp_nc_f32,
59 xnn_operator_type_clamp_nc_u8,
60 xnn_operator_type_convolution_nhwc_f32,
61 xnn_operator_type_convolution_nhwc_q8,
62 xnn_operator_type_convolution_nchw_f32,
63 xnn_operator_type_deconvolution_nhwc_f32,
64 xnn_operator_type_deconvolution_nhwc_q8,
Marat Dukhan69180502019-12-06 15:00:31 -080065 xnn_operator_type_divide_nd_f32,
Marat Dukhanefc47b82019-11-18 09:25:38 -080066 xnn_operator_type_fully_connected_nc_f32,
67 xnn_operator_type_fully_connected_nc_q8,
68 xnn_operator_type_global_average_pooling_nwc_f32,
69 xnn_operator_type_global_average_pooling_nwc_q8,
70 xnn_operator_type_global_average_pooling_ncw_f32,
71 xnn_operator_type_hardswish_nc_f32,
72 xnn_operator_type_leaky_relu_nc_q8,
73 xnn_operator_type_max_pooling_nhwc_f32,
74 xnn_operator_type_max_pooling_nhwc_u8,
Marat Dukhan79e7f842019-12-05 14:35:50 -080075 xnn_operator_type_maximum_nd_f32,
76 xnn_operator_type_minimum_nd_f32,
Marat Dukhanefc47b82019-11-18 09:25:38 -080077 xnn_operator_type_multiply_nd_f32,
78 xnn_operator_type_prelu_nc_f32,
79 xnn_operator_type_resize_bilinear_nhwc_f32,
80 xnn_operator_type_sigmoid_nc_f32,
81 xnn_operator_type_sigmoid_nc_q8,
Marat Dukhanfd8e6892020-01-27 15:25:25 -080082 xnn_operator_type_softmax_nc_f32,
83 xnn_operator_type_softmax_nc_q8,
Marat Dukhan05f3f6d2019-12-03 15:13:53 -080084 xnn_operator_type_subtract_nd_f32,
Marat Dukhanefc47b82019-11-18 09:25:38 -080085 xnn_operator_type_unpooling_nhwc_x32,
XNNPACK Teamb455b122019-09-27 18:10:33 -070086};
87
88struct xnn_ukernel_dconv2d {
89 union {
90 xnn_conv_hwc2spchw_ukernel_function hwc2spchw_function;
91 xnn_conv_hwc_ukernel_function hwc_function;
92 };
93 uint8_t output_height_tile;
94 uint8_t output_channel_tile;
95};
96
97struct xnn_ukernel_dwconv {
98 union {
99 xnn_dwconv_up_ukernel_function unipass_function;
100 xnn_dwconv_mp_ukernel_function multipass_function;
101 };
102 uint8_t mr;
103 uint8_t qr;
104};
105
106// Direct 2D Depthwise Convolution
107struct xnn_ukernel_dwconv2d {
108 union {
109 xnn_dwconv_spchw_ukernel_function spchw_function;
110 };
111 uint8_t input_width_tile;
112 uint8_t output_width_tile;
113};
114
115struct xnn_ukernel_gemm {
Marat Dukhan05702cf2020-03-26 15:41:33 -0700116 struct xnn_hmp_gemm_ukernel general_case;
117 struct xnn_hmp_gemm_ukernel mr1_case;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700118 uint8_t mr;
119 uint8_t nr;
120 uint8_t kr;
121};
122
123struct xnn_ukernel_igemm {
Marat Dukhan05702cf2020-03-26 15:41:33 -0700124 struct xnn_hmp_igemm_ukernel general_case;
125 struct xnn_hmp_igemm_ukernel mr1_case;
126 struct xnn_hmp_gemm_ukernel gemm_case;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700127 uint8_t mr;
128 uint8_t nr;
129 uint8_t kr;
130};
131
132struct xnn_ukernel_spmm {
133 xnn_spmm_ukernel_function function;
134 uint8_t mr;
135};
136
137struct xnn_ukernel_vmulcaddc {
138 xnn_vmulcaddc_ukernel_function function;
139 uint8_t mr;
140};
141
142struct xnn_ukernel {
143 enum xnn_ukernel_type type;
144 union {
145 struct xnn_ukernel_dconv2d dconv2d;
146 struct xnn_ukernel_dwconv dwconv;
147 struct xnn_ukernel_dwconv2d dwconv2d;
148 struct xnn_ukernel_gemm gemm;
149 struct xnn_ukernel_igemm igemm;
150 struct xnn_ukernel_spmm spmm;
151 struct xnn_ukernel_vmulcaddc vmulcaddc;
152 };
153};
154
155enum xnn_run_state {
156 xnn_run_state_invalid = 0,
157 xnn_run_state_ready,
158 xnn_run_state_skip,
159};
160
161struct subconvolution_params {
162 void* weights;
163 size_t w_stride;
164 const void** indirection_buffer;
165 void* output;
166 size_t slice_width;
167 size_t slice_height;
168 size_t indirection_y_stride;
169 size_t indirection_x_stride;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700170 // scaled_kernel_size := kernel_size * mr * sizeof(void*).
XNNPACK Teamb455b122019-09-27 18:10:33 -0700171 size_t scaled_kernel_size;
172};
173
174struct xnn_operator {
175 size_t batch_size;
176 uint32_t padding_top;
177 uint32_t padding_right;
178 uint32_t padding_bottom;
179 uint32_t padding_left;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700180 uint32_t kernel_height;
181 uint32_t kernel_width;
182 uint32_t stride_height;
183 uint32_t stride_width;
184 uint32_t dilation_height;
185 uint32_t dilation_width;
186 uint32_t groups;
187 size_t group_channels;
188 size_t group_input_channels;
189 size_t group_output_channels;
190 size_t channels;
191
192 size_t pad_before_channels;
193 size_t pad_after_channels;
194 uint32_t pad_value;
195
196 size_t input_height;
197 size_t input_width;
198 size_t input_pixel_stride;
199 const void* input;
200 const void** indirection_buffer;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700201
202 size_t input2_pixel_stride;
203 const void* input2;
204
205 size_t output_height;
206 size_t output_width;
207 size_t output_pixel_stride;
208 void* output;
209
210 void* packed_weights;
211 // Total number of non-zero kernel elements when weights use sparse representation.
212 size_t num_nonzero_values;
213 // Total number of non-zero kernel blocks when weights use sparse representation.
214 size_t num_nonzero_blocks;
215 // Total number of output channel blocks when weights use sparse representation.
216 size_t num_output_channel_blocks;
217 // Input channel corresponding to the first non-zero kernel element.
218 size_t first_input_channel;
219
220 float input_scale;
221 float output_scale;
222 uint8_t input_zero_point;
223 uint8_t kernel_zero_point;
224 uint8_t output_zero_point;
225 uint8_t output_min;
226 uint8_t output_max;
227
228 size_t valid_batch_size;
229 size_t last_input_height;
230 size_t last_input_width;
231 const void* last_input;
Marat Dukhan69722492019-11-11 19:55:50 -0800232 size_t last_output_height;
233 size_t last_output_width;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700234 void* last_output;
235
236 void* zero_buffer;
237 void* lookup_table;
238 void* pixelwise_buffer;
239 struct subconvolution_params* subconvolution_buffer;
Marat Dukhan8440fde2019-10-24 12:46:13 -0700240 uint32_t flags;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700241
242 union {
Marat Dukhan5868d802020-03-19 17:18:45 -0700243 // Parameters for Global Average Pooling in CHW layout
XNNPACK Teamb455b122019-09-27 18:10:33 -0700244 union xnn_f32_gavgpool_params f32_gavgpool_params;
245 union xnn_f32_hswish_params f32_hswish_params;
Marat Dukhan5868d802020-03-19 17:18:45 -0700246 // Pixelwise Average Pooling normally use f32_output_params, but also initialize f32_avgpool_params in case it needs
247 // to switch to Global Average Pooling operation.
248 struct {
249 union xnn_f32_avgpool_params f32_avgpool_params;
250 union xnn_f32_output_params f32_output_params;
251 };
XNNPACK Teamb455b122019-09-27 18:10:33 -0700252 union xnn_f32_spchw_params f32_spchw_params;
253 union xnn_q8_add_params q8_add_params;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700254 union xnn_q8_gemm_params q8_gemm_params;
Marat Dukhan5868d802020-03-19 17:18:45 -0700255 // Average Pooling normally use q8_avgpool_params, but also initialize q8_gavgpool_params in case it needs to switch
256 // to Global Average Pooling operation.
257 struct {
258 union xnn_q8_avgpool_params q8_avgpool_params;
259 union xnn_q8_avgpool_params q8_gavgpool_params;
260 };
XNNPACK Teamb455b122019-09-27 18:10:33 -0700261 union xnn_u8_output_params u8_output_params;
262 };
263 enum xnn_operator_type type;
264 struct xnn_ukernel ukernel;
265
266 struct compute_parameters compute;
267 struct compute_parameters compute2;
268 union {
269 struct add_contiguous_context add_contiguous;
270 struct add_strided_context add_strided;
271 struct argmax_pooling_context argmax_pooling;
272 struct average_pooling_context average_pooling;
273 struct channel_pad_context channel_pad;
274 struct channel_shuffle_context channel_shuffle;
275 struct dconv2d_context dconv2d;
276 struct dwconv2d_context dwconv2d;
277 struct dwconv_context dwconv;
Marat Dukhanca2733c2019-11-15 23:21:17 -0800278 struct elementwise_binary_context elementwise_binary;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700279 struct gemm_context gemm;
Marat Dukhanefc47b82019-11-18 09:25:38 -0800280 struct global_average_pooling_nwc_context global_average_pooling_nwc;
281 struct global_average_pooling_ncw_context global_average_pooling_ncw;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700282 struct igemm_context igemm;
283 struct lut_contiguous_context lut_contiguous;
284 struct lut_strided_context lut_strided;
285 struct max_pooling_context max_pooling;
286 struct pixelwise_average_pooling_context pixelwise_average_pooling;
287 struct prelu_context prelu;
Marat Dukhan69722492019-11-11 19:55:50 -0800288 struct resize_bilinear_context resize_bilinear;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700289 struct spmm_context spmm;
290 struct subconv_context subconv;
Marat Dukhan29954272020-02-13 17:56:11 -0800291 struct subgemm_context subgemm;
Marat Dukhanfd8e6892020-01-27 15:25:25 -0800292 struct f32_three_pass_softmax_context f32_three_pass_softmax;
293 struct u8_softmax_context u8_softmax;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700294 struct univector_contiguous_context univector_contiguous;
295 struct univector_strided_context univector_strided;
296 struct unpooling_context unpooling;
297 struct vmulcaddc_context vmulcaddc;
298 } context;
299
300 enum xnn_run_state state;
301};