blob: 4a5410a13f3d9b496feda14d0c4487cd41e3fc23 [file] [log] [blame]
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#if defined(__cplusplus) && (__cplusplus >= 201103L)
12 #include <cstdint>
13 #include <cstddef>
14 #include <cassert>
15 #include <cmath>
16#else
17 #include <stdint.h>
18 #include <stddef.h>
19 #include <assert.h>
20 #include <math.h>
21#endif
22
23#include <fp16.h>
24
25#include <xnnpack/common.h>
Marat Dukhan2b9efd82020-06-08 01:09:31 -070026#include <xnnpack/math.h>
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070027#include <xnnpack/params.h>
28
29
30static inline union xnn_q8_gemm_params xnn_init_scalar_q8_gemm_params(
31 uint8_t input_zero_point,
32 uint8_t kernel_zero_point,
33 float scale,
34 uint8_t output_zero_point,
35 uint8_t output_min,
36 uint8_t output_max)
37{
38 // Compute requantization parameters
39 const uint32_t scale_bits = fp32_to_bits(scale);
40
41 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
42 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
43 assert(multiplier >= INT32_C(0x40000000));
44 assert(multiplier <= INT32_C(0x7FFFFF80));
45
46 // Shift is in [0, 31] range.
47 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
48 assert(shift >= 0);
49 assert(shift < 32);
50
51 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
52 const uint32_t remainder_threshold = remainder_mask >> 1;
53
54 union xnn_q8_gemm_params params;
55 params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
56 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
57 params.scalar.multiplier = multiplier;
58 params.scalar.remainder_mask = (int32_t) remainder_mask;
59 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
60 params.scalar.shift = (uint32_t) shift;
61 params.scalar.output_min_less_zero_point =
62 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
63 params.scalar.output_max_less_zero_point =
64 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
65 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
66 return params;
67}
68
69static inline union xnn_q8_gemm_params xnn_init_q8_gemm_params(
70 uint8_t input_zero_point,
71 uint8_t kernel_zero_point,
72 float scale,
73 uint8_t output_zero_point,
74 uint8_t output_min,
75 uint8_t output_max)
76{
77 // Compute requantization parameters.
78 const uint32_t scale_bits = fp32_to_bits(scale);
79
80 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
81 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
82 assert(multiplier >= INT32_C(0x40000000));
83 assert(multiplier <= INT32_C(0x7FFFFF80));
84
85 // Shift is in [0, 31] range.
86 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
87 assert(shift >= 0);
88 assert(shift < 32);
89
90 union xnn_q8_gemm_params params;
91 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
92 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
93 const uint32_t remainder_threshold = remainder_mask >> 1;
94 for (uint32_t i = 0; i < 8; i++) {
95 params.sse2.input_zero_point[i] = (int16_t) (uint16_t) input_zero_point;
96 params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
97 }
98 params.sse2.multiplier[0] = multiplier;
99 params.sse2.multiplier[1] = multiplier;
100 params.sse2.multiplier[2] = multiplier;
101 params.sse2.multiplier[3] = multiplier;
102 params.sse2.rounding[0] = UINT64_C(0x40000000);
103 params.sse2.rounding[1] = UINT64_C(0x40000000);
104 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
105 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
106 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
107 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
108 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
109 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
110 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
111 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
112 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
113 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
114 for (uint32_t i = 0; i < 8; i++) {
115 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
116 }
117 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700118 params.sse2.output_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700119 params.sse2.output_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700120 }
121 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
122 params.neon.input_zero_point = (int16_t) (uint16_t) input_zero_point;
123 params.neon.kernel_zero_point = (int16_t) (uint16_t) kernel_zero_point;
124 params.neon.multiplier = multiplier;
125 params.neon.right_shift = -shift;
126 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700127 params.neon.output_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700128 params.neon.output_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700129 #else
130 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
131 const uint32_t remainder_threshold = remainder_mask >> 1;
132 params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
133 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
134 params.scalar.multiplier = multiplier;
135 params.scalar.remainder_mask = (int32_t) remainder_mask;
136 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
137 params.scalar.shift = (uint32_t) shift;
138 params.scalar.output_min_less_zero_point =
139 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
140 params.scalar.output_max_less_zero_point =
141 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
142 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
143 #endif
144 return params;
145}
146
147static inline union xnn_q8_avgpool_params xnn_init_q8_avgpool_params(
148 int32_t bias,
149 float scale,
150 uint8_t output_zero_point,
151 uint8_t output_min,
152 uint8_t output_max)
153{
154 // Compute requantization parameters.
155 assert(scale >= 0x1.0p-32f);
156 assert(scale < 256.0f);
157 const uint32_t scale_bits = fp32_to_bits(scale);
158
159 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
160 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
161 assert(multiplier >= INT32_C(0x00800000));
162 assert(multiplier <= INT32_C(0x00FFFFFF));
163
164 // Shift is in [16, 55] range.
165 const int32_t shift = 127 + 23 - (scale_bits >> 23);
166 assert(shift >= 16);
167 assert(shift < 64);
168
169 union xnn_q8_avgpool_params params;
170 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
171 const uint32_t right_shift = (uint32_t) shift;
172 const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
173 params.sse2.bias[0] = bias;
174 params.sse2.bias[1] = bias;
175 params.sse2.bias[2] = bias;
176 params.sse2.bias[3] = bias;
177 params.sse2.multiplier[0] = (uint32_t) multiplier;
178 params.sse2.multiplier[1] = (uint32_t) multiplier;
179 params.sse2.multiplier[2] = (uint32_t) multiplier;
180 params.sse2.multiplier[3] = (uint32_t) multiplier;
181 params.sse2.rounding[0] = rounding;
182 params.sse2.rounding[1] = rounding;
183 params.sse2.right_shift[0] = (uint64_t) right_shift;
184 params.sse2.right_shift[1] = (uint64_t) right_shift;
185 for (uint32_t i = 0; i < 8; i++) {
186 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
187 }
188 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700189 params.sse2.output_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700190 params.sse2.output_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700191 }
192 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
193 params.neon.bias = bias;
194 params.neon.multiplier = multiplier;
195 params.neon.left_shift = (int64_t) -shift;
196 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700197 params.neon.output_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700198 params.neon.output_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700199 #else
200 const uint32_t right_shift = (uint32_t) shift;
201 const int64_t rounding = INT64_C(1) << (right_shift - 1);
202 params.scalar.bias = bias;
203 params.scalar.multiplier = multiplier;
204 params.scalar.rounding = rounding;
205 params.scalar.right_shift = right_shift;
206 params.scalar.output_min_less_zero_point =
207 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
208 params.scalar.output_max_less_zero_point =
209 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
210 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
211 #endif
212 return params;
213}
214
215static inline union xnn_q8_avgpool_params xnn_init_scalar_q8_avgpool_params(
216 int32_t bias,
217 float scale,
218 uint8_t output_zero_point,
219 uint8_t output_min,
220 uint8_t output_max)
221{
222 // Compute requantization parameters.
223 assert(scale >= 0x1.0p-32f);
224 assert(scale < 256.0f);
225 const uint32_t scale_bits = fp32_to_bits(scale);
226
227 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
228 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
229 assert(multiplier >= INT32_C(0x00800000));
230 assert(multiplier <= INT32_C(0x00FFFFFF));
231
232 // Shift is in [16, 55] range.
233 const int32_t shift = 127 + 23 - (scale_bits >> 23);
234 assert(shift >= 16);
235 assert(shift < 64);
236
237 union xnn_q8_avgpool_params params;
238 const uint32_t right_shift = (uint32_t) shift;
239 const int64_t rounding = INT64_C(1) << (right_shift - 1);
240 params.scalar.bias = bias;
241 params.scalar.rounding = rounding;
242 params.scalar.multiplier = multiplier;
243 params.scalar.right_shift = right_shift;
244 params.scalar.output_min_less_zero_point =
245 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
246 params.scalar.output_max_less_zero_point =
247 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
248 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
249 return params;
250}
251
Marat Dukhan8452ff52020-04-08 20:44:58 -0700252static inline void xnn_update_f32_scaleminmax_params(
253 union xnn_f32_scaleminmax_params* params,
254 float scale)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700255{
256 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
257 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhan8452ff52020-04-08 20:44:58 -0700258 params->sse2.scale[i] = scale;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700259 }
260 #else
Marat Dukhan8452ff52020-04-08 20:44:58 -0700261 params->scalar.scale = scale;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700262 #endif
263}
264
Marat Dukhan8452ff52020-04-08 20:44:58 -0700265static inline union xnn_f32_scaleminmax_params xnn_init_f32_scaleminmax_params(
266 float scale,
267 float min,
268 float max)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700269{
Marat Dukhan8452ff52020-04-08 20:44:58 -0700270 union xnn_f32_scaleminmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700271 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
272 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhan8452ff52020-04-08 20:44:58 -0700273 params.sse2.scale[i] = scale;
274 params.sse2.min[i] = min;
275 params.sse2.max[i] = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700276 }
277 #else
Marat Dukhan8452ff52020-04-08 20:44:58 -0700278 params.scalar.scale = scale;
279 params.scalar.min = min;
280 params.scalar.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700281 #endif
282 return params;
283}
284
285static inline union xnn_f32_gavgpool_params xnn_init_f32_gavgpool_params(
286 float multiplier,
287 float output_min,
288 float output_max,
289 uint32_t width)
290{
291 union xnn_f32_gavgpool_params params;
292 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
293 for (uint32_t i = 0; i < 4; i++) {
294 params.sse.multiplier[i] = multiplier;
295 params.sse.output_min[i] = output_min;
296 params.sse.output_max[i] = output_max;
297 }
298
299 const uint32_t w = (width - 1) & 3;
300 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
301 params.sse.mask[1] = -(uint32_t) (w >= 1);
302 params.sse.mask[2] = -(uint32_t) (w >= 2);
303 params.sse.mask[3] = -(uint32_t) (w >= 3);
304 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
305 params.neon.multiplier = multiplier;
306 params.neon.output_min = output_min;
307 params.neon.output_max = output_max;
308
309 const uint32_t w = (width - 1) & 3;
310 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
311 params.neon.mask[1] = -(uint32_t) (w >= 1);
312 params.neon.mask[2] = -(uint32_t) (w >= 2);
313 params.neon.mask[3] = -(uint32_t) (w >= 3);
314 #else
315 params.scalar.multiplier = multiplier;
316 params.scalar.output_min = output_min;
317 params.scalar.output_max = output_max;
318 #endif
319 return params;
320}
321
322static inline void xnn_update_f32_gavgpool_params(
323 union xnn_f32_gavgpool_params* params,
324 float multiplier,
325 uint32_t width)
326{
327 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
328 for (uint32_t i = 0; i < 4; i++) {
329 params->sse.multiplier[i] = multiplier;
330 }
331
332 const uint32_t w = (width - 1) & 3;
333 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
334 params->sse.mask[1] = -(uint32_t) (w >= 1);
335 params->sse.mask[2] = -(uint32_t) (w >= 2);
336 params->sse.mask[3] = -(uint32_t) (w >= 3);
337 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
338 params->neon.multiplier = multiplier;
339
340 const uint32_t w = (width - 1) & 3;
341 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
342 params->neon.mask[1] = -(uint32_t) (w >= 1);
343 params->neon.mask[2] = -(uint32_t) (w >= 2);
344 params->neon.mask[3] = -(uint32_t) (w >= 3);
345 #else
346 params->scalar.multiplier = multiplier;
347 #endif
348}
349
Marat Dukhan8452ff52020-04-08 20:44:58 -0700350static inline union xnn_f32_scaleminmax_params xnn_init_scalar_f32_scaleminmax_params(
351 float scale,
352 float min,
353 float max)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700354{
Marat Dukhan8452ff52020-04-08 20:44:58 -0700355 union xnn_f32_scaleminmax_params params;
356 params.scalar.scale = scale;
357 params.scalar.min = min;
358 params.scalar.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700359 return params;
360}
361
362static inline union xnn_f32_gavgpool_params xnn_init_scalar_f32_gavgpool_params(
363 float multiplier,
364 float output_min,
365 float output_max,
366 uint32_t width)
367{
368 union xnn_f32_gavgpool_params params;
369 params.scalar.multiplier = multiplier;
370 params.scalar.output_min = output_min;
371 params.scalar.output_max = output_max;
372 return params;
373}
374
Frank Barchard99003a82020-05-04 10:39:38 -0700375static inline struct xnn_f16_scaleminmax_params xnn_init_f16_scaleminmax_params(
376 uint16_t scale,
377 uint16_t min,
378 uint16_t max)
379{
380 struct xnn_f16_scaleminmax_params params;
381 params.scale = scale;
382 params.min = min;
383 params.max = max;
384 return params;
385}
386
Frank Barchardd793f6c2020-05-08 13:37:43 -0700387static inline struct xnn_f16_minmax_params xnn_init_f16_minmax_params(
388 uint16_t min,
389 uint16_t max)
390{
391 struct xnn_f16_minmax_params params;
392 params.min = min;
393 params.max = max;
394 return params;
395}
396
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700397static inline union xnn_f32_minmax_params xnn_init_f32_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700398 float output_min,
399 float output_max)
400{
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700401 union xnn_f32_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700402 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
403 for (uint32_t i = 0; i < 4; i++) {
404 params.sse.min[i] = output_min;
405 params.sse.max[i] = output_max;
406 }
407 #else
408 params.scalar.min = output_min;
409 params.scalar.max = output_max;
410 #endif
411 return params;
412}
413
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700414static inline union xnn_f32_minmax_params xnn_init_scalar_f32_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700415 float output_min,
416 float output_max)
417{
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700418 union xnn_f32_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700419 params.scalar.min = output_min;
420 params.scalar.max = output_max;
421 return params;
422}
423
Frank Barchardb1966592020-05-12 13:47:06 -0700424static inline struct xnn_f16_hswish_params xnn_init_f16_hswish_params(void)
425{
426 struct xnn_f16_hswish_params params;
427 params.sixth = fp16_ieee_from_fp32_value(0x1.555556p-3f);
428 params.half = fp16_ieee_from_fp32_value(0.5f);
429 params.one = fp16_ieee_from_fp32_value(1.0f);
430 return params;
431}
432
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700433static inline union xnn_f32_hswish_params xnn_init_f32_hswish_params(void)
434{
435 union xnn_f32_hswish_params params;
436 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
437 for (uint32_t i = 0; i < 4; i++) {
438 params.sse.sixth[i] = 0x1.555556p-3f;
439 params.sse.half[i] = 0.5f;
440 params.sse.one[i] = 1.0f;
441 }
442 #else
443 params.scalar.sixth = 0x1.555556p-3f;
444 params.scalar.half = 0.5f;
445 params.scalar.one = 1.0f;
446 #endif
447 return params;
448}
449
450static inline union xnn_f32_hswish_params xnn_init_scalar_f32_hswish_params(void)
451{
452 union xnn_f32_hswish_params params;
453 params.scalar.sixth = 0x1.555556p-3f;
454 params.scalar.half = 0.5f;
455 params.scalar.one = 1.0f;
456 return params;
457}
458
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700459static inline union xnn_f32_abs_params xnn_init_f32_abs_params(void)
460{
461 union xnn_f32_abs_params params = { 0 };
462 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
463 for (uint32_t i = 0; i < 4; i++) {
464 params.sse.nonsign_mask[i] = math_nonsign_mask_f32();
465 }
466 #endif
467 return params;
468}
469
470static inline union xnn_f32_abs_params xnn_init_scalar_f32_abs_params(void)
471{
472 union xnn_f32_abs_params params = { 0 };
473 return params;
474}
475
476static inline union xnn_f32_neg_params xnn_init_f32_neg_params(void)
477{
478 union xnn_f32_neg_params params = { 0 };
479 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
480 for (uint32_t i = 0; i < 4; i++) {
481 params.sse.sign_mask[i] = -0.0f;
482 }
483 #endif
484 return params;
485}
486
487static inline union xnn_f32_neg_params xnn_init_scalar_f32_neg_params(void)
488{
489 union xnn_f32_neg_params params = { 0 };
490 return params;
491}
492
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700493static inline union xnn_f32_rnd_params xnn_init_f32_rnd_params(void)
494{
495 union xnn_f32_rnd_params params = { 0 };
496 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
497 for (uint32_t i = 0; i < 4; i++) {
498 params.sse2.sign_mask[i] = -0.0f;
499 }
500 for (uint32_t i = 0; i < 4; i++) {
501 params.sse2.one[i] = 1.0f;
502 }
503 #endif
504 return params;
505}
506
507static inline union xnn_f32_rnd_params xnn_init_scalar_f32_rnd_params(void)
508{
509 union xnn_f32_rnd_params params = { 0 };
510 return params;
511}
512
Marat Dukhan1f29b802020-05-15 23:46:39 -0700513static inline union xnn_f32_chw_params xnn_init_f32_chw_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700514 uint32_t width,
515 float output_min,
516 float output_max)
517{
Marat Dukhan1f29b802020-05-15 23:46:39 -0700518 union xnn_f32_chw_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700519 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
520 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700521 params.sse.min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700522 params.sse.max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700523 }
524
525 const uint32_t w4 = (width - 1) & 3;
526 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
527 params.sse.mask[1] = -(uint32_t) (w4 >= 1);
528 params.sse.mask[2] = -(uint32_t) (w4 >= 2);
529 params.sse.mask[3] = -(uint32_t) (w4 >= 3);
530
531 const uint32_t w8 = (width - 1) & 7;
532 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
533 params.sse.mask_even[1] = -(uint32_t) (w8 >= 2);
534 params.sse.mask_even[2] = -(uint32_t) (w8 >= 4);
535 params.sse.mask_even[3] = -(uint32_t) (w8 >= 6);
536 params.sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
537 params.sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
538 params.sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
539 params.sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
540 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700541 params.neon.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700542 params.neon.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700543
544 const uint32_t w4 = (width - 1) & 3;
545 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
546 params.neon.mask[1] = -(uint32_t) (w4 >= 1);
547 params.neon.mask[2] = -(uint32_t) (w4 >= 2);
548 params.neon.mask[3] = -(uint32_t) (w4 >= 3);
549
550 const uint32_t w8 = (width - 1) & 7;
551 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
552 params.neon.mask_even[1] = -(uint32_t) (w8 >= 2);
553 params.neon.mask_even[2] = -(uint32_t) (w8 >= 4);
554 params.neon.mask_even[3] = -(uint32_t) (w8 >= 6);
555 params.neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
556 params.neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
557 params.neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
558 params.neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
559 #else
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700560 params.scalar.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700561 params.scalar.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700562 #endif
563 return params;
564}
565
Marat Dukhan1f29b802020-05-15 23:46:39 -0700566static inline void xnn_update_f32_chw_params(
567 union xnn_f32_chw_params* params,
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700568 uint32_t width)
569{
570 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
571 const uint32_t w4 = (width - 1) & 3;
572 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
573 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
574 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
575 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
576
577 const uint32_t w8 = (width - 1) & 7;
578 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
579 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
580 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
581 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
582 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
583 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
584 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
585 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
586 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
587 const uint32_t w4 = (width - 1) & 3;
588 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
589 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
590 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
591 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
592
593 const uint32_t w8 = (width - 1) & 7;
594 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
595 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
596 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
597 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
598 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
599 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
600 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
601 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
602 #endif
603}
604
Marat Dukhan1f29b802020-05-15 23:46:39 -0700605static inline union xnn_f32_chw_params xnn_init_scalar_f32_chw_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700606 uint32_t width,
607 float output_min,
608 float output_max)
609{
Marat Dukhan1f29b802020-05-15 23:46:39 -0700610 union xnn_f32_chw_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700611 params.scalar.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700612 params.scalar.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700613 return params;
614}
615
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700616static inline union xnn_u8_minmax_params xnn_init_u8_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700617 uint8_t output_min,
618 uint8_t output_max)
619{
620 assert(output_min < output_max);
621
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700622 union xnn_u8_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700623 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
624 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700625 params.sse2.min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700626 params.sse2.max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700627 }
628 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700629 params.neon.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700630 params.neon.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700631 #else
632 params.scalar.min = (int32_t) (uint32_t) output_min;
633 params.scalar.max = (int32_t) (uint32_t) output_max;
634 #endif
635 return params;
636}
637
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700638static inline union xnn_u8_minmax_params xnn_init_scalar_u8_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700639 uint8_t output_min,
640 uint8_t output_max)
641{
642 assert(output_min < output_max);
643
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700644 union xnn_u8_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700645 params.scalar.min = (int32_t) (uint32_t) output_min;
646 params.scalar.max = (int32_t) (uint32_t) output_max;
647 return params;
648}
649
650static inline union xnn_q8_add_params xnn_init_q8_add_params(
651 uint8_t a_zero_point,
652 uint8_t b_zero_point,
653 uint8_t output_zero_point,
654 float a_output_scale,
655 float b_output_scale,
656 uint8_t output_min,
657 uint8_t output_max)
658{
659 assert(a_output_scale >= 0x1.0p-14f);
660 assert(b_output_scale >= 0x1.0p-14f);
661 assert(a_output_scale < 0x1.0p+8f);
662 assert(b_output_scale < 0x1.0p+8f);
663
664 // Compute requantization parameters.
665 const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
666 assert(max_output_scale >= 0x1.0p-14f);
667 assert(max_output_scale < 0x1.0p+8f);
668 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
669 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
670 // Shift is in [13, 31] range.
671 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
672 assert(shift < 32);
673 assert(shift >= 13);
674
675 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
676
677 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
Marat Dukhanef3e7dc2020-04-13 01:19:56 -0700678 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(a_output_scale * scale_multiplier);
679 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(b_output_scale * scale_multiplier);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700680 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
681 assert(a_multiplier < UINT32_C(0x00400000));
682 assert(b_multiplier < UINT32_C(0x00400000));
683
684 union xnn_q8_add_params params;
685 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
686 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
687 const uint32_t remainder_threshold = remainder_mask >> 1;
688 const int32_t zero_point_product =
689 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
690 for (uint32_t i = 0; i < 4; i++) {
691 params.sse2.zero_point_product[i] = zero_point_product;
692 }
693 for (uint32_t i = 0; i < 8; i++) {
694 params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
695 }
696 for (uint32_t i = 0; i < 8; i++) {
697 params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
698 params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
699 params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
700 params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
701 }
702 params.sse2.a_multiplier = a_multiplier;
703 params.sse2.b_multiplier = b_multiplier;
704 for (uint32_t i = 0; i < 4; i++) {
705 params.sse2.remainder_mask[i] = remainder_mask;
706 params.sse2.remainder_threshold[i] = remainder_threshold;
707 }
708 params.sse2.shift = shift;
709 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700710 params.sse2.y_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700711 params.sse2.y_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700712 }
713 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
714 params.neon.a_zero_point = a_zero_point;
715 params.neon.b_zero_point = b_zero_point;
716 params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
717 params.neon.a_multiplier = (int32_t) a_multiplier;
718 params.neon.b_multiplier = (int32_t) b_multiplier;
719 params.neon.right_shift = (int32_t) -shift;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700720 params.neon.y_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700721 params.neon.y_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700722 #else
723 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
724 const uint32_t remainder_threshold = remainder_mask >> 1;
725 params.scalar.zero_point_product =
726 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
727 params.scalar.a_multiplier = a_multiplier;
728 params.scalar.b_multiplier = b_multiplier;
729 params.scalar.remainder_mask = (int32_t) remainder_mask;
730 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
731 params.scalar.shift = shift;
732 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700733 params.scalar.y_min = (int32_t) (uint32_t) output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700734 params.scalar.y_max = (int32_t) (uint32_t) output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700735 #endif
736 return params;
737}
738
739static inline union xnn_q8_add_params xnn_init_scalar_q8_add_params(
740 uint8_t a_zero_point,
741 uint8_t b_zero_point,
742 uint8_t output_zero_point,
743 float a_output_scale,
744 float b_output_scale,
745 uint8_t output_min,
746 uint8_t output_max)
747{
748 assert(a_output_scale >= 0x1.0p-10f);
749 assert(b_output_scale >= 0x1.0p-10f);
750 assert(a_output_scale < 0x1.0p+8f);
751 assert(b_output_scale < 0x1.0p+8f);
752
753 // Compute requantization parameters.
754 const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
755 assert(max_output_scale >= 0x1.0p-10f);
756 assert(max_output_scale < 0x1.0p+8f);
757 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
758 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
759 // Shift is in [13, 31] range.
760 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
761 assert(shift < 32);
762 assert(shift >= 13);
763
764 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
Marat Dukhanef3e7dc2020-04-13 01:19:56 -0700765 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
766 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700767 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
768 assert(a_multiplier < UINT32_C(0x00400000));
769 assert(b_multiplier < UINT32_C(0x00400000));
770
771 union xnn_q8_add_params params;
772 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
773 const uint32_t remainder_threshold = remainder_mask >> 1;
774 params.scalar.zero_point_product =
775 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
776 params.scalar.a_multiplier = a_multiplier;
777 params.scalar.b_multiplier = b_multiplier;
778 params.scalar.remainder_mask = (int32_t) remainder_mask;
779 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
780 params.scalar.shift = shift;
781 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700782 params.scalar.y_min = (int32_t) (uint32_t) output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700783 params.scalar.y_max = (int32_t) (uint32_t) output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700784 return params;
785}
786
787static inline union xnn_q31_requantization_params xnn_init_scalar_requantization_params(
788 float scale,
789 uint8_t zero_point,
790 uint8_t min,
791 uint8_t max)
792{
793 // Compute requantization parameters.
794 assert(scale < 1.0f);
795 assert(scale >= 0x1.0p-32f);
796 const uint32_t scale_bits = fp32_to_bits(scale);
797
798 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
799 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
800 assert(multiplier >= INT32_C(0x40000000));
801 assert(multiplier <= INT32_C(0x7FFFFF80));
802
803 // Shift is in [0, 31] range.
804 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
805 assert(shift >= 0);
806 assert(shift < 32);
807
808 union xnn_q31_requantization_params params;
809 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
810 const uint32_t remainder_threshold = remainder_mask >> 1;
811 params.scalar.multiplier = multiplier;
812 params.scalar.remainder_mask = (int32_t) remainder_mask;
813 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
814 params.scalar.shift = (uint32_t) shift;
815 params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
816 params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
817 params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
818 return params;
819}
820
821static inline union xnn_q31_requantization_params xnn_init_requantization_params(
822 float scale,
823 uint8_t zero_point,
824 uint8_t min,
825 uint8_t max)
826{
827 // Compute requantization parameters.
828 const uint32_t scale_bits = fp32_to_bits(scale);
829
830 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
831 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
832 assert(multiplier >= INT32_C(0x40000000));
833 assert(multiplier <= INT32_C(0x7FFFFF80));
834
835 // Shift is in [0, 31] range.
836 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
837 assert(shift >= 0);
838 assert(shift < 32);
839
840 union xnn_q31_requantization_params params;
841 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
842 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
843 const uint32_t remainder_threshold = remainder_mask >> 1;
844 params.sse2.multiplier[0] = multiplier;
845 params.sse2.multiplier[1] = multiplier;
846 params.sse2.multiplier[2] = multiplier;
847 params.sse2.multiplier[3] = multiplier;
848 params.sse2.rounding[0] = UINT64_C(0x40000000);
849 params.sse2.rounding[1] = UINT64_C(0x40000000);
850 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
851 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
852 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
853 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
854 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
855 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
856 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
857 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
858 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
859 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
860 for (uint32_t i = 0; i < 8; i++) {
861 params.sse2.zero_point[i] = (int16_t) (uint16_t) zero_point;
862 }
863 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700864 params.sse2.min[i] = min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700865 params.sse2.max[i] = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700866 }
867 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
868 params.neon.multiplier = multiplier;
869 params.neon.right_shift = -shift;
870 params.neon.zero_point = (int16_t) (uint16_t) zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700871 params.neon.min = min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700872 params.neon.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700873 #else
874 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
875 const uint32_t remainder_threshold = remainder_mask >> 1;
876 params.scalar.multiplier = multiplier;
877 params.scalar.remainder_mask = (int32_t) remainder_mask;
878 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
879 params.scalar.shift = (uint32_t) shift;
880 params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
881 params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
882 params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
883 #endif
884 return params;
885}