blob: a2c3509249148e5fdbbf6223ddd10510a467bf1a [file] [log] [blame]
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <stdint.h>
7#include <stddef.h>
8#include <assert.h>
9#include <math.h>
10
11#include <fp16.h>
12
13#include <xnnpack/math.h>
14#include <xnnpack/params-init.h>
15
16
Marat Dukhanc698c112021-07-01 18:52:10 -070017void xnn_init_qu8_conv_minmax_gemmlowp_scalar_params(
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070018 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
19 uint8_t kernel_zero_point,
20 float scale,
21 uint8_t output_zero_point,
22 uint8_t output_min,
23 uint8_t output_max)
24{
25 // Compute requantization parameters
26 const uint32_t scale_bits = fp32_to_bits(scale);
27
28 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
Marat Dukhanc698c112021-07-01 18:52:10 -070029 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070030 assert(multiplier >= INT32_C(0x40000000));
31 assert(multiplier <= INT32_C(0x7FFFFF80));
32
33 // Shift is in [0, 31] range.
34 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
35 assert(shift >= 0);
36 assert(shift < 32);
37
38 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
39 const uint32_t remainder_threshold = remainder_mask >> 1;
40
Marat Dukhanc698c112021-07-01 18:52:10 -070041 params->gemmlowp_scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
42 params->gemmlowp_scalar.multiplier = multiplier;
43 params->gemmlowp_scalar.remainder_mask = (int32_t) remainder_mask;
44 params->gemmlowp_scalar.remainder_threshold = (int32_t) remainder_threshold;
45 params->gemmlowp_scalar.shift = (uint32_t) shift;
46 params->gemmlowp_scalar.output_min_less_zero_point =
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070047 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
Marat Dukhanc698c112021-07-01 18:52:10 -070048 params->gemmlowp_scalar.output_max_less_zero_point =
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070049 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
Marat Dukhanc698c112021-07-01 18:52:10 -070050 params->gemmlowp_scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070051}
52
Marat Dukhan927d4742021-07-15 13:42:49 -070053void xnn_init_qu8_conv_minmax_fp32_scalar_lrint_params(
54 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
55 uint8_t kernel_zero_point,
56 float scale,
57 uint8_t output_zero_point,
58 uint8_t output_min,
59 uint8_t output_max)
60{
61 params->fp32_scalar_lrint.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
62 params->fp32_scalar_lrint.scale = scale;
63 params->fp32_scalar_lrint.output_min_less_zero_point = (long) (int32_t) ((uint32_t) output_min - (uint32_t) output_zero_point);
64 params->fp32_scalar_lrint.output_max_less_zero_point = (long) (int32_t) ((uint32_t) output_max - (uint32_t) output_zero_point);
65 params->fp32_scalar_lrint.output_zero_point = (int32_t) (uint32_t) output_zero_point;
66}
67
68void xnn_init_qu8_conv_minmax_fp32_scalar_magic_params(
69 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
70 uint8_t kernel_zero_point,
71 float scale,
72 uint8_t output_zero_point,
73 uint8_t output_min,
74 uint8_t output_max)
75{
76 params->fp32_scalar_magic.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
77 params->fp32_scalar_magic.scale = scale;
78 params->fp32_scalar_magic.output_min_less_zero_point = (float) (int32_t) ((uint32_t) output_min - (uint32_t) output_zero_point);
79 params->fp32_scalar_magic.output_max_less_zero_point = (float) (int32_t) ((uint32_t) output_max - (uint32_t) output_zero_point);
80 params->fp32_scalar_magic.magic_bias = 12582912.0f;
81 params->fp32_scalar_magic.magic_bias_less_output_zero_point = INT32_C(0x4B400000) - (int32_t) (uint32_t) output_zero_point;
82}
83
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070084#if XNN_ARCH_X86 || XNN_ARCH_X86_64
Marat Dukhanc698c112021-07-01 18:52:10 -070085void xnn_init_qu8_conv_minmax_gemmlowp_sse2_params(
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070086 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
87 uint8_t kernel_zero_point,
88 float scale,
89 uint8_t output_zero_point,
90 uint8_t output_min,
91 uint8_t output_max)
92{
93 // Compute requantization parameters.
94 const uint32_t scale_bits = fp32_to_bits(scale);
95
96 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
Marat Dukhanc698c112021-07-01 18:52:10 -070097 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -070098 assert(multiplier >= INT32_C(0x40000000));
99 assert(multiplier <= INT32_C(0x7FFFFF80));
100
101 // Shift is in [0, 31] range.
102 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
103 assert(shift >= 0);
104 assert(shift < 32);
105
106 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
107 const uint32_t remainder_threshold = remainder_mask >> 1;
108 for (uint32_t i = 0; i < 8; i++) {
Marat Dukhanc698c112021-07-01 18:52:10 -0700109 params->gemmlowp_sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700110 }
Marat Dukhanc698c112021-07-01 18:52:10 -0700111 params->gemmlowp_sse2.multiplier[0] = multiplier;
112 params->gemmlowp_sse2.multiplier[1] = multiplier;
113 params->gemmlowp_sse2.multiplier[2] = multiplier;
114 params->gemmlowp_sse2.multiplier[3] = multiplier;
115 params->gemmlowp_sse2.rounding[0] = UINT64_C(0x40000000);
116 params->gemmlowp_sse2.rounding[1] = UINT64_C(0x40000000);
117 params->gemmlowp_sse2.remainder_mask[0] = (int32_t) remainder_mask;
118 params->gemmlowp_sse2.remainder_mask[1] = (int32_t) remainder_mask;
119 params->gemmlowp_sse2.remainder_mask[2] = (int32_t) remainder_mask;
120 params->gemmlowp_sse2.remainder_mask[3] = (int32_t) remainder_mask;
121 params->gemmlowp_sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
122 params->gemmlowp_sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
123 params->gemmlowp_sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
124 params->gemmlowp_sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
125 params->gemmlowp_sse2.shift[0] = (uint64_t) (uint32_t) shift;
126 params->gemmlowp_sse2.shift[1] = (uint64_t) (uint32_t) shift;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700127 for (uint32_t i = 0; i < 8; i++) {
Marat Dukhanc698c112021-07-01 18:52:10 -0700128 params->gemmlowp_sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700129 }
130 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhanc698c112021-07-01 18:52:10 -0700131 params->gemmlowp_sse2.output_min[i] = output_min;
132 params->gemmlowp_sse2.output_max[i] = output_max;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700133 }
134}
Marat Dukhanef47f8d2021-07-02 15:08:32 -0700135
136void xnn_init_qu8_conv_minmax_fp32_sse2_params(
137 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
138 uint8_t kernel_zero_point,
139 float scale,
140 uint8_t output_zero_point,
141 uint8_t output_min,
142 uint8_t output_max)
143{
144 for (uint32_t i = 0; i < 4; i++) {
145 params->fp32_sse2.scale[i] = scale;
146 }
147 for (uint32_t i = 0; i < 8; i++) {
148 params->fp32_sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
149 params->fp32_sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
150 }
151 for (uint32_t i = 0; i < 16; i++) {
152 params->fp32_sse2.output_min[i] = output_min;
153 params->fp32_sse2.output_max[i] = output_max;
154 }
155}
Marat Dukhan902ef7f2021-07-02 16:11:06 -0700156
157void xnn_init_qu8_conv_minmax_fp32_avx2_params(
158 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
159 uint8_t kernel_zero_point,
160 float scale,
161 uint8_t output_zero_point,
162 uint8_t output_min,
163 uint8_t output_max)
164{
165 for (uint32_t i = 0; i < 8; i++) {
166 params->fp32_avx2.scale[i] = scale;
167 }
168 for (uint32_t i = 0; i < 16; i++) {
169 params->fp32_avx2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
170 params->fp32_avx2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
171 }
172 for (uint32_t i = 0; i < 32; i++) {
173 params->fp32_avx2.output_min[i] = output_min;
174 params->fp32_avx2.output_max[i] = output_max;
175 }
176}
Marat Dukhan3cf2e222021-07-08 11:38:45 -0700177
178void xnn_init_qu8_conv_minmax_fp32_avx512_params(
179 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
180 uint8_t kernel_zero_point,
181 float scale,
182 uint8_t output_zero_point,
183 uint8_t output_min,
184 uint8_t output_max)
185{
186 for (uint32_t i = 0; i < 16; i++) {
187 params->fp32_avx512.scale[i] = scale;
188 }
189 for (uint32_t i = 0; i < 32; i++) {
190 params->fp32_avx512.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
191 params->fp32_avx512.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
192 }
193 for (uint32_t i = 0; i < 64; i++) {
194 params->fp32_avx512.output_min[i] = output_min;
195 params->fp32_avx512.output_max[i] = output_max;
196 }
197}
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700198#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
199
200#if XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhanc698c112021-07-01 18:52:10 -0700201void xnn_init_qu8_conv_minmax_gemmlowp_neon_params(
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700202 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
203 uint8_t kernel_zero_point,
204 float scale,
205 uint8_t output_zero_point,
206 uint8_t output_min,
207 uint8_t output_max)
208{
209 // Compute requantization parameters.
210 const uint32_t scale_bits = fp32_to_bits(scale);
211
212 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
Marat Dukhanc698c112021-07-01 18:52:10 -0700213 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700214 assert(multiplier >= INT32_C(0x40000000));
215 assert(multiplier <= INT32_C(0x7FFFFF80));
216
217 // Shift is in [0, 31] range.
218 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
219 assert(shift >= 0);
220 assert(shift < 32);
221
Marat Dukhan69c8a292021-07-14 19:34:56 -0700222 params->gemmlowp_neon.kernel_zero_point = kernel_zero_point;
Marat Dukhanc698c112021-07-01 18:52:10 -0700223 params->gemmlowp_neon.multiplier = multiplier;
224 params->gemmlowp_neon.right_shift = -shift;
225 params->gemmlowp_neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
226 params->gemmlowp_neon.output_min = output_min;
227 params->gemmlowp_neon.output_max = output_max;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700228}
Marat Dukhan69c8a292021-07-14 19:34:56 -0700229
230void xnn_init_qu8_conv_minmax_fp32_neon_params(
231 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
232 uint8_t kernel_zero_point,
233 float scale,
234 uint8_t output_zero_point,
235 uint8_t output_min,
236 uint8_t output_max)
237{
238 params->fp32_neon.kernel_zero_point = kernel_zero_point;
239 params->fp32_neon.scale = scale;
240 params->fp32_neon.output_min_less_zero_point = (float) (int32_t) ((uint32_t) output_min - (uint32_t) output_zero_point);
241 params->fp32_neon.output_max_less_zero_point = (float) (int32_t) ((uint32_t) output_max - (uint32_t) output_zero_point);
242 params->fp32_neon.magic_bias = 12582912.0f;
243 params->fp32_neon.magic_bias_less_zero_point = INT32_C(0x4B400000) - (int32_t) (uint32_t) output_zero_point;
244}
245
246void xnn_init_qu8_conv_minmax_fp32_neonv8_params(
247 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
248 uint8_t kernel_zero_point,
249 float scale,
250 uint8_t output_zero_point,
251 uint8_t output_min,
252 uint8_t output_max)
253{
254 params->fp32_neonv8.kernel_zero_point = kernel_zero_point;
255 params->fp32_neonv8.scale = scale;
256 params->fp32_neonv8.output_zero_point = (int16_t) (uint16_t) output_zero_point;
257 params->fp32_neonv8.output_min = output_min;
258 params->fp32_neonv8.output_max = output_max;
259}
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700260#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
261
Marat Dukhan43bee052021-07-14 20:57:18 -0700262#if XNN_ARCH_WASMSIMD
263void xnn_init_qu8_conv_minmax_fp32_wasmsimd_params(
264 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
265 uint8_t kernel_zero_point,
266 float scale,
267 uint8_t output_zero_point,
268 uint8_t output_min,
269 uint8_t output_max)
270{
271 for (uint32_t i = 0; i < 8; i++) {
272 params->fp32_wasmsimd.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
273 }
274 for (uint32_t i = 0; i < 4; i++) {
275 params->fp32_wasmsimd.scale[i] = scale;
276 params->fp32_wasmsimd.output_min_less_zero_point[i] = (float) (int32_t) ((uint32_t) output_min - (uint32_t) output_zero_point);
277 params->fp32_wasmsimd.output_max_less_zero_point[i] = (float) (int32_t) ((uint32_t) output_max - (uint32_t) output_zero_point);
278 params->fp32_wasmsimd.magic_bias[i] = 12582912.0f;
279 params->fp32_wasmsimd.magic_bias_less_output_zero_point[i] = INT32_C(0x4B400000) - (int32_t) (uint32_t) output_zero_point;
280 }
281}
282#endif // XNN_ARCH_WASMSIMD
283
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700284void xnn_init_qs8_conv_minmax_gemmlowp_scalar_params(
285 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
286 float scale,
287 int8_t output_zero_point,
288 int8_t output_min,
289 int8_t output_max)
290{
291 // Compute requantization parameters
292 const uint32_t scale_bits = fp32_to_bits(scale);
293
294 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
295 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
296 assert(multiplier >= INT32_C(0x40000000));
297 assert(multiplier <= INT32_C(0x7FFFFF80));
298
299 // Shift is in [0, 31] range.
300 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
301 assert(shift >= 0);
302 assert(shift < 32);
303
304 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
305 const uint32_t remainder_threshold = remainder_mask >> 1;
306
307 params->gemmlowp_scalar.multiplier = multiplier;
308 params->gemmlowp_scalar.remainder_mask = (int32_t) remainder_mask;
309 params->gemmlowp_scalar.remainder_threshold = (int32_t) remainder_threshold;
310 params->gemmlowp_scalar.shift = (uint32_t) shift;
311 params->gemmlowp_scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
312 params->gemmlowp_scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
313 params->gemmlowp_scalar.output_zero_point = (int32_t) output_zero_point;
314}
315
316void xnn_init_qs8_conv_minmax_fp32_scalar_lrint_params(
317 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
318 float scale,
319 int8_t output_zero_point,
320 int8_t output_min,
321 int8_t output_max)
322{
323 params->fp32_scalar_lrint.scale = scale;
324 params->fp32_scalar_lrint.output_min_less_zero_point = (long) ((int32_t) output_min - (int32_t) output_zero_point);
325 params->fp32_scalar_lrint.output_max_less_zero_point = (long) ((int32_t) output_max - (int32_t) output_zero_point);
326 params->fp32_scalar_lrint.output_zero_point = (int32_t) output_zero_point;
327}
328
329void xnn_init_qs8_conv_minmax_fp32_scalar_magic_params(
330 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
331 float scale,
332 int8_t output_zero_point,
333 int8_t output_min,
334 int8_t output_max)
335{
336 params->fp32_scalar_magic.scale = scale;
337 params->fp32_scalar_magic.output_min_less_zero_point = (float) ((int32_t) output_min - (int32_t) output_zero_point);
338 params->fp32_scalar_magic.output_max_less_zero_point = (float) ((int32_t) output_max - (int32_t) output_zero_point);
339 params->fp32_scalar_magic.magic_bias = 12582912.0f;
340 params->fp32_scalar_magic.magic_bias_less_output_zero_point = INT32_C(0x4B400000) - (int32_t) output_zero_point;
341}
342
343#if XNN_ARCH_X86 || XNN_ARCH_X86_64
344void xnn_init_qs8_conv_minmax_gemmlowp_sse2_params(
345 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
346 float scale,
347 int8_t output_zero_point,
348 int8_t output_min,
349 int8_t output_max)
350{
351 // Compute requantization parameters.
352 const uint32_t scale_bits = fp32_to_bits(scale);
353
354 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
355 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
356 assert(multiplier >= INT32_C(0x40000000));
357 assert(multiplier <= INT32_C(0x7FFFFF80));
358
359 // Shift is in [0, 31] range.
360 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
361 assert(shift >= 0);
362 assert(shift < 32);
363
364 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
365 const uint32_t remainder_threshold = remainder_mask >> 1;
366 params->gemmlowp_sse2.multiplier[0] = multiplier;
367 params->gemmlowp_sse2.multiplier[1] = multiplier;
368 params->gemmlowp_sse2.multiplier[2] = multiplier;
369 params->gemmlowp_sse2.multiplier[3] = multiplier;
370 params->gemmlowp_sse2.rounding[0] = UINT64_C(0x40000000);
371 params->gemmlowp_sse2.rounding[1] = UINT64_C(0x40000000);
372 params->gemmlowp_sse2.remainder_mask[0] = (int32_t) remainder_mask;
373 params->gemmlowp_sse2.remainder_mask[1] = (int32_t) remainder_mask;
374 params->gemmlowp_sse2.remainder_mask[2] = (int32_t) remainder_mask;
375 params->gemmlowp_sse2.remainder_mask[3] = (int32_t) remainder_mask;
376 params->gemmlowp_sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
377 params->gemmlowp_sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
378 params->gemmlowp_sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
379 params->gemmlowp_sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
380 params->gemmlowp_sse2.shift[0] = (uint64_t) (uint32_t) shift;
381 params->gemmlowp_sse2.shift[1] = (uint64_t) (uint32_t) shift;
382 for (uint32_t i = 0; i < 8; i++) {
383 params->gemmlowp_sse2.output_zero_point[i] = (int16_t) output_zero_point;
384 params->gemmlowp_sse2.output_min[i] = (int16_t) output_min;
385 params->gemmlowp_sse2.output_max[i] = (int16_t) output_max;
386 }
387}
388
389void xnn_init_qs8_conv_minmax_gemmlowp_sse4_params(
390 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
391 float scale,
392 int8_t output_zero_point,
393 int8_t output_min,
394 int8_t output_max)
395{
396 // Compute requantization parameters.
397 const uint32_t scale_bits = fp32_to_bits(scale);
398
399 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
400 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
401 assert(multiplier >= INT32_C(0x40000000));
402 assert(multiplier <= INT32_C(0x7FFFFF80));
403
404 // Shift is in [0, 31] range.
405 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
406 assert(shift >= 0);
407 assert(shift < 32);
408
409 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
410 const uint32_t remainder_threshold = remainder_mask >> 1;
411 params->gemmlowp_sse4.multiplier[0] = multiplier;
412 params->gemmlowp_sse4.multiplier[1] = multiplier;
413 params->gemmlowp_sse4.multiplier[2] = multiplier;
414 params->gemmlowp_sse4.multiplier[3] = multiplier;
415 params->gemmlowp_sse4.rounding[0] = UINT64_C(0x40000000);
416 params->gemmlowp_sse4.rounding[1] = UINT64_C(0x40000000);
417 params->gemmlowp_sse4.remainder_mask[0] = (int32_t) remainder_mask;
418 params->gemmlowp_sse4.remainder_mask[1] = (int32_t) remainder_mask;
419 params->gemmlowp_sse4.remainder_mask[2] = (int32_t) remainder_mask;
420 params->gemmlowp_sse4.remainder_mask[3] = (int32_t) remainder_mask;
421 params->gemmlowp_sse4.remainder_threshold[0] = (int32_t) remainder_threshold;
422 params->gemmlowp_sse4.remainder_threshold[1] = (int32_t) remainder_threshold;
423 params->gemmlowp_sse4.remainder_threshold[2] = (int32_t) remainder_threshold;
424 params->gemmlowp_sse4.remainder_threshold[3] = (int32_t) remainder_threshold;
425 params->gemmlowp_sse4.shift[0] = (uint64_t) (uint32_t) shift;
426 params->gemmlowp_sse4.shift[1] = (uint64_t) (uint32_t) shift;
427 for (uint32_t i = 0; i < 8; i++) {
428 params->gemmlowp_sse4.output_zero_point[i] = (int16_t) output_zero_point;
429 }
430 for (uint32_t i = 0; i < 16; i++) {
431 params->gemmlowp_sse4.output_min[i] = output_min;
432 params->gemmlowp_sse4.output_max[i] = output_max;
433 }
434}
435
436void xnn_init_qs8_conv_minmax_gemmlowp_avx2_params(
437 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
438 float scale,
439 int8_t output_zero_point,
440 int8_t output_min,
441 int8_t output_max)
442{
443 // Compute requantization parameters.
444 const uint32_t scale_bits = fp32_to_bits(scale);
445
446 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
447 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
448 assert(multiplier >= INT32_C(0x40000000));
449 assert(multiplier <= INT32_C(0x7FFFFF80));
450
451 // Shift is in [0, 31] range.
452 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
453 assert(shift >= 0);
454 assert(shift < 32);
455
456 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
457 const uint32_t remainder_threshold = remainder_mask >> 1;
458 for (uint32_t i = 0; i < 8; i++) {
459 params->gemmlowp_avx2.multiplier[i] = multiplier;
460 }
461 params->gemmlowp_avx2.rounding[0] = UINT64_C(0x40000000);
462 params->gemmlowp_avx2.rounding[1] = UINT64_C(0x40000000);
463 params->gemmlowp_avx2.rounding[2] = UINT64_C(0x40000000);
464 params->gemmlowp_avx2.rounding[3] = UINT64_C(0x40000000);
465 for (uint32_t i = 0; i < 8; i++) {
466 params->gemmlowp_avx2.remainder_mask[i] = (int32_t) remainder_mask;
467 params->gemmlowp_avx2.remainder_threshold[i] = (int32_t) remainder_threshold;
468 }
469 params->gemmlowp_avx2.shift[0] = (uint64_t) (uint32_t) shift;
470 params->gemmlowp_avx2.shift[1] = (uint64_t) (uint32_t) shift;
471 params->gemmlowp_avx2.shift[2] = (uint64_t) (uint32_t) shift;
472 params->gemmlowp_avx2.shift[3] = (uint64_t) (uint32_t) shift;
473 for (uint32_t i = 0; i < 16; i++) {
474 params->gemmlowp_avx2.output_zero_point[i] = (int16_t) output_zero_point;
475 }
476 for (uint32_t i = 0; i < 32; i++) {
477 params->gemmlowp_avx2.output_min[i] = output_min;
478 params->gemmlowp_avx2.output_max[i] = output_max;
479 }
480}
481
482void xnn_init_qs8_conv_minmax_gemmlowp_avx512_params(
483 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
484 float scale,
485 int8_t output_zero_point,
486 int8_t output_min,
487 int8_t output_max)
488{
489 // Compute requantization parameters.
490 const uint32_t scale_bits = fp32_to_bits(scale);
491
492 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
493 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
494 assert(multiplier >= INT32_C(0x40000000));
495 assert(multiplier <= INT32_C(0x7FFFFF80));
496
497 // Shift is in [0, 31] range.
498 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
499 assert(shift >= 0);
500 assert(shift < 32);
501
502 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
503 const uint32_t remainder_threshold = remainder_mask >> 1;
504 params->gemmlowp_avx512.multiplier = (int64_t) multiplier;
505 params->gemmlowp_avx512.rounding = UINT64_C(0x40000000);
506 params->gemmlowp_avx512.remainder_mask = (int32_t) remainder_mask;
507 params->gemmlowp_avx512.remainder_threshold = (int32_t) remainder_threshold;
508 params->gemmlowp_avx512.shift = (uint64_t) (uint32_t) shift;
509 for (uint32_t i = 0; i < 32; i++) {
510 params->gemmlowp_avx512.output_zero_point[i] = (int16_t) output_zero_point;
511 }
512 for (uint32_t i = 0; i < 64; i++) {
513 params->gemmlowp_avx512.output_min[i] = output_min;
514 params->gemmlowp_avx512.output_max[i] = output_max;
515 }
516}
517
518void xnn_init_qs8_conv_minmax_fp32_sse2_params(
519 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
520 float scale,
521 int8_t output_zero_point,
522 int8_t output_min,
523 int8_t output_max)
524{
525 for (uint32_t i = 0; i < 4; i++) {
526 params->fp32_sse2.scale[i] = scale;
527 }
528 for (uint32_t i = 0; i < 8; i++) {
529 params->fp32_sse2.output_zero_point[i] = (int16_t) output_zero_point;
530 params->fp32_sse2.output_min[i] = (int16_t) output_min;
531 params->fp32_sse2.output_max[i] = (int16_t) output_max;
532 }
533}
534
535void xnn_init_qs8_conv_minmax_fp32_sse4_params(
536 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
537 float scale,
538 int8_t output_zero_point,
539 int8_t output_min,
540 int8_t output_max)
541{
542 for (uint32_t i = 0; i < 4; i++) {
543 params->fp32_sse4.scale[i] = scale;
544 }
545 for (uint32_t i = 0; i < 8; i++) {
546 params->fp32_sse4.output_zero_point[i] = (int16_t) output_zero_point;
547 }
548 for (uint32_t i = 0; i < 16; i++) {
549 params->fp32_sse4.output_min[i] = output_min;
550 params->fp32_sse4.output_max[i] = output_max;
551 }
552}
553
554void xnn_init_qs8_conv_minmax_fp32_avx2_params(
555 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
556 float scale,
557 int8_t output_zero_point,
558 int8_t output_min,
559 int8_t output_max)
560{
561 for (uint32_t i = 0; i < 8; i++) {
562 params->fp32_avx2.scale[i] = scale;
563 }
564 for (uint32_t i = 0; i < 16; i++) {
565 params->fp32_avx2.output_zero_point[i] = (int16_t) output_zero_point;
566 }
567 for (uint32_t i = 0; i < 32; i++) {
568 params->fp32_avx2.output_min[i] = output_min;
569 params->fp32_avx2.output_max[i] = output_max;
570 }
571}
572
573void xnn_init_qs8_conv_minmax_fp32_avx512_params(
574 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
575 float scale,
576 int8_t output_zero_point,
577 int8_t output_min,
578 int8_t output_max)
579{
580 for (uint32_t i = 0; i < 16; i++) {
581 params->fp32_avx512.scale[i] = scale;
582 }
583 for (uint32_t i = 0; i < 32; i++) {
584 params->fp32_avx512.output_zero_point[i] = (int16_t) output_zero_point;
585 }
586 for (uint32_t i = 0; i < 64; i++) {
587 params->fp32_avx512.output_min[i] = output_min;
588 params->fp32_avx512.output_max[i] = output_max;
589 }
590}
591#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
592
593#if XNN_ARCH_ARM || XNN_ARCH_ARM64
594void xnn_init_qs8_conv_minmax_gemmlowp_neon_params(
595 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
596 float scale,
597 int8_t output_zero_point,
598 int8_t output_min,
599 int8_t output_max)
600{
601 // Compute requantization parameters.
602 const uint32_t scale_bits = fp32_to_bits(scale);
603
604 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
605 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
606 assert(multiplier >= INT32_C(0x40000000));
607 assert(multiplier <= INT32_C(0x7FFFFF80));
608
609 // Shift is in [0, 31] range.
610 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
611 assert(shift >= 0);
612 assert(shift < 32);
613
614 params->gemmlowp_neon.multiplier = multiplier;
615 params->gemmlowp_neon.right_shift = -shift;
616 params->gemmlowp_neon.output_zero_point = (int16_t) output_zero_point;
617 params->gemmlowp_neon.output_min = output_min;
618 params->gemmlowp_neon.output_max = output_max;
619}
620
621void xnn_init_qs8_conv_minmax_fp32_neon_params(
622 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
623 float scale,
624 int8_t output_zero_point,
625 int8_t output_min,
626 int8_t output_max)
627{
628 params->fp32_neon.scale = scale;
629 params->fp32_neon.output_min_less_zero_point = (float) ((int32_t) output_min - (int32_t) output_zero_point);
630 params->fp32_neon.output_max_less_zero_point = (float) ((int32_t) output_max - (int32_t) output_zero_point);
631 params->fp32_neon.magic_bias = 12582912.0f;
632 params->fp32_neon.magic_bias_less_zero_point = INT32_C(0x4B400000) - (int32_t) output_zero_point;
633}
634
635void xnn_init_qs8_conv_minmax_fp32_neonv8_params(
636 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
637 float scale,
638 int8_t output_zero_point,
639 int8_t output_min,
640 int8_t output_max)
641{
642 params->fp32_neonv8.scale = scale;
643 params->fp32_neonv8.output_zero_point = (int16_t) output_zero_point;
644 params->fp32_neonv8.output_min = output_min;
645 params->fp32_neonv8.output_max = output_max;
646}
Marat Dukhanbe18f5c2021-07-16 18:46:39 -0700647
648void xnn_init_qs8_conv_minmax_rndnu_neon_params(
649 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
650 float scale,
651 int8_t output_zero_point,
652 int8_t output_min,
653 int8_t output_max)
654{
655 // Compute requantization parameters.
656 const uint32_t scale_bits = fp32_to_bits(scale);
657
658 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
659 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
660 assert(multiplier >= INT32_C(0x40000000));
661 assert(multiplier <= INT32_C(0x7FFFFF80));
662
663 // Shift is in [0, 31] range.
664 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
665 assert(shift >= 0);
666 assert(shift < 32);
667
668 // Split shift into pre_shift + post_shift, post_shift in [1, 31] range.
669 const int32_t post_shift = math_max_s32(shift, 1);
670 const int32_t pre_shift = shift - post_shift;
671
672 params->rndnu_neon.right_pre_shift = -pre_shift;
673 params->rndnu_neon.multiplier = multiplier;
674 params->rndnu_neon.right_post_shift = -post_shift;
675 params->rndnu_neon.output_zero_point = (int16_t) output_zero_point;
676 params->rndnu_neon.output_min = output_min;
677 params->rndnu_neon.output_max = output_max;
678}
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700679#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
680
681#if XNN_ARCH_WASMSIMD
682void xnn_init_qs8_conv_minmax_gemmlowp_wasmsimd_params(
683 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
684 float scale,
685 int8_t output_zero_point,
686 int8_t output_min,
687 int8_t output_max)
688{
689 // Compute requantization parameters.
690 const uint32_t scale_bits = fp32_to_bits(scale);
691
692 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
693 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
694 assert(multiplier >= INT32_C(0x40000000));
695 assert(multiplier <= INT32_C(0x7FFFFF80));
696
697 // Shift is in [0, 31] range.
698 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
699 assert(shift >= 0);
700 assert(shift < 32);
701
702 const int64_t twice_multiplier = INT64_C(2) * (int64_t) multiplier;
703 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
704 const uint32_t remainder_threshold = remainder_mask >> 1;
705 params->gemmlowp_wasmsimd.multiplier[0] = twice_multiplier;
706 params->gemmlowp_wasmsimd.multiplier[1] = twice_multiplier;
707 params->gemmlowp_wasmsimd.rounding[0] = INT64_C(0x80000000);
708 params->gemmlowp_wasmsimd.rounding[1] = INT64_C(0x80000000);
709 params->gemmlowp_wasmsimd.remainder_mask[0] = (int32_t) remainder_mask;
710 params->gemmlowp_wasmsimd.remainder_mask[1] = (int32_t) remainder_mask;
711 params->gemmlowp_wasmsimd.remainder_mask[2] = (int32_t) remainder_mask;
712 params->gemmlowp_wasmsimd.remainder_mask[3] = (int32_t) remainder_mask;
713 params->gemmlowp_wasmsimd.remainder_threshold[0] = (int32_t) remainder_threshold;
714 params->gemmlowp_wasmsimd.remainder_threshold[1] = (int32_t) remainder_threshold;
715 params->gemmlowp_wasmsimd.remainder_threshold[2] = (int32_t) remainder_threshold;
716 params->gemmlowp_wasmsimd.remainder_threshold[3] = (int32_t) remainder_threshold;
717 params->gemmlowp_wasmsimd.shift = shift;
718 for (uint32_t i = 0; i < 8; i++) {
719 params->gemmlowp_wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
720 }
721 for (uint32_t i = 0; i < 16; i++) {
722 params->gemmlowp_wasmsimd.output_min[i] = output_min;
723 params->gemmlowp_wasmsimd.output_max[i] = output_max;
724 }
725}
Marat Dukhan4741e412021-06-30 13:38:06 -0700726
727void xnn_init_qs8_conv_minmax_fp32_wasmsimd_params(
728 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
729 float scale,
730 int8_t output_zero_point,
731 int8_t output_min,
732 int8_t output_max)
733{
734 for (uint32_t i = 0; i < 4; i++) {
735 params->fp32_wasmsimd.scale[i] = scale;
736 params->fp32_wasmsimd.output_min_less_zero_point[i] = (float) ((int32_t) output_min - (int32_t) output_zero_point);
737 params->fp32_wasmsimd.output_max_less_zero_point[i] = (float) ((int32_t) output_max - (int32_t) output_zero_point);
738 params->fp32_wasmsimd.magic_bias[i] = 12582912.0f;
739 params->fp32_wasmsimd.magic_bias_less_output_zero_point[i] = INT32_C(0x4B400000) - (int32_t) output_zero_point;
740 }
741}
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700742#endif // XNN_ARCH_WASMSIMD
743
744void xnn_init_qc8_scale_fp32_params(
745 size_t channels,
746 size_t channels_tile,
747 size_t stride,
748 const float scale[XNN_MIN_ELEMENTS(1)],
749 void* packed_w)
750{
751 for (size_t tile_start = 0; tile_start < channels; tile_start += channels_tile) {
752 const size_t tile_size = min(channels - tile_start, channels_tile);
753 for (size_t tile_offset = 0; tile_offset < tile_size; tile_offset++) {
754 ((float*) packed_w)[tile_offset] = scale[tile_start + tile_offset];
755 }
756 packed_w = (void*) ((uintptr_t) packed_w + stride);
757 }
758}
759
Marat Dukhand6021542021-06-30 09:04:20 -0700760void xnn_init_qs8_minmax_scalar_lrint_params(
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700761 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
762 int8_t output_zero_point,
763 int8_t output_min,
764 int8_t output_max)
765{
Marat Dukhand6021542021-06-30 09:04:20 -0700766 params->scalar_lrint.output_min_less_zero_point = (long) ((int32_t) output_min - (int32_t) output_zero_point);
767 params->scalar_lrint.output_max_less_zero_point = (long) ((int32_t) output_max - (int32_t) output_zero_point);
768 params->scalar_lrint.output_zero_point = (int32_t) output_zero_point;
769}
770
771void xnn_init_qs8_minmax_scalar_magic_params(
772 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
773 int8_t output_zero_point,
774 int8_t output_min,
775 int8_t output_max)
776{
777 params->scalar_magic.output_min_less_zero_point = (float) ((int32_t) output_min - (int32_t) output_zero_point);
778 params->scalar_magic.output_max_less_zero_point = (float) ((int32_t) output_max - (int32_t) output_zero_point);
779 params->scalar_magic.magic_bias = 12582912.0f;
780 params->scalar_magic.magic_bias_less_output_zero_point = INT32_C(0x4B400000) - (int32_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700781}
782
783#if XNN_ARCH_X86 || XNN_ARCH_X86_64
784void xnn_init_qs8_minmax_sse2_params(
785 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
786 int8_t output_zero_point,
787 int8_t output_min,
788 int8_t output_max)
789{
790 for (uint32_t i = 0; i < 8; i++) {
791 params->sse2.output_zero_point[i] = (int16_t) output_zero_point;
792 params->sse2.output_min[i] = (int16_t) output_min;
793 params->sse2.output_max[i] = (int16_t) output_max;
794 }
795}
796
797void xnn_init_qs8_minmax_sse4_params(
798 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
799 int8_t output_zero_point,
800 int8_t output_min,
801 int8_t output_max)
802{
803 for (uint32_t i = 0; i < 8; i++) {
804 params->sse4.output_zero_point[i] = (int16_t) output_zero_point;
805 }
806 for (uint32_t i = 0; i < 16; i++) {
807 params->sse4.output_min[i] = output_min;
808 params->sse4.output_max[i] = output_max;
809 }
810}
811
812void xnn_init_qs8_minmax_avx2_params(
813 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
814 int8_t output_zero_point,
815 int8_t output_min,
816 int8_t output_max)
817{
818 for (uint32_t i = 0; i < 16; i++) {
819 params->avx2.output_zero_point[i] = (int16_t) output_zero_point;
820 }
821 for (uint32_t i = 0; i < 32; i++) {
822 params->avx2.output_min[i] = output_min;
823 params->avx2.output_max[i] = output_max;
824 }
825}
826
827void xnn_init_qs8_minmax_avx512_params(
828 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
829 int8_t output_zero_point,
830 int8_t output_min,
831 int8_t output_max)
832{
833 for (uint32_t i = 0; i < 32; i++) {
834 params->avx512.output_zero_point[i] = (int16_t) output_zero_point;
835 }
836 for (uint32_t i = 0; i < 64; i++) {
837 params->avx512.output_min[i] = output_min;
838 params->avx512.output_max[i] = output_max;
839 }
840}
841#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
842
843#if XNN_ARCH_ARM || XNN_ARCH_ARM64
844void xnn_init_qs8_minmax_neon_params(
845 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
846 int8_t output_zero_point,
847 int8_t output_min,
848 int8_t output_max)
849{
850 params->neon.output_zero_point = (int16_t) output_zero_point;
851 params->neon.output_min = output_min;
852 params->neon.output_max = output_max;
853}
854
855void xnn_init_qs8_minmax_neon_fp32_params(
856 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
857 int8_t output_zero_point,
858 int8_t output_min,
859 int8_t output_max)
860{
861 params->neon_fp32.output_min_less_zero_point = (float) ((int32_t) output_min - (int32_t) output_zero_point);
862 params->neon_fp32.output_max_less_zero_point = (float) ((int32_t) output_max - (int32_t) output_zero_point);
863 params->neon_fp32.magic_bias = 12582912.0f;
864 params->neon_fp32.magic_bias_less_zero_point = INT32_C(0x4B400000) - (int32_t) output_zero_point;
865}
866#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
867
868#if XNN_ARCH_WASMSIMD
869void xnn_init_qs8_minmax_wasmsimd_params(
870 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
871 int8_t output_zero_point,
872 int8_t output_min,
873 int8_t output_max)
874{
Marat Dukhan47c12202021-06-30 15:09:34 -0700875 for (uint32_t i = 0; i < 4; i++) {
876 params->wasmsimd.output_min_less_zero_point[i] = (float) ((int32_t) output_min - (int32_t) output_zero_point);
877 params->wasmsimd.output_max_less_zero_point[i] = (float) ((int32_t) output_max - (int32_t) output_zero_point);
878 params->wasmsimd.magic_bias[i] = 12582912.0f;
879 params->wasmsimd.magic_bias_less_output_zero_point[i] = INT32_C(0x4B400000) - (int32_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -0700880 }
881}
882#endif // XNN_ARCH_WASMSIMD
883
884void xnn_init_qu8_avgpool_params(
885 union xnn_qu8_avgpool_params params[XNN_MIN_ELEMENTS(1)],
886 int32_t bias,
887 float scale,
888 uint8_t output_zero_point,
889 uint8_t output_min,
890 uint8_t output_max)
891{
892 // Compute requantization parameters.
893 assert(scale >= 0x1.0p-32f);
894 assert(scale < 256.0f);
895 const uint32_t scale_bits = fp32_to_bits(scale);
896
897 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
898 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
899 assert(multiplier >= INT32_C(0x00800000));
900 assert(multiplier <= INT32_C(0x00FFFFFF));
901
902 // Shift is in [16, 55] range.
903 const int32_t shift = 127 + 23 - (scale_bits >> 23);
904 assert(shift >= 16);
905 assert(shift < 64);
906
907 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
908 const uint32_t right_shift = (uint32_t) shift;
909 const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
910 params->sse2.bias[0] = bias;
911 params->sse2.bias[1] = bias;
912 params->sse2.bias[2] = bias;
913 params->sse2.bias[3] = bias;
914 params->sse2.multiplier[0] = (uint32_t) multiplier;
915 params->sse2.multiplier[1] = (uint32_t) multiplier;
916 params->sse2.multiplier[2] = (uint32_t) multiplier;
917 params->sse2.multiplier[3] = (uint32_t) multiplier;
918 params->sse2.rounding[0] = rounding;
919 params->sse2.rounding[1] = rounding;
920 params->sse2.right_shift[0] = (uint64_t) right_shift;
921 params->sse2.right_shift[1] = (uint64_t) right_shift;
922 for (uint32_t i = 0; i < 8; i++) {
923 params->sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
924 }
925 for (uint32_t i = 0; i < 16; i++) {
926 params->sse2.output_min[i] = output_min;
927 params->sse2.output_max[i] = output_max;
928 }
929 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
930 params->neon.bias = bias;
931 params->neon.multiplier = multiplier;
932 params->neon.left_shift = (int64_t) -shift;
933 params->neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
934 params->neon.output_min = output_min;
935 params->neon.output_max = output_max;
936 #else
937 const uint32_t right_shift = (uint32_t) shift;
938 const int64_t rounding = INT64_C(1) << (right_shift - 1);
939 params->scalar.bias = bias;
940 params->scalar.multiplier = multiplier;
941 params->scalar.rounding = rounding;
942 params->scalar.right_shift = right_shift;
943 params->scalar.output_min_less_zero_point =
944 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
945 params->scalar.output_max_less_zero_point =
946 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
947 params->scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
948 #endif
949}
950
951void xnn_init_scalar_qu8_avgpool_params(
952 union xnn_qu8_avgpool_params params[XNN_MIN_ELEMENTS(1)],
953 int32_t bias,
954 float scale,
955 uint8_t output_zero_point,
956 uint8_t output_min,
957 uint8_t output_max)
958{
959 // Compute requantization parameters.
960 assert(scale >= 0x1.0p-32f);
961 assert(scale < 256.0f);
962 const uint32_t scale_bits = fp32_to_bits(scale);
963
964 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
965 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
966 assert(multiplier >= INT32_C(0x00800000));
967 assert(multiplier <= INT32_C(0x00FFFFFF));
968
969 // Shift is in [16, 55] range.
970 const int32_t shift = 127 + 23 - (scale_bits >> 23);
971 assert(shift >= 16);
972 assert(shift < 64);
973
974 const uint32_t right_shift = (uint32_t) shift;
975 const int64_t rounding = INT64_C(1) << (right_shift - 1);
976 params->scalar.bias = bias;
977 params->scalar.rounding = rounding;
978 params->scalar.multiplier = multiplier;
979 params->scalar.right_shift = right_shift;
980 params->scalar.output_min_less_zero_point =
981 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
982 params->scalar.output_max_less_zero_point =
983 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
984 params->scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
985}
986
987void xnn_update_qu8_avgpool_params(
988 union xnn_qu8_avgpool_params* params,
989 int32_t bias,
990 float scale)
991{
992 // Compute requantization parameters.
993 assert(scale >= 0x1.0p-32f);
994 assert(scale < 256.0f);
995 const uint32_t scale_bits = fp32_to_bits(scale);
996
997 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
998 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
999 assert(multiplier >= INT32_C(0x00800000));
1000 assert(multiplier <= INT32_C(0x00FFFFFF));
1001
1002 // Shift is in [16, 55] range.
1003 const int32_t shift = 127 + 23 - (scale_bits >> 23);
1004 assert(shift >= 16);
1005 assert(shift < 64);
1006
1007 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1008 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
1009 params->sse2.bias[0] = bias;
1010 params->sse2.bias[1] = bias;
1011 params->sse2.bias[2] = bias;
1012 params->sse2.bias[3] = bias;
1013 params->sse2.multiplier[0] = (uint32_t) multiplier;
1014 params->sse2.multiplier[1] = (uint32_t) multiplier;
1015 params->sse2.multiplier[2] = (uint32_t) multiplier;
1016 params->sse2.multiplier[3] = (uint32_t) multiplier;
1017 params->sse2.rounding[0] = rounding;
1018 params->sse2.rounding[1] = rounding;
1019 params->sse2.right_shift[0] = (uint64_t) (uint32_t) shift;
1020 params->sse2.right_shift[1] = (uint64_t) (uint32_t) shift;
1021 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1022 params->neon.bias = bias;
1023 params->neon.multiplier = multiplier;
1024 params->neon.left_shift = (int64_t) -shift;
1025 #else
1026 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
1027 params->scalar.bias = bias;
1028 params->scalar.multiplier = multiplier;
1029 params->scalar.rounding = rounding;
1030 params->scalar.right_shift = (uint32_t) shift;
1031 #endif
1032}
1033
1034void xnn_init_qs8_avgpool_params(
1035 union xnn_qs8_avgpool_params params[XNN_MIN_ELEMENTS(1)],
1036 int32_t bias,
1037 float scale,
1038 int8_t output_zero_point,
1039 int8_t output_min,
1040 int8_t output_max)
1041{
1042 // Compute requantization parameters.
1043 assert(scale >= 0x1.0p-32f);
1044 assert(scale < 256.0f);
1045 const uint32_t scale_bits = fp32_to_bits(scale);
1046
1047 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
1048 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
1049 assert(multiplier >= INT32_C(0x00800000));
1050 assert(multiplier <= INT32_C(0x00FFFFFF));
1051
1052 // Shift is in [16, 55] range.
1053 const int32_t shift = 127 + 23 - (scale_bits >> 23);
1054 assert(shift >= 16);
1055 assert(shift < 64);
1056
1057 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1058 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
1059 params->sse2.bias[0] = bias;
1060 params->sse2.bias[1] = bias;
1061 params->sse2.bias[2] = bias;
1062 params->sse2.bias[3] = bias;
1063 params->sse2.multiplier[0] = (uint32_t) multiplier;
1064 params->sse2.multiplier[1] = (uint32_t) multiplier;
1065 params->sse2.multiplier[2] = (uint32_t) multiplier;
1066 params->sse2.multiplier[3] = (uint32_t) multiplier;
1067 params->sse2.rounding[0] = rounding;
1068 params->sse2.rounding[1] = rounding;
1069 params->sse2.shift[0] = (uint64_t) (uint32_t) shift;
1070 params->sse2.shift[1] = (uint64_t) (uint32_t) shift;
1071 for (uint32_t i = 0; i < 8; i++) {
1072 params->sse2.output_zero_point[i] = (int16_t) output_zero_point;
1073 params->sse2.output_min[i] = (int16_t) output_min;
1074 params->sse2.output_max[i] = (int16_t) output_max;
1075 }
1076 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1077 params->neon.bias = bias;
1078 params->neon.multiplier = multiplier;
1079 params->neon.left_shift = (int64_t) -shift;
1080 params->neon.output_zero_point = (int16_t) output_zero_point;
1081 params->neon.output_min = output_min;
1082 params->neon.output_max = output_max;
1083 #elif XNN_ARCH_WASMSIMD
1084 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
1085 params->wasmsimd.bias[0] = bias;
1086 params->wasmsimd.bias[1] = bias;
1087 params->wasmsimd.bias[2] = bias;
1088 params->wasmsimd.bias[3] = bias;
1089 params->wasmsimd.multiplier[0] = (int64_t) multiplier;
1090 params->wasmsimd.multiplier[1] = (int64_t) multiplier;
1091 params->wasmsimd.rounding[0] = rounding;
1092 params->wasmsimd.rounding[1] = rounding;
1093 params->wasmsimd.shift = shift;
1094 for (uint32_t i = 0; i < 8; i++) {
1095 params->wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
1096 }
1097 for (uint32_t i = 0; i < 16; i++) {
1098 params->wasmsimd.output_min[i] = output_min;
1099 params->wasmsimd.output_max[i] = output_max;
1100 }
1101 #else
1102 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
1103 params->scalar.bias = bias;
1104 params->scalar.multiplier = multiplier;
1105 params->scalar.rounding = rounding;
1106 params->scalar.shift = (uint32_t) shift;
1107 params->scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
1108 params->scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
1109 params->scalar.output_zero_point = (int32_t) output_zero_point;
1110 #endif
1111}
1112
1113void xnn_init_scalar_qs8_avgpool_params(
1114 union xnn_qs8_avgpool_params params[XNN_MIN_ELEMENTS(1)],
1115 int32_t bias,
1116 float scale,
1117 int8_t output_zero_point,
1118 int8_t output_min,
1119 int8_t output_max)
1120{
1121 // Compute requantization parameters.
1122 assert(scale >= 0x1.0p-32f);
1123 assert(scale < 256.0f);
1124 const uint32_t scale_bits = fp32_to_bits(scale);
1125
1126 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
1127 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
1128 assert(multiplier >= INT32_C(0x00800000));
1129 assert(multiplier <= INT32_C(0x00FFFFFF));
1130
1131 // Shift is in [16, 55] range.
1132 const int32_t shift = 127 + 23 - (scale_bits >> 23);
1133 assert(shift >= 16);
1134 assert(shift < 64);
1135
1136 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
1137 params->scalar.bias = bias;
1138 params->scalar.rounding = rounding;
1139 params->scalar.multiplier = multiplier;
1140 params->scalar.shift = shift;
1141 params->scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
1142 params->scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
1143 params->scalar.output_zero_point = (int32_t) output_zero_point;
1144}
1145
1146void xnn_update_qs8_avgpool_params(
1147 union xnn_qs8_avgpool_params* params,
1148 int32_t bias,
1149 float scale)
1150{
1151 // Compute requantization parameters.
1152 assert(scale >= 0x1.0p-32f);
1153 assert(scale < 256.0f);
1154 const uint32_t scale_bits = fp32_to_bits(scale);
1155
1156 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
1157 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
1158 assert(multiplier >= INT32_C(0x00800000));
1159 assert(multiplier <= INT32_C(0x00FFFFFF));
1160
1161 // Shift is in [16, 55] range.
1162 const int32_t shift = 127 + 23 - (scale_bits >> 23);
1163 assert(shift >= 16);
1164 assert(shift < 64);
1165
1166 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1167 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
1168 params->sse2.bias[0] = bias;
1169 params->sse2.bias[1] = bias;
1170 params->sse2.bias[2] = bias;
1171 params->sse2.bias[3] = bias;
1172 params->sse2.multiplier[0] = (uint32_t) multiplier;
1173 params->sse2.multiplier[1] = (uint32_t) multiplier;
1174 params->sse2.multiplier[2] = (uint32_t) multiplier;
1175 params->sse2.multiplier[3] = (uint32_t) multiplier;
1176 params->sse2.rounding[0] = rounding;
1177 params->sse2.rounding[1] = rounding;
1178 params->sse2.shift[0] = (uint64_t) (uint32_t) shift;
1179 params->sse2.shift[1] = (uint64_t) (uint32_t) shift;
1180 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1181 params->neon.bias = bias;
1182 params->neon.multiplier = multiplier;
1183 params->neon.left_shift = (int64_t) -shift;
1184 #elif XNN_ARCH_WASMSIMD
1185 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
1186 params->wasmsimd.bias[0] = bias;
1187 params->wasmsimd.bias[1] = bias;
1188 params->wasmsimd.bias[2] = bias;
1189 params->wasmsimd.bias[3] = bias;
1190 params->wasmsimd.multiplier[0] = (int64_t) multiplier;
1191 params->wasmsimd.multiplier[1] = (int64_t) multiplier;
1192 params->wasmsimd.rounding[0] = rounding;
1193 params->wasmsimd.rounding[1] = rounding;
1194 params->wasmsimd.shift = shift;
1195 #else
1196 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
1197 params->scalar.bias = bias;
1198 params->scalar.multiplier = multiplier;
1199 params->scalar.rounding = rounding;
1200 params->scalar.shift = (uint32_t) shift;
1201 #endif
1202}
1203
1204void xnn_update_f16_scaleminmax_params(
1205 struct xnn_f16_scaleminmax_params* params,
1206 uint16_t scale)
1207{
1208 params->scale = scale;
1209}
1210
1211void xnn_update_f32_scaleminmax_params(
1212 union xnn_f32_scaleminmax_params* params,
1213 float scale)
1214{
1215 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1216 for (uint32_t i = 0; i < 4; i++) {
1217 params->sse2.scale[i] = scale;
1218 }
1219 #else
1220 params->scalar.scale = scale;
1221 #endif
1222}
1223
1224void xnn_init_f16_scaleminmax_params(
1225 struct xnn_f16_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
1226 uint16_t scale,
1227 uint16_t min,
1228 uint16_t max)
1229{
1230 params->scale = scale;
1231 params->min = min;
1232 params->max = max;
1233 params->pad = 0; // unused.
1234}
1235
1236void xnn_init_f32_scaleminmax_params(
1237 union xnn_f32_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
1238 float scale,
1239 float min,
1240 float max)
1241{
1242 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1243 for (uint32_t i = 0; i < 4; i++) {
1244 params->sse2.scale[i] = scale;
1245 params->sse2.min[i] = min;
1246 params->sse2.max[i] = max;
1247 }
1248 #else
1249 params->scalar.scale = scale;
1250 params->scalar.min = min;
1251 params->scalar.max = max;
1252 #endif
1253}
1254
1255void xnn_init_f32_gavgpool_params(
1256 union xnn_f32_gavgpool_params params[XNN_MIN_ELEMENTS(1)],
1257 float multiplier,
1258 float output_min,
1259 float output_max,
1260 uint32_t width)
1261{
1262 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1263 for (uint32_t i = 0; i < 4; i++) {
1264 params->sse.multiplier[i] = multiplier;
1265 params->sse.output_min[i] = output_min;
1266 params->sse.output_max[i] = output_max;
1267 }
1268
1269 const uint32_t w = (width - 1) & 3;
1270 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1271 params->sse.mask[1] = -(uint32_t) (w >= 1);
1272 params->sse.mask[2] = -(uint32_t) (w >= 2);
1273 params->sse.mask[3] = -(uint32_t) (w >= 3);
1274 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1275 params->neon.multiplier = multiplier;
1276 params->neon.output_min = output_min;
1277 params->neon.output_max = output_max;
1278
1279 const uint32_t w = (width - 1) & 3;
1280 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1281 params->neon.mask[1] = -(uint32_t) (w >= 1);
1282 params->neon.mask[2] = -(uint32_t) (w >= 2);
1283 params->neon.mask[3] = -(uint32_t) (w >= 3);
1284 #else
1285 params->scalar.multiplier = multiplier;
1286 params->scalar.output_min = output_min;
1287 params->scalar.output_max = output_max;
1288
1289 const uint32_t w = (width - 1) & 3;
1290 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1291 params->scalar.mask[1] = -(int32_t) (w >= 1);
1292 params->scalar.mask[2] = -(int32_t) (w >= 2);
1293 params->scalar.mask[3] = -(int32_t) (w >= 3);
1294 #endif
1295}
1296
1297void xnn_update_f32_gavgpool_params(
1298 union xnn_f32_gavgpool_params* params,
1299 float multiplier,
1300 uint32_t width)
1301{
1302 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1303 for (uint32_t i = 0; i < 4; i++) {
1304 params->sse.multiplier[i] = multiplier;
1305 }
1306
1307 const uint32_t w = (width - 1) & 3;
1308 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1309 params->sse.mask[1] = -(uint32_t) (w >= 1);
1310 params->sse.mask[2] = -(uint32_t) (w >= 2);
1311 params->sse.mask[3] = -(uint32_t) (w >= 3);
1312 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1313 params->neon.multiplier = multiplier;
1314
1315 const uint32_t w = (width - 1) & 3;
1316 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1317 params->neon.mask[1] = -(uint32_t) (w >= 1);
1318 params->neon.mask[2] = -(uint32_t) (w >= 2);
1319 params->neon.mask[3] = -(uint32_t) (w >= 3);
1320 #else
1321 params->scalar.multiplier = multiplier;
1322
1323 const uint32_t w = (width - 1) & 3;
1324 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1325 params->scalar.mask[1] = -(int32_t) (w >= 1);
1326 params->scalar.mask[2] = -(int32_t) (w >= 2);
1327 params->scalar.mask[3] = -(int32_t) (w >= 3);
1328 #endif
1329}
1330
1331void xnn_init_scalar_f32_scaleminmax_params(
1332 union xnn_f32_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
1333 float scale,
1334 float min,
1335 float max)
1336{
1337 params->scalar.scale = scale;
1338 params->scalar.min = min;
1339 params->scalar.max = max;
1340}
1341
1342void xnn_init_scalar_f32_gavgpool_params(
1343 union xnn_f32_gavgpool_params params[XNN_MIN_ELEMENTS(1)],
1344 float multiplier,
1345 float output_min,
1346 float output_max,
1347 uint32_t width)
1348{
1349 params->scalar.multiplier = multiplier;
1350 params->scalar.output_min = output_min;
1351 params->scalar.output_max = output_max;
1352
1353 const uint32_t w = (width - 1) & 3;
1354 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1355 params->scalar.mask[1] = -(int32_t) (w >= 1);
1356 params->scalar.mask[2] = -(int32_t) (w >= 2);
1357 params->scalar.mask[3] = -(int32_t) (w >= 3);
1358}
1359
1360void xnn_init_f16_minmax_params(
1361 struct xnn_f16_minmax_params params[XNN_MIN_ELEMENTS(1)],
1362 uint16_t min,
1363 uint16_t max)
1364{
1365 params->min = min;
1366 params->max = max;
1367}
1368
1369void xnn_init_f32_minmax_params(
1370 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)],
1371 float output_min,
1372 float output_max)
1373{
1374 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1375 for (uint32_t i = 0; i < 4; i++) {
1376 params->sse.min[i] = output_min;
1377 params->sse.max[i] = output_max;
1378 }
1379 #else
1380 params->scalar.min = output_min;
1381 params->scalar.max = output_max;
1382 #endif
1383}
1384
1385#if XNN_ARCH_X86 || XNN_ARCH_X86_64
1386void xnn_init_f32_minmax_sse_params(
1387 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)],
1388 float output_min,
1389 float output_max)
1390{
1391 for (uint32_t i = 0; i < 4; i++) {
1392 params->sse.min[i] = output_min;
1393 params->sse.max[i] = output_max;
1394 }
1395}
1396
1397void xnn_init_f32_minmax_avx_params(
1398 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)],
1399 float output_min,
1400 float output_max)
1401{
1402 for (uint32_t i = 0; i < 8; i++) {
1403 params->avx.min[i] = output_min;
1404 params->avx.max[i] = output_max;
1405 }
1406}
1407#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1408
1409void xnn_init_f32_minmax_scalar_params(
1410 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)],
1411 float output_min,
1412 float output_max)
1413{
1414 params->scalar.min = output_min;
1415 params->scalar.max = output_max;
1416}
1417
1418void xnn_init_f16_hswish_params(
1419 struct xnn_f16_hswish_params params[XNN_MIN_ELEMENTS(1)])
1420{
1421 params->sixth = UINT16_C(0x3155);
1422 params->three = UINT16_C(0x4200);
1423 params->six = UINT16_C(0x4600);
1424}
1425
1426void xnn_init_f32_hswish_params(
1427 union xnn_f32_hswish_params params[XNN_MIN_ELEMENTS(1)])
1428{
1429 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1430 for (uint32_t i = 0; i < 4; i++) {
1431 params->sse.sixth[i] = 0x1.555556p-3f;
1432 params->sse.half[i] = 0.5f;
1433 params->sse.one[i] = 1.0f;
1434 }
1435 #else
1436 params->scalar.sixth = 0x1.555556p-3f;
1437 params->scalar.three = 3.0f;
1438 params->scalar.six = 6.0f;
1439 #endif
1440}
1441
1442void xnn_init_scalar_f32_hswish_params(
1443 union xnn_f32_hswish_params params[XNN_MIN_ELEMENTS(1)])
1444{
1445 params->scalar.sixth = 0x1.555556p-3f;
1446 params->scalar.three = 3.0f;
1447 params->scalar.six = 6.0f;
1448}
1449
1450void xnn_init_f32_abs_params(
1451 union xnn_f32_abs_params params[XNN_MIN_ELEMENTS(1)])
1452{
1453 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1454 for (uint32_t i = 0; i < 4; i++) {
1455 params->sse.nonsign_mask[i] = math_nonsign_mask_f32();
1456 }
1457 #elif XNN_ARCH_WASMSIMD
1458 params->wasmsimd.nonsign_mask = math_nonsign_mask_f32();
1459 #endif
1460}
1461
1462void xnn_init_scalar_f32_abs_params(
1463 union xnn_f32_abs_params params[XNN_MIN_ELEMENTS(1)])
1464{
1465}
1466
1467void xnn_init_f32_neg_params(
1468 union xnn_f32_neg_params params[XNN_MIN_ELEMENTS(1)])
1469{
1470 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1471 for (uint32_t i = 0; i < 4; i++) {
1472 params->sse.sign_mask[i] = -0.0f;
1473 }
1474 #elif XNN_ARCH_WASMSIMD
1475 params->wasmsimd.sign_mask = -0.0f;
1476 #endif
1477}
1478
1479void xnn_init_scalar_f32_neg_params(
1480 union xnn_f32_neg_params params[XNN_MIN_ELEMENTS(1)])
1481{
1482}
1483
1484void xnn_init_f32_rnd_params(
1485 union xnn_f32_rnd_params params[XNN_MIN_ELEMENTS(1)])
1486{
1487 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1488 for (uint32_t i = 0; i < 4; i++) {
1489 params->sse2.sign_mask[i] = -0.0f;
1490 }
1491 for (uint32_t i = 0; i < 4; i++) {
1492 params->sse2.one[i] = 1.0f;
1493 }
1494 #endif
1495}
1496
1497void xnn_init_scalar_f32_rnd_params(
1498 union xnn_f32_rnd_params params[XNN_MIN_ELEMENTS(1)])
1499{
1500}
1501
1502void xnn_init_f32_elu_params(
1503 union xnn_f32_elu_params params[XNN_MIN_ELEMENTS(1)],
1504 float prescale,
1505 float alpha,
1506 float beta)
1507{
1508 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1509 for (uint32_t i = 0; i < 4; i++) {
1510 params->sse.prescale[i] = prescale;
1511 params->sse.alpha[i] = alpha;
1512 params->sse.beta[i] = beta;
1513 }
1514 #else
1515 params->scalar.prescale = prescale;
1516 params->scalar.alpha = alpha;
1517 params->scalar.beta = beta;
1518 #endif
1519}
1520
1521void xnn_init_scalar_f32_elu_params(
1522 union xnn_f32_elu_params params[XNN_MIN_ELEMENTS(1)],
1523 float prescale,
1524 float alpha,
1525 float beta)
1526{
1527 params->scalar.prescale = prescale;
1528 params->scalar.alpha = alpha;
1529 params->scalar.beta = beta;
1530}
1531
1532void xnn_init_f32_lrelu_params(
1533 union xnn_f32_lrelu_params params[XNN_MIN_ELEMENTS(1)],
1534 float slope)
1535{
1536 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1537 for (uint32_t i = 0; i < 4; i++) {
1538 params->sse.slope[i] = slope;
1539 }
1540 #else
1541 params->scalar.slope = slope;
1542 #endif
1543}
1544
1545void xnn_init_scalar_f32_lrelu_params(
1546 union xnn_f32_lrelu_params params[XNN_MIN_ELEMENTS(1)],
1547 float slope)
1548{
1549 params->scalar.slope = slope;
1550}
1551
1552void xnn_init_f32_sqrt_params(
1553 union xnn_f32_sqrt_params params[XNN_MIN_ELEMENTS(1)])
1554{
1555 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1556 params->fma.half = 0.5f;
1557 #endif
1558}
1559
1560void xnn_init_scalar_f32_sqrt_params(
1561 union xnn_f32_sqrt_params params[XNN_MIN_ELEMENTS(1)])
1562{
1563}
1564
1565void xnn_init_f32_chw_params(
1566 union xnn_f32_chw_params params[XNN_MIN_ELEMENTS(1)],
1567 uint32_t width,
1568 float output_min,
1569 float output_max)
1570{
1571 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1572 for (uint32_t i = 0; i < 4; i++) {
1573 params->sse.min[i] = output_min;
1574 params->sse.max[i] = output_max;
1575 }
1576
1577 const uint32_t w4 = (width - 1) & 3;
1578 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1579 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
1580 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
1581 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
1582
1583 const uint32_t w8 = (width - 1) & 7;
1584 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
1585 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
1586 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
1587 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
1588 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
1589 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
1590 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
1591 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
1592 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1593 params->neon.min = output_min;
1594 params->neon.max = output_max;
1595
1596 const uint32_t w4 = (width - 1) & 3;
1597 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1598 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
1599 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
1600 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
1601
1602 const uint32_t w8 = (width - 1) & 7;
1603 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1604 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1605 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1606 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1607 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1608 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1609 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1610 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
1611 #else
1612 params->scalar.min = output_min;
1613 params->scalar.max = output_max;
1614
1615 const uint32_t w4 = (width - 1) & 3;
1616 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1617 params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1618 params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1619 params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1620
1621 const uint32_t w8 = (width - 1) & 7;
1622 params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1623 params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1624 params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1625 params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1626 params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1627 params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1628 params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1629 params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1630 #endif
1631}
1632
1633void xnn_update_f32_chw_params(
1634 union xnn_f32_chw_params* params,
1635 uint32_t width)
1636{
1637 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1638 const uint32_t w4 = (width - 1) & 3;
1639 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1640 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
1641 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
1642 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
1643
1644 const uint32_t w8 = (width - 1) & 7;
1645 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
1646 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
1647 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
1648 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
1649 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
1650 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
1651 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
1652 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
1653 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1654 const uint32_t w4 = (width - 1) & 3;
1655 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1656 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
1657 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
1658 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
1659
1660 const uint32_t w8 = (width - 1) & 7;
1661 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1662 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1663 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1664 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1665 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1666 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1667 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1668 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
1669 #else
1670 const uint32_t w4 = (width - 1) & 3;
1671 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1672 params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1673 params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1674 params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1675
1676 const uint32_t w8 = (width - 1) & 7;
1677 params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1678 params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1679 params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1680 params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1681 params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1682 params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1683 params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1684 params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1685 #endif
1686}
1687
1688void xnn_init_scalar_f32_chw_params(
1689 union xnn_f32_chw_params params[XNN_MIN_ELEMENTS(1)],
1690 uint32_t width,
1691 float output_min,
1692 float output_max)
1693{
1694 params->scalar.min = output_min;
1695 params->scalar.max = output_max;
1696
1697 const uint32_t w4 = (width - 1) & 3;
1698 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1699 params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1700 params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1701 params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1702
1703 const uint32_t w8 = (width - 1) & 7;
1704 params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1705 params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1706 params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1707 params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1708 params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1709 params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1710 params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1711 params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1712}
1713
1714void xnn_init_u8_minmax_params(
1715 union xnn_u8_minmax_params params[XNN_MIN_ELEMENTS(1)],
1716 uint8_t output_min,
1717 uint8_t output_max)
1718{
1719 assert(output_min < output_max);
1720
1721 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1722 for (uint32_t i = 0; i < 16; i++) {
1723 params->sse2.min[i] = output_min;
1724 params->sse2.max[i] = output_max;
1725 }
1726 #else
1727 params->scalar.min = output_min;
1728 params->scalar.max = output_max;
1729 #endif
1730}
1731
1732void xnn_init_scalar_u8_minmax_params(
1733 union xnn_u8_minmax_params params[XNN_MIN_ELEMENTS(1)],
1734 uint8_t output_min,
1735 uint8_t output_max)
1736{
1737 assert(output_min < output_max);
1738
1739 params->scalar.min = (int32_t) (uint32_t) output_min;
1740 params->scalar.max = (int32_t) (uint32_t) output_max;
1741}
1742
Marat Dukhan6e0fc392021-07-19 18:38:24 -07001743void xnn_init_qu8_add_minmax_params(
1744 union xnn_qu8_add_minmax_params params[XNN_MIN_ELEMENTS(1)],
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001745 uint8_t a_zero_point,
1746 uint8_t b_zero_point,
1747 uint8_t output_zero_point,
1748 float a_output_scale,
1749 float b_output_scale,
1750 uint8_t output_min,
1751 uint8_t output_max)
1752{
Marat Dukhand4c478b2021-07-19 19:31:43 -07001753 assert(a_output_scale >= 0x1.0p-10f);
1754 assert(b_output_scale >= 0x1.0p-10f);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001755 assert(a_output_scale < 0x1.0p+8f);
1756 assert(b_output_scale < 0x1.0p+8f);
1757
1758 // Compute requantization parameters.
1759 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
Marat Dukhand4c478b2021-07-19 19:31:43 -07001760 assert(max_output_scale >= 0x1.0p-10f);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001761 assert(max_output_scale < 0x1.0p+8f);
1762 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1763 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001764
Marat Dukhanf0ebd4b2021-07-19 20:16:17 -07001765 // Shift is in [12, 30] range.
1766 const uint32_t shift = (uint32_t) (20 /* multiplier bits */ - max_scale_exponent);
1767 assert(shift <= 30);
1768 assert(shift >= 12);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001769
Marat Dukhanf0ebd4b2021-07-19 20:16:17 -07001770 // Multipliers are in [0, 2**21) range, largest multiplier is in [2**20, 2**21) range.
1771 const int32_t a_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1772 const int32_t b_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
1773 assert(math_max_s32(a_multiplier, b_multiplier) >= INT32_C(0x00100000));
1774 assert(a_multiplier < INT32_C(0x00200000));
1775 assert(b_multiplier < INT32_C(0x00200000));
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001776
1777 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1778 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1779 const uint32_t remainder_threshold = remainder_mask >> 1;
1780 const int32_t zero_point_product =
1781 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1782 for (uint32_t i = 0; i < 4; i++) {
1783 params->sse2.zero_point_product[i] = zero_point_product;
1784 }
1785 for (uint32_t i = 0; i < 8; i++) {
1786 params->sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
1787 }
1788 for (uint32_t i = 0; i < 8; i++) {
1789 params->sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
1790 params->sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
1791 params->sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
1792 params->sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
1793 }
1794 params->sse2.a_multiplier = a_multiplier;
1795 params->sse2.b_multiplier = b_multiplier;
1796 for (uint32_t i = 0; i < 4; i++) {
1797 params->sse2.remainder_mask[i] = remainder_mask;
1798 params->sse2.remainder_threshold[i] = remainder_threshold;
1799 }
1800 params->sse2.shift = shift;
1801 for (uint32_t i = 0; i < 16; i++) {
1802 params->sse2.y_min[i] = output_min;
1803 params->sse2.y_max[i] = output_max;
1804 }
1805 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1806 params->neon.a_zero_point = a_zero_point;
1807 params->neon.b_zero_point = b_zero_point;
1808 params->neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
1809 params->neon.a_multiplier = (int32_t) a_multiplier;
1810 params->neon.b_multiplier = (int32_t) b_multiplier;
1811 params->neon.right_shift = (int32_t) -shift;
1812 params->neon.y_min = output_min;
1813 params->neon.y_max = output_max;
1814 #else
1815 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1816 const uint32_t remainder_threshold = remainder_mask >> 1;
1817 params->scalar.zero_point_product =
1818 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1819 params->scalar.a_multiplier = a_multiplier;
1820 params->scalar.b_multiplier = b_multiplier;
1821 params->scalar.remainder_mask = (int32_t) remainder_mask;
1822 params->scalar.remainder_threshold = (int32_t) remainder_threshold;
1823 params->scalar.shift = shift;
1824 params->scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1825 params->scalar.y_min = (int32_t) (uint32_t) output_min;
1826 params->scalar.y_max = (int32_t) (uint32_t) output_max;
1827 #endif
1828}
1829
Marat Dukhan6e0fc392021-07-19 18:38:24 -07001830void xnn_init_qu8_add_minmax_scalar_params(
1831 union xnn_qu8_add_minmax_params params[XNN_MIN_ELEMENTS(1)],
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001832 uint8_t a_zero_point,
1833 uint8_t b_zero_point,
1834 uint8_t output_zero_point,
1835 float a_output_scale,
1836 float b_output_scale,
1837 uint8_t output_min,
1838 uint8_t output_max)
1839{
Marat Dukhand4c478b2021-07-19 19:31:43 -07001840 assert(a_output_scale >= 0x1.0p-10f);
1841 assert(b_output_scale >= 0x1.0p-10f);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001842 assert(a_output_scale < 0x1.0p+8f);
1843 assert(b_output_scale < 0x1.0p+8f);
1844
1845 // Compute requantization parameters.
1846 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
1847 assert(max_output_scale >= 0x1.0p-10f);
1848 assert(max_output_scale < 0x1.0p+8f);
1849 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1850 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001851
Marat Dukhanf0ebd4b2021-07-19 20:16:17 -07001852 // Shift is in [12, 30] range.
1853 const uint32_t shift = (uint32_t) (20 /* multiplier bits */ - max_scale_exponent);
1854 assert(shift <= 30);
1855 assert(shift >= 12);
1856
1857 // Multipliers are in [0, 2**21) range, largest multiplier is in [2**20, 2**21) range.
1858 const int32_t a_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1859 const int32_t b_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
1860 assert(math_max_s32(a_multiplier, b_multiplier) >= INT32_C(0x00100000));
1861 assert(a_multiplier < INT32_C(0x00200000));
1862 assert(b_multiplier < INT32_C(0x00200000));
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001863
1864 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1865 const uint32_t remainder_threshold = remainder_mask >> 1;
1866 params->scalar.zero_point_product =
1867 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1868 params->scalar.a_multiplier = a_multiplier;
1869 params->scalar.b_multiplier = b_multiplier;
1870 params->scalar.remainder_mask = (int32_t) remainder_mask;
1871 params->scalar.remainder_threshold = (int32_t) remainder_threshold;
1872 params->scalar.shift = shift;
1873 params->scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1874 params->scalar.y_min = (int32_t) (uint32_t) output_min;
1875 params->scalar.y_max = (int32_t) (uint32_t) output_max;
1876}
1877
Marat Dukhan66913242021-07-20 16:11:23 -07001878#if XNN_ARCH_X86 || XNN_ARCH_X86_64
1879void xnn_init_qs8_add_minmax_sse2_params(
Marat Dukhan6e0fc392021-07-19 18:38:24 -07001880 union xnn_qs8_add_minmax_params params[XNN_MIN_ELEMENTS(1)],
Marat Dukhan49d90052021-07-19 19:59:30 -07001881 int8_t a_zero_point,
1882 int8_t b_zero_point,
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001883 int8_t output_zero_point,
Marat Dukhan49d90052021-07-19 19:59:30 -07001884 float a_output_scale,
1885 float b_output_scale,
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001886 int8_t output_min,
1887 int8_t output_max)
1888{
Marat Dukhan49d90052021-07-19 19:59:30 -07001889 assert(a_output_scale >= 0x1.0p-10f);
1890 assert(b_output_scale >= 0x1.0p-10f);
1891 assert(a_output_scale < 0x1.0p+8f);
1892 assert(b_output_scale < 0x1.0p+8f);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001893
1894 // Compute requantization parameters.
Marat Dukhan49d90052021-07-19 19:59:30 -07001895 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
Marat Dukhand4c478b2021-07-19 19:31:43 -07001896 assert(max_output_scale >= 0x1.0p-10f);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001897 assert(max_output_scale < 0x1.0p+8f);
1898 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1899 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001900
Marat Dukhanf0ebd4b2021-07-19 20:16:17 -07001901 // Shift is in [12, 30] range.
1902 const uint32_t shift = (uint32_t) (20 /* multiplier bits */ - max_scale_exponent);
1903 assert(shift <= 30);
1904 assert(shift >= 12);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001905
Marat Dukhanf0ebd4b2021-07-19 20:16:17 -07001906 // Multipliers are in [0, 2**21) range, largest multiplier is in [2**20, 2**21) range.
1907 const int32_t a_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1908 const int32_t b_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
1909 assert(math_max_s32(a_multiplier, b_multiplier) >= INT32_C(0x00100000));
1910 assert(a_multiplier < INT32_C(0x00200000));
1911 assert(b_multiplier < INT32_C(0x00200000));
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001912
Marat Dukhan66913242021-07-20 16:11:23 -07001913 const int32_t rounding = INT32_C(1) << (shift - 1);
1914 const int32_t bias = (int32_t) -(a_multiplier * (int32_t) a_zero_point + b_multiplier * (int32_t) b_zero_point);
1915 for (uint32_t i = 0; i < 4; i++) {
1916 params->sse2.bias[i] = bias;
1917 }
1918 const uint16_t a_multiplier_lo = (uint16_t) a_multiplier;
1919 const uint16_t a_multiplier_hi = (uint16_t) ((uint32_t) a_multiplier >> 16);
1920 const uint16_t b_multiplier_lo = (uint16_t) b_multiplier;
1921 const uint16_t b_multiplier_hi = (uint16_t) ((uint32_t) b_multiplier >> 16);
1922 for (uint32_t i = 0; i < 8; i++) {
1923 params->sse2.a_multiplier_lo[i] = a_multiplier_lo;
1924 params->sse2.a_multiplier_hi[i] = a_multiplier_hi;
1925 params->sse2.b_multiplier_lo[i] = b_multiplier_lo;
1926 params->sse2.b_multiplier_hi[i] = b_multiplier_hi;
1927 }
1928 params->sse2.shift = shift;
Marat Dukhan7679b1e2021-07-20 18:32:23 -07001929 params->sse2.b_multiplier = (uint32_t) b_multiplier;
Marat Dukhan66913242021-07-20 16:11:23 -07001930 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhan66913242021-07-20 16:11:23 -07001931 params->sse2.rounding[i] = rounding;
1932 }
1933 for (uint32_t i = 0; i < 8; i++) {
1934 params->sse2.output_zero_point[i] = (int16_t) output_zero_point;
1935 params->sse2.output_min[i] = (int16_t) output_min;
1936 params->sse2.output_max[i] = (int16_t) output_max;
1937 }
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07001938}
Marat Dukhan7679b1e2021-07-20 18:32:23 -07001939
1940void xnn_init_qs8_add_minmax_sse4_mul16_params(
1941 union xnn_qs8_add_minmax_params params[XNN_MIN_ELEMENTS(1)],
1942 int8_t a_zero_point,
1943 int8_t b_zero_point,
1944 int8_t output_zero_point,
1945 float a_output_scale,
1946 float b_output_scale,
1947 int8_t output_min,
1948 int8_t output_max)
1949{
1950 assert(a_output_scale >= 0x1.0p-10f);
1951 assert(b_output_scale >= 0x1.0p-10f);
1952 assert(a_output_scale < 0x1.0p+8f);
1953 assert(b_output_scale < 0x1.0p+8f);
1954
1955 // Compute requantization parameters.
1956 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
1957 assert(max_output_scale >= 0x1.0p-10f);
1958 assert(max_output_scale < 0x1.0p+8f);
1959 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1960 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1961
1962 // Shift is in [12, 30] range.
1963 const uint32_t shift = (uint32_t) (20 /* multiplier bits */ - max_scale_exponent);
1964 assert(shift <= 30);
1965 assert(shift >= 12);
1966
1967 // Multipliers are in [0, 2**21) range, largest multiplier is in [2**20, 2**21) range.
1968 const int32_t a_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1969 const int32_t b_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
1970 assert(math_max_s32(a_multiplier, b_multiplier) >= INT32_C(0x00100000));
1971 assert(a_multiplier < INT32_C(0x00200000));
1972 assert(b_multiplier < INT32_C(0x00200000));
1973
1974 const int32_t rounding = INT32_C(1) << (shift - 1);
1975 const int32_t bias = (int32_t) -(a_multiplier * (int32_t) a_zero_point + b_multiplier * (int32_t) b_zero_point);
1976 for (uint32_t i = 0; i < 4; i++) {
1977 params->sse4_mul16.bias[i] = bias;
1978 }
1979 const uint16_t a_multiplier_lo = (uint16_t) a_multiplier;
1980 const uint16_t a_multiplier_hi = (uint16_t) ((uint32_t) a_multiplier >> 16);
1981 const uint16_t b_multiplier_lo = (uint16_t) b_multiplier;
1982 const uint16_t b_multiplier_hi = (uint16_t) ((uint32_t) b_multiplier >> 16);
1983 for (uint32_t i = 0; i < 8; i++) {
1984 params->sse4_mul16.a_multiplier_lo[i] = a_multiplier_lo;
1985 params->sse4_mul16.a_multiplier_hi[i] = a_multiplier_hi;
1986 params->sse4_mul16.b_multiplier_lo[i] = b_multiplier_lo;
1987 params->sse4_mul16.b_multiplier_hi[i] = b_multiplier_hi;
1988 }
1989 params->sse4_mul16.shift = shift;
1990 params->sse4_mul16.b_multiplier = (uint32_t) b_multiplier;
1991 for (uint32_t i = 0; i < 4; i++) {
1992 params->sse4_mul16.rounding[i] = rounding;
1993 }
1994 for (uint32_t i = 0; i < 8; i++) {
1995 params->sse4_mul16.output_zero_point[i] = (int16_t) output_zero_point;
1996 }
1997 for (uint32_t i = 0; i < 16; i++) {
1998 params->sse4_mul16.output_min[i] = output_min;
1999 params->sse4_mul16.output_max[i] = output_max;
2000 }
2001}
2002
2003void xnn_init_qs8_add_minmax_sse4_mul32_params(
2004 union xnn_qs8_add_minmax_params params[XNN_MIN_ELEMENTS(1)],
2005 int8_t a_zero_point,
2006 int8_t b_zero_point,
2007 int8_t output_zero_point,
2008 float a_output_scale,
2009 float b_output_scale,
2010 int8_t output_min,
2011 int8_t output_max)
2012{
2013 assert(a_output_scale >= 0x1.0p-10f);
2014 assert(b_output_scale >= 0x1.0p-10f);
2015 assert(a_output_scale < 0x1.0p+8f);
2016 assert(b_output_scale < 0x1.0p+8f);
2017
2018 // Compute requantization parameters.
2019 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
2020 assert(max_output_scale >= 0x1.0p-10f);
2021 assert(max_output_scale < 0x1.0p+8f);
2022 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
2023 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
2024
2025 // Shift is in [12, 30] range.
2026 const uint32_t shift = (uint32_t) (20 /* multiplier bits */ - max_scale_exponent);
2027 assert(shift <= 30);
2028 assert(shift >= 12);
2029
2030 // Multipliers are in [0, 2**21) range, largest multiplier is in [2**20, 2**21) range.
2031 const int32_t a_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
2032 const int32_t b_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
2033 assert(math_max_s32(a_multiplier, b_multiplier) >= INT32_C(0x00100000));
2034 assert(a_multiplier < INT32_C(0x00200000));
2035 assert(b_multiplier < INT32_C(0x00200000));
2036
2037 const int32_t rounding = INT32_C(1) << (shift - 1);
2038 const int32_t bias = (int32_t) -(a_multiplier * (int32_t) a_zero_point + b_multiplier * (int32_t) b_zero_point);
2039 for (uint32_t i = 0; i < 4; i++) {
2040 params->sse4_mul32.bias[i] = bias;
2041 params->sse4_mul32.a_multiplier[i] = a_multiplier;
2042 params->sse4_mul32.b_multiplier[i] = b_multiplier;
2043 params->sse4_mul32.rounding[i] = rounding;
2044 params->sse4_mul32.shift[i] = shift;
2045 }
2046 for (uint32_t i = 0; i < 8; i++) {
2047 params->sse4_mul32.output_zero_point[i] = (int16_t) output_zero_point;
2048 }
2049 for (uint32_t i = 0; i < 16; i++) {
2050 params->sse4_mul32.output_min[i] = output_min;
2051 params->sse4_mul32.output_max[i] = output_max;
2052 }
2053}
2054
2055void xnn_init_qs8_add_minmax_avx2_params(
2056 union xnn_qs8_add_minmax_params params[XNN_MIN_ELEMENTS(1)],
2057 int8_t a_zero_point,
2058 int8_t b_zero_point,
2059 int8_t output_zero_point,
2060 float a_output_scale,
2061 float b_output_scale,
2062 int8_t output_min,
2063 int8_t output_max)
2064{
2065 assert(a_output_scale >= 0x1.0p-10f);
2066 assert(b_output_scale >= 0x1.0p-10f);
2067 assert(a_output_scale < 0x1.0p+8f);
2068 assert(b_output_scale < 0x1.0p+8f);
2069
2070 // Compute requantization parameters.
2071 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
2072 assert(max_output_scale >= 0x1.0p-10f);
2073 assert(max_output_scale < 0x1.0p+8f);
2074 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
2075 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
2076
2077 // Shift is in [12, 30] range.
2078 const uint32_t shift = (uint32_t) (20 /* multiplier bits */ - max_scale_exponent);
2079 assert(shift <= 30);
2080 assert(shift >= 12);
2081
2082 // Multipliers are in [0, 2**21) range, largest multiplier is in [2**20, 2**21) range.
2083 const int32_t a_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
2084 const int32_t b_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
2085 assert(math_max_s32(a_multiplier, b_multiplier) >= INT32_C(0x00100000));
2086 assert(a_multiplier < INT32_C(0x00200000));
2087 assert(b_multiplier < INT32_C(0x00200000));
2088
2089 const int32_t rounding = INT32_C(1) << (shift - 1);
2090 const int32_t bias = (int32_t) -(a_multiplier * (int32_t) a_zero_point + b_multiplier * (int32_t) b_zero_point);
2091 for (uint32_t i = 0; i < 8; i++) {
2092 params->avx2.bias[i] = bias;
2093 params->avx2.a_multiplier[i] = a_multiplier;
2094 params->avx2.b_multiplier[i] = b_multiplier;
2095 params->avx2.rounding[i] = rounding;
2096 params->avx2.shift[i] = shift;
2097 }
2098 for (uint32_t i = 0; i < 16; i++) {
2099 params->avx2.output_zero_point[i] = (int16_t) output_zero_point;
2100 params->avx2.output_min[i] = output_min;
2101 params->avx2.output_max[i] = output_max;
2102 }
2103}
Marat Dukhan66913242021-07-20 16:11:23 -07002104#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
2105
2106#if XNN_ARCH_ARM || XNN_ARCH_ARM64
2107void xnn_init_qs8_add_minmax_neon_params(
2108 union xnn_qs8_add_minmax_params params[XNN_MIN_ELEMENTS(1)],
2109 int8_t a_zero_point,
2110 int8_t b_zero_point,
2111 int8_t output_zero_point,
2112 float a_output_scale,
2113 float b_output_scale,
2114 int8_t output_min,
2115 int8_t output_max)
2116{
2117 assert(a_output_scale >= 0x1.0p-10f);
2118 assert(b_output_scale >= 0x1.0p-10f);
2119 assert(a_output_scale < 0x1.0p+8f);
2120 assert(b_output_scale < 0x1.0p+8f);
2121
2122 // Compute requantization parameters.
2123 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
2124 assert(max_output_scale >= 0x1.0p-10f);
2125 assert(max_output_scale < 0x1.0p+8f);
2126 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
2127 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
2128
2129 // Shift is in [12, 30] range.
2130 const uint32_t shift = (uint32_t) (20 /* multiplier bits */ - max_scale_exponent);
2131 assert(shift <= 30);
2132 assert(shift >= 12);
2133
2134 // Multipliers are in [0, 2**21) range, largest multiplier is in [2**20, 2**21) range.
2135 const int32_t a_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
2136 const int32_t b_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
2137 assert(math_max_s32(a_multiplier, b_multiplier) >= INT32_C(0x00100000));
2138 assert(a_multiplier < INT32_C(0x00200000));
2139 assert(b_multiplier < INT32_C(0x00200000));
2140
2141 params->neon.a_zero_point = a_zero_point;
2142 params->neon.b_zero_point = b_zero_point;
2143 params->neon.a_multiplier = (int32_t) a_multiplier;
2144 params->neon.b_multiplier = (int32_t) b_multiplier;
2145 params->neon.right_shift = (int32_t) -shift;
2146 params->neon.output_zero_point = (int16_t) output_zero_point;
2147 params->neon.output_min = output_min;
2148 params->neon.output_max = output_max;
2149}
2150#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
2151
2152#if XNN_ARCH_WASMSIMD
2153void xnn_init_qs8_add_minmax_wasmsimd_params(
2154 union xnn_qs8_add_minmax_params params[XNN_MIN_ELEMENTS(1)],
2155 int8_t a_zero_point,
2156 int8_t b_zero_point,
2157 int8_t output_zero_point,
2158 float a_output_scale,
2159 float b_output_scale,
2160 int8_t output_min,
2161 int8_t output_max)
2162{
2163 assert(a_output_scale >= 0x1.0p-10f);
2164 assert(b_output_scale >= 0x1.0p-10f);
2165 assert(a_output_scale < 0x1.0p+8f);
2166 assert(b_output_scale < 0x1.0p+8f);
2167
2168 // Compute requantization parameters.
2169 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
2170 assert(max_output_scale >= 0x1.0p-10f);
2171 assert(max_output_scale < 0x1.0p+8f);
2172 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
2173 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
2174
2175 // Shift is in [12, 30] range.
2176 const uint32_t shift = (uint32_t) (20 /* multiplier bits */ - max_scale_exponent);
2177 assert(shift <= 30);
2178 assert(shift >= 12);
2179
2180 // Multipliers are in [0, 2**21) range, largest multiplier is in [2**20, 2**21) range.
2181 const int32_t a_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
2182 const int32_t b_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
2183 assert(math_max_s32(a_multiplier, b_multiplier) >= INT32_C(0x00100000));
2184 assert(a_multiplier < INT32_C(0x00200000));
2185 assert(b_multiplier < INT32_C(0x00200000));
2186
2187 const int32_t rounding = INT32_C(1) << (shift - 1);
2188 const int32_t bias = (int32_t) -(a_multiplier * (int32_t) a_zero_point + b_multiplier * (int32_t) b_zero_point);
2189 for (uint32_t i = 0; i < 4; i++) {
2190 params->wasmsimd.bias[i] = bias;
2191 params->wasmsimd.a_multiplier[i] = a_multiplier;
2192 params->wasmsimd.b_multiplier[i] = b_multiplier;
2193 params->wasmsimd.rounding[i] = rounding;
2194 }
2195 params->wasmsimd.shift = shift;
2196 for (uint32_t i = 0; i < 8; i++) {
2197 params->wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
2198 }
2199 for (uint32_t i = 0; i < 16; i++) {
2200 params->wasmsimd.output_min[i] = output_min;
2201 params->wasmsimd.output_max[i] = output_max;
2202 }
2203}
2204#endif // XNN_ARCH_WASMSIMD
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002205
Marat Dukhan6e0fc392021-07-19 18:38:24 -07002206void xnn_init_qs8_add_minmax_scalar_params(
2207 union xnn_qs8_add_minmax_params params[XNN_MIN_ELEMENTS(1)],
Marat Dukhan49d90052021-07-19 19:59:30 -07002208 int8_t a_zero_point,
2209 int8_t b_zero_point,
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002210 int8_t output_zero_point,
Marat Dukhan49d90052021-07-19 19:59:30 -07002211 float a_output_scale,
2212 float b_output_scale,
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002213 int8_t output_min,
2214 int8_t output_max)
2215{
Marat Dukhan49d90052021-07-19 19:59:30 -07002216 assert(a_output_scale >= 0x1.0p-10f);
2217 assert(b_output_scale >= 0x1.0p-10f);
2218 assert(a_output_scale < 0x1.0p+8f);
2219 assert(b_output_scale < 0x1.0p+8f);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002220
2221 // Compute requantization parameters.
Marat Dukhan49d90052021-07-19 19:59:30 -07002222 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002223 assert(max_output_scale >= 0x1.0p-10f);
2224 assert(max_output_scale < 0x1.0p+8f);
2225 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
2226 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002227
Marat Dukhanf0ebd4b2021-07-19 20:16:17 -07002228 // Shift is in [12, 30] range.
2229 const uint32_t shift = (uint32_t) (20 /* multiplier bits */ - max_scale_exponent);
2230 assert(shift <= 30);
2231 assert(shift >= 12);
2232
2233 // Multipliers are in [0, 2**21) range, largest multiplier is in [2**20, 2**21) range.
Marat Dukhan49d90052021-07-19 19:59:30 -07002234 const int32_t a_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
2235 const int32_t b_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
Marat Dukhanf0ebd4b2021-07-19 20:16:17 -07002236 assert(math_max_s32(a_multiplier, b_multiplier) >= INT32_C(0x00100000));
2237 assert(a_multiplier < INT32_C(0x00200000));
2238 assert(b_multiplier < INT32_C(0x00200000));
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002239
Marat Dukhane6a48052021-07-19 20:46:35 -07002240 const int32_t rounding = INT32_C(1) << (shift - 1);
2241 params->scalar.bias = (int32_t) -(a_multiplier * (int32_t) a_zero_point + b_multiplier * (int32_t) b_zero_point);
Marat Dukhan49d90052021-07-19 19:59:30 -07002242 params->scalar.a_multiplier = a_multiplier;
2243 params->scalar.b_multiplier = b_multiplier;
Marat Dukhane6a48052021-07-19 20:46:35 -07002244 params->scalar.rounding = rounding;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002245 params->scalar.shift = shift;
Marat Dukhane6a48052021-07-19 20:46:35 -07002246 params->scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
2247 params->scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002248 params->scalar.output_zero_point = (int32_t) output_zero_point;
Marat Dukhanfcfdd2d2021-06-29 18:57:02 -07002249}