blob: 8fce26f19b152ff439d1bfce17861f8c4c45e583 [file] [log] [blame]
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#if defined(__cplusplus) && (__cplusplus >= 201103L)
12 #include <cstdint>
13 #include <cstddef>
14 #include <cassert>
15 #include <cmath>
16#else
17 #include <stdint.h>
18 #include <stddef.h>
19 #include <assert.h>
20 #include <math.h>
21#endif
22
23#include <fp16.h>
24
25#include <xnnpack/common.h>
Marat Dukhan2b9efd82020-06-08 01:09:31 -070026#include <xnnpack/math.h>
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070027#include <xnnpack/params.h>
28
29
Marat Dukhan08b7a972020-07-14 18:17:29 -070030static inline union xnn_qu8_gemm_params xnn_init_scalar_qu8_gemm_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070031 uint8_t kernel_zero_point,
32 float scale,
33 uint8_t output_zero_point,
34 uint8_t output_min,
35 uint8_t output_max)
36{
37 // Compute requantization parameters
38 const uint32_t scale_bits = fp32_to_bits(scale);
39
40 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
41 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
42 assert(multiplier >= INT32_C(0x40000000));
43 assert(multiplier <= INT32_C(0x7FFFFF80));
44
45 // Shift is in [0, 31] range.
46 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
47 assert(shift >= 0);
48 assert(shift < 32);
49
50 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
51 const uint32_t remainder_threshold = remainder_mask >> 1;
52
Marat Dukhan08b7a972020-07-14 18:17:29 -070053 union xnn_qu8_gemm_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070054 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
55 params.scalar.multiplier = multiplier;
56 params.scalar.remainder_mask = (int32_t) remainder_mask;
57 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
58 params.scalar.shift = (uint32_t) shift;
59 params.scalar.output_min_less_zero_point =
60 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
61 params.scalar.output_max_less_zero_point =
62 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
63 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
64 return params;
65}
66
Marat Dukhan08b7a972020-07-14 18:17:29 -070067static inline union xnn_qu8_gemm_params xnn_init_qu8_gemm_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070068 uint8_t kernel_zero_point,
69 float scale,
70 uint8_t output_zero_point,
71 uint8_t output_min,
72 uint8_t output_max)
73{
74 // Compute requantization parameters.
75 const uint32_t scale_bits = fp32_to_bits(scale);
76
77 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
78 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
79 assert(multiplier >= INT32_C(0x40000000));
80 assert(multiplier <= INT32_C(0x7FFFFF80));
81
82 // Shift is in [0, 31] range.
83 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
84 assert(shift >= 0);
85 assert(shift < 32);
86
Marat Dukhan08b7a972020-07-14 18:17:29 -070087 union xnn_qu8_gemm_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070088 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
89 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
90 const uint32_t remainder_threshold = remainder_mask >> 1;
91 for (uint32_t i = 0; i < 8; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070092 params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
93 }
94 params.sse2.multiplier[0] = multiplier;
95 params.sse2.multiplier[1] = multiplier;
96 params.sse2.multiplier[2] = multiplier;
97 params.sse2.multiplier[3] = multiplier;
98 params.sse2.rounding[0] = UINT64_C(0x40000000);
99 params.sse2.rounding[1] = UINT64_C(0x40000000);
100 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
101 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
102 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
103 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
104 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
105 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
106 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
107 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
108 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
109 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
110 for (uint32_t i = 0; i < 8; i++) {
111 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
112 }
113 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700114 params.sse2.output_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700115 params.sse2.output_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700116 }
117 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhan91992462020-07-30 00:06:34 -0700118 params.neon.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700119 params.neon.multiplier = multiplier;
120 params.neon.right_shift = -shift;
121 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700122 params.neon.output_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700123 params.neon.output_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700124 #else
125 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
126 const uint32_t remainder_threshold = remainder_mask >> 1;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700127 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
128 params.scalar.multiplier = multiplier;
129 params.scalar.remainder_mask = (int32_t) remainder_mask;
130 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
131 params.scalar.shift = (uint32_t) shift;
132 params.scalar.output_min_less_zero_point =
133 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
134 params.scalar.output_max_less_zero_point =
135 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
136 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
137 #endif
138 return params;
139}
140
Marat Dukhan595e1702020-07-31 10:12:52 -0700141static inline union xnn_qs8_gemm_params xnn_init_scalar_qs8_gemm_params(
142 float scale,
143 int8_t output_zero_point,
144 int8_t output_min,
145 int8_t output_max)
146{
147 // Compute requantization parameters
148 const uint32_t scale_bits = fp32_to_bits(scale);
149
150 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
151 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
152 assert(multiplier >= INT32_C(0x40000000));
153 assert(multiplier <= INT32_C(0x7FFFFF80));
154
155 // Shift is in [0, 31] range.
156 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
157 assert(shift >= 0);
158 assert(shift < 32);
159
160 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
161 const uint32_t remainder_threshold = remainder_mask >> 1;
162
163 union xnn_qs8_gemm_params params;
164 params.scalar.multiplier = multiplier;
165 params.scalar.remainder_mask = (int32_t) remainder_mask;
166 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
167 params.scalar.shift = (uint32_t) shift;
168 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
169 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
170 params.scalar.output_zero_point = (int32_t) output_zero_point;
171 return params;
172}
173
174static inline union xnn_qs8_gemm_params xnn_init_qs8_gemm_params(
175 float scale,
176 int8_t output_zero_point,
177 int8_t output_min,
178 int8_t output_max)
179{
180 // Compute requantization parameters.
181 const uint32_t scale_bits = fp32_to_bits(scale);
182
183 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
184 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
185 assert(multiplier >= INT32_C(0x40000000));
186 assert(multiplier <= INT32_C(0x7FFFFF80));
187
188 // Shift is in [0, 31] range.
189 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
190 assert(shift >= 0);
191 assert(shift < 32);
192
193 union xnn_qs8_gemm_params params;
194 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
195 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
196 const uint32_t remainder_threshold = remainder_mask >> 1;
197 params.sse2.multiplier[0] = multiplier;
198 params.sse2.multiplier[1] = multiplier;
199 params.sse2.multiplier[2] = multiplier;
200 params.sse2.multiplier[3] = multiplier;
201 params.sse2.rounding[0] = UINT64_C(0x40000000);
202 params.sse2.rounding[1] = UINT64_C(0x40000000);
203 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
204 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
205 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
206 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
207 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
208 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
209 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
210 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
211 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
212 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
213 for (uint32_t i = 0; i < 8; i++) {
214 params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
215 params.sse2.output_min[i] = (int16_t) output_min;
216 params.sse2.output_max[i] = (int16_t) output_max;
217 }
218 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
219 params.neon.multiplier = multiplier;
220 params.neon.right_shift = -shift;
221 params.neon.output_zero_point = (int16_t) output_zero_point;
222 params.neon.output_min = output_min;
223 params.neon.output_max = output_max;
Marat Dukhan27203da2020-08-05 15:19:03 -0700224 #elif XNN_ARCH_WASMSIMD
225 const int64_t twice_multiplier = INT64_C(2) * (int64_t) multiplier;
226 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
227 const uint32_t remainder_threshold = remainder_mask >> 1;
228 params.wasmsimd.multiplier[0] = twice_multiplier;
229 params.wasmsimd.multiplier[1] = twice_multiplier;
230 params.wasmsimd.rounding[0] = INT64_C(0x80000000);
231 params.wasmsimd.rounding[1] = INT64_C(0x80000000);
232 params.wasmsimd.remainder_mask[0] = (int32_t) remainder_mask;
233 params.wasmsimd.remainder_mask[1] = (int32_t) remainder_mask;
234 params.wasmsimd.remainder_mask[2] = (int32_t) remainder_mask;
235 params.wasmsimd.remainder_mask[3] = (int32_t) remainder_mask;
236 params.wasmsimd.remainder_threshold[0] = (int32_t) remainder_threshold;
237 params.wasmsimd.remainder_threshold[1] = (int32_t) remainder_threshold;
238 params.wasmsimd.remainder_threshold[2] = (int32_t) remainder_threshold;
239 params.wasmsimd.remainder_threshold[3] = (int32_t) remainder_threshold;
240 params.wasmsimd.shift = shift;
241 for (uint32_t i = 0; i < 8; i++) {
242 params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
243 }
244 for (uint32_t i = 0; i < 16; i++) {
245 params.wasmsimd.output_min[i] = output_min;
246 params.wasmsimd.output_max[i] = output_max;
247 }
Marat Dukhan595e1702020-07-31 10:12:52 -0700248 #else
249 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
250 const uint32_t remainder_threshold = remainder_mask >> 1;
251 params.scalar.multiplier = multiplier;
252 params.scalar.remainder_mask = (int32_t) remainder_mask;
253 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
254 params.scalar.shift = (uint32_t) shift;
255 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
256 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
257 params.scalar.output_zero_point = (int32_t) output_zero_point;
258 #endif
259 return params;
260}
261
Marat Dukhan683fab32020-08-03 19:42:52 -0700262static inline union xnn_qs8_gemm_xw_params xnn_init_scalar_qs8_gemm_xw_params(
263 float scale,
264 int8_t output_zero_point,
265 int8_t output_min,
266 int8_t output_max)
267{
268 union {
269 union xnn_qs8_gemm_xw_params gemm_xw;
270 union xnn_qs8_gemm_params gemm;
271 } params;
272 params.gemm = xnn_init_scalar_qs8_gemm_params(scale, output_zero_point, output_min, output_max);
273 return params.gemm_xw;
274}
275
276static inline union xnn_qs8_gemm_xw_params xnn_init_qs8_gemm_xw_params(
277 float scale,
278 int8_t output_zero_point,
279 int8_t output_min,
280 int8_t output_max)
281{
282 union {
283 union xnn_qs8_gemm_xw_params gemm_xw;
284 union xnn_qs8_gemm_params gemm;
285 } params;
286 params.gemm = xnn_init_qs8_gemm_params(scale, output_zero_point, output_min, output_max);
287 return params.gemm_xw;
288}
289
Marat Dukhan08b7a972020-07-14 18:17:29 -0700290static inline union xnn_qu8_avgpool_params xnn_init_qu8_avgpool_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700291 int32_t bias,
292 float scale,
293 uint8_t output_zero_point,
294 uint8_t output_min,
295 uint8_t output_max)
296{
297 // Compute requantization parameters.
298 assert(scale >= 0x1.0p-32f);
299 assert(scale < 256.0f);
300 const uint32_t scale_bits = fp32_to_bits(scale);
301
302 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
303 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
304 assert(multiplier >= INT32_C(0x00800000));
305 assert(multiplier <= INT32_C(0x00FFFFFF));
306
307 // Shift is in [16, 55] range.
308 const int32_t shift = 127 + 23 - (scale_bits >> 23);
309 assert(shift >= 16);
310 assert(shift < 64);
311
Marat Dukhan08b7a972020-07-14 18:17:29 -0700312 union xnn_qu8_avgpool_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700313 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
314 const uint32_t right_shift = (uint32_t) shift;
315 const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
316 params.sse2.bias[0] = bias;
317 params.sse2.bias[1] = bias;
318 params.sse2.bias[2] = bias;
319 params.sse2.bias[3] = bias;
320 params.sse2.multiplier[0] = (uint32_t) multiplier;
321 params.sse2.multiplier[1] = (uint32_t) multiplier;
322 params.sse2.multiplier[2] = (uint32_t) multiplier;
323 params.sse2.multiplier[3] = (uint32_t) multiplier;
324 params.sse2.rounding[0] = rounding;
325 params.sse2.rounding[1] = rounding;
326 params.sse2.right_shift[0] = (uint64_t) right_shift;
327 params.sse2.right_shift[1] = (uint64_t) right_shift;
328 for (uint32_t i = 0; i < 8; i++) {
329 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
330 }
331 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700332 params.sse2.output_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700333 params.sse2.output_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700334 }
335 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
336 params.neon.bias = bias;
337 params.neon.multiplier = multiplier;
338 params.neon.left_shift = (int64_t) -shift;
339 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700340 params.neon.output_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700341 params.neon.output_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700342 #else
343 const uint32_t right_shift = (uint32_t) shift;
344 const int64_t rounding = INT64_C(1) << (right_shift - 1);
345 params.scalar.bias = bias;
346 params.scalar.multiplier = multiplier;
347 params.scalar.rounding = rounding;
348 params.scalar.right_shift = right_shift;
349 params.scalar.output_min_less_zero_point =
350 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
351 params.scalar.output_max_less_zero_point =
352 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
353 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
354 #endif
355 return params;
356}
357
Marat Dukhan08b7a972020-07-14 18:17:29 -0700358static inline union xnn_qu8_avgpool_params xnn_init_scalar_qu8_avgpool_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700359 int32_t bias,
360 float scale,
361 uint8_t output_zero_point,
362 uint8_t output_min,
363 uint8_t output_max)
364{
365 // Compute requantization parameters.
366 assert(scale >= 0x1.0p-32f);
367 assert(scale < 256.0f);
368 const uint32_t scale_bits = fp32_to_bits(scale);
369
370 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
371 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
372 assert(multiplier >= INT32_C(0x00800000));
373 assert(multiplier <= INT32_C(0x00FFFFFF));
374
375 // Shift is in [16, 55] range.
376 const int32_t shift = 127 + 23 - (scale_bits >> 23);
377 assert(shift >= 16);
378 assert(shift < 64);
379
Marat Dukhan08b7a972020-07-14 18:17:29 -0700380 union xnn_qu8_avgpool_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700381 const uint32_t right_shift = (uint32_t) shift;
382 const int64_t rounding = INT64_C(1) << (right_shift - 1);
383 params.scalar.bias = bias;
384 params.scalar.rounding = rounding;
385 params.scalar.multiplier = multiplier;
386 params.scalar.right_shift = right_shift;
387 params.scalar.output_min_less_zero_point =
388 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
389 params.scalar.output_max_less_zero_point =
390 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
391 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
392 return params;
393}
394
Marat Dukhan54e95a02020-08-06 23:55:13 -0700395static inline void xnn_update_qu8_avgpool_params(
396 union xnn_qu8_avgpool_params* params,
397 int32_t bias,
398 float scale)
399{
400 // Compute requantization parameters.
401 assert(scale >= 0x1.0p-32f);
402 assert(scale < 256.0f);
403 const uint32_t scale_bits = fp32_to_bits(scale);
404
405 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
406 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
407 assert(multiplier >= INT32_C(0x00800000));
408 assert(multiplier <= INT32_C(0x00FFFFFF));
409
410 // Shift is in [16, 55] range.
411 const int32_t shift = 127 + 23 - (scale_bits >> 23);
412 assert(shift >= 16);
413 assert(shift < 64);
414
415 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
416 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
417 params->sse2.bias[0] = bias;
418 params->sse2.bias[1] = bias;
419 params->sse2.bias[2] = bias;
420 params->sse2.bias[3] = bias;
421 params->sse2.multiplier[0] = (uint32_t) multiplier;
422 params->sse2.multiplier[1] = (uint32_t) multiplier;
423 params->sse2.multiplier[2] = (uint32_t) multiplier;
424 params->sse2.multiplier[3] = (uint32_t) multiplier;
425 params->sse2.rounding[0] = rounding;
426 params->sse2.rounding[1] = rounding;
427 params->sse2.right_shift[0] = (uint64_t) (uint32_t) shift;
428 params->sse2.right_shift[1] = (uint64_t) (uint32_t) shift;
429 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
430 params->neon.bias = bias;
431 params->neon.multiplier = multiplier;
432 params->neon.left_shift = (int64_t) -shift;
433 #else
434 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
435 params->scalar.bias = bias;
436 params->scalar.multiplier = multiplier;
437 params->scalar.rounding = rounding;
438 params->scalar.right_shift = (uint32_t) shift;
439 #endif
440}
441
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700442static inline union xnn_qs8_avgpool_params xnn_init_qs8_avgpool_params(
443 int32_t bias,
444 float scale,
445 int8_t output_zero_point,
446 int8_t output_min,
447 int8_t output_max)
448{
449 // Compute requantization parameters.
450 assert(scale >= 0x1.0p-32f);
451 assert(scale < 256.0f);
452 const uint32_t scale_bits = fp32_to_bits(scale);
453
454 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
455 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
456 assert(multiplier >= INT32_C(0x00800000));
457 assert(multiplier <= INT32_C(0x00FFFFFF));
458
459 // Shift is in [16, 55] range.
460 const int32_t shift = 127 + 23 - (scale_bits >> 23);
461 assert(shift >= 16);
462 assert(shift < 64);
463
464 union xnn_qs8_avgpool_params params;
465 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
Marat Dukhanef451802020-08-06 11:53:47 -0700466 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700467 params.sse2.bias[0] = bias;
468 params.sse2.bias[1] = bias;
469 params.sse2.bias[2] = bias;
470 params.sse2.bias[3] = bias;
471 params.sse2.multiplier[0] = (uint32_t) multiplier;
472 params.sse2.multiplier[1] = (uint32_t) multiplier;
473 params.sse2.multiplier[2] = (uint32_t) multiplier;
474 params.sse2.multiplier[3] = (uint32_t) multiplier;
475 params.sse2.rounding[0] = rounding;
476 params.sse2.rounding[1] = rounding;
Marat Dukhanef451802020-08-06 11:53:47 -0700477 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
478 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700479 for (uint32_t i = 0; i < 8; i++) {
480 params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
481 params.sse2.output_min[i] = (int16_t) output_min;
482 params.sse2.output_max[i] = (int16_t) output_max;
483 }
484 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
485 params.neon.bias = bias;
486 params.neon.multiplier = multiplier;
487 params.neon.left_shift = (int64_t) -shift;
488 params.neon.output_zero_point = (int16_t) output_zero_point;
489 params.neon.output_min = output_min;
490 params.neon.output_max = output_max;
Marat Dukhanef451802020-08-06 11:53:47 -0700491 #elif XNN_ARCH_WASMSIMD
492 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
493 params.wasmsimd.bias[0] = bias;
494 params.wasmsimd.bias[1] = bias;
495 params.wasmsimd.bias[2] = bias;
496 params.wasmsimd.bias[3] = bias;
497 params.wasmsimd.multiplier[0] = (int64_t) multiplier;
498 params.wasmsimd.multiplier[1] = (int64_t) multiplier;
499 params.wasmsimd.rounding[0] = rounding;
500 params.wasmsimd.rounding[1] = rounding;
501 params.wasmsimd.shift = shift;
502 for (uint32_t i = 0; i < 8; i++) {
503 params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
504 }
505 for (uint32_t i = 0; i < 16; i++) {
506 params.wasmsimd.output_min[i] = output_min;
507 params.wasmsimd.output_max[i] = output_max;
508 }
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700509 #else
Marat Dukhanef451802020-08-06 11:53:47 -0700510 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700511 params.scalar.bias = bias;
512 params.scalar.multiplier = multiplier;
513 params.scalar.rounding = rounding;
Marat Dukhanef451802020-08-06 11:53:47 -0700514 params.scalar.shift = (uint32_t) shift;
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700515 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
516 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
517 params.scalar.output_zero_point = (int32_t) output_zero_point;
518 #endif
519 return params;
520}
521
522static inline union xnn_qs8_avgpool_params xnn_init_scalar_qs8_avgpool_params(
523 int32_t bias,
524 float scale,
525 int8_t output_zero_point,
526 int8_t output_min,
527 int8_t output_max)
528{
529 // Compute requantization parameters.
530 assert(scale >= 0x1.0p-32f);
531 assert(scale < 256.0f);
532 const uint32_t scale_bits = fp32_to_bits(scale);
533
534 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
535 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
536 assert(multiplier >= INT32_C(0x00800000));
537 assert(multiplier <= INT32_C(0x00FFFFFF));
538
539 // Shift is in [16, 55] range.
540 const int32_t shift = 127 + 23 - (scale_bits >> 23);
541 assert(shift >= 16);
542 assert(shift < 64);
543
544 union xnn_qs8_avgpool_params params;
Marat Dukhanef451802020-08-06 11:53:47 -0700545 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700546 params.scalar.bias = bias;
547 params.scalar.rounding = rounding;
548 params.scalar.multiplier = multiplier;
Marat Dukhanef451802020-08-06 11:53:47 -0700549 params.scalar.shift = shift;
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700550 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
551 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
552 params.scalar.output_zero_point = (int32_t) output_zero_point;
553 return params;
554}
555
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700556static inline void xnn_update_qs8_avgpool_params(
557 union xnn_qs8_avgpool_params* params,
558 int32_t bias,
559 float scale)
560{
561 // Compute requantization parameters.
562 assert(scale >= 0x1.0p-32f);
563 assert(scale < 256.0f);
564 const uint32_t scale_bits = fp32_to_bits(scale);
565
566 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
567 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
568 assert(multiplier >= INT32_C(0x00800000));
569 assert(multiplier <= INT32_C(0x00FFFFFF));
570
571 // Shift is in [16, 55] range.
572 const int32_t shift = 127 + 23 - (scale_bits >> 23);
573 assert(shift >= 16);
574 assert(shift < 64);
575
576 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
577 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
578 params->sse2.bias[0] = bias;
579 params->sse2.bias[1] = bias;
580 params->sse2.bias[2] = bias;
581 params->sse2.bias[3] = bias;
582 params->sse2.multiplier[0] = (uint32_t) multiplier;
583 params->sse2.multiplier[1] = (uint32_t) multiplier;
584 params->sse2.multiplier[2] = (uint32_t) multiplier;
585 params->sse2.multiplier[3] = (uint32_t) multiplier;
586 params->sse2.rounding[0] = rounding;
587 params->sse2.rounding[1] = rounding;
588 params->sse2.shift[0] = (uint64_t) (uint32_t) shift;
589 params->sse2.shift[1] = (uint64_t) (uint32_t) shift;
590 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
591 params->neon.bias = bias;
592 params->neon.multiplier = multiplier;
593 params->neon.left_shift = (int64_t) -shift;
Marat Dukhan4de076d2020-08-08 23:29:52 -0700594 #elif XNN_ARCH_WASMSIMD
595 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
596 params->wasmsimd.bias[0] = bias;
597 params->wasmsimd.bias[1] = bias;
598 params->wasmsimd.bias[2] = bias;
599 params->wasmsimd.bias[3] = bias;
600 params->wasmsimd.multiplier[0] = (int64_t) multiplier;
601 params->wasmsimd.multiplier[1] = (int64_t) multiplier;
602 params->wasmsimd.rounding[0] = rounding;
603 params->wasmsimd.rounding[1] = rounding;
604 params->wasmsimd.shift = shift;
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700605 #else
606 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
607 params->scalar.bias = bias;
608 params->scalar.multiplier = multiplier;
609 params->scalar.rounding = rounding;
610 params->scalar.shift = (uint32_t) shift;
611 #endif
612}
613
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700614static inline void xnn_update_f16_scaleminmax_params(
615 struct xnn_f16_scaleminmax_params* params,
616 uint16_t scale)
617{
618 params->scale = scale;
619}
620
Marat Dukhan8452ff52020-04-08 20:44:58 -0700621static inline void xnn_update_f32_scaleminmax_params(
622 union xnn_f32_scaleminmax_params* params,
623 float scale)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700624{
625 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
626 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhan8452ff52020-04-08 20:44:58 -0700627 params->sse2.scale[i] = scale;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700628 }
629 #else
Marat Dukhan8452ff52020-04-08 20:44:58 -0700630 params->scalar.scale = scale;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700631 #endif
632}
633
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700634static inline struct xnn_f16_scaleminmax_params xnn_init_f16_scaleminmax_params(
635 uint16_t scale,
636 uint16_t min,
637 uint16_t max)
638{
639 struct xnn_f16_scaleminmax_params params;
640 params.scale = scale;
641 params.min = min;
642 params.max = max;
643 return params;
644}
645
Marat Dukhan8452ff52020-04-08 20:44:58 -0700646static inline union xnn_f32_scaleminmax_params xnn_init_f32_scaleminmax_params(
647 float scale,
648 float min,
649 float max)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700650{
Marat Dukhan8452ff52020-04-08 20:44:58 -0700651 union xnn_f32_scaleminmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700652 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
653 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhan8452ff52020-04-08 20:44:58 -0700654 params.sse2.scale[i] = scale;
655 params.sse2.min[i] = min;
656 params.sse2.max[i] = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700657 }
658 #else
Marat Dukhan8452ff52020-04-08 20:44:58 -0700659 params.scalar.scale = scale;
660 params.scalar.min = min;
661 params.scalar.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700662 #endif
663 return params;
664}
665
666static inline union xnn_f32_gavgpool_params xnn_init_f32_gavgpool_params(
667 float multiplier,
668 float output_min,
669 float output_max,
670 uint32_t width)
671{
672 union xnn_f32_gavgpool_params params;
673 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
674 for (uint32_t i = 0; i < 4; i++) {
675 params.sse.multiplier[i] = multiplier;
676 params.sse.output_min[i] = output_min;
677 params.sse.output_max[i] = output_max;
678 }
679
680 const uint32_t w = (width - 1) & 3;
681 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
682 params.sse.mask[1] = -(uint32_t) (w >= 1);
683 params.sse.mask[2] = -(uint32_t) (w >= 2);
684 params.sse.mask[3] = -(uint32_t) (w >= 3);
685 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
686 params.neon.multiplier = multiplier;
687 params.neon.output_min = output_min;
688 params.neon.output_max = output_max;
689
690 const uint32_t w = (width - 1) & 3;
691 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
692 params.neon.mask[1] = -(uint32_t) (w >= 1);
693 params.neon.mask[2] = -(uint32_t) (w >= 2);
694 params.neon.mask[3] = -(uint32_t) (w >= 3);
695 #else
696 params.scalar.multiplier = multiplier;
697 params.scalar.output_min = output_min;
698 params.scalar.output_max = output_max;
Erich Elsen6f278b52020-06-10 16:13:11 -0700699
700 const uint32_t w = (width - 1) & 3;
701 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
702 params.scalar.mask[1] = -(int32_t) (w >= 1);
703 params.scalar.mask[2] = -(int32_t) (w >= 2);
704 params.scalar.mask[3] = -(int32_t) (w >= 3);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700705 #endif
706 return params;
707}
708
709static inline void xnn_update_f32_gavgpool_params(
710 union xnn_f32_gavgpool_params* params,
711 float multiplier,
712 uint32_t width)
713{
714 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
715 for (uint32_t i = 0; i < 4; i++) {
716 params->sse.multiplier[i] = multiplier;
717 }
718
719 const uint32_t w = (width - 1) & 3;
720 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
721 params->sse.mask[1] = -(uint32_t) (w >= 1);
722 params->sse.mask[2] = -(uint32_t) (w >= 2);
723 params->sse.mask[3] = -(uint32_t) (w >= 3);
724 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
725 params->neon.multiplier = multiplier;
726
727 const uint32_t w = (width - 1) & 3;
728 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
729 params->neon.mask[1] = -(uint32_t) (w >= 1);
730 params->neon.mask[2] = -(uint32_t) (w >= 2);
731 params->neon.mask[3] = -(uint32_t) (w >= 3);
732 #else
733 params->scalar.multiplier = multiplier;
Erich Elsen6f278b52020-06-10 16:13:11 -0700734
735 const uint32_t w = (width - 1) & 3;
736 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
Marat Dukhan548e8302020-08-04 15:03:46 -0700737 params->scalar.mask[1] = -(int32_t) (w >= 1);
738 params->scalar.mask[2] = -(int32_t) (w >= 2);
739 params->scalar.mask[3] = -(int32_t) (w >= 3);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700740 #endif
741}
742
Marat Dukhan8452ff52020-04-08 20:44:58 -0700743static inline union xnn_f32_scaleminmax_params xnn_init_scalar_f32_scaleminmax_params(
744 float scale,
745 float min,
746 float max)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700747{
Marat Dukhan8452ff52020-04-08 20:44:58 -0700748 union xnn_f32_scaleminmax_params params;
749 params.scalar.scale = scale;
750 params.scalar.min = min;
751 params.scalar.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700752 return params;
753}
754
755static inline union xnn_f32_gavgpool_params xnn_init_scalar_f32_gavgpool_params(
756 float multiplier,
757 float output_min,
758 float output_max,
759 uint32_t width)
760{
761 union xnn_f32_gavgpool_params params;
762 params.scalar.multiplier = multiplier;
763 params.scalar.output_min = output_min;
764 params.scalar.output_max = output_max;
Erich Elsen6f278b52020-06-10 16:13:11 -0700765
766 const uint32_t w = (width - 1) & 3;
767 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
768 params.scalar.mask[1] = -(int32_t) (w >= 1);
769 params.scalar.mask[2] = -(int32_t) (w >= 2);
770 params.scalar.mask[3] = -(int32_t) (w >= 3);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700771 return params;
772}
773
Frank Barchardd793f6c2020-05-08 13:37:43 -0700774static inline struct xnn_f16_minmax_params xnn_init_f16_minmax_params(
775 uint16_t min,
776 uint16_t max)
777{
778 struct xnn_f16_minmax_params params;
779 params.min = min;
780 params.max = max;
781 return params;
782}
783
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700784static inline union xnn_f32_minmax_params xnn_init_f32_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700785 float output_min,
786 float output_max)
787{
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700788 union xnn_f32_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700789 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
790 for (uint32_t i = 0; i < 4; i++) {
791 params.sse.min[i] = output_min;
792 params.sse.max[i] = output_max;
793 }
794 #else
795 params.scalar.min = output_min;
796 params.scalar.max = output_max;
797 #endif
798 return params;
799}
800
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700801static inline union xnn_f32_minmax_params xnn_init_scalar_f32_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700802 float output_min,
803 float output_max)
804{
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700805 union xnn_f32_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700806 params.scalar.min = output_min;
807 params.scalar.max = output_max;
808 return params;
809}
810
Frank Barchardb1966592020-05-12 13:47:06 -0700811static inline struct xnn_f16_hswish_params xnn_init_f16_hswish_params(void)
812{
813 struct xnn_f16_hswish_params params;
Marat Dukhan201ea0e2020-07-10 22:47:19 -0700814 params.sixth = UINT16_C(0x3155);
815 params.three = UINT16_C(0x4200);
816 params.six = UINT16_C(0x4600);
Frank Barchardb1966592020-05-12 13:47:06 -0700817 return params;
818}
819
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700820static inline union xnn_f32_hswish_params xnn_init_f32_hswish_params(void)
821{
822 union xnn_f32_hswish_params params;
823 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
824 for (uint32_t i = 0; i < 4; i++) {
825 params.sse.sixth[i] = 0x1.555556p-3f;
826 params.sse.half[i] = 0.5f;
827 params.sse.one[i] = 1.0f;
828 }
829 #else
830 params.scalar.sixth = 0x1.555556p-3f;
Marat Dukhan9df9dc62020-07-10 20:08:49 -0700831 params.scalar.three = 3.0f;
832 params.scalar.six = 6.0f;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700833 #endif
834 return params;
835}
836
837static inline union xnn_f32_hswish_params xnn_init_scalar_f32_hswish_params(void)
838{
839 union xnn_f32_hswish_params params;
840 params.scalar.sixth = 0x1.555556p-3f;
Marat Dukhan9df9dc62020-07-10 20:08:49 -0700841 params.scalar.three = 3.0f;
842 params.scalar.six = 6.0f;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700843 return params;
844}
845
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700846static inline union xnn_f32_abs_params xnn_init_f32_abs_params(void)
847{
848 union xnn_f32_abs_params params = { 0 };
849 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
850 for (uint32_t i = 0; i < 4; i++) {
851 params.sse.nonsign_mask[i] = math_nonsign_mask_f32();
852 }
Marat Dukhan37c83512020-06-29 13:25:53 -0700853 #elif XNN_ARCH_WASMSIMD
854 params.wasmsimd.nonsign_mask = math_nonsign_mask_f32();
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700855 #endif
856 return params;
857}
858
859static inline union xnn_f32_abs_params xnn_init_scalar_f32_abs_params(void)
860{
861 union xnn_f32_abs_params params = { 0 };
862 return params;
863}
864
865static inline union xnn_f32_neg_params xnn_init_f32_neg_params(void)
866{
867 union xnn_f32_neg_params params = { 0 };
868 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
869 for (uint32_t i = 0; i < 4; i++) {
870 params.sse.sign_mask[i] = -0.0f;
871 }
Marat Dukhan37c83512020-06-29 13:25:53 -0700872 #elif XNN_ARCH_WASMSIMD
873 params.wasmsimd.sign_mask = -0.0f;
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700874 #endif
875 return params;
876}
877
878static inline union xnn_f32_neg_params xnn_init_scalar_f32_neg_params(void)
879{
880 union xnn_f32_neg_params params = { 0 };
881 return params;
882}
883
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700884static inline union xnn_f32_rnd_params xnn_init_f32_rnd_params(void)
885{
886 union xnn_f32_rnd_params params = { 0 };
887 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
888 for (uint32_t i = 0; i < 4; i++) {
889 params.sse2.sign_mask[i] = -0.0f;
890 }
891 for (uint32_t i = 0; i < 4; i++) {
892 params.sse2.one[i] = 1.0f;
893 }
894 #endif
895 return params;
896}
897
898static inline union xnn_f32_rnd_params xnn_init_scalar_f32_rnd_params(void)
899{
900 union xnn_f32_rnd_params params = { 0 };
901 return params;
902}
903
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700904static inline union xnn_f32_lrelu_params xnn_init_f32_lrelu_params(float slope)
905{
906 union xnn_f32_lrelu_params params;
907 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
908 for (uint32_t i = 0; i < 4; i++) {
909 params.sse.slope[i] = slope;
910 }
911 #else
912 params.scalar.slope = slope;
913 #endif
914 return params;
915}
916
917static inline union xnn_f32_lrelu_params xnn_init_scalar_f32_lrelu_params(float slope)
918{
919 union xnn_f32_lrelu_params params;
920 params.scalar.slope = slope;
921 return params;
922}
923
Marat Dukhanf4db2f32020-06-30 10:55:30 -0700924static inline union xnn_f32_sqrt_params xnn_init_f32_sqrt_params(void)
925{
926 union xnn_f32_sqrt_params params = { 0 };
927 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
928 params.fma.half = 0.5f;
929 #endif
930 return params;
931}
932
933static inline union xnn_f32_sqrt_params xnn_init_scalar_f32_sqrt_params(void)
934{
935 union xnn_f32_sqrt_params params = { 0 };
936 return params;
937}
938
Marat Dukhan1f29b802020-05-15 23:46:39 -0700939static inline union xnn_f32_chw_params xnn_init_f32_chw_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700940 uint32_t width,
941 float output_min,
942 float output_max)
943{
Marat Dukhan1f29b802020-05-15 23:46:39 -0700944 union xnn_f32_chw_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700945 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
946 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700947 params.sse.min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700948 params.sse.max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700949 }
950
951 const uint32_t w4 = (width - 1) & 3;
952 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
953 params.sse.mask[1] = -(uint32_t) (w4 >= 1);
954 params.sse.mask[2] = -(uint32_t) (w4 >= 2);
955 params.sse.mask[3] = -(uint32_t) (w4 >= 3);
956
957 const uint32_t w8 = (width - 1) & 7;
958 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
959 params.sse.mask_even[1] = -(uint32_t) (w8 >= 2);
960 params.sse.mask_even[2] = -(uint32_t) (w8 >= 4);
961 params.sse.mask_even[3] = -(uint32_t) (w8 >= 6);
962 params.sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
963 params.sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
964 params.sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
965 params.sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
966 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700967 params.neon.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700968 params.neon.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700969
970 const uint32_t w4 = (width - 1) & 3;
971 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
972 params.neon.mask[1] = -(uint32_t) (w4 >= 1);
973 params.neon.mask[2] = -(uint32_t) (w4 >= 2);
974 params.neon.mask[3] = -(uint32_t) (w4 >= 3);
975
976 const uint32_t w8 = (width - 1) & 7;
977 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
978 params.neon.mask_even[1] = -(uint32_t) (w8 >= 2);
979 params.neon.mask_even[2] = -(uint32_t) (w8 >= 4);
980 params.neon.mask_even[3] = -(uint32_t) (w8 >= 6);
981 params.neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
982 params.neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
983 params.neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
984 params.neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
985 #else
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700986 params.scalar.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700987 params.scalar.max = output_max;
Erich Elsene6214af2020-06-10 22:17:22 -0700988
989 const uint32_t w4 = (width - 1) & 3;
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -0700990 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
991 params.scalar.mask[1] = -(uint32_t) (w4 >= 1);
992 params.scalar.mask[2] = -(uint32_t) (w4 >= 2);
993 params.scalar.mask[3] = -(uint32_t) (w4 >= 3);
Erich Elsenfd7a6e32020-06-11 12:04:44 -0700994
995 const uint32_t w8 = (width - 1) & 7;
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -0700996 params.scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
997 params.scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
998 params.scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
999 params.scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1000 params.scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1001 params.scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1002 params.scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1003 params.scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001004 #endif
1005 return params;
1006}
1007
Marat Dukhan1f29b802020-05-15 23:46:39 -07001008static inline void xnn_update_f32_chw_params(
1009 union xnn_f32_chw_params* params,
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001010 uint32_t width)
1011{
1012 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1013 const uint32_t w4 = (width - 1) & 3;
1014 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1015 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
1016 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
1017 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
1018
1019 const uint32_t w8 = (width - 1) & 7;
1020 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
1021 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
1022 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
1023 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
1024 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
1025 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
1026 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
1027 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
1028 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1029 const uint32_t w4 = (width - 1) & 3;
1030 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1031 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
1032 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
1033 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
1034
1035 const uint32_t w8 = (width - 1) & 7;
1036 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1037 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1038 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1039 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1040 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1041 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1042 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1043 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -07001044 #else
1045 const uint32_t w4 = (width - 1) & 3;
1046 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1047 params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1048 params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1049 params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1050
1051 const uint32_t w8 = (width - 1) & 7;
1052 params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1053 params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1054 params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1055 params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1056 params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1057 params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1058 params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1059 params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001060 #endif
1061}
1062
Marat Dukhan1f29b802020-05-15 23:46:39 -07001063static inline union xnn_f32_chw_params xnn_init_scalar_f32_chw_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001064 uint32_t width,
1065 float output_min,
1066 float output_max)
1067{
Marat Dukhan1f29b802020-05-15 23:46:39 -07001068 union xnn_f32_chw_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001069 params.scalar.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001070 params.scalar.max = output_max;
Erich Elsene6214af2020-06-10 22:17:22 -07001071
1072 const uint32_t w4 = (width - 1) & 3;
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -07001073 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1074 params.scalar.mask[1] = -(uint32_t) (w4 >= 1);
1075 params.scalar.mask[2] = -(uint32_t) (w4 >= 2);
1076 params.scalar.mask[3] = -(uint32_t) (w4 >= 3);
Erich Elsenfd7a6e32020-06-11 12:04:44 -07001077
1078 const uint32_t w8 = (width - 1) & 7;
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -07001079 params.scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1080 params.scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1081 params.scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1082 params.scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1083 params.scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1084 params.scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1085 params.scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1086 params.scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
Erich Elsenfd7a6e32020-06-11 12:04:44 -07001087
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001088 return params;
1089}
1090
Marat Dukhaneb09a6b2020-04-08 17:34:32 -07001091static inline union xnn_u8_minmax_params xnn_init_u8_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001092 uint8_t output_min,
1093 uint8_t output_max)
1094{
1095 assert(output_min < output_max);
1096
Marat Dukhaneb09a6b2020-04-08 17:34:32 -07001097 union xnn_u8_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001098 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1099 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001100 params.sse2.min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001101 params.sse2.max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001102 }
1103 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001104 params.neon.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001105 params.neon.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001106 #else
1107 params.scalar.min = (int32_t) (uint32_t) output_min;
1108 params.scalar.max = (int32_t) (uint32_t) output_max;
1109 #endif
1110 return params;
1111}
1112
Marat Dukhaneb09a6b2020-04-08 17:34:32 -07001113static inline union xnn_u8_minmax_params xnn_init_scalar_u8_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001114 uint8_t output_min,
1115 uint8_t output_max)
1116{
1117 assert(output_min < output_max);
1118
Marat Dukhaneb09a6b2020-04-08 17:34:32 -07001119 union xnn_u8_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001120 params.scalar.min = (int32_t) (uint32_t) output_min;
1121 params.scalar.max = (int32_t) (uint32_t) output_max;
1122 return params;
1123}
1124
Marat Dukhan08b7a972020-07-14 18:17:29 -07001125static inline union xnn_qu8_add_params xnn_init_qu8_add_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001126 uint8_t a_zero_point,
1127 uint8_t b_zero_point,
1128 uint8_t output_zero_point,
1129 float a_output_scale,
1130 float b_output_scale,
1131 uint8_t output_min,
1132 uint8_t output_max)
1133{
1134 assert(a_output_scale >= 0x1.0p-14f);
1135 assert(b_output_scale >= 0x1.0p-14f);
1136 assert(a_output_scale < 0x1.0p+8f);
1137 assert(b_output_scale < 0x1.0p+8f);
1138
1139 // Compute requantization parameters.
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001140 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001141 assert(max_output_scale >= 0x1.0p-14f);
1142 assert(max_output_scale < 0x1.0p+8f);
1143 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1144 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1145 // Shift is in [13, 31] range.
1146 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1147 assert(shift < 32);
1148 assert(shift >= 13);
1149
1150 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1151
1152 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
Marat Dukhanef3e7dc2020-04-13 01:19:56 -07001153 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(a_output_scale * scale_multiplier);
1154 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(b_output_scale * scale_multiplier);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001155 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1156 assert(a_multiplier < UINT32_C(0x00400000));
1157 assert(b_multiplier < UINT32_C(0x00400000));
1158
Marat Dukhan08b7a972020-07-14 18:17:29 -07001159 union xnn_qu8_add_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001160 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1161 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1162 const uint32_t remainder_threshold = remainder_mask >> 1;
1163 const int32_t zero_point_product =
1164 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1165 for (uint32_t i = 0; i < 4; i++) {
1166 params.sse2.zero_point_product[i] = zero_point_product;
1167 }
1168 for (uint32_t i = 0; i < 8; i++) {
1169 params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
1170 }
1171 for (uint32_t i = 0; i < 8; i++) {
1172 params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
1173 params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
1174 params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
1175 params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
1176 }
1177 params.sse2.a_multiplier = a_multiplier;
1178 params.sse2.b_multiplier = b_multiplier;
1179 for (uint32_t i = 0; i < 4; i++) {
1180 params.sse2.remainder_mask[i] = remainder_mask;
1181 params.sse2.remainder_threshold[i] = remainder_threshold;
1182 }
1183 params.sse2.shift = shift;
1184 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001185 params.sse2.y_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001186 params.sse2.y_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001187 }
1188 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1189 params.neon.a_zero_point = a_zero_point;
1190 params.neon.b_zero_point = b_zero_point;
1191 params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
1192 params.neon.a_multiplier = (int32_t) a_multiplier;
1193 params.neon.b_multiplier = (int32_t) b_multiplier;
1194 params.neon.right_shift = (int32_t) -shift;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001195 params.neon.y_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001196 params.neon.y_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001197 #else
1198 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1199 const uint32_t remainder_threshold = remainder_mask >> 1;
1200 params.scalar.zero_point_product =
1201 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1202 params.scalar.a_multiplier = a_multiplier;
1203 params.scalar.b_multiplier = b_multiplier;
1204 params.scalar.remainder_mask = (int32_t) remainder_mask;
1205 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1206 params.scalar.shift = shift;
1207 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001208 params.scalar.y_min = (int32_t) (uint32_t) output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001209 params.scalar.y_max = (int32_t) (uint32_t) output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001210 #endif
1211 return params;
1212}
1213
Marat Dukhan08b7a972020-07-14 18:17:29 -07001214static inline union xnn_qu8_add_params xnn_init_scalar_qu8_add_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001215 uint8_t a_zero_point,
1216 uint8_t b_zero_point,
1217 uint8_t output_zero_point,
1218 float a_output_scale,
1219 float b_output_scale,
1220 uint8_t output_min,
1221 uint8_t output_max)
1222{
1223 assert(a_output_scale >= 0x1.0p-10f);
1224 assert(b_output_scale >= 0x1.0p-10f);
1225 assert(a_output_scale < 0x1.0p+8f);
1226 assert(b_output_scale < 0x1.0p+8f);
1227
1228 // Compute requantization parameters.
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001229 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001230 assert(max_output_scale >= 0x1.0p-10f);
1231 assert(max_output_scale < 0x1.0p+8f);
1232 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1233 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1234 // Shift is in [13, 31] range.
1235 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1236 assert(shift < 32);
1237 assert(shift >= 13);
1238
1239 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
Marat Dukhanef3e7dc2020-04-13 01:19:56 -07001240 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1241 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001242 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1243 assert(a_multiplier < UINT32_C(0x00400000));
1244 assert(b_multiplier < UINT32_C(0x00400000));
1245
Marat Dukhan08b7a972020-07-14 18:17:29 -07001246 union xnn_qu8_add_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001247 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1248 const uint32_t remainder_threshold = remainder_mask >> 1;
1249 params.scalar.zero_point_product =
1250 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1251 params.scalar.a_multiplier = a_multiplier;
1252 params.scalar.b_multiplier = b_multiplier;
1253 params.scalar.remainder_mask = (int32_t) remainder_mask;
1254 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1255 params.scalar.shift = shift;
1256 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001257 params.scalar.y_min = (int32_t) (uint32_t) output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001258 params.scalar.y_max = (int32_t) (uint32_t) output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001259 return params;
1260}
1261
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001262static inline union xnn_qs8_add_params xnn_init_qs8_add_params(
1263 int8_t x_zero_point,
1264 int8_t y_zero_point,
1265 int8_t output_zero_point,
1266 float x_output_scale,
1267 float y_output_scale,
1268 int8_t output_min,
1269 int8_t output_max)
1270{
1271 assert(x_output_scale >= 0x1.0p-14f);
1272 assert(y_output_scale >= 0x1.0p-14f);
1273 assert(x_output_scale < 0x1.0p+8f);
1274 assert(y_output_scale < 0x1.0p+8f);
1275
1276 // Compute requantization parameters.
1277 const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1278 assert(max_output_scale >= 0x1.0p-14f);
1279 assert(max_output_scale < 0x1.0p+8f);
1280 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1281 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1282 // Shift is in [13, 31] range.
1283 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1284 assert(shift < 32);
1285 assert(shift >= 13);
1286
1287 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1288
1289 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1290 const int32_t x_multiplier = (int32_t) lrintf(x_output_scale * scale_multiplier);
1291 const int32_t y_multiplier = (int32_t) lrintf(y_output_scale * scale_multiplier);
1292 assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1293 assert(x_multiplier < INT32_C(0x00400000));
1294 assert(y_multiplier < INT32_C(0x00400000));
1295
1296 union xnn_qs8_add_params params;
1297 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1298 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1299 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1300 const int32_t zero_point_product =
1301 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1302 for (uint32_t i = 0; i < 4; i++) {
1303 params.sse2.zero_point_product[i] = zero_point_product;
1304 }
1305 const uint16_t x_multiplier_lo = (uint16_t) x_multiplier;
1306 const uint16_t x_multiplier_hi = (uint16_t) ((uint32_t) x_multiplier >> 16);
1307 const uint16_t y_multiplier_lo = (uint16_t) y_multiplier;
1308 const uint16_t y_multiplier_hi = (uint16_t) ((uint32_t) y_multiplier >> 16);
1309 for (uint32_t i = 0; i < 8; i++) {
1310 params.sse2.x_multiplier_lo[i] = x_multiplier_lo;
1311 params.sse2.x_multiplier_hi[i] = x_multiplier_hi;
1312 params.sse2.y_multiplier_lo[i] = y_multiplier_lo;
1313 params.sse2.y_multiplier_hi[i] = y_multiplier_hi;
1314 }
1315 params.sse2.shift = shift;
1316 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhanbb9225e2020-09-06 22:40:56 -07001317 params.sse2.x_multiplier[i] = x_multiplier;
1318 params.sse2.y_multiplier[i] = y_multiplier;
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001319 params.sse2.remainder_mask[i] = remainder_mask;
1320 params.sse2.remainder_threshold[i] = remainder_threshold;
1321 }
1322 for (uint32_t i = 0; i < 8; i++) {
1323 params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
1324 params.sse2.output_min[i] = (int16_t) output_min;
1325 params.sse2.output_max[i] = (int16_t) output_max;
1326 }
1327 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1328 params.neon.x_zero_point = x_zero_point;
1329 params.neon.y_zero_point = y_zero_point;
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001330 params.neon.x_multiplier = (int32_t) x_multiplier;
1331 params.neon.y_multiplier = (int32_t) y_multiplier;
1332 params.neon.right_shift = (int32_t) -shift;
Marat Dukhanba7b2792020-09-02 14:26:45 -07001333 params.neon.output_zero_point = (int16_t) output_zero_point;
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001334 params.neon.output_min = output_min;
1335 params.neon.output_max = output_max;
Marat Dukhan5df27f82020-09-02 23:59:21 -07001336 #elif XNN_ARCH_WASMSIMD
1337 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1338 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1339 const int32_t zero_point_product =
1340 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1341 for (uint32_t i = 0; i < 4; i++) {
1342 params.wasmsimd.zero_point_product[i] = zero_point_product;
1343 params.wasmsimd.x_multiplier[i] = x_multiplier;
1344 params.wasmsimd.y_multiplier[i] = y_multiplier;
1345 params.wasmsimd.remainder_mask[i] = remainder_mask;
1346 params.wasmsimd.remainder_threshold[i] = remainder_threshold;
1347 }
1348 params.wasmsimd.shift = shift;
1349 for (uint32_t i = 0; i < 8; i++) {
1350 params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
1351 }
1352 for (uint32_t i = 0; i < 16; i++) {
1353 params.wasmsimd.output_min[i] = output_min;
1354 params.wasmsimd.output_max[i] = output_max;
1355 }
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001356 #else
1357 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1358 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1359 params.scalar.zero_point_product =
1360 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1361 params.scalar.x_multiplier = x_multiplier;
1362 params.scalar.y_multiplier = y_multiplier;
1363 params.scalar.remainder_mask = (int32_t) remainder_mask;
1364 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
Marat Dukhan5df27f82020-09-02 23:59:21 -07001365 params.scalar.shift = (int32_t) shift;
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001366 params.scalar.output_zero_point = (int32_t) output_zero_point;
1367 params.scalar.output_min = (int32_t) output_min;
1368 params.scalar.output_max = (int32_t) output_max;
1369 #endif
1370 return params;
1371}
1372
1373static inline union xnn_qs8_add_params xnn_init_scalar_qs8_add_params(
1374 int8_t x_zero_point,
1375 int8_t y_zero_point,
1376 int8_t output_zero_point,
1377 float x_output_scale,
1378 float y_output_scale,
1379 int8_t output_min,
1380 int8_t output_max)
1381{
1382 assert(x_output_scale >= 0x1.0p-10f);
1383 assert(y_output_scale >= 0x1.0p-10f);
1384 assert(x_output_scale < 0x1.0p+8f);
1385 assert(y_output_scale < 0x1.0p+8f);
1386
1387 // Compute requantization parameters.
1388 const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1389 assert(max_output_scale >= 0x1.0p-10f);
1390 assert(max_output_scale < 0x1.0p+8f);
1391 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1392 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1393 // Shift is in [13, 31] range.
1394 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1395 assert(shift < 32);
1396 assert(shift >= 13);
1397
1398 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1399 const int32_t x_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(x_output_scale) + (shift << 23)));
1400 const int32_t y_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(y_output_scale) + (shift << 23)));
1401 assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1402 assert(x_multiplier < INT32_C(0x00400000));
1403 assert(y_multiplier < INT32_C(0x00400000));
1404
1405 union xnn_qs8_add_params params;
1406 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1407 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1408 params.scalar.zero_point_product =
1409 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1410 params.scalar.x_multiplier = x_multiplier;
1411 params.scalar.y_multiplier = y_multiplier;
1412 params.scalar.remainder_mask = (int32_t) remainder_mask;
1413 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1414 params.scalar.shift = shift;
1415 params.scalar.output_zero_point = (int32_t) output_zero_point;
1416 params.scalar.output_min = (int32_t) output_min;
1417 params.scalar.output_max = (int32_t) output_max;
1418 return params;
1419}
1420
Marat Dukhanec88e272020-07-30 15:02:09 -07001421static inline union xnn_qu8_requantization_params xnn_init_scalar_qu8_requantization_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001422 float scale,
1423 uint8_t zero_point,
1424 uint8_t min,
1425 uint8_t max)
1426{
1427 // Compute requantization parameters.
1428 assert(scale < 1.0f);
1429 assert(scale >= 0x1.0p-32f);
1430 const uint32_t scale_bits = fp32_to_bits(scale);
1431
1432 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
1433 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1434 assert(multiplier >= INT32_C(0x40000000));
1435 assert(multiplier <= INT32_C(0x7FFFFF80));
1436
1437 // Shift is in [0, 31] range.
1438 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1439 assert(shift >= 0);
1440 assert(shift < 32);
1441
Marat Dukhanec88e272020-07-30 15:02:09 -07001442 union xnn_qu8_requantization_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001443 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1444 const uint32_t remainder_threshold = remainder_mask >> 1;
Marat Dukhanec88e272020-07-30 15:02:09 -07001445 params.q31.multiplier = multiplier;
1446 params.q31.remainder_mask = (int32_t) remainder_mask;
1447 params.q31.remainder_threshold = (int32_t) remainder_threshold;
1448 params.q31.shift = (uint32_t) shift;
1449 params.q31.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
1450 params.q31.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
1451 params.q31.zero_point = (int32_t) (uint32_t) zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001452 return params;
1453}
1454
Marat Dukhan595e1702020-07-31 10:12:52 -07001455static inline union xnn_qs8_requantization_params xnn_init_scalar_qs8_requantization_params(
1456 float scale,
1457 int8_t zero_point,
1458 int8_t min,
1459 int8_t max)
1460{
1461 // Compute requantization parameters.
1462 assert(scale < 1.0f);
1463 assert(scale >= 0x1.0p-32f);
1464 const uint32_t scale_bits = fp32_to_bits(scale);
1465
1466 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
1467 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1468 assert(multiplier >= INT32_C(0x40000000));
1469 assert(multiplier <= INT32_C(0x7FFFFF80));
1470
1471 // Shift is in [0, 31] range.
1472 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1473 assert(shift >= 0);
1474 assert(shift < 32);
1475
1476 union xnn_qs8_requantization_params params;
1477 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1478 const uint32_t remainder_threshold = remainder_mask >> 1;
1479 params.q31.multiplier = multiplier;
1480 params.q31.remainder_mask = (int32_t) remainder_mask;
1481 params.q31.remainder_threshold = (int32_t) remainder_threshold;
1482 params.q31.shift = (uint32_t) shift;
1483 params.q31.min_less_zero_point = (int32_t) min - (int32_t) zero_point;
1484 params.q31.max_less_zero_point = (int32_t) max - (int32_t) zero_point;
1485 params.q31.zero_point = (int32_t) zero_point;
1486 return params;
1487}
1488