blob: 5f757ab2c4583fda918c68159e320965e99976a2 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#include <assert.h>
10#include <math.h>
11#include <stddef.h>
12#include <stdint.h>
13#include <stdlib.h>
14
15#include <xnnpack.h>
16#include <xnnpack/allocator.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070017#include <xnnpack/log.h>
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070018#include <xnnpack/operator.h>
19#include <xnnpack/params-init.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070020#include <xnnpack/params.h>
21
22
23enum xnn_status xnn_create_global_average_pooling_nwc_q8(
24 size_t channels,
25 size_t input_stride,
26 size_t output_stride,
27 uint8_t input_zero_point,
28 float input_scale,
29 uint8_t output_zero_point,
30 float output_scale,
31 uint8_t output_min,
32 uint8_t output_max,
33 uint32_t flags,
34 xnn_operator_t* global_average_pooling_op_out)
35{
36 xnn_operator_t global_average_pooling_op = NULL;
37 enum xnn_status status = xnn_status_uninitialized;
38
39 if (!xnn_params.initialized) {
40 xnn_log_error("failed to create Global Average Pooling operator: XNNPACK is not initialized");
41 goto error;
42 }
43
44 status = xnn_status_invalid_parameter;
45
46 if (channels == 0) {
47 xnn_log_error(
48 "failed to create Global Average Pooling operator with %zu channels: number of channels must be non-zero",
49 channels);
50 goto error;
51 }
52
53 if (input_stride < channels) {
54 xnn_log_error(
55 "failed to create Global Average Pooling operator with input element stride of %zu: "
56 "stride must be at least as large as the number of channels (%zu)",
57 input_stride, channels);
58 goto error;
59 }
60
61 if (output_stride < channels) {
62 xnn_log_error(
63 "failed to create Global Average Pooling operator with output element stride of %zu: "
64 "stride must be at least as large as the number of channels (%zu)",
65 output_stride, channels);
66 goto error;
67 }
68
69 if (input_scale <= 0.0f || !isnormal(input_scale)) {
70 xnn_log_error(
71 "failed to create Global Average Pooling operator with %.7g input scale: "
72 "scale must be finite, normalized, and positive",
73 input_scale);
74 goto error;
75 }
76
77 if (output_scale <= 0.0f || !isnormal(output_scale)) {
78 xnn_log_error(
79 "failed to create Global Average Pooling operator with %.7g output scale: "
80 "scale must be finite, normalized, and positive",
81 output_scale);
82 goto error;
83 }
84
85 if (output_min >= output_max) {
86 xnn_log_error(
87 "failed to create Global Average Pooling operator with [%" PRIu8 ", %" PRIu8 "] output range: "
88 "range min must be below range max",
89 output_min, output_max);
90 goto error;
91 }
92
93 status = xnn_status_unsupported_parameter;
94
95 const float input_output_scale = input_scale / output_scale;
96 if (input_output_scale < 0x1.0p-8f || input_output_scale >= 0x1.0p+8f) {
97 xnn_log_error(
98 "failed to create Global Average Pooling operator with %.7g input-to-output scale ratio: "
99 "scale ratio must be in [2**-8, 2**8) range",
100 input_output_scale);
101 goto error;
102 }
103
104 status = xnn_status_out_of_memory;
105
Marat Dukhan04f03be2019-11-19 12:36:47 -0800106 global_average_pooling_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700107 if (global_average_pooling_op == NULL) {
108 xnn_log_error("failed to allocate %zu bytes for Global Average Pooling operator descriptor", sizeof(struct xnn_operator));
109 goto error;
110 }
111
Marat Dukhan04f03be2019-11-19 12:36:47 -0800112 void* zero_buffer = xnn_allocate_zero_simd_memory(channels * sizeof(uint8_t) + XNN_EXTRA_BYTES);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700113 if (zero_buffer == NULL) {
114 xnn_log_error("failed to allocate %zu bytes for Global Average Pooling zero padding",
115 channels * sizeof(uint8_t) + XNN_EXTRA_BYTES);
116 goto error;
117 }
118 global_average_pooling_op->zero_buffer = zero_buffer;
119
120 global_average_pooling_op->channels = channels;
121 global_average_pooling_op->input_pixel_stride = input_stride;
122 global_average_pooling_op->output_pixel_stride = output_stride;
123 global_average_pooling_op->input_zero_point = input_zero_point;
124 global_average_pooling_op->output_zero_point = output_zero_point;
125 global_average_pooling_op->input_scale = input_scale;
126 global_average_pooling_op->output_scale = output_scale;
127 global_average_pooling_op->output_min = output_min;
128 global_average_pooling_op->output_max = output_max;
129
Marat Dukhanefc47b82019-11-18 09:25:38 -0800130 global_average_pooling_op->type = xnn_operator_type_global_average_pooling_nwc_q8;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700131 global_average_pooling_op->ukernel.type = xnn_ukernel_type_global_average_pooling;
132
133 global_average_pooling_op->state = xnn_run_state_invalid;
134
135 *global_average_pooling_op_out = global_average_pooling_op;
136 return xnn_status_success;
137
138error:
139 xnn_delete_operator(global_average_pooling_op);
140 return status;
141}
142
143enum xnn_status xnn_create_global_average_pooling_nwc_f32(
144 size_t channels,
145 size_t input_stride,
146 size_t output_stride,
147 float output_min,
148 float output_max,
149 uint32_t flags,
150 xnn_operator_t* global_average_pooling_op_out)
151{
152 xnn_operator_t global_average_pooling_op = NULL;
153 enum xnn_status status = xnn_status_uninitialized;
154
155 if (!xnn_params.initialized) {
156 xnn_log_error("failed to create Global Average Pooling operator: XNNPACK is not initialized");
157 goto error;
158 }
159
160 status = xnn_status_invalid_parameter;
161
162 if (channels == 0) {
163 xnn_log_error(
164 "failed to create Global Average Pooling operator with %zu channels: number of channels must be non-zero",
165 channels);
166 goto error;
167 }
168
169 if (input_stride < channels) {
170 xnn_log_error(
171 "failed to create Global Average Pooling operator with input element stride of %zu: "
172 "stride must be at least as large as the number of channels (%zu)",
173 input_stride, channels);
174 goto error;
175 }
176
177 if (output_stride < channels) {
178 xnn_log_error(
179 "failed to create Global Average Pooling operator with output element stride of %zu: "
180 "stride must be at least as large as the number of channels (%zu)",
181 output_stride, channels);
182 goto error;
183 }
184
185 if (isnan(output_min)) {
186 xnn_log_error(
187 "failed to create Global Average Pooling operator with NaN output lower bound: lower bound must be non-NaN");
188 goto error;
189 }
190
191 if (isnan(output_max)) {
192 xnn_log_error(
193 "failed to create Global Average Pooling operator with NaN output upper bound: upper bound must be non-NaN");
194 goto error;
195 }
196
197 if (output_min >= output_max) {
198 xnn_log_error(
199 "failed to create Global Average Pooling operator with [%.7g, %.7g] output range: "
200 "lower bound must be below upper bound",
201 output_min, output_max);
202 goto error;
203 }
204
205 status = xnn_status_out_of_memory;
206
Marat Dukhan04f03be2019-11-19 12:36:47 -0800207 global_average_pooling_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700208 if (global_average_pooling_op == NULL) {
209 xnn_log_error("failed to allocate %zu bytes for Global Average Pooling operator descriptor", sizeof(struct xnn_operator));
210 goto error;
211 }
212
Marat Dukhan04f03be2019-11-19 12:36:47 -0800213 void* zero_buffer = xnn_allocate_zero_simd_memory(channels * sizeof(float) + XNN_EXTRA_BYTES);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700214 if (zero_buffer == NULL) {
215 xnn_log_error("failed to allocate %zu bytes for Global Average Pooling zero padding",
216 channels * sizeof(float) + XNN_EXTRA_BYTES);
217 goto error;
218 }
219 global_average_pooling_op->zero_buffer = zero_buffer;
220
221 global_average_pooling_op->channels = channels;
222 global_average_pooling_op->input_pixel_stride = input_stride;
223 global_average_pooling_op->output_pixel_stride = output_stride;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700224 global_average_pooling_op->f32_avgpool_params = xnn_init_f32_avgpool_params(nanf(""), output_min, output_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700225
Marat Dukhanefc47b82019-11-18 09:25:38 -0800226 global_average_pooling_op->type = xnn_operator_type_global_average_pooling_nwc_f32;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700227 global_average_pooling_op->ukernel.type = xnn_ukernel_type_global_average_pooling;
228
229 global_average_pooling_op->state = xnn_run_state_invalid;
230
231 *global_average_pooling_op_out = global_average_pooling_op;
232 return xnn_status_success;
233
234error:
235 xnn_delete_operator(global_average_pooling_op);
236 return status;
237}
238
239enum xnn_status xnn_setup_global_average_pooling_nwc_q8(
240 xnn_operator_t global_average_pooling_op,
241 size_t batch_size,
242 size_t width,
243 const uint8_t* input,
244 uint8_t* output,
245 pthreadpool_t threadpool)
246{
Marat Dukhanefc47b82019-11-18 09:25:38 -0800247 if (global_average_pooling_op->type != xnn_operator_type_global_average_pooling_nwc_q8) {
248 xnn_log_error("failed to setup Global Average Pooling (NWC, Q8) operator: operator type mismatch");
XNNPACK Teamb455b122019-09-27 18:10:33 -0700249 return xnn_status_invalid_parameter;
250 }
251 global_average_pooling_op->state = xnn_run_state_invalid;
252
253 if (!xnn_params.initialized) {
254 xnn_log_error("failed to setup Global Average Pooling operator: XNNPACK is not initialized");
255 return xnn_status_uninitialized;
256 }
257
258 if (width == 0) {
259 xnn_log_error("failed to setup Global Average Pooling operator with width %zu: width must be non-zero", width);
260 return xnn_status_invalid_parameter;
261 }
262
263 if (batch_size == 0) {
264 global_average_pooling_op->state = xnn_run_state_skip;
265 return xnn_status_success;
266 }
267
268 global_average_pooling_op->batch_size = batch_size;
269 global_average_pooling_op->input_width = width;
270 global_average_pooling_op->input = input;
271 global_average_pooling_op->output = output;
272
273 global_average_pooling_op->q8_avgpool_params =
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700274 xnn_init_q8_avgpool_params(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700275 -(int32_t) width * (int32_t) (uint32_t) global_average_pooling_op->input_zero_point,
276 global_average_pooling_op->input_scale / (global_average_pooling_op->output_scale * (float) width),
277 global_average_pooling_op->output_zero_point,
278 global_average_pooling_op->output_min,
279 global_average_pooling_op->output_max);
280
281 const size_t input_stride_in_bytes = global_average_pooling_op->input_pixel_stride * sizeof(uint8_t);
282 const size_t channels = global_average_pooling_op->channels;
Marat Dukhanefc47b82019-11-18 09:25:38 -0800283 global_average_pooling_op->context.global_average_pooling_nwc = (struct global_average_pooling_nwc_context) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700284 .input = input,
285 .zero = global_average_pooling_op->zero_buffer,
286 .input_pixel_stride = input_stride_in_bytes,
287 .input_batch_stride = input_stride_in_bytes * width,
288 .input_elements = width,
289 .channels = channels,
290 .output = output,
291 .output_batch_stride = global_average_pooling_op->output_pixel_stride * sizeof(uint8_t),
292 .params.q8 = global_average_pooling_op->q8_avgpool_params,
293 };
294 global_average_pooling_op->compute.type = xnn_parallelization_type_1d;
295 global_average_pooling_op->compute.range[0] = batch_size;
296
297 if (width <= xnn_params.q8.gavgpool.mr) {
Marat Dukhanefc47b82019-11-18 09:25:38 -0800298 global_average_pooling_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_global_average_pooling_nwc_unipass;
299 global_average_pooling_op->context.global_average_pooling_nwc.unipass_ukernel = xnn_params.q8.gavgpool.up;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700300 } else {
Marat Dukhanefc47b82019-11-18 09:25:38 -0800301 global_average_pooling_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_global_average_pooling_nwc_multipass;
302 global_average_pooling_op->context.global_average_pooling_nwc.multipass_ukernel = xnn_params.q8.gavgpool.mp;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700303 }
304 global_average_pooling_op->state = xnn_run_state_ready;
305
306 return xnn_status_success;
307}
308
309enum xnn_status xnn_setup_global_average_pooling_nwc_f32(
310 xnn_operator_t global_average_pooling_op,
311 size_t batch_size,
312 size_t width,
313 const float* input,
314 float* output,
315 pthreadpool_t threadpool)
316{
Marat Dukhanefc47b82019-11-18 09:25:38 -0800317 if (global_average_pooling_op->type != xnn_operator_type_global_average_pooling_nwc_f32) {
318 xnn_log_error("failed to setup Global Average Pooling (NWC, F32) operator: operator type mismatch");
XNNPACK Teamb455b122019-09-27 18:10:33 -0700319 return xnn_status_invalid_parameter;
320 }
321 global_average_pooling_op->state = xnn_run_state_invalid;
322
323 if (!xnn_params.initialized) {
324 xnn_log_error("failed to setup Global Average Pooling operator: XNNPACK is not initialized");
325 return xnn_status_uninitialized;
326 }
327
328 if (width == 0) {
329 xnn_log_error("failed to setup Global Average Pooling operator with width %zu: width must be non-zero", width);
330 return xnn_status_invalid_parameter;
331 }
332
333 if (batch_size == 0) {
334 global_average_pooling_op->state = xnn_run_state_skip;
335 return xnn_status_success;
336 }
337
338 global_average_pooling_op->batch_size = batch_size;
339 global_average_pooling_op->input_width = width;
340 global_average_pooling_op->input = input;
341 global_average_pooling_op->output = output;
342
343 xnn_update_f32_avgpool_params(&global_average_pooling_op->f32_avgpool_params, 1.0f / (float) width);
344
345 const size_t input_stride_in_bytes = global_average_pooling_op->input_pixel_stride * sizeof(float);
346 const size_t channels = global_average_pooling_op->channels;
Marat Dukhanefc47b82019-11-18 09:25:38 -0800347 global_average_pooling_op->context.global_average_pooling_nwc = (struct global_average_pooling_nwc_context) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700348 .input = input,
349 .zero = global_average_pooling_op->zero_buffer,
350 .input_pixel_stride = input_stride_in_bytes,
351 .input_batch_stride = input_stride_in_bytes * width,
352 .input_elements = width,
353 .channels = channels,
354 .output = output,
355 .output_batch_stride = global_average_pooling_op->output_pixel_stride * sizeof(float),
356 .params.f32 = global_average_pooling_op->f32_avgpool_params,
357 };
358 global_average_pooling_op->compute.type = xnn_parallelization_type_1d;
359 global_average_pooling_op->compute.range[0] = batch_size;
360
361 if (width <= xnn_params.f32.gavgpool.mr) {
Marat Dukhanefc47b82019-11-18 09:25:38 -0800362 global_average_pooling_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_global_average_pooling_nwc_unipass;
363 global_average_pooling_op->context.global_average_pooling_nwc.unipass_ukernel = xnn_params.f32.gavgpool.up;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700364 } else {
Marat Dukhanefc47b82019-11-18 09:25:38 -0800365 global_average_pooling_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_global_average_pooling_nwc_multipass;
366 global_average_pooling_op->context.global_average_pooling_nwc.multipass_ukernel = xnn_params.f32.gavgpool.mp;
XNNPACK Teamb455b122019-09-27 18:10:33 -0700367 }
368 global_average_pooling_op->state = xnn_run_state_ready;
369
370 return xnn_status_success;
371}