blob: d7db75c5cc2d741adbd95947f98b088e4a3aaf43 [file] [log] [blame]
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <stdint.h>
7#include <stddef.h>
8#include <assert.h>
9#include <math.h>
10
11#include <fp16.h>
12
13#include <xnnpack/math.h>
14#include <xnnpack/params-init.h>
15
16
Marat Dukhanc698c112021-07-01 18:52:10 -070017void xnn_init_qu8_conv_minmax_gemmlowp_scalar_params(
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070018 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
19 uint8_t kernel_zero_point,
20 float scale,
21 uint8_t output_zero_point,
22 uint8_t output_min,
23 uint8_t output_max)
24{
25 // Compute requantization parameters
26 const uint32_t scale_bits = fp32_to_bits(scale);
27
28 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
Marat Dukhanc698c112021-07-01 18:52:10 -070029 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070030 assert(multiplier >= INT32_C(0x40000000));
31 assert(multiplier <= INT32_C(0x7FFFFF80));
32
33 // Shift is in [0, 31] range.
34 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
35 assert(shift >= 0);
36 assert(shift < 32);
37
38 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
39 const uint32_t remainder_threshold = remainder_mask >> 1;
40
Marat Dukhanc698c112021-07-01 18:52:10 -070041 params->gemmlowp_scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
42 params->gemmlowp_scalar.multiplier = multiplier;
43 params->gemmlowp_scalar.remainder_mask = (int32_t) remainder_mask;
44 params->gemmlowp_scalar.remainder_threshold = (int32_t) remainder_threshold;
45 params->gemmlowp_scalar.shift = (uint32_t) shift;
46 params->gemmlowp_scalar.output_min_less_zero_point =
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070047 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
Marat Dukhanc698c112021-07-01 18:52:10 -070048 params->gemmlowp_scalar.output_max_less_zero_point =
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070049 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
Marat Dukhanc698c112021-07-01 18:52:10 -070050 params->gemmlowp_scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070051}
52
53#if XNN_ARCH_X86 || XNN_ARCH_X86_64
Marat Dukhanc698c112021-07-01 18:52:10 -070054void xnn_init_qu8_conv_minmax_gemmlowp_sse2_params(
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070055 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
56 uint8_t kernel_zero_point,
57 float scale,
58 uint8_t output_zero_point,
59 uint8_t output_min,
60 uint8_t output_max)
61{
62 // Compute requantization parameters.
63 const uint32_t scale_bits = fp32_to_bits(scale);
64
65 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
Marat Dukhanc698c112021-07-01 18:52:10 -070066 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070067 assert(multiplier >= INT32_C(0x40000000));
68 assert(multiplier <= INT32_C(0x7FFFFF80));
69
70 // Shift is in [0, 31] range.
71 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
72 assert(shift >= 0);
73 assert(shift < 32);
74
75 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
76 const uint32_t remainder_threshold = remainder_mask >> 1;
77 for (uint32_t i = 0; i < 8; i++) {
Marat Dukhanc698c112021-07-01 18:52:10 -070078 params->gemmlowp_sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070079 }
Marat Dukhanc698c112021-07-01 18:52:10 -070080 params->gemmlowp_sse2.multiplier[0] = multiplier;
81 params->gemmlowp_sse2.multiplier[1] = multiplier;
82 params->gemmlowp_sse2.multiplier[2] = multiplier;
83 params->gemmlowp_sse2.multiplier[3] = multiplier;
84 params->gemmlowp_sse2.rounding[0] = UINT64_C(0x40000000);
85 params->gemmlowp_sse2.rounding[1] = UINT64_C(0x40000000);
86 params->gemmlowp_sse2.remainder_mask[0] = (int32_t) remainder_mask;
87 params->gemmlowp_sse2.remainder_mask[1] = (int32_t) remainder_mask;
88 params->gemmlowp_sse2.remainder_mask[2] = (int32_t) remainder_mask;
89 params->gemmlowp_sse2.remainder_mask[3] = (int32_t) remainder_mask;
90 params->gemmlowp_sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
91 params->gemmlowp_sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
92 params->gemmlowp_sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
93 params->gemmlowp_sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
94 params->gemmlowp_sse2.shift[0] = (uint64_t) (uint32_t) shift;
95 params->gemmlowp_sse2.shift[1] = (uint64_t) (uint32_t) shift;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070096 for (uint32_t i = 0; i < 8; i++) {
Marat Dukhanc698c112021-07-01 18:52:10 -070097 params->gemmlowp_sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070098 }
99 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhanc698c112021-07-01 18:52:10 -0700100 params->gemmlowp_sse2.output_min[i] = output_min;
101 params->gemmlowp_sse2.output_max[i] = output_max;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700102 }
103}
Marat Dukhanef47f8d2021-07-02 15:08:32 -0700104
105void xnn_init_qu8_conv_minmax_fp32_sse2_params(
106 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
107 uint8_t kernel_zero_point,
108 float scale,
109 uint8_t output_zero_point,
110 uint8_t output_min,
111 uint8_t output_max)
112{
113 for (uint32_t i = 0; i < 4; i++) {
114 params->fp32_sse2.scale[i] = scale;
115 }
116 for (uint32_t i = 0; i < 8; i++) {
117 params->fp32_sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
118 params->fp32_sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
119 }
120 for (uint32_t i = 0; i < 16; i++) {
121 params->fp32_sse2.output_min[i] = output_min;
122 params->fp32_sse2.output_max[i] = output_max;
123 }
124}
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700125#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
126
127#if XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhanc698c112021-07-01 18:52:10 -0700128void xnn_init_qu8_conv_minmax_gemmlowp_neon_params(
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700129 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
130 uint8_t kernel_zero_point,
131 float scale,
132 uint8_t output_zero_point,
133 uint8_t output_min,
134 uint8_t output_max)
135{
136 // Compute requantization parameters.
137 const uint32_t scale_bits = fp32_to_bits(scale);
138
139 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
Marat Dukhanc698c112021-07-01 18:52:10 -0700140 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700141 assert(multiplier >= INT32_C(0x40000000));
142 assert(multiplier <= INT32_C(0x7FFFFF80));
143
144 // Shift is in [0, 31] range.
145 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
146 assert(shift >= 0);
147 assert(shift < 32);
148
Marat Dukhanc698c112021-07-01 18:52:10 -0700149 params->gemmlowp_neon.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
150 params->gemmlowp_neon.multiplier = multiplier;
151 params->gemmlowp_neon.right_shift = -shift;
152 params->gemmlowp_neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
153 params->gemmlowp_neon.output_min = output_min;
154 params->gemmlowp_neon.output_max = output_max;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700155}
156#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
157
158void xnn_init_qs8_conv_minmax_gemmlowp_scalar_params(
159 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
160 float scale,
161 int8_t output_zero_point,
162 int8_t output_min,
163 int8_t output_max)
164{
165 // Compute requantization parameters
166 const uint32_t scale_bits = fp32_to_bits(scale);
167
168 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
169 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
170 assert(multiplier >= INT32_C(0x40000000));
171 assert(multiplier <= INT32_C(0x7FFFFF80));
172
173 // Shift is in [0, 31] range.
174 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
175 assert(shift >= 0);
176 assert(shift < 32);
177
178 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
179 const uint32_t remainder_threshold = remainder_mask >> 1;
180
181 params->gemmlowp_scalar.multiplier = multiplier;
182 params->gemmlowp_scalar.remainder_mask = (int32_t) remainder_mask;
183 params->gemmlowp_scalar.remainder_threshold = (int32_t) remainder_threshold;
184 params->gemmlowp_scalar.shift = (uint32_t) shift;
185 params->gemmlowp_scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
186 params->gemmlowp_scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
187 params->gemmlowp_scalar.output_zero_point = (int32_t) output_zero_point;
188}
189
190void xnn_init_qs8_conv_minmax_fp32_scalar_lrint_params(
191 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
192 float scale,
193 int8_t output_zero_point,
194 int8_t output_min,
195 int8_t output_max)
196{
197 params->fp32_scalar_lrint.scale = scale;
198 params->fp32_scalar_lrint.output_min_less_zero_point = (long) ((int32_t) output_min - (int32_t) output_zero_point);
199 params->fp32_scalar_lrint.output_max_less_zero_point = (long) ((int32_t) output_max - (int32_t) output_zero_point);
200 params->fp32_scalar_lrint.output_zero_point = (int32_t) output_zero_point;
201}
202
203void xnn_init_qs8_conv_minmax_fp32_scalar_magic_params(
204 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
205 float scale,
206 int8_t output_zero_point,
207 int8_t output_min,
208 int8_t output_max)
209{
210 params->fp32_scalar_magic.scale = scale;
211 params->fp32_scalar_magic.output_min_less_zero_point = (float) ((int32_t) output_min - (int32_t) output_zero_point);
212 params->fp32_scalar_magic.output_max_less_zero_point = (float) ((int32_t) output_max - (int32_t) output_zero_point);
213 params->fp32_scalar_magic.magic_bias = 12582912.0f;
214 params->fp32_scalar_magic.magic_bias_less_output_zero_point = INT32_C(0x4B400000) - (int32_t) output_zero_point;
215}
216
217#if XNN_ARCH_X86 || XNN_ARCH_X86_64
218void xnn_init_qs8_conv_minmax_gemmlowp_sse2_params(
219 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
220 float scale,
221 int8_t output_zero_point,
222 int8_t output_min,
223 int8_t output_max)
224{
225 // Compute requantization parameters.
226 const uint32_t scale_bits = fp32_to_bits(scale);
227
228 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
229 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
230 assert(multiplier >= INT32_C(0x40000000));
231 assert(multiplier <= INT32_C(0x7FFFFF80));
232
233 // Shift is in [0, 31] range.
234 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
235 assert(shift >= 0);
236 assert(shift < 32);
237
238 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
239 const uint32_t remainder_threshold = remainder_mask >> 1;
240 params->gemmlowp_sse2.multiplier[0] = multiplier;
241 params->gemmlowp_sse2.multiplier[1] = multiplier;
242 params->gemmlowp_sse2.multiplier[2] = multiplier;
243 params->gemmlowp_sse2.multiplier[3] = multiplier;
244 params->gemmlowp_sse2.rounding[0] = UINT64_C(0x40000000);
245 params->gemmlowp_sse2.rounding[1] = UINT64_C(0x40000000);
246 params->gemmlowp_sse2.remainder_mask[0] = (int32_t) remainder_mask;
247 params->gemmlowp_sse2.remainder_mask[1] = (int32_t) remainder_mask;
248 params->gemmlowp_sse2.remainder_mask[2] = (int32_t) remainder_mask;
249 params->gemmlowp_sse2.remainder_mask[3] = (int32_t) remainder_mask;
250 params->gemmlowp_sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
251 params->gemmlowp_sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
252 params->gemmlowp_sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
253 params->gemmlowp_sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
254 params->gemmlowp_sse2.shift[0] = (uint64_t) (uint32_t) shift;
255 params->gemmlowp_sse2.shift[1] = (uint64_t) (uint32_t) shift;
256 for (uint32_t i = 0; i < 8; i++) {
257 params->gemmlowp_sse2.output_zero_point[i] = (int16_t) output_zero_point;
258 params->gemmlowp_sse2.output_min[i] = (int16_t) output_min;
259 params->gemmlowp_sse2.output_max[i] = (int16_t) output_max;
260 }
261}
262
263void xnn_init_qs8_conv_minmax_gemmlowp_sse4_params(
264 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
265 float scale,
266 int8_t output_zero_point,
267 int8_t output_min,
268 int8_t output_max)
269{
270 // Compute requantization parameters.
271 const uint32_t scale_bits = fp32_to_bits(scale);
272
273 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
274 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
275 assert(multiplier >= INT32_C(0x40000000));
276 assert(multiplier <= INT32_C(0x7FFFFF80));
277
278 // Shift is in [0, 31] range.
279 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
280 assert(shift >= 0);
281 assert(shift < 32);
282
283 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
284 const uint32_t remainder_threshold = remainder_mask >> 1;
285 params->gemmlowp_sse4.multiplier[0] = multiplier;
286 params->gemmlowp_sse4.multiplier[1] = multiplier;
287 params->gemmlowp_sse4.multiplier[2] = multiplier;
288 params->gemmlowp_sse4.multiplier[3] = multiplier;
289 params->gemmlowp_sse4.rounding[0] = UINT64_C(0x40000000);
290 params->gemmlowp_sse4.rounding[1] = UINT64_C(0x40000000);
291 params->gemmlowp_sse4.remainder_mask[0] = (int32_t) remainder_mask;
292 params->gemmlowp_sse4.remainder_mask[1] = (int32_t) remainder_mask;
293 params->gemmlowp_sse4.remainder_mask[2] = (int32_t) remainder_mask;
294 params->gemmlowp_sse4.remainder_mask[3] = (int32_t) remainder_mask;
295 params->gemmlowp_sse4.remainder_threshold[0] = (int32_t) remainder_threshold;
296 params->gemmlowp_sse4.remainder_threshold[1] = (int32_t) remainder_threshold;
297 params->gemmlowp_sse4.remainder_threshold[2] = (int32_t) remainder_threshold;
298 params->gemmlowp_sse4.remainder_threshold[3] = (int32_t) remainder_threshold;
299 params->gemmlowp_sse4.shift[0] = (uint64_t) (uint32_t) shift;
300 params->gemmlowp_sse4.shift[1] = (uint64_t) (uint32_t) shift;
301 for (uint32_t i = 0; i < 8; i++) {
302 params->gemmlowp_sse4.output_zero_point[i] = (int16_t) output_zero_point;
303 }
304 for (uint32_t i = 0; i < 16; i++) {
305 params->gemmlowp_sse4.output_min[i] = output_min;
306 params->gemmlowp_sse4.output_max[i] = output_max;
307 }
308}
309
310void xnn_init_qs8_conv_minmax_gemmlowp_avx2_params(
311 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
312 float scale,
313 int8_t output_zero_point,
314 int8_t output_min,
315 int8_t output_max)
316{
317 // Compute requantization parameters.
318 const uint32_t scale_bits = fp32_to_bits(scale);
319
320 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
321 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
322 assert(multiplier >= INT32_C(0x40000000));
323 assert(multiplier <= INT32_C(0x7FFFFF80));
324
325 // Shift is in [0, 31] range.
326 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
327 assert(shift >= 0);
328 assert(shift < 32);
329
330 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
331 const uint32_t remainder_threshold = remainder_mask >> 1;
332 for (uint32_t i = 0; i < 8; i++) {
333 params->gemmlowp_avx2.multiplier[i] = multiplier;
334 }
335 params->gemmlowp_avx2.rounding[0] = UINT64_C(0x40000000);
336 params->gemmlowp_avx2.rounding[1] = UINT64_C(0x40000000);
337 params->gemmlowp_avx2.rounding[2] = UINT64_C(0x40000000);
338 params->gemmlowp_avx2.rounding[3] = UINT64_C(0x40000000);
339 for (uint32_t i = 0; i < 8; i++) {
340 params->gemmlowp_avx2.remainder_mask[i] = (int32_t) remainder_mask;
341 params->gemmlowp_avx2.remainder_threshold[i] = (int32_t) remainder_threshold;
342 }
343 params->gemmlowp_avx2.shift[0] = (uint64_t) (uint32_t) shift;
344 params->gemmlowp_avx2.shift[1] = (uint64_t) (uint32_t) shift;
345 params->gemmlowp_avx2.shift[2] = (uint64_t) (uint32_t) shift;
346 params->gemmlowp_avx2.shift[3] = (uint64_t) (uint32_t) shift;
347 for (uint32_t i = 0; i < 16; i++) {
348 params->gemmlowp_avx2.output_zero_point[i] = (int16_t) output_zero_point;
349 }
350 for (uint32_t i = 0; i < 32; i++) {
351 params->gemmlowp_avx2.output_min[i] = output_min;
352 params->gemmlowp_avx2.output_max[i] = output_max;
353 }
354}
355
356void xnn_init_qs8_conv_minmax_gemmlowp_avx512_params(
357 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
358 float scale,
359 int8_t output_zero_point,
360 int8_t output_min,
361 int8_t output_max)
362{
363 // Compute requantization parameters.
364 const uint32_t scale_bits = fp32_to_bits(scale);
365
366 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
367 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
368 assert(multiplier >= INT32_C(0x40000000));
369 assert(multiplier <= INT32_C(0x7FFFFF80));
370
371 // Shift is in [0, 31] range.
372 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
373 assert(shift >= 0);
374 assert(shift < 32);
375
376 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
377 const uint32_t remainder_threshold = remainder_mask >> 1;
378 params->gemmlowp_avx512.multiplier = (int64_t) multiplier;
379 params->gemmlowp_avx512.rounding = UINT64_C(0x40000000);
380 params->gemmlowp_avx512.remainder_mask = (int32_t) remainder_mask;
381 params->gemmlowp_avx512.remainder_threshold = (int32_t) remainder_threshold;
382 params->gemmlowp_avx512.shift = (uint64_t) (uint32_t) shift;
383 for (uint32_t i = 0; i < 32; i++) {
384 params->gemmlowp_avx512.output_zero_point[i] = (int16_t) output_zero_point;
385 }
386 for (uint32_t i = 0; i < 64; i++) {
387 params->gemmlowp_avx512.output_min[i] = output_min;
388 params->gemmlowp_avx512.output_max[i] = output_max;
389 }
390}
391
392void xnn_init_qs8_conv_minmax_fp32_sse2_params(
393 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
394 float scale,
395 int8_t output_zero_point,
396 int8_t output_min,
397 int8_t output_max)
398{
399 for (uint32_t i = 0; i < 4; i++) {
400 params->fp32_sse2.scale[i] = scale;
401 }
402 for (uint32_t i = 0; i < 8; i++) {
403 params->fp32_sse2.output_zero_point[i] = (int16_t) output_zero_point;
404 params->fp32_sse2.output_min[i] = (int16_t) output_min;
405 params->fp32_sse2.output_max[i] = (int16_t) output_max;
406 }
407}
408
409void xnn_init_qs8_conv_minmax_fp32_sse4_params(
410 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
411 float scale,
412 int8_t output_zero_point,
413 int8_t output_min,
414 int8_t output_max)
415{
416 for (uint32_t i = 0; i < 4; i++) {
417 params->fp32_sse4.scale[i] = scale;
418 }
419 for (uint32_t i = 0; i < 8; i++) {
420 params->fp32_sse4.output_zero_point[i] = (int16_t) output_zero_point;
421 }
422 for (uint32_t i = 0; i < 16; i++) {
423 params->fp32_sse4.output_min[i] = output_min;
424 params->fp32_sse4.output_max[i] = output_max;
425 }
426}
427
428void xnn_init_qs8_conv_minmax_fp32_avx2_params(
429 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
430 float scale,
431 int8_t output_zero_point,
432 int8_t output_min,
433 int8_t output_max)
434{
435 for (uint32_t i = 0; i < 8; i++) {
436 params->fp32_avx2.scale[i] = scale;
437 }
438 for (uint32_t i = 0; i < 16; i++) {
439 params->fp32_avx2.output_zero_point[i] = (int16_t) output_zero_point;
440 }
441 for (uint32_t i = 0; i < 32; i++) {
442 params->fp32_avx2.output_min[i] = output_min;
443 params->fp32_avx2.output_max[i] = output_max;
444 }
445}
446
447void xnn_init_qs8_conv_minmax_fp32_avx512_params(
448 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
449 float scale,
450 int8_t output_zero_point,
451 int8_t output_min,
452 int8_t output_max)
453{
454 for (uint32_t i = 0; i < 16; i++) {
455 params->fp32_avx512.scale[i] = scale;
456 }
457 for (uint32_t i = 0; i < 32; i++) {
458 params->fp32_avx512.output_zero_point[i] = (int16_t) output_zero_point;
459 }
460 for (uint32_t i = 0; i < 64; i++) {
461 params->fp32_avx512.output_min[i] = output_min;
462 params->fp32_avx512.output_max[i] = output_max;
463 }
464}
465#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
466
467#if XNN_ARCH_ARM || XNN_ARCH_ARM64
468void xnn_init_qs8_conv_minmax_gemmlowp_neon_params(
469 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
470 float scale,
471 int8_t output_zero_point,
472 int8_t output_min,
473 int8_t output_max)
474{
475 // Compute requantization parameters.
476 const uint32_t scale_bits = fp32_to_bits(scale);
477
478 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
479 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
480 assert(multiplier >= INT32_C(0x40000000));
481 assert(multiplier <= INT32_C(0x7FFFFF80));
482
483 // Shift is in [0, 31] range.
484 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
485 assert(shift >= 0);
486 assert(shift < 32);
487
488 params->gemmlowp_neon.multiplier = multiplier;
489 params->gemmlowp_neon.right_shift = -shift;
490 params->gemmlowp_neon.output_zero_point = (int16_t) output_zero_point;
491 params->gemmlowp_neon.output_min = output_min;
492 params->gemmlowp_neon.output_max = output_max;
493}
494
495void xnn_init_qs8_conv_minmax_fp32_neon_params(
496 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
497 float scale,
498 int8_t output_zero_point,
499 int8_t output_min,
500 int8_t output_max)
501{
502 params->fp32_neon.scale = scale;
503 params->fp32_neon.output_min_less_zero_point = (float) ((int32_t) output_min - (int32_t) output_zero_point);
504 params->fp32_neon.output_max_less_zero_point = (float) ((int32_t) output_max - (int32_t) output_zero_point);
505 params->fp32_neon.magic_bias = 12582912.0f;
506 params->fp32_neon.magic_bias_less_zero_point = INT32_C(0x4B400000) - (int32_t) output_zero_point;
507}
508
509void xnn_init_qs8_conv_minmax_fp32_neonv8_params(
510 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
511 float scale,
512 int8_t output_zero_point,
513 int8_t output_min,
514 int8_t output_max)
515{
516 params->fp32_neonv8.scale = scale;
517 params->fp32_neonv8.output_zero_point = (int16_t) output_zero_point;
518 params->fp32_neonv8.output_min = output_min;
519 params->fp32_neonv8.output_max = output_max;
520}
521#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
522
523#if XNN_ARCH_WASMSIMD
524void xnn_init_qs8_conv_minmax_gemmlowp_wasmsimd_params(
525 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
526 float scale,
527 int8_t output_zero_point,
528 int8_t output_min,
529 int8_t output_max)
530{
531 // Compute requantization parameters.
532 const uint32_t scale_bits = fp32_to_bits(scale);
533
534 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
535 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
536 assert(multiplier >= INT32_C(0x40000000));
537 assert(multiplier <= INT32_C(0x7FFFFF80));
538
539 // Shift is in [0, 31] range.
540 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
541 assert(shift >= 0);
542 assert(shift < 32);
543
544 const int64_t twice_multiplier = INT64_C(2) * (int64_t) multiplier;
545 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
546 const uint32_t remainder_threshold = remainder_mask >> 1;
547 params->gemmlowp_wasmsimd.multiplier[0] = twice_multiplier;
548 params->gemmlowp_wasmsimd.multiplier[1] = twice_multiplier;
549 params->gemmlowp_wasmsimd.rounding[0] = INT64_C(0x80000000);
550 params->gemmlowp_wasmsimd.rounding[1] = INT64_C(0x80000000);
551 params->gemmlowp_wasmsimd.remainder_mask[0] = (int32_t) remainder_mask;
552 params->gemmlowp_wasmsimd.remainder_mask[1] = (int32_t) remainder_mask;
553 params->gemmlowp_wasmsimd.remainder_mask[2] = (int32_t) remainder_mask;
554 params->gemmlowp_wasmsimd.remainder_mask[3] = (int32_t) remainder_mask;
555 params->gemmlowp_wasmsimd.remainder_threshold[0] = (int32_t) remainder_threshold;
556 params->gemmlowp_wasmsimd.remainder_threshold[1] = (int32_t) remainder_threshold;
557 params->gemmlowp_wasmsimd.remainder_threshold[2] = (int32_t) remainder_threshold;
558 params->gemmlowp_wasmsimd.remainder_threshold[3] = (int32_t) remainder_threshold;
559 params->gemmlowp_wasmsimd.shift = shift;
560 for (uint32_t i = 0; i < 8; i++) {
561 params->gemmlowp_wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
562 }
563 for (uint32_t i = 0; i < 16; i++) {
564 params->gemmlowp_wasmsimd.output_min[i] = output_min;
565 params->gemmlowp_wasmsimd.output_max[i] = output_max;
566 }
567}
Marat Dukhan4741e412021-06-30 13:38:06 -0700568
569void xnn_init_qs8_conv_minmax_fp32_wasmsimd_params(
570 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
571 float scale,
572 int8_t output_zero_point,
573 int8_t output_min,
574 int8_t output_max)
575{
576 for (uint32_t i = 0; i < 4; i++) {
577 params->fp32_wasmsimd.scale[i] = scale;
578 params->fp32_wasmsimd.output_min_less_zero_point[i] = (float) ((int32_t) output_min - (int32_t) output_zero_point);
579 params->fp32_wasmsimd.output_max_less_zero_point[i] = (float) ((int32_t) output_max - (int32_t) output_zero_point);
580 params->fp32_wasmsimd.magic_bias[i] = 12582912.0f;
581 params->fp32_wasmsimd.magic_bias_less_output_zero_point[i] = INT32_C(0x4B400000) - (int32_t) output_zero_point;
582 }
583}
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700584#endif // XNN_ARCH_WASMSIMD
585
586void xnn_init_qc8_scale_fp32_params(
587 size_t channels,
588 size_t channels_tile,
589 size_t stride,
590 const float scale[XNN_MIN_ELEMENTS(1)],
591 void* packed_w)
592{
593 for (size_t tile_start = 0; tile_start < channels; tile_start += channels_tile) {
594 const size_t tile_size = min(channels - tile_start, channels_tile);
595 for (size_t tile_offset = 0; tile_offset < tile_size; tile_offset++) {
596 ((float*) packed_w)[tile_offset] = scale[tile_start + tile_offset];
597 }
598 packed_w = (void*) ((uintptr_t) packed_w + stride);
599 }
600}
601
Marat Dukhand6021542021-06-30 09:04:20 -0700602void xnn_init_qs8_minmax_scalar_lrint_params(
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700603 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
604 int8_t output_zero_point,
605 int8_t output_min,
606 int8_t output_max)
607{
Marat Dukhand6021542021-06-30 09:04:20 -0700608 params->scalar_lrint.output_min_less_zero_point = (long) ((int32_t) output_min - (int32_t) output_zero_point);
609 params->scalar_lrint.output_max_less_zero_point = (long) ((int32_t) output_max - (int32_t) output_zero_point);
610 params->scalar_lrint.output_zero_point = (int32_t) output_zero_point;
611}
612
613void xnn_init_qs8_minmax_scalar_magic_params(
614 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
615 int8_t output_zero_point,
616 int8_t output_min,
617 int8_t output_max)
618{
619 params->scalar_magic.output_min_less_zero_point = (float) ((int32_t) output_min - (int32_t) output_zero_point);
620 params->scalar_magic.output_max_less_zero_point = (float) ((int32_t) output_max - (int32_t) output_zero_point);
621 params->scalar_magic.magic_bias = 12582912.0f;
622 params->scalar_magic.magic_bias_less_output_zero_point = INT32_C(0x4B400000) - (int32_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700623}
624
625#if XNN_ARCH_X86 || XNN_ARCH_X86_64
626void xnn_init_qs8_minmax_sse2_params(
627 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
628 int8_t output_zero_point,
629 int8_t output_min,
630 int8_t output_max)
631{
632 for (uint32_t i = 0; i < 8; i++) {
633 params->sse2.output_zero_point[i] = (int16_t) output_zero_point;
634 params->sse2.output_min[i] = (int16_t) output_min;
635 params->sse2.output_max[i] = (int16_t) output_max;
636 }
637}
638
639void xnn_init_qs8_minmax_sse4_params(
640 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
641 int8_t output_zero_point,
642 int8_t output_min,
643 int8_t output_max)
644{
645 for (uint32_t i = 0; i < 8; i++) {
646 params->sse4.output_zero_point[i] = (int16_t) output_zero_point;
647 }
648 for (uint32_t i = 0; i < 16; i++) {
649 params->sse4.output_min[i] = output_min;
650 params->sse4.output_max[i] = output_max;
651 }
652}
653
654void xnn_init_qs8_minmax_avx2_params(
655 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
656 int8_t output_zero_point,
657 int8_t output_min,
658 int8_t output_max)
659{
660 for (uint32_t i = 0; i < 16; i++) {
661 params->avx2.output_zero_point[i] = (int16_t) output_zero_point;
662 }
663 for (uint32_t i = 0; i < 32; i++) {
664 params->avx2.output_min[i] = output_min;
665 params->avx2.output_max[i] = output_max;
666 }
667}
668
669void xnn_init_qs8_minmax_avx512_params(
670 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
671 int8_t output_zero_point,
672 int8_t output_min,
673 int8_t output_max)
674{
675 for (uint32_t i = 0; i < 32; i++) {
676 params->avx512.output_zero_point[i] = (int16_t) output_zero_point;
677 }
678 for (uint32_t i = 0; i < 64; i++) {
679 params->avx512.output_min[i] = output_min;
680 params->avx512.output_max[i] = output_max;
681 }
682}
683#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
684
685#if XNN_ARCH_ARM || XNN_ARCH_ARM64
686void xnn_init_qs8_minmax_neon_params(
687 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
688 int8_t output_zero_point,
689 int8_t output_min,
690 int8_t output_max)
691{
692 params->neon.output_zero_point = (int16_t) output_zero_point;
693 params->neon.output_min = output_min;
694 params->neon.output_max = output_max;
695}
696
697void xnn_init_qs8_minmax_neon_fp32_params(
698 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
699 int8_t output_zero_point,
700 int8_t output_min,
701 int8_t output_max)
702{
703 params->neon_fp32.output_min_less_zero_point = (float) ((int32_t) output_min - (int32_t) output_zero_point);
704 params->neon_fp32.output_max_less_zero_point = (float) ((int32_t) output_max - (int32_t) output_zero_point);
705 params->neon_fp32.magic_bias = 12582912.0f;
706 params->neon_fp32.magic_bias_less_zero_point = INT32_C(0x4B400000) - (int32_t) output_zero_point;
707}
708#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
709
710#if XNN_ARCH_WASMSIMD
711void xnn_init_qs8_minmax_wasmsimd_params(
712 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
713 int8_t output_zero_point,
714 int8_t output_min,
715 int8_t output_max)
716{
Marat Dukhan47c12202021-06-30 15:09:34 -0700717 for (uint32_t i = 0; i < 4; i++) {
718 params->wasmsimd.output_min_less_zero_point[i] = (float) ((int32_t) output_min - (int32_t) output_zero_point);
719 params->wasmsimd.output_max_less_zero_point[i] = (float) ((int32_t) output_max - (int32_t) output_zero_point);
720 params->wasmsimd.magic_bias[i] = 12582912.0f;
721 params->wasmsimd.magic_bias_less_output_zero_point[i] = INT32_C(0x4B400000) - (int32_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700722 }
723}
724#endif // XNN_ARCH_WASMSIMD
725
726void xnn_init_qu8_avgpool_params(
727 union xnn_qu8_avgpool_params params[XNN_MIN_ELEMENTS(1)],
728 int32_t bias,
729 float scale,
730 uint8_t output_zero_point,
731 uint8_t output_min,
732 uint8_t output_max)
733{
734 // Compute requantization parameters.
735 assert(scale >= 0x1.0p-32f);
736 assert(scale < 256.0f);
737 const uint32_t scale_bits = fp32_to_bits(scale);
738
739 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
740 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
741 assert(multiplier >= INT32_C(0x00800000));
742 assert(multiplier <= INT32_C(0x00FFFFFF));
743
744 // Shift is in [16, 55] range.
745 const int32_t shift = 127 + 23 - (scale_bits >> 23);
746 assert(shift >= 16);
747 assert(shift < 64);
748
749 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
750 const uint32_t right_shift = (uint32_t) shift;
751 const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
752 params->sse2.bias[0] = bias;
753 params->sse2.bias[1] = bias;
754 params->sse2.bias[2] = bias;
755 params->sse2.bias[3] = bias;
756 params->sse2.multiplier[0] = (uint32_t) multiplier;
757 params->sse2.multiplier[1] = (uint32_t) multiplier;
758 params->sse2.multiplier[2] = (uint32_t) multiplier;
759 params->sse2.multiplier[3] = (uint32_t) multiplier;
760 params->sse2.rounding[0] = rounding;
761 params->sse2.rounding[1] = rounding;
762 params->sse2.right_shift[0] = (uint64_t) right_shift;
763 params->sse2.right_shift[1] = (uint64_t) right_shift;
764 for (uint32_t i = 0; i < 8; i++) {
765 params->sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
766 }
767 for (uint32_t i = 0; i < 16; i++) {
768 params->sse2.output_min[i] = output_min;
769 params->sse2.output_max[i] = output_max;
770 }
771 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
772 params->neon.bias = bias;
773 params->neon.multiplier = multiplier;
774 params->neon.left_shift = (int64_t) -shift;
775 params->neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
776 params->neon.output_min = output_min;
777 params->neon.output_max = output_max;
778 #else
779 const uint32_t right_shift = (uint32_t) shift;
780 const int64_t rounding = INT64_C(1) << (right_shift - 1);
781 params->scalar.bias = bias;
782 params->scalar.multiplier = multiplier;
783 params->scalar.rounding = rounding;
784 params->scalar.right_shift = right_shift;
785 params->scalar.output_min_less_zero_point =
786 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
787 params->scalar.output_max_less_zero_point =
788 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
789 params->scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
790 #endif
791}
792
793void xnn_init_scalar_qu8_avgpool_params(
794 union xnn_qu8_avgpool_params params[XNN_MIN_ELEMENTS(1)],
795 int32_t bias,
796 float scale,
797 uint8_t output_zero_point,
798 uint8_t output_min,
799 uint8_t output_max)
800{
801 // Compute requantization parameters.
802 assert(scale >= 0x1.0p-32f);
803 assert(scale < 256.0f);
804 const uint32_t scale_bits = fp32_to_bits(scale);
805
806 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
807 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
808 assert(multiplier >= INT32_C(0x00800000));
809 assert(multiplier <= INT32_C(0x00FFFFFF));
810
811 // Shift is in [16, 55] range.
812 const int32_t shift = 127 + 23 - (scale_bits >> 23);
813 assert(shift >= 16);
814 assert(shift < 64);
815
816 const uint32_t right_shift = (uint32_t) shift;
817 const int64_t rounding = INT64_C(1) << (right_shift - 1);
818 params->scalar.bias = bias;
819 params->scalar.rounding = rounding;
820 params->scalar.multiplier = multiplier;
821 params->scalar.right_shift = right_shift;
822 params->scalar.output_min_less_zero_point =
823 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
824 params->scalar.output_max_less_zero_point =
825 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
826 params->scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
827}
828
829void xnn_update_qu8_avgpool_params(
830 union xnn_qu8_avgpool_params* params,
831 int32_t bias,
832 float scale)
833{
834 // Compute requantization parameters.
835 assert(scale >= 0x1.0p-32f);
836 assert(scale < 256.0f);
837 const uint32_t scale_bits = fp32_to_bits(scale);
838
839 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
840 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
841 assert(multiplier >= INT32_C(0x00800000));
842 assert(multiplier <= INT32_C(0x00FFFFFF));
843
844 // Shift is in [16, 55] range.
845 const int32_t shift = 127 + 23 - (scale_bits >> 23);
846 assert(shift >= 16);
847 assert(shift < 64);
848
849 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
850 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
851 params->sse2.bias[0] = bias;
852 params->sse2.bias[1] = bias;
853 params->sse2.bias[2] = bias;
854 params->sse2.bias[3] = bias;
855 params->sse2.multiplier[0] = (uint32_t) multiplier;
856 params->sse2.multiplier[1] = (uint32_t) multiplier;
857 params->sse2.multiplier[2] = (uint32_t) multiplier;
858 params->sse2.multiplier[3] = (uint32_t) multiplier;
859 params->sse2.rounding[0] = rounding;
860 params->sse2.rounding[1] = rounding;
861 params->sse2.right_shift[0] = (uint64_t) (uint32_t) shift;
862 params->sse2.right_shift[1] = (uint64_t) (uint32_t) shift;
863 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
864 params->neon.bias = bias;
865 params->neon.multiplier = multiplier;
866 params->neon.left_shift = (int64_t) -shift;
867 #else
868 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
869 params->scalar.bias = bias;
870 params->scalar.multiplier = multiplier;
871 params->scalar.rounding = rounding;
872 params->scalar.right_shift = (uint32_t) shift;
873 #endif
874}
875
876void xnn_init_qs8_avgpool_params(
877 union xnn_qs8_avgpool_params params[XNN_MIN_ELEMENTS(1)],
878 int32_t bias,
879 float scale,
880 int8_t output_zero_point,
881 int8_t output_min,
882 int8_t output_max)
883{
884 // Compute requantization parameters.
885 assert(scale >= 0x1.0p-32f);
886 assert(scale < 256.0f);
887 const uint32_t scale_bits = fp32_to_bits(scale);
888
889 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
890 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
891 assert(multiplier >= INT32_C(0x00800000));
892 assert(multiplier <= INT32_C(0x00FFFFFF));
893
894 // Shift is in [16, 55] range.
895 const int32_t shift = 127 + 23 - (scale_bits >> 23);
896 assert(shift >= 16);
897 assert(shift < 64);
898
899 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
900 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
901 params->sse2.bias[0] = bias;
902 params->sse2.bias[1] = bias;
903 params->sse2.bias[2] = bias;
904 params->sse2.bias[3] = bias;
905 params->sse2.multiplier[0] = (uint32_t) multiplier;
906 params->sse2.multiplier[1] = (uint32_t) multiplier;
907 params->sse2.multiplier[2] = (uint32_t) multiplier;
908 params->sse2.multiplier[3] = (uint32_t) multiplier;
909 params->sse2.rounding[0] = rounding;
910 params->sse2.rounding[1] = rounding;
911 params->sse2.shift[0] = (uint64_t) (uint32_t) shift;
912 params->sse2.shift[1] = (uint64_t) (uint32_t) shift;
913 for (uint32_t i = 0; i < 8; i++) {
914 params->sse2.output_zero_point[i] = (int16_t) output_zero_point;
915 params->sse2.output_min[i] = (int16_t) output_min;
916 params->sse2.output_max[i] = (int16_t) output_max;
917 }
918 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
919 params->neon.bias = bias;
920 params->neon.multiplier = multiplier;
921 params->neon.left_shift = (int64_t) -shift;
922 params->neon.output_zero_point = (int16_t) output_zero_point;
923 params->neon.output_min = output_min;
924 params->neon.output_max = output_max;
925 #elif XNN_ARCH_WASMSIMD
926 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
927 params->wasmsimd.bias[0] = bias;
928 params->wasmsimd.bias[1] = bias;
929 params->wasmsimd.bias[2] = bias;
930 params->wasmsimd.bias[3] = bias;
931 params->wasmsimd.multiplier[0] = (int64_t) multiplier;
932 params->wasmsimd.multiplier[1] = (int64_t) multiplier;
933 params->wasmsimd.rounding[0] = rounding;
934 params->wasmsimd.rounding[1] = rounding;
935 params->wasmsimd.shift = shift;
936 for (uint32_t i = 0; i < 8; i++) {
937 params->wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
938 }
939 for (uint32_t i = 0; i < 16; i++) {
940 params->wasmsimd.output_min[i] = output_min;
941 params->wasmsimd.output_max[i] = output_max;
942 }
943 #else
944 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
945 params->scalar.bias = bias;
946 params->scalar.multiplier = multiplier;
947 params->scalar.rounding = rounding;
948 params->scalar.shift = (uint32_t) shift;
949 params->scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
950 params->scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
951 params->scalar.output_zero_point = (int32_t) output_zero_point;
952 #endif
953}
954
955void xnn_init_scalar_qs8_avgpool_params(
956 union xnn_qs8_avgpool_params params[XNN_MIN_ELEMENTS(1)],
957 int32_t bias,
958 float scale,
959 int8_t output_zero_point,
960 int8_t output_min,
961 int8_t output_max)
962{
963 // Compute requantization parameters.
964 assert(scale >= 0x1.0p-32f);
965 assert(scale < 256.0f);
966 const uint32_t scale_bits = fp32_to_bits(scale);
967
968 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
969 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
970 assert(multiplier >= INT32_C(0x00800000));
971 assert(multiplier <= INT32_C(0x00FFFFFF));
972
973 // Shift is in [16, 55] range.
974 const int32_t shift = 127 + 23 - (scale_bits >> 23);
975 assert(shift >= 16);
976 assert(shift < 64);
977
978 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
979 params->scalar.bias = bias;
980 params->scalar.rounding = rounding;
981 params->scalar.multiplier = multiplier;
982 params->scalar.shift = shift;
983 params->scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
984 params->scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
985 params->scalar.output_zero_point = (int32_t) output_zero_point;
986}
987
988void xnn_update_qs8_avgpool_params(
989 union xnn_qs8_avgpool_params* params,
990 int32_t bias,
991 float scale)
992{
993 // Compute requantization parameters.
994 assert(scale >= 0x1.0p-32f);
995 assert(scale < 256.0f);
996 const uint32_t scale_bits = fp32_to_bits(scale);
997
998 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
999 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
1000 assert(multiplier >= INT32_C(0x00800000));
1001 assert(multiplier <= INT32_C(0x00FFFFFF));
1002
1003 // Shift is in [16, 55] range.
1004 const int32_t shift = 127 + 23 - (scale_bits >> 23);
1005 assert(shift >= 16);
1006 assert(shift < 64);
1007
1008 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1009 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
1010 params->sse2.bias[0] = bias;
1011 params->sse2.bias[1] = bias;
1012 params->sse2.bias[2] = bias;
1013 params->sse2.bias[3] = bias;
1014 params->sse2.multiplier[0] = (uint32_t) multiplier;
1015 params->sse2.multiplier[1] = (uint32_t) multiplier;
1016 params->sse2.multiplier[2] = (uint32_t) multiplier;
1017 params->sse2.multiplier[3] = (uint32_t) multiplier;
1018 params->sse2.rounding[0] = rounding;
1019 params->sse2.rounding[1] = rounding;
1020 params->sse2.shift[0] = (uint64_t) (uint32_t) shift;
1021 params->sse2.shift[1] = (uint64_t) (uint32_t) shift;
1022 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1023 params->neon.bias = bias;
1024 params->neon.multiplier = multiplier;
1025 params->neon.left_shift = (int64_t) -shift;
1026 #elif XNN_ARCH_WASMSIMD
1027 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
1028 params->wasmsimd.bias[0] = bias;
1029 params->wasmsimd.bias[1] = bias;
1030 params->wasmsimd.bias[2] = bias;
1031 params->wasmsimd.bias[3] = bias;
1032 params->wasmsimd.multiplier[0] = (int64_t) multiplier;
1033 params->wasmsimd.multiplier[1] = (int64_t) multiplier;
1034 params->wasmsimd.rounding[0] = rounding;
1035 params->wasmsimd.rounding[1] = rounding;
1036 params->wasmsimd.shift = shift;
1037 #else
1038 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
1039 params->scalar.bias = bias;
1040 params->scalar.multiplier = multiplier;
1041 params->scalar.rounding = rounding;
1042 params->scalar.shift = (uint32_t) shift;
1043 #endif
1044}
1045
1046void xnn_update_f16_scaleminmax_params(
1047 struct xnn_f16_scaleminmax_params* params,
1048 uint16_t scale)
1049{
1050 params->scale = scale;
1051}
1052
1053void xnn_update_f32_scaleminmax_params(
1054 union xnn_f32_scaleminmax_params* params,
1055 float scale)
1056{
1057 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1058 for (uint32_t i = 0; i < 4; i++) {
1059 params->sse2.scale[i] = scale;
1060 }
1061 #else
1062 params->scalar.scale = scale;
1063 #endif
1064}
1065
1066void xnn_init_f16_scaleminmax_params(
1067 struct xnn_f16_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
1068 uint16_t scale,
1069 uint16_t min,
1070 uint16_t max)
1071{
1072 params->scale = scale;
1073 params->min = min;
1074 params->max = max;
1075 params->pad = 0; // unused.
1076}
1077
1078void xnn_init_f32_scaleminmax_params(
1079 union xnn_f32_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
1080 float scale,
1081 float min,
1082 float max)
1083{
1084 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1085 for (uint32_t i = 0; i < 4; i++) {
1086 params->sse2.scale[i] = scale;
1087 params->sse2.min[i] = min;
1088 params->sse2.max[i] = max;
1089 }
1090 #else
1091 params->scalar.scale = scale;
1092 params->scalar.min = min;
1093 params->scalar.max = max;
1094 #endif
1095}
1096
1097void xnn_init_f32_gavgpool_params(
1098 union xnn_f32_gavgpool_params params[XNN_MIN_ELEMENTS(1)],
1099 float multiplier,
1100 float output_min,
1101 float output_max,
1102 uint32_t width)
1103{
1104 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1105 for (uint32_t i = 0; i < 4; i++) {
1106 params->sse.multiplier[i] = multiplier;
1107 params->sse.output_min[i] = output_min;
1108 params->sse.output_max[i] = output_max;
1109 }
1110
1111 const uint32_t w = (width - 1) & 3;
1112 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1113 params->sse.mask[1] = -(uint32_t) (w >= 1);
1114 params->sse.mask[2] = -(uint32_t) (w >= 2);
1115 params->sse.mask[3] = -(uint32_t) (w >= 3);
1116 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1117 params->neon.multiplier = multiplier;
1118 params->neon.output_min = output_min;
1119 params->neon.output_max = output_max;
1120
1121 const uint32_t w = (width - 1) & 3;
1122 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1123 params->neon.mask[1] = -(uint32_t) (w >= 1);
1124 params->neon.mask[2] = -(uint32_t) (w >= 2);
1125 params->neon.mask[3] = -(uint32_t) (w >= 3);
1126 #else
1127 params->scalar.multiplier = multiplier;
1128 params->scalar.output_min = output_min;
1129 params->scalar.output_max = output_max;
1130
1131 const uint32_t w = (width - 1) & 3;
1132 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1133 params->scalar.mask[1] = -(int32_t) (w >= 1);
1134 params->scalar.mask[2] = -(int32_t) (w >= 2);
1135 params->scalar.mask[3] = -(int32_t) (w >= 3);
1136 #endif
1137}
1138
1139void xnn_update_f32_gavgpool_params(
1140 union xnn_f32_gavgpool_params* params,
1141 float multiplier,
1142 uint32_t width)
1143{
1144 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1145 for (uint32_t i = 0; i < 4; i++) {
1146 params->sse.multiplier[i] = multiplier;
1147 }
1148
1149 const uint32_t w = (width - 1) & 3;
1150 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1151 params->sse.mask[1] = -(uint32_t) (w >= 1);
1152 params->sse.mask[2] = -(uint32_t) (w >= 2);
1153 params->sse.mask[3] = -(uint32_t) (w >= 3);
1154 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1155 params->neon.multiplier = multiplier;
1156
1157 const uint32_t w = (width - 1) & 3;
1158 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1159 params->neon.mask[1] = -(uint32_t) (w >= 1);
1160 params->neon.mask[2] = -(uint32_t) (w >= 2);
1161 params->neon.mask[3] = -(uint32_t) (w >= 3);
1162 #else
1163 params->scalar.multiplier = multiplier;
1164
1165 const uint32_t w = (width - 1) & 3;
1166 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1167 params->scalar.mask[1] = -(int32_t) (w >= 1);
1168 params->scalar.mask[2] = -(int32_t) (w >= 2);
1169 params->scalar.mask[3] = -(int32_t) (w >= 3);
1170 #endif
1171}
1172
1173void xnn_init_scalar_f32_scaleminmax_params(
1174 union xnn_f32_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
1175 float scale,
1176 float min,
1177 float max)
1178{
1179 params->scalar.scale = scale;
1180 params->scalar.min = min;
1181 params->scalar.max = max;
1182}
1183
1184void xnn_init_scalar_f32_gavgpool_params(
1185 union xnn_f32_gavgpool_params params[XNN_MIN_ELEMENTS(1)],
1186 float multiplier,
1187 float output_min,
1188 float output_max,
1189 uint32_t width)
1190{
1191 params->scalar.multiplier = multiplier;
1192 params->scalar.output_min = output_min;
1193 params->scalar.output_max = output_max;
1194
1195 const uint32_t w = (width - 1) & 3;
1196 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1197 params->scalar.mask[1] = -(int32_t) (w >= 1);
1198 params->scalar.mask[2] = -(int32_t) (w >= 2);
1199 params->scalar.mask[3] = -(int32_t) (w >= 3);
1200}
1201
1202void xnn_init_f16_minmax_params(
1203 struct xnn_f16_minmax_params params[XNN_MIN_ELEMENTS(1)],
1204 uint16_t min,
1205 uint16_t max)
1206{
1207 params->min = min;
1208 params->max = max;
1209}
1210
1211void xnn_init_f32_minmax_params(
1212 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)],
1213 float output_min,
1214 float output_max)
1215{
1216 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1217 for (uint32_t i = 0; i < 4; i++) {
1218 params->sse.min[i] = output_min;
1219 params->sse.max[i] = output_max;
1220 }
1221 #else
1222 params->scalar.min = output_min;
1223 params->scalar.max = output_max;
1224 #endif
1225}
1226
1227#if XNN_ARCH_X86 || XNN_ARCH_X86_64
1228void xnn_init_f32_minmax_sse_params(
1229 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)],
1230 float output_min,
1231 float output_max)
1232{
1233 for (uint32_t i = 0; i < 4; i++) {
1234 params->sse.min[i] = output_min;
1235 params->sse.max[i] = output_max;
1236 }
1237}
1238
1239void xnn_init_f32_minmax_avx_params(
1240 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)],
1241 float output_min,
1242 float output_max)
1243{
1244 for (uint32_t i = 0; i < 8; i++) {
1245 params->avx.min[i] = output_min;
1246 params->avx.max[i] = output_max;
1247 }
1248}
1249#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1250
1251void xnn_init_f32_minmax_scalar_params(
1252 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)],
1253 float output_min,
1254 float output_max)
1255{
1256 params->scalar.min = output_min;
1257 params->scalar.max = output_max;
1258}
1259
1260void xnn_init_f16_hswish_params(
1261 struct xnn_f16_hswish_params params[XNN_MIN_ELEMENTS(1)])
1262{
1263 params->sixth = UINT16_C(0x3155);
1264 params->three = UINT16_C(0x4200);
1265 params->six = UINT16_C(0x4600);
1266}
1267
1268void xnn_init_f32_hswish_params(
1269 union xnn_f32_hswish_params params[XNN_MIN_ELEMENTS(1)])
1270{
1271 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1272 for (uint32_t i = 0; i < 4; i++) {
1273 params->sse.sixth[i] = 0x1.555556p-3f;
1274 params->sse.half[i] = 0.5f;
1275 params->sse.one[i] = 1.0f;
1276 }
1277 #else
1278 params->scalar.sixth = 0x1.555556p-3f;
1279 params->scalar.three = 3.0f;
1280 params->scalar.six = 6.0f;
1281 #endif
1282}
1283
1284void xnn_init_scalar_f32_hswish_params(
1285 union xnn_f32_hswish_params params[XNN_MIN_ELEMENTS(1)])
1286{
1287 params->scalar.sixth = 0x1.555556p-3f;
1288 params->scalar.three = 3.0f;
1289 params->scalar.six = 6.0f;
1290}
1291
1292void xnn_init_f32_abs_params(
1293 union xnn_f32_abs_params params[XNN_MIN_ELEMENTS(1)])
1294{
1295 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1296 for (uint32_t i = 0; i < 4; i++) {
1297 params->sse.nonsign_mask[i] = math_nonsign_mask_f32();
1298 }
1299 #elif XNN_ARCH_WASMSIMD
1300 params->wasmsimd.nonsign_mask = math_nonsign_mask_f32();
1301 #endif
1302}
1303
1304void xnn_init_scalar_f32_abs_params(
1305 union xnn_f32_abs_params params[XNN_MIN_ELEMENTS(1)])
1306{
1307}
1308
1309void xnn_init_f32_neg_params(
1310 union xnn_f32_neg_params params[XNN_MIN_ELEMENTS(1)])
1311{
1312 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1313 for (uint32_t i = 0; i < 4; i++) {
1314 params->sse.sign_mask[i] = -0.0f;
1315 }
1316 #elif XNN_ARCH_WASMSIMD
1317 params->wasmsimd.sign_mask = -0.0f;
1318 #endif
1319}
1320
1321void xnn_init_scalar_f32_neg_params(
1322 union xnn_f32_neg_params params[XNN_MIN_ELEMENTS(1)])
1323{
1324}
1325
1326void xnn_init_f32_rnd_params(
1327 union xnn_f32_rnd_params params[XNN_MIN_ELEMENTS(1)])
1328{
1329 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1330 for (uint32_t i = 0; i < 4; i++) {
1331 params->sse2.sign_mask[i] = -0.0f;
1332 }
1333 for (uint32_t i = 0; i < 4; i++) {
1334 params->sse2.one[i] = 1.0f;
1335 }
1336 #endif
1337}
1338
1339void xnn_init_scalar_f32_rnd_params(
1340 union xnn_f32_rnd_params params[XNN_MIN_ELEMENTS(1)])
1341{
1342}
1343
1344void xnn_init_f32_elu_params(
1345 union xnn_f32_elu_params params[XNN_MIN_ELEMENTS(1)],
1346 float prescale,
1347 float alpha,
1348 float beta)
1349{
1350 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1351 for (uint32_t i = 0; i < 4; i++) {
1352 params->sse.prescale[i] = prescale;
1353 params->sse.alpha[i] = alpha;
1354 params->sse.beta[i] = beta;
1355 }
1356 #else
1357 params->scalar.prescale = prescale;
1358 params->scalar.alpha = alpha;
1359 params->scalar.beta = beta;
1360 #endif
1361}
1362
1363void xnn_init_scalar_f32_elu_params(
1364 union xnn_f32_elu_params params[XNN_MIN_ELEMENTS(1)],
1365 float prescale,
1366 float alpha,
1367 float beta)
1368{
1369 params->scalar.prescale = prescale;
1370 params->scalar.alpha = alpha;
1371 params->scalar.beta = beta;
1372}
1373
1374void xnn_init_f32_lrelu_params(
1375 union xnn_f32_lrelu_params params[XNN_MIN_ELEMENTS(1)],
1376 float slope)
1377{
1378 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1379 for (uint32_t i = 0; i < 4; i++) {
1380 params->sse.slope[i] = slope;
1381 }
1382 #else
1383 params->scalar.slope = slope;
1384 #endif
1385}
1386
1387void xnn_init_scalar_f32_lrelu_params(
1388 union xnn_f32_lrelu_params params[XNN_MIN_ELEMENTS(1)],
1389 float slope)
1390{
1391 params->scalar.slope = slope;
1392}
1393
1394void xnn_init_f32_sqrt_params(
1395 union xnn_f32_sqrt_params params[XNN_MIN_ELEMENTS(1)])
1396{
1397 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1398 params->fma.half = 0.5f;
1399 #endif
1400}
1401
1402void xnn_init_scalar_f32_sqrt_params(
1403 union xnn_f32_sqrt_params params[XNN_MIN_ELEMENTS(1)])
1404{
1405}
1406
1407void xnn_init_f32_chw_params(
1408 union xnn_f32_chw_params params[XNN_MIN_ELEMENTS(1)],
1409 uint32_t width,
1410 float output_min,
1411 float output_max)
1412{
1413 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1414 for (uint32_t i = 0; i < 4; i++) {
1415 params->sse.min[i] = output_min;
1416 params->sse.max[i] = output_max;
1417 }
1418
1419 const uint32_t w4 = (width - 1) & 3;
1420 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1421 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
1422 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
1423 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
1424
1425 const uint32_t w8 = (width - 1) & 7;
1426 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
1427 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
1428 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
1429 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
1430 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
1431 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
1432 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
1433 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
1434 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1435 params->neon.min = output_min;
1436 params->neon.max = output_max;
1437
1438 const uint32_t w4 = (width - 1) & 3;
1439 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1440 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
1441 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
1442 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
1443
1444 const uint32_t w8 = (width - 1) & 7;
1445 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1446 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1447 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1448 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1449 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1450 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1451 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1452 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
1453 #else
1454 params->scalar.min = output_min;
1455 params->scalar.max = output_max;
1456
1457 const uint32_t w4 = (width - 1) & 3;
1458 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1459 params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1460 params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1461 params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1462
1463 const uint32_t w8 = (width - 1) & 7;
1464 params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1465 params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1466 params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1467 params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1468 params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1469 params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1470 params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1471 params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1472 #endif
1473}
1474
1475void xnn_update_f32_chw_params(
1476 union xnn_f32_chw_params* params,
1477 uint32_t width)
1478{
1479 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1480 const uint32_t w4 = (width - 1) & 3;
1481 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1482 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
1483 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
1484 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
1485
1486 const uint32_t w8 = (width - 1) & 7;
1487 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
1488 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
1489 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
1490 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
1491 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
1492 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
1493 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
1494 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
1495 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1496 const uint32_t w4 = (width - 1) & 3;
1497 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1498 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
1499 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
1500 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
1501
1502 const uint32_t w8 = (width - 1) & 7;
1503 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1504 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1505 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1506 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1507 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1508 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1509 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1510 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
1511 #else
1512 const uint32_t w4 = (width - 1) & 3;
1513 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1514 params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1515 params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1516 params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1517
1518 const uint32_t w8 = (width - 1) & 7;
1519 params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1520 params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1521 params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1522 params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1523 params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1524 params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1525 params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1526 params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1527 #endif
1528}
1529
1530void xnn_init_scalar_f32_chw_params(
1531 union xnn_f32_chw_params params[XNN_MIN_ELEMENTS(1)],
1532 uint32_t width,
1533 float output_min,
1534 float output_max)
1535{
1536 params->scalar.min = output_min;
1537 params->scalar.max = output_max;
1538
1539 const uint32_t w4 = (width - 1) & 3;
1540 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1541 params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1542 params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1543 params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1544
1545 const uint32_t w8 = (width - 1) & 7;
1546 params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1547 params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1548 params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1549 params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1550 params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1551 params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1552 params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1553 params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1554}
1555
1556void xnn_init_u8_minmax_params(
1557 union xnn_u8_minmax_params params[XNN_MIN_ELEMENTS(1)],
1558 uint8_t output_min,
1559 uint8_t output_max)
1560{
1561 assert(output_min < output_max);
1562
1563 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1564 for (uint32_t i = 0; i < 16; i++) {
1565 params->sse2.min[i] = output_min;
1566 params->sse2.max[i] = output_max;
1567 }
1568 #else
1569 params->scalar.min = output_min;
1570 params->scalar.max = output_max;
1571 #endif
1572}
1573
1574void xnn_init_scalar_u8_minmax_params(
1575 union xnn_u8_minmax_params params[XNN_MIN_ELEMENTS(1)],
1576 uint8_t output_min,
1577 uint8_t output_max)
1578{
1579 assert(output_min < output_max);
1580
1581 params->scalar.min = (int32_t) (uint32_t) output_min;
1582 params->scalar.max = (int32_t) (uint32_t) output_max;
1583}
1584
1585void xnn_init_qu8_add_params(
1586 union xnn_qu8_add_params params[XNN_MIN_ELEMENTS(1)],
1587 uint8_t a_zero_point,
1588 uint8_t b_zero_point,
1589 uint8_t output_zero_point,
1590 float a_output_scale,
1591 float b_output_scale,
1592 uint8_t output_min,
1593 uint8_t output_max)
1594{
1595 assert(a_output_scale >= 0x1.0p-14f);
1596 assert(b_output_scale >= 0x1.0p-14f);
1597 assert(a_output_scale < 0x1.0p+8f);
1598 assert(b_output_scale < 0x1.0p+8f);
1599
1600 // Compute requantization parameters.
1601 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
1602 assert(max_output_scale >= 0x1.0p-14f);
1603 assert(max_output_scale < 0x1.0p+8f);
1604 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1605 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1606 // Shift is in [13, 31] range.
1607 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1608 assert(shift < 32);
1609 assert(shift >= 13);
1610
1611 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1612
1613 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1614 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(a_output_scale * scale_multiplier);
1615 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(b_output_scale * scale_multiplier);
1616 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1617 assert(a_multiplier < UINT32_C(0x00400000));
1618 assert(b_multiplier < UINT32_C(0x00400000));
1619
1620 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1621 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1622 const uint32_t remainder_threshold = remainder_mask >> 1;
1623 const int32_t zero_point_product =
1624 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1625 for (uint32_t i = 0; i < 4; i++) {
1626 params->sse2.zero_point_product[i] = zero_point_product;
1627 }
1628 for (uint32_t i = 0; i < 8; i++) {
1629 params->sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
1630 }
1631 for (uint32_t i = 0; i < 8; i++) {
1632 params->sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
1633 params->sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
1634 params->sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
1635 params->sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
1636 }
1637 params->sse2.a_multiplier = a_multiplier;
1638 params->sse2.b_multiplier = b_multiplier;
1639 for (uint32_t i = 0; i < 4; i++) {
1640 params->sse2.remainder_mask[i] = remainder_mask;
1641 params->sse2.remainder_threshold[i] = remainder_threshold;
1642 }
1643 params->sse2.shift = shift;
1644 for (uint32_t i = 0; i < 16; i++) {
1645 params->sse2.y_min[i] = output_min;
1646 params->sse2.y_max[i] = output_max;
1647 }
1648 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1649 params->neon.a_zero_point = a_zero_point;
1650 params->neon.b_zero_point = b_zero_point;
1651 params->neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
1652 params->neon.a_multiplier = (int32_t) a_multiplier;
1653 params->neon.b_multiplier = (int32_t) b_multiplier;
1654 params->neon.right_shift = (int32_t) -shift;
1655 params->neon.y_min = output_min;
1656 params->neon.y_max = output_max;
1657 #else
1658 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1659 const uint32_t remainder_threshold = remainder_mask >> 1;
1660 params->scalar.zero_point_product =
1661 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1662 params->scalar.a_multiplier = a_multiplier;
1663 params->scalar.b_multiplier = b_multiplier;
1664 params->scalar.remainder_mask = (int32_t) remainder_mask;
1665 params->scalar.remainder_threshold = (int32_t) remainder_threshold;
1666 params->scalar.shift = shift;
1667 params->scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1668 params->scalar.y_min = (int32_t) (uint32_t) output_min;
1669 params->scalar.y_max = (int32_t) (uint32_t) output_max;
1670 #endif
1671}
1672
1673void xnn_init_scalar_qu8_add_params(
1674 union xnn_qu8_add_params params[XNN_MIN_ELEMENTS(1)],
1675 uint8_t a_zero_point,
1676 uint8_t b_zero_point,
1677 uint8_t output_zero_point,
1678 float a_output_scale,
1679 float b_output_scale,
1680 uint8_t output_min,
1681 uint8_t output_max)
1682{
1683 assert(a_output_scale >= 0x1.0p-10f);
1684 assert(b_output_scale >= 0x1.0p-10f);
1685 assert(a_output_scale < 0x1.0p+8f);
1686 assert(b_output_scale < 0x1.0p+8f);
1687
1688 // Compute requantization parameters.
1689 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
1690 assert(max_output_scale >= 0x1.0p-10f);
1691 assert(max_output_scale < 0x1.0p+8f);
1692 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1693 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1694 // Shift is in [13, 31] range.
1695 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1696 assert(shift < 32);
1697 assert(shift >= 13);
1698
1699 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1700 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1701 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
1702 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1703 assert(a_multiplier < UINT32_C(0x00400000));
1704 assert(b_multiplier < UINT32_C(0x00400000));
1705
1706 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1707 const uint32_t remainder_threshold = remainder_mask >> 1;
1708 params->scalar.zero_point_product =
1709 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1710 params->scalar.a_multiplier = a_multiplier;
1711 params->scalar.b_multiplier = b_multiplier;
1712 params->scalar.remainder_mask = (int32_t) remainder_mask;
1713 params->scalar.remainder_threshold = (int32_t) remainder_threshold;
1714 params->scalar.shift = shift;
1715 params->scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1716 params->scalar.y_min = (int32_t) (uint32_t) output_min;
1717 params->scalar.y_max = (int32_t) (uint32_t) output_max;
1718}
1719
1720void xnn_init_qs8_add_params(
1721 union xnn_qs8_add_params params[XNN_MIN_ELEMENTS(1)],
1722 int8_t x_zero_point,
1723 int8_t y_zero_point,
1724 int8_t output_zero_point,
1725 float x_output_scale,
1726 float y_output_scale,
1727 int8_t output_min,
1728 int8_t output_max)
1729{
1730 assert(x_output_scale >= 0x1.0p-14f);
1731 assert(y_output_scale >= 0x1.0p-14f);
1732 assert(x_output_scale < 0x1.0p+8f);
1733 assert(y_output_scale < 0x1.0p+8f);
1734
1735 // Compute requantization parameters.
1736 const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1737 assert(max_output_scale >= 0x1.0p-14f);
1738 assert(max_output_scale < 0x1.0p+8f);
1739 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1740 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1741 // Shift is in [13, 31] range.
1742 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1743 assert(shift < 32);
1744 assert(shift >= 13);
1745
1746 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1747
1748 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1749 const int32_t x_multiplier = (int32_t) lrintf(x_output_scale * scale_multiplier);
1750 const int32_t y_multiplier = (int32_t) lrintf(y_output_scale * scale_multiplier);
1751 assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1752 assert(x_multiplier < INT32_C(0x00400000));
1753 assert(y_multiplier < INT32_C(0x00400000));
1754
1755 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1756 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1757 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1758 const int32_t zero_point_product =
1759 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1760 for (uint32_t i = 0; i < 4; i++) {
1761 params->sse2.zero_point_product[i] = zero_point_product;
1762 }
1763 const uint16_t x_multiplier_lo = (uint16_t) x_multiplier;
1764 const uint16_t x_multiplier_hi = (uint16_t) ((uint32_t) x_multiplier >> 16);
1765 const uint16_t y_multiplier_lo = (uint16_t) y_multiplier;
1766 const uint16_t y_multiplier_hi = (uint16_t) ((uint32_t) y_multiplier >> 16);
1767 for (uint32_t i = 0; i < 8; i++) {
1768 params->sse2.x_multiplier_lo[i] = x_multiplier_lo;
1769 params->sse2.x_multiplier_hi[i] = x_multiplier_hi;
1770 params->sse2.y_multiplier_lo[i] = y_multiplier_lo;
1771 params->sse2.y_multiplier_hi[i] = y_multiplier_hi;
1772 }
1773 params->sse2.shift = shift;
1774 for (uint32_t i = 0; i < 4; i++) {
1775 params->sse2.x_multiplier[i] = x_multiplier;
1776 params->sse2.y_multiplier[i] = y_multiplier;
1777 params->sse2.remainder_mask[i] = remainder_mask;
1778 params->sse2.remainder_threshold[i] = remainder_threshold;
1779 }
1780 for (uint32_t i = 0; i < 8; i++) {
1781 params->sse2.output_zero_point[i] = (int16_t) output_zero_point;
1782 params->sse2.output_min[i] = (int16_t) output_min;
1783 params->sse2.output_max[i] = (int16_t) output_max;
1784 }
1785 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1786 params->neon.x_zero_point = x_zero_point;
1787 params->neon.y_zero_point = y_zero_point;
1788 params->neon.x_multiplier = (int32_t) x_multiplier;
1789 params->neon.y_multiplier = (int32_t) y_multiplier;
1790 params->neon.right_shift = (int32_t) -shift;
1791 params->neon.output_zero_point = (int16_t) output_zero_point;
1792 params->neon.output_min = output_min;
1793 params->neon.output_max = output_max;
1794 #elif XNN_ARCH_WASMSIMD
1795 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1796 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1797 const int32_t zero_point_product =
1798 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1799 for (uint32_t i = 0; i < 4; i++) {
1800 params->wasmsimd.zero_point_product[i] = zero_point_product;
1801 params->wasmsimd.x_multiplier[i] = x_multiplier;
1802 params->wasmsimd.y_multiplier[i] = y_multiplier;
1803 params->wasmsimd.remainder_mask[i] = remainder_mask;
1804 params->wasmsimd.remainder_threshold[i] = remainder_threshold;
1805 }
1806 params->wasmsimd.shift = shift;
1807 for (uint32_t i = 0; i < 8; i++) {
1808 params->wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
1809 }
1810 for (uint32_t i = 0; i < 16; i++) {
1811 params->wasmsimd.output_min[i] = output_min;
1812 params->wasmsimd.output_max[i] = output_max;
1813 }
1814 #else
1815 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1816 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1817 params->scalar.zero_point_product =
1818 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1819 params->scalar.x_multiplier = x_multiplier;
1820 params->scalar.y_multiplier = y_multiplier;
1821 params->scalar.remainder_mask = (int32_t) remainder_mask;
1822 params->scalar.remainder_threshold = (int32_t) remainder_threshold;
1823 params->scalar.shift = (int32_t) shift;
1824 params->scalar.output_zero_point = (int32_t) output_zero_point;
1825 params->scalar.output_min = (int32_t) output_min;
1826 params->scalar.output_max = (int32_t) output_max;
1827 #endif
1828}
1829
1830void xnn_init_scalar_qs8_add_params(
1831 union xnn_qs8_add_params params[XNN_MIN_ELEMENTS(1)],
1832 int8_t x_zero_point,
1833 int8_t y_zero_point,
1834 int8_t output_zero_point,
1835 float x_output_scale,
1836 float y_output_scale,
1837 int8_t output_min,
1838 int8_t output_max)
1839{
1840 assert(x_output_scale >= 0x1.0p-10f);
1841 assert(y_output_scale >= 0x1.0p-10f);
1842 assert(x_output_scale < 0x1.0p+8f);
1843 assert(y_output_scale < 0x1.0p+8f);
1844
1845 // Compute requantization parameters.
1846 const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1847 assert(max_output_scale >= 0x1.0p-10f);
1848 assert(max_output_scale < 0x1.0p+8f);
1849 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1850 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1851 // Shift is in [13, 31] range.
1852 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1853 assert(shift < 32);
1854 assert(shift >= 13);
1855
1856 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1857 const int32_t x_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(x_output_scale) + (shift << 23)));
1858 const int32_t y_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(y_output_scale) + (shift << 23)));
1859 assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1860 assert(x_multiplier < INT32_C(0x00400000));
1861 assert(y_multiplier < INT32_C(0x00400000));
1862
1863 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1864 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1865 params->scalar.zero_point_product =
1866 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1867 params->scalar.x_multiplier = x_multiplier;
1868 params->scalar.y_multiplier = y_multiplier;
1869 params->scalar.remainder_mask = (int32_t) remainder_mask;
1870 params->scalar.remainder_threshold = (int32_t) remainder_threshold;
1871 params->scalar.shift = shift;
1872 params->scalar.output_zero_point = (int32_t) output_zero_point;
1873 params->scalar.output_min = (int32_t) output_min;
1874 params->scalar.output_max = (int32_t) output_max;
1875}