blob: 81f4d82b953cc08a3e58e06b42de7f000bbfe832 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <stddef.h>
12#include <stdint.h>
13
14#include <pthreadpool.h>
15
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070016#include <xnnpack/params.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070017#include <xnnpack/compute.h>
18
19
20enum xnn_ukernel_type {
21 xnn_ukernel_type_none = 0,
22 xnn_ukernel_type_add,
23 xnn_ukernel_type_argmax_pooling,
24 xnn_ukernel_type_average_pooling,
25 xnn_ukernel_type_channel_shuffle,
26 xnn_ukernel_type_clamp,
27 xnn_ukernel_type_igemm,
28 xnn_ukernel_type_dconv2d_hwc2spchw,
29 xnn_ukernel_type_dwconv,
30 xnn_ukernel_type_gemm,
31 xnn_ukernel_type_global_average_pooling,
32 xnn_ukernel_type_hswish,
33 xnn_ukernel_type_lut,
34 xnn_ukernel_type_max_pooling,
35 xnn_ukernel_type_pad,
36 xnn_ukernel_type_pixelwise_average_pooling,
37 xnn_ukernel_type_prelu,
38 xnn_ukernel_type_softargmax,
39 xnn_ukernel_type_spmm,
40 xnn_ukernel_type_subconv2d,
41 xnn_ukernel_type_unpooling,
42 xnn_ukernel_type_vmulcaddc,
43};
44
45enum xnn_operator_type {
46 xnn_operator_type_none = 0,
47 xnn_operator_type_add_f32,
48 xnn_operator_type_add_q8,
49 xnn_operator_type_argmax_pooling_f32,
50 xnn_operator_type_average_pooling_f32,
51 xnn_operator_type_average_pooling_q8,
52 xnn_operator_type_channel_pad_x32,
53 xnn_operator_type_channel_shuffle_x8,
54 xnn_operator_type_channel_shuffle_x32,
55 xnn_operator_type_clamp_f32,
56 xnn_operator_type_clamp_u8,
57 xnn_operator_type_convolution_f32,
58 xnn_operator_type_convolution_spnchw_f32,
59 xnn_operator_type_convolution_q8,
60 xnn_operator_type_deconvolution_f32,
61 xnn_operator_type_deconvolution_q8,
62 xnn_operator_type_fully_connected_f32,
63 xnn_operator_type_fully_connected_q8,
64 xnn_operator_type_global_average_pooling_f32,
65 xnn_operator_type_global_average_pooling_q8,
66 xnn_operator_type_global_average_pooling_spnchw_f32,
67 xnn_operator_type_hswish_f32,
68 xnn_operator_type_leaky_relu_q8,
69 xnn_operator_type_max_pooling_f32,
70 xnn_operator_type_max_pooling_u8,
71 xnn_operator_type_prelu_f32,
72 xnn_operator_type_sigmoid_q8,
73 xnn_operator_type_softargmax_q8,
74 xnn_operator_type_unpooling_x32,
75};
76
77struct xnn_ukernel_dconv2d {
78 union {
79 xnn_conv_hwc2spchw_ukernel_function hwc2spchw_function;
80 xnn_conv_hwc_ukernel_function hwc_function;
81 };
82 uint8_t output_height_tile;
83 uint8_t output_channel_tile;
84};
85
86struct xnn_ukernel_dwconv {
87 union {
88 xnn_dwconv_up_ukernel_function unipass_function;
89 xnn_dwconv_mp_ukernel_function multipass_function;
90 };
91 uint8_t mr;
92 uint8_t qr;
93};
94
95// Direct 2D Depthwise Convolution
96struct xnn_ukernel_dwconv2d {
97 union {
98 xnn_dwconv_spchw_ukernel_function spchw_function;
99 };
100 uint8_t input_width_tile;
101 uint8_t output_width_tile;
102};
103
104struct xnn_ukernel_gemm {
105 xnn_gemm_ukernel_function default_function;
106 xnn_gemm_ukernel_function mr1_function;
107 uint8_t mr;
108 uint8_t nr;
109 uint8_t kr;
110};
111
112struct xnn_ukernel_igemm {
113 xnn_igemm_ukernel_function default_function;
114 xnn_igemm_ukernel_function mr1_function;
115 uint8_t mr;
116 uint8_t nr;
117 uint8_t kr;
118};
119
120struct xnn_ukernel_spmm {
121 xnn_spmm_ukernel_function function;
122 uint8_t mr;
123};
124
125struct xnn_ukernel_vmulcaddc {
126 xnn_vmulcaddc_ukernel_function function;
127 uint8_t mr;
128};
129
130struct xnn_ukernel {
131 enum xnn_ukernel_type type;
132 union {
133 struct xnn_ukernel_dconv2d dconv2d;
134 struct xnn_ukernel_dwconv dwconv;
135 struct xnn_ukernel_dwconv2d dwconv2d;
136 struct xnn_ukernel_gemm gemm;
137 struct xnn_ukernel_igemm igemm;
138 struct xnn_ukernel_spmm spmm;
139 struct xnn_ukernel_vmulcaddc vmulcaddc;
140 };
141};
142
143enum xnn_run_state {
144 xnn_run_state_invalid = 0,
145 xnn_run_state_ready,
146 xnn_run_state_skip,
147};
148
149struct subconvolution_params {
150 void* weights;
151 size_t w_stride;
152 const void** indirection_buffer;
153 void* output;
154 size_t slice_width;
155 size_t slice_height;
156 size_t indirection_y_stride;
157 size_t indirection_x_stride;
Marat Dukhan80fc9322019-09-29 21:06:36 -0700158 // scaled_kernel_size := kernel_size * mr * sizeof(void*).
XNNPACK Teamb455b122019-09-27 18:10:33 -0700159 size_t scaled_kernel_size;
160};
161
162struct xnn_operator {
163 size_t batch_size;
164 uint32_t padding_top;
165 uint32_t padding_right;
166 uint32_t padding_bottom;
167 uint32_t padding_left;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700168 uint32_t kernel_height;
169 uint32_t kernel_width;
170 uint32_t stride_height;
171 uint32_t stride_width;
172 uint32_t dilation_height;
173 uint32_t dilation_width;
174 uint32_t groups;
175 size_t group_channels;
176 size_t group_input_channels;
177 size_t group_output_channels;
178 size_t channels;
179
180 size_t pad_before_channels;
181 size_t pad_after_channels;
182 uint32_t pad_value;
183
184 size_t input_height;
185 size_t input_width;
186 size_t input_pixel_stride;
187 const void* input;
188 const void** indirection_buffer;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700189
190 size_t input2_pixel_stride;
191 const void* input2;
192
193 size_t output_height;
194 size_t output_width;
195 size_t output_pixel_stride;
196 void* output;
197
198 void* packed_weights;
199 // Total number of non-zero kernel elements when weights use sparse representation.
200 size_t num_nonzero_values;
201 // Total number of non-zero kernel blocks when weights use sparse representation.
202 size_t num_nonzero_blocks;
203 // Total number of output channel blocks when weights use sparse representation.
204 size_t num_output_channel_blocks;
205 // Input channel corresponding to the first non-zero kernel element.
206 size_t first_input_channel;
207
208 float input_scale;
209 float output_scale;
210 uint8_t input_zero_point;
211 uint8_t kernel_zero_point;
212 uint8_t output_zero_point;
213 uint8_t output_min;
214 uint8_t output_max;
215
216 size_t valid_batch_size;
217 size_t last_input_height;
218 size_t last_input_width;
219 const void* last_input;
220 void* last_output;
221
222 void* zero_buffer;
223 void* lookup_table;
224 void* pixelwise_buffer;
225 struct subconvolution_params* subconvolution_buffer;
Marat Dukhan8440fde2019-10-24 12:46:13 -0700226 uint32_t flags;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700227
228 union {
229 union xnn_f32_avgpool_params f32_avgpool_params;
230 union xnn_f32_gavgpool_params f32_gavgpool_params;
231 union xnn_f32_hswish_params f32_hswish_params;
232 union xnn_f32_output_params f32_output_params;
233 union xnn_f32_spchw_params f32_spchw_params;
234 union xnn_q8_add_params q8_add_params;
235 union xnn_q8_avgpool_params q8_avgpool_params;
236 union xnn_q8_gemm_params q8_gemm_params;
237 union xnn_u8_output_params u8_output_params;
238 };
239 enum xnn_operator_type type;
240 struct xnn_ukernel ukernel;
241
242 struct compute_parameters compute;
243 struct compute_parameters compute2;
244 union {
245 struct add_contiguous_context add_contiguous;
246 struct add_strided_context add_strided;
247 struct argmax_pooling_context argmax_pooling;
248 struct average_pooling_context average_pooling;
249 struct channel_pad_context channel_pad;
250 struct channel_shuffle_context channel_shuffle;
251 struct dconv2d_context dconv2d;
252 struct dwconv2d_context dwconv2d;
253 struct dwconv_context dwconv;
254 struct gemm_context gemm;
255 struct global_average_pooling_context global_average_pooling;
256 struct global_average_pooling_spnchw_context global_average_pooling_spnchw;
257 struct igemm_context igemm;
258 struct lut_contiguous_context lut_contiguous;
259 struct lut_strided_context lut_strided;
260 struct max_pooling_context max_pooling;
261 struct pixelwise_average_pooling_context pixelwise_average_pooling;
262 struct prelu_context prelu;
263 struct spmm_context spmm;
264 struct subconv_context subconv;
265 struct u8_softargmax_context u8_softargmax;
266 struct univector_contiguous_context univector_contiguous;
267 struct univector_strided_context univector_strided;
268 struct unpooling_context unpooling;
269 struct vmulcaddc_context vmulcaddc;
270 } context;
271
272 enum xnn_run_state state;
273};