blob: e34bcdc07745b25883da52c9c4e0da3bf32c2852 [file] [log] [blame]
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#if defined(__cplusplus) && (__cplusplus >= 201103L)
12 #include <cstdint>
13 #include <cstddef>
14 #include <cassert>
15 #include <cmath>
16#else
17 #include <stdint.h>
18 #include <stddef.h>
19 #include <assert.h>
20 #include <math.h>
21#endif
22
23#include <fp16.h>
24
25#include <xnnpack/common.h>
Marat Dukhan2b9efd82020-06-08 01:09:31 -070026#include <xnnpack/math.h>
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070027#include <xnnpack/params.h>
28
29
Marat Dukhan08b7a972020-07-14 18:17:29 -070030static inline union xnn_qu8_gemm_params xnn_init_scalar_qu8_gemm_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070031 uint8_t kernel_zero_point,
32 float scale,
33 uint8_t output_zero_point,
34 uint8_t output_min,
35 uint8_t output_max)
36{
37 // Compute requantization parameters
38 const uint32_t scale_bits = fp32_to_bits(scale);
39
40 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
41 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
42 assert(multiplier >= INT32_C(0x40000000));
43 assert(multiplier <= INT32_C(0x7FFFFF80));
44
45 // Shift is in [0, 31] range.
46 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
47 assert(shift >= 0);
48 assert(shift < 32);
49
50 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
51 const uint32_t remainder_threshold = remainder_mask >> 1;
52
Marat Dukhan08b7a972020-07-14 18:17:29 -070053 union xnn_qu8_gemm_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070054 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
55 params.scalar.multiplier = multiplier;
56 params.scalar.remainder_mask = (int32_t) remainder_mask;
57 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
58 params.scalar.shift = (uint32_t) shift;
59 params.scalar.output_min_less_zero_point =
60 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
61 params.scalar.output_max_less_zero_point =
62 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
63 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
64 return params;
65}
66
Marat Dukhan08b7a972020-07-14 18:17:29 -070067static inline union xnn_qu8_gemm_params xnn_init_qu8_gemm_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070068 uint8_t kernel_zero_point,
69 float scale,
70 uint8_t output_zero_point,
71 uint8_t output_min,
72 uint8_t output_max)
73{
74 // Compute requantization parameters.
75 const uint32_t scale_bits = fp32_to_bits(scale);
76
77 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
78 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
79 assert(multiplier >= INT32_C(0x40000000));
80 assert(multiplier <= INT32_C(0x7FFFFF80));
81
82 // Shift is in [0, 31] range.
83 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
84 assert(shift >= 0);
85 assert(shift < 32);
86
Marat Dukhan08b7a972020-07-14 18:17:29 -070087 union xnn_qu8_gemm_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070088 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
89 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
90 const uint32_t remainder_threshold = remainder_mask >> 1;
91 for (uint32_t i = 0; i < 8; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070092 params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
93 }
94 params.sse2.multiplier[0] = multiplier;
95 params.sse2.multiplier[1] = multiplier;
96 params.sse2.multiplier[2] = multiplier;
97 params.sse2.multiplier[3] = multiplier;
98 params.sse2.rounding[0] = UINT64_C(0x40000000);
99 params.sse2.rounding[1] = UINT64_C(0x40000000);
100 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
101 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
102 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
103 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
104 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
105 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
106 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
107 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
108 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
109 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
110 for (uint32_t i = 0; i < 8; i++) {
111 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
112 }
113 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700114 params.sse2.output_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700115 params.sse2.output_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700116 }
117 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhan91992462020-07-30 00:06:34 -0700118 params.neon.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700119 params.neon.multiplier = multiplier;
120 params.neon.right_shift = -shift;
121 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700122 params.neon.output_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700123 params.neon.output_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700124 #else
125 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
126 const uint32_t remainder_threshold = remainder_mask >> 1;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700127 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
128 params.scalar.multiplier = multiplier;
129 params.scalar.remainder_mask = (int32_t) remainder_mask;
130 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
131 params.scalar.shift = (uint32_t) shift;
132 params.scalar.output_min_less_zero_point =
133 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
134 params.scalar.output_max_less_zero_point =
135 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
136 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
137 #endif
138 return params;
139}
140
Marat Dukhan595e1702020-07-31 10:12:52 -0700141static inline union xnn_qs8_gemm_params xnn_init_scalar_qs8_gemm_params(
142 float scale,
143 int8_t output_zero_point,
144 int8_t output_min,
145 int8_t output_max)
146{
147 // Compute requantization parameters
148 const uint32_t scale_bits = fp32_to_bits(scale);
149
150 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
151 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
152 assert(multiplier >= INT32_C(0x40000000));
153 assert(multiplier <= INT32_C(0x7FFFFF80));
154
155 // Shift is in [0, 31] range.
156 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
157 assert(shift >= 0);
158 assert(shift < 32);
159
160 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
161 const uint32_t remainder_threshold = remainder_mask >> 1;
162
163 union xnn_qs8_gemm_params params;
164 params.scalar.multiplier = multiplier;
165 params.scalar.remainder_mask = (int32_t) remainder_mask;
166 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
167 params.scalar.shift = (uint32_t) shift;
168 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
169 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
170 params.scalar.output_zero_point = (int32_t) output_zero_point;
171 return params;
172}
173
174static inline union xnn_qs8_gemm_params xnn_init_qs8_gemm_params(
175 float scale,
176 int8_t output_zero_point,
177 int8_t output_min,
178 int8_t output_max)
179{
180 // Compute requantization parameters.
181 const uint32_t scale_bits = fp32_to_bits(scale);
182
183 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
184 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
185 assert(multiplier >= INT32_C(0x40000000));
186 assert(multiplier <= INT32_C(0x7FFFFF80));
187
188 // Shift is in [0, 31] range.
189 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
190 assert(shift >= 0);
191 assert(shift < 32);
192
193 union xnn_qs8_gemm_params params;
194 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
195 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
196 const uint32_t remainder_threshold = remainder_mask >> 1;
197 params.sse2.multiplier[0] = multiplier;
198 params.sse2.multiplier[1] = multiplier;
199 params.sse2.multiplier[2] = multiplier;
200 params.sse2.multiplier[3] = multiplier;
201 params.sse2.rounding[0] = UINT64_C(0x40000000);
202 params.sse2.rounding[1] = UINT64_C(0x40000000);
203 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
204 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
205 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
206 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
207 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
208 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
209 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
210 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
211 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
212 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
213 for (uint32_t i = 0; i < 8; i++) {
214 params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
215 params.sse2.output_min[i] = (int16_t) output_min;
216 params.sse2.output_max[i] = (int16_t) output_max;
217 }
218 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
219 params.neon.multiplier = multiplier;
220 params.neon.right_shift = -shift;
221 params.neon.output_zero_point = (int16_t) output_zero_point;
222 params.neon.output_min = output_min;
223 params.neon.output_max = output_max;
Marat Dukhan27203da2020-08-05 15:19:03 -0700224 #elif XNN_ARCH_WASMSIMD
225 const int64_t twice_multiplier = INT64_C(2) * (int64_t) multiplier;
226 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
227 const uint32_t remainder_threshold = remainder_mask >> 1;
228 params.wasmsimd.multiplier[0] = twice_multiplier;
229 params.wasmsimd.multiplier[1] = twice_multiplier;
230 params.wasmsimd.rounding[0] = INT64_C(0x80000000);
231 params.wasmsimd.rounding[1] = INT64_C(0x80000000);
232 params.wasmsimd.remainder_mask[0] = (int32_t) remainder_mask;
233 params.wasmsimd.remainder_mask[1] = (int32_t) remainder_mask;
234 params.wasmsimd.remainder_mask[2] = (int32_t) remainder_mask;
235 params.wasmsimd.remainder_mask[3] = (int32_t) remainder_mask;
236 params.wasmsimd.remainder_threshold[0] = (int32_t) remainder_threshold;
237 params.wasmsimd.remainder_threshold[1] = (int32_t) remainder_threshold;
238 params.wasmsimd.remainder_threshold[2] = (int32_t) remainder_threshold;
239 params.wasmsimd.remainder_threshold[3] = (int32_t) remainder_threshold;
240 params.wasmsimd.shift = shift;
241 for (uint32_t i = 0; i < 8; i++) {
242 params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
243 }
244 for (uint32_t i = 0; i < 16; i++) {
245 params.wasmsimd.output_min[i] = output_min;
246 params.wasmsimd.output_max[i] = output_max;
247 }
Marat Dukhan595e1702020-07-31 10:12:52 -0700248 #else
249 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
250 const uint32_t remainder_threshold = remainder_mask >> 1;
251 params.scalar.multiplier = multiplier;
252 params.scalar.remainder_mask = (int32_t) remainder_mask;
253 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
254 params.scalar.shift = (uint32_t) shift;
255 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
256 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
257 params.scalar.output_zero_point = (int32_t) output_zero_point;
258 #endif
259 return params;
260}
261
Marat Dukhan683fab32020-08-03 19:42:52 -0700262static inline union xnn_qs8_gemm_xw_params xnn_init_scalar_qs8_gemm_xw_params(
263 float scale,
264 int8_t output_zero_point,
265 int8_t output_min,
266 int8_t output_max)
267{
268 union {
269 union xnn_qs8_gemm_xw_params gemm_xw;
270 union xnn_qs8_gemm_params gemm;
271 } params;
272 params.gemm = xnn_init_scalar_qs8_gemm_params(scale, output_zero_point, output_min, output_max);
273 return params.gemm_xw;
274}
275
276static inline union xnn_qs8_gemm_xw_params xnn_init_qs8_gemm_xw_params(
277 float scale,
278 int8_t output_zero_point,
279 int8_t output_min,
280 int8_t output_max)
281{
282 union {
283 union xnn_qs8_gemm_xw_params gemm_xw;
284 union xnn_qs8_gemm_params gemm;
285 } params;
286 params.gemm = xnn_init_qs8_gemm_params(scale, output_zero_point, output_min, output_max);
287 return params.gemm_xw;
288}
289
Marat Dukhan08b7a972020-07-14 18:17:29 -0700290static inline union xnn_qu8_avgpool_params xnn_init_qu8_avgpool_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700291 int32_t bias,
292 float scale,
293 uint8_t output_zero_point,
294 uint8_t output_min,
295 uint8_t output_max)
296{
297 // Compute requantization parameters.
298 assert(scale >= 0x1.0p-32f);
299 assert(scale < 256.0f);
300 const uint32_t scale_bits = fp32_to_bits(scale);
301
302 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
303 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
304 assert(multiplier >= INT32_C(0x00800000));
305 assert(multiplier <= INT32_C(0x00FFFFFF));
306
307 // Shift is in [16, 55] range.
308 const int32_t shift = 127 + 23 - (scale_bits >> 23);
309 assert(shift >= 16);
310 assert(shift < 64);
311
Marat Dukhan08b7a972020-07-14 18:17:29 -0700312 union xnn_qu8_avgpool_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700313 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
314 const uint32_t right_shift = (uint32_t) shift;
315 const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
316 params.sse2.bias[0] = bias;
317 params.sse2.bias[1] = bias;
318 params.sse2.bias[2] = bias;
319 params.sse2.bias[3] = bias;
320 params.sse2.multiplier[0] = (uint32_t) multiplier;
321 params.sse2.multiplier[1] = (uint32_t) multiplier;
322 params.sse2.multiplier[2] = (uint32_t) multiplier;
323 params.sse2.multiplier[3] = (uint32_t) multiplier;
324 params.sse2.rounding[0] = rounding;
325 params.sse2.rounding[1] = rounding;
326 params.sse2.right_shift[0] = (uint64_t) right_shift;
327 params.sse2.right_shift[1] = (uint64_t) right_shift;
328 for (uint32_t i = 0; i < 8; i++) {
329 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
330 }
331 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700332 params.sse2.output_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700333 params.sse2.output_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700334 }
335 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
336 params.neon.bias = bias;
337 params.neon.multiplier = multiplier;
338 params.neon.left_shift = (int64_t) -shift;
339 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700340 params.neon.output_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700341 params.neon.output_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700342 #else
343 const uint32_t right_shift = (uint32_t) shift;
344 const int64_t rounding = INT64_C(1) << (right_shift - 1);
345 params.scalar.bias = bias;
346 params.scalar.multiplier = multiplier;
347 params.scalar.rounding = rounding;
348 params.scalar.right_shift = right_shift;
349 params.scalar.output_min_less_zero_point =
350 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
351 params.scalar.output_max_less_zero_point =
352 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
353 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
354 #endif
355 return params;
356}
357
Marat Dukhan08b7a972020-07-14 18:17:29 -0700358static inline union xnn_qu8_avgpool_params xnn_init_scalar_qu8_avgpool_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700359 int32_t bias,
360 float scale,
361 uint8_t output_zero_point,
362 uint8_t output_min,
363 uint8_t output_max)
364{
365 // Compute requantization parameters.
366 assert(scale >= 0x1.0p-32f);
367 assert(scale < 256.0f);
368 const uint32_t scale_bits = fp32_to_bits(scale);
369
370 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
371 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
372 assert(multiplier >= INT32_C(0x00800000));
373 assert(multiplier <= INT32_C(0x00FFFFFF));
374
375 // Shift is in [16, 55] range.
376 const int32_t shift = 127 + 23 - (scale_bits >> 23);
377 assert(shift >= 16);
378 assert(shift < 64);
379
Marat Dukhan08b7a972020-07-14 18:17:29 -0700380 union xnn_qu8_avgpool_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700381 const uint32_t right_shift = (uint32_t) shift;
382 const int64_t rounding = INT64_C(1) << (right_shift - 1);
383 params.scalar.bias = bias;
384 params.scalar.rounding = rounding;
385 params.scalar.multiplier = multiplier;
386 params.scalar.right_shift = right_shift;
387 params.scalar.output_min_less_zero_point =
388 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
389 params.scalar.output_max_less_zero_point =
390 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
391 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
392 return params;
393}
394
Marat Dukhan54e95a02020-08-06 23:55:13 -0700395static inline void xnn_update_qu8_avgpool_params(
396 union xnn_qu8_avgpool_params* params,
397 int32_t bias,
398 float scale)
399{
400 // Compute requantization parameters.
401 assert(scale >= 0x1.0p-32f);
402 assert(scale < 256.0f);
403 const uint32_t scale_bits = fp32_to_bits(scale);
404
405 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
406 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
407 assert(multiplier >= INT32_C(0x00800000));
408 assert(multiplier <= INT32_C(0x00FFFFFF));
409
410 // Shift is in [16, 55] range.
411 const int32_t shift = 127 + 23 - (scale_bits >> 23);
412 assert(shift >= 16);
413 assert(shift < 64);
414
415 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
416 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
417 params->sse2.bias[0] = bias;
418 params->sse2.bias[1] = bias;
419 params->sse2.bias[2] = bias;
420 params->sse2.bias[3] = bias;
421 params->sse2.multiplier[0] = (uint32_t) multiplier;
422 params->sse2.multiplier[1] = (uint32_t) multiplier;
423 params->sse2.multiplier[2] = (uint32_t) multiplier;
424 params->sse2.multiplier[3] = (uint32_t) multiplier;
425 params->sse2.rounding[0] = rounding;
426 params->sse2.rounding[1] = rounding;
427 params->sse2.right_shift[0] = (uint64_t) (uint32_t) shift;
428 params->sse2.right_shift[1] = (uint64_t) (uint32_t) shift;
429 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
430 params->neon.bias = bias;
431 params->neon.multiplier = multiplier;
432 params->neon.left_shift = (int64_t) -shift;
433 #else
434 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
435 params->scalar.bias = bias;
436 params->scalar.multiplier = multiplier;
437 params->scalar.rounding = rounding;
438 params->scalar.right_shift = (uint32_t) shift;
439 #endif
440}
441
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700442static inline union xnn_qs8_avgpool_params xnn_init_qs8_avgpool_params(
443 int32_t bias,
444 float scale,
445 int8_t output_zero_point,
446 int8_t output_min,
447 int8_t output_max)
448{
449 // Compute requantization parameters.
450 assert(scale >= 0x1.0p-32f);
451 assert(scale < 256.0f);
452 const uint32_t scale_bits = fp32_to_bits(scale);
453
454 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
455 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
456 assert(multiplier >= INT32_C(0x00800000));
457 assert(multiplier <= INT32_C(0x00FFFFFF));
458
459 // Shift is in [16, 55] range.
460 const int32_t shift = 127 + 23 - (scale_bits >> 23);
461 assert(shift >= 16);
462 assert(shift < 64);
463
464 union xnn_qs8_avgpool_params params;
465 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
Marat Dukhanef451802020-08-06 11:53:47 -0700466 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700467 params.sse2.bias[0] = bias;
468 params.sse2.bias[1] = bias;
469 params.sse2.bias[2] = bias;
470 params.sse2.bias[3] = bias;
471 params.sse2.multiplier[0] = (uint32_t) multiplier;
472 params.sse2.multiplier[1] = (uint32_t) multiplier;
473 params.sse2.multiplier[2] = (uint32_t) multiplier;
474 params.sse2.multiplier[3] = (uint32_t) multiplier;
475 params.sse2.rounding[0] = rounding;
476 params.sse2.rounding[1] = rounding;
Marat Dukhanef451802020-08-06 11:53:47 -0700477 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
478 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700479 for (uint32_t i = 0; i < 8; i++) {
480 params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
481 params.sse2.output_min[i] = (int16_t) output_min;
482 params.sse2.output_max[i] = (int16_t) output_max;
483 }
484 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
485 params.neon.bias = bias;
486 params.neon.multiplier = multiplier;
487 params.neon.left_shift = (int64_t) -shift;
488 params.neon.output_zero_point = (int16_t) output_zero_point;
489 params.neon.output_min = output_min;
490 params.neon.output_max = output_max;
Marat Dukhanef451802020-08-06 11:53:47 -0700491 #elif XNN_ARCH_WASMSIMD
492 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
493 params.wasmsimd.bias[0] = bias;
494 params.wasmsimd.bias[1] = bias;
495 params.wasmsimd.bias[2] = bias;
496 params.wasmsimd.bias[3] = bias;
497 params.wasmsimd.multiplier[0] = (int64_t) multiplier;
498 params.wasmsimd.multiplier[1] = (int64_t) multiplier;
499 params.wasmsimd.rounding[0] = rounding;
500 params.wasmsimd.rounding[1] = rounding;
501 params.wasmsimd.shift = shift;
502 for (uint32_t i = 0; i < 8; i++) {
503 params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
504 }
505 for (uint32_t i = 0; i < 16; i++) {
506 params.wasmsimd.output_min[i] = output_min;
507 params.wasmsimd.output_max[i] = output_max;
508 }
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700509 #else
Marat Dukhanef451802020-08-06 11:53:47 -0700510 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700511 params.scalar.bias = bias;
512 params.scalar.multiplier = multiplier;
513 params.scalar.rounding = rounding;
Marat Dukhanef451802020-08-06 11:53:47 -0700514 params.scalar.shift = (uint32_t) shift;
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700515 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
516 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
517 params.scalar.output_zero_point = (int32_t) output_zero_point;
518 #endif
519 return params;
520}
521
522static inline union xnn_qs8_avgpool_params xnn_init_scalar_qs8_avgpool_params(
523 int32_t bias,
524 float scale,
525 int8_t output_zero_point,
526 int8_t output_min,
527 int8_t output_max)
528{
529 // Compute requantization parameters.
530 assert(scale >= 0x1.0p-32f);
531 assert(scale < 256.0f);
532 const uint32_t scale_bits = fp32_to_bits(scale);
533
534 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
535 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
536 assert(multiplier >= INT32_C(0x00800000));
537 assert(multiplier <= INT32_C(0x00FFFFFF));
538
539 // Shift is in [16, 55] range.
540 const int32_t shift = 127 + 23 - (scale_bits >> 23);
541 assert(shift >= 16);
542 assert(shift < 64);
543
544 union xnn_qs8_avgpool_params params;
Marat Dukhanef451802020-08-06 11:53:47 -0700545 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700546 params.scalar.bias = bias;
547 params.scalar.rounding = rounding;
548 params.scalar.multiplier = multiplier;
Marat Dukhanef451802020-08-06 11:53:47 -0700549 params.scalar.shift = shift;
Marat Dukhan4ed53f42020-08-06 01:12:55 -0700550 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
551 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
552 params.scalar.output_zero_point = (int32_t) output_zero_point;
553 return params;
554}
555
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700556static inline void xnn_update_qs8_avgpool_params(
557 union xnn_qs8_avgpool_params* params,
558 int32_t bias,
559 float scale)
560{
561 // Compute requantization parameters.
562 assert(scale >= 0x1.0p-32f);
563 assert(scale < 256.0f);
564 const uint32_t scale_bits = fp32_to_bits(scale);
565
566 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
567 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
568 assert(multiplier >= INT32_C(0x00800000));
569 assert(multiplier <= INT32_C(0x00FFFFFF));
570
571 // Shift is in [16, 55] range.
572 const int32_t shift = 127 + 23 - (scale_bits >> 23);
573 assert(shift >= 16);
574 assert(shift < 64);
575
576 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
577 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
578 params->sse2.bias[0] = bias;
579 params->sse2.bias[1] = bias;
580 params->sse2.bias[2] = bias;
581 params->sse2.bias[3] = bias;
582 params->sse2.multiplier[0] = (uint32_t) multiplier;
583 params->sse2.multiplier[1] = (uint32_t) multiplier;
584 params->sse2.multiplier[2] = (uint32_t) multiplier;
585 params->sse2.multiplier[3] = (uint32_t) multiplier;
586 params->sse2.rounding[0] = rounding;
587 params->sse2.rounding[1] = rounding;
588 params->sse2.shift[0] = (uint64_t) (uint32_t) shift;
589 params->sse2.shift[1] = (uint64_t) (uint32_t) shift;
590 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
591 params->neon.bias = bias;
592 params->neon.multiplier = multiplier;
593 params->neon.left_shift = (int64_t) -shift;
Marat Dukhan4de076d2020-08-08 23:29:52 -0700594 #elif XNN_ARCH_WASMSIMD
595 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
596 params->wasmsimd.bias[0] = bias;
597 params->wasmsimd.bias[1] = bias;
598 params->wasmsimd.bias[2] = bias;
599 params->wasmsimd.bias[3] = bias;
600 params->wasmsimd.multiplier[0] = (int64_t) multiplier;
601 params->wasmsimd.multiplier[1] = (int64_t) multiplier;
602 params->wasmsimd.rounding[0] = rounding;
603 params->wasmsimd.rounding[1] = rounding;
604 params->wasmsimd.shift = shift;
Marat Dukhan9e0b5392020-08-07 02:29:34 -0700605 #else
606 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
607 params->scalar.bias = bias;
608 params->scalar.multiplier = multiplier;
609 params->scalar.rounding = rounding;
610 params->scalar.shift = (uint32_t) shift;
611 #endif
612}
613
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700614static inline void xnn_update_f16_scaleminmax_params(
615 struct xnn_f16_scaleminmax_params* params,
616 uint16_t scale)
617{
618 params->scale = scale;
619}
620
Marat Dukhan8452ff52020-04-08 20:44:58 -0700621static inline void xnn_update_f32_scaleminmax_params(
622 union xnn_f32_scaleminmax_params* params,
623 float scale)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700624{
625 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
626 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhan8452ff52020-04-08 20:44:58 -0700627 params->sse2.scale[i] = scale;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700628 }
629 #else
Marat Dukhan8452ff52020-04-08 20:44:58 -0700630 params->scalar.scale = scale;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700631 #endif
632}
633
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700634static inline struct xnn_f16_scaleminmax_params xnn_init_f16_scaleminmax_params(
635 uint16_t scale,
636 uint16_t min,
637 uint16_t max)
638{
639 struct xnn_f16_scaleminmax_params params;
640 params.scale = scale;
641 params.min = min;
642 params.max = max;
643 return params;
644}
645
Marat Dukhan8452ff52020-04-08 20:44:58 -0700646static inline union xnn_f32_scaleminmax_params xnn_init_f32_scaleminmax_params(
647 float scale,
648 float min,
649 float max)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700650{
Marat Dukhan8452ff52020-04-08 20:44:58 -0700651 union xnn_f32_scaleminmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700652 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
653 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhan8452ff52020-04-08 20:44:58 -0700654 params.sse2.scale[i] = scale;
655 params.sse2.min[i] = min;
656 params.sse2.max[i] = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700657 }
658 #else
Marat Dukhan8452ff52020-04-08 20:44:58 -0700659 params.scalar.scale = scale;
660 params.scalar.min = min;
661 params.scalar.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700662 #endif
663 return params;
664}
665
666static inline union xnn_f32_gavgpool_params xnn_init_f32_gavgpool_params(
667 float multiplier,
668 float output_min,
669 float output_max,
670 uint32_t width)
671{
672 union xnn_f32_gavgpool_params params;
673 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
674 for (uint32_t i = 0; i < 4; i++) {
675 params.sse.multiplier[i] = multiplier;
676 params.sse.output_min[i] = output_min;
677 params.sse.output_max[i] = output_max;
678 }
679
680 const uint32_t w = (width - 1) & 3;
681 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
682 params.sse.mask[1] = -(uint32_t) (w >= 1);
683 params.sse.mask[2] = -(uint32_t) (w >= 2);
684 params.sse.mask[3] = -(uint32_t) (w >= 3);
685 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
686 params.neon.multiplier = multiplier;
687 params.neon.output_min = output_min;
688 params.neon.output_max = output_max;
689
690 const uint32_t w = (width - 1) & 3;
691 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
692 params.neon.mask[1] = -(uint32_t) (w >= 1);
693 params.neon.mask[2] = -(uint32_t) (w >= 2);
694 params.neon.mask[3] = -(uint32_t) (w >= 3);
695 #else
696 params.scalar.multiplier = multiplier;
697 params.scalar.output_min = output_min;
698 params.scalar.output_max = output_max;
Erich Elsen6f278b52020-06-10 16:13:11 -0700699
700 const uint32_t w = (width - 1) & 3;
701 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
702 params.scalar.mask[1] = -(int32_t) (w >= 1);
703 params.scalar.mask[2] = -(int32_t) (w >= 2);
704 params.scalar.mask[3] = -(int32_t) (w >= 3);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700705 #endif
706 return params;
707}
708
709static inline void xnn_update_f32_gavgpool_params(
710 union xnn_f32_gavgpool_params* params,
711 float multiplier,
712 uint32_t width)
713{
714 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
715 for (uint32_t i = 0; i < 4; i++) {
716 params->sse.multiplier[i] = multiplier;
717 }
718
719 const uint32_t w = (width - 1) & 3;
720 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
721 params->sse.mask[1] = -(uint32_t) (w >= 1);
722 params->sse.mask[2] = -(uint32_t) (w >= 2);
723 params->sse.mask[3] = -(uint32_t) (w >= 3);
724 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
725 params->neon.multiplier = multiplier;
726
727 const uint32_t w = (width - 1) & 3;
728 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
729 params->neon.mask[1] = -(uint32_t) (w >= 1);
730 params->neon.mask[2] = -(uint32_t) (w >= 2);
731 params->neon.mask[3] = -(uint32_t) (w >= 3);
732 #else
733 params->scalar.multiplier = multiplier;
Erich Elsen6f278b52020-06-10 16:13:11 -0700734
735 const uint32_t w = (width - 1) & 3;
736 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
Marat Dukhan548e8302020-08-04 15:03:46 -0700737 params->scalar.mask[1] = -(int32_t) (w >= 1);
738 params->scalar.mask[2] = -(int32_t) (w >= 2);
739 params->scalar.mask[3] = -(int32_t) (w >= 3);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700740 #endif
741}
742
Marat Dukhan8452ff52020-04-08 20:44:58 -0700743static inline union xnn_f32_scaleminmax_params xnn_init_scalar_f32_scaleminmax_params(
744 float scale,
745 float min,
746 float max)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700747{
Marat Dukhan8452ff52020-04-08 20:44:58 -0700748 union xnn_f32_scaleminmax_params params;
749 params.scalar.scale = scale;
750 params.scalar.min = min;
751 params.scalar.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700752 return params;
753}
754
755static inline union xnn_f32_gavgpool_params xnn_init_scalar_f32_gavgpool_params(
756 float multiplier,
757 float output_min,
758 float output_max,
759 uint32_t width)
760{
761 union xnn_f32_gavgpool_params params;
762 params.scalar.multiplier = multiplier;
763 params.scalar.output_min = output_min;
764 params.scalar.output_max = output_max;
Erich Elsen6f278b52020-06-10 16:13:11 -0700765
766 const uint32_t w = (width - 1) & 3;
767 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
768 params.scalar.mask[1] = -(int32_t) (w >= 1);
769 params.scalar.mask[2] = -(int32_t) (w >= 2);
770 params.scalar.mask[3] = -(int32_t) (w >= 3);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700771 return params;
772}
773
Frank Barchardd793f6c2020-05-08 13:37:43 -0700774static inline struct xnn_f16_minmax_params xnn_init_f16_minmax_params(
775 uint16_t min,
776 uint16_t max)
777{
778 struct xnn_f16_minmax_params params;
779 params.min = min;
780 params.max = max;
781 return params;
782}
783
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700784static inline union xnn_f32_minmax_params xnn_init_f32_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700785 float output_min,
786 float output_max)
787{
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700788 union xnn_f32_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700789 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
790 for (uint32_t i = 0; i < 4; i++) {
791 params.sse.min[i] = output_min;
792 params.sse.max[i] = output_max;
793 }
794 #else
795 params.scalar.min = output_min;
796 params.scalar.max = output_max;
797 #endif
798 return params;
799}
800
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700801static inline union xnn_f32_minmax_params xnn_init_scalar_f32_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700802 float output_min,
803 float output_max)
804{
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700805 union xnn_f32_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700806 params.scalar.min = output_min;
807 params.scalar.max = output_max;
808 return params;
809}
810
Frank Barchardb1966592020-05-12 13:47:06 -0700811static inline struct xnn_f16_hswish_params xnn_init_f16_hswish_params(void)
812{
813 struct xnn_f16_hswish_params params;
Marat Dukhan201ea0e2020-07-10 22:47:19 -0700814 params.sixth = UINT16_C(0x3155);
815 params.three = UINT16_C(0x4200);
816 params.six = UINT16_C(0x4600);
Frank Barchardb1966592020-05-12 13:47:06 -0700817 return params;
818}
819
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700820static inline union xnn_f32_hswish_params xnn_init_f32_hswish_params(void)
821{
822 union xnn_f32_hswish_params params;
823 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
824 for (uint32_t i = 0; i < 4; i++) {
825 params.sse.sixth[i] = 0x1.555556p-3f;
826 params.sse.half[i] = 0.5f;
827 params.sse.one[i] = 1.0f;
828 }
829 #else
830 params.scalar.sixth = 0x1.555556p-3f;
Marat Dukhan9df9dc62020-07-10 20:08:49 -0700831 params.scalar.three = 3.0f;
832 params.scalar.six = 6.0f;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700833 #endif
834 return params;
835}
836
837static inline union xnn_f32_hswish_params xnn_init_scalar_f32_hswish_params(void)
838{
839 union xnn_f32_hswish_params params;
840 params.scalar.sixth = 0x1.555556p-3f;
Marat Dukhan9df9dc62020-07-10 20:08:49 -0700841 params.scalar.three = 3.0f;
842 params.scalar.six = 6.0f;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700843 return params;
844}
845
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700846static inline union xnn_f32_abs_params xnn_init_f32_abs_params(void)
847{
848 union xnn_f32_abs_params params = { 0 };
849 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
850 for (uint32_t i = 0; i < 4; i++) {
851 params.sse.nonsign_mask[i] = math_nonsign_mask_f32();
852 }
Marat Dukhan37c83512020-06-29 13:25:53 -0700853 #elif XNN_ARCH_WASMSIMD
854 params.wasmsimd.nonsign_mask = math_nonsign_mask_f32();
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700855 #endif
856 return params;
857}
858
859static inline union xnn_f32_abs_params xnn_init_scalar_f32_abs_params(void)
860{
861 union xnn_f32_abs_params params = { 0 };
862 return params;
863}
864
865static inline union xnn_f32_neg_params xnn_init_f32_neg_params(void)
866{
867 union xnn_f32_neg_params params = { 0 };
868 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
869 for (uint32_t i = 0; i < 4; i++) {
870 params.sse.sign_mask[i] = -0.0f;
871 }
Marat Dukhan37c83512020-06-29 13:25:53 -0700872 #elif XNN_ARCH_WASMSIMD
873 params.wasmsimd.sign_mask = -0.0f;
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700874 #endif
875 return params;
876}
877
878static inline union xnn_f32_neg_params xnn_init_scalar_f32_neg_params(void)
879{
880 union xnn_f32_neg_params params = { 0 };
881 return params;
882}
883
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700884static inline union xnn_f32_rnd_params xnn_init_f32_rnd_params(void)
885{
886 union xnn_f32_rnd_params params = { 0 };
887 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
888 for (uint32_t i = 0; i < 4; i++) {
889 params.sse2.sign_mask[i] = -0.0f;
890 }
891 for (uint32_t i = 0; i < 4; i++) {
892 params.sse2.one[i] = 1.0f;
893 }
894 #endif
895 return params;
896}
897
898static inline union xnn_f32_rnd_params xnn_init_scalar_f32_rnd_params(void)
899{
900 union xnn_f32_rnd_params params = { 0 };
901 return params;
902}
903
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800904static inline union xnn_f32_elu_params xnn_init_f32_elu_params(float prescale, float alpha, float beta)
905{
906 union xnn_f32_elu_params params;
907 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
908 for (uint32_t i = 0; i < 4; i++) {
909 params.sse.prescale[i] = prescale;
910 params.sse.alpha[i] = alpha;
911 params.sse.beta[i] = beta;
912 }
913 #else
914 params.scalar.prescale = prescale;
915 params.scalar.alpha = alpha;
916 params.scalar.beta = beta;
917 #endif
918 return params;
919}
920
921static inline union xnn_f32_elu_params xnn_init_scalar_f32_elu_params(float prescale, float alpha, float beta)
922{
923 union xnn_f32_elu_params params;
924 params.scalar.prescale = prescale;
925 params.scalar.alpha = alpha;
926 params.scalar.beta = beta;
927 return params;
928}
929
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700930static inline union xnn_f32_lrelu_params xnn_init_f32_lrelu_params(float slope)
931{
932 union xnn_f32_lrelu_params params;
933 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
934 for (uint32_t i = 0; i < 4; i++) {
935 params.sse.slope[i] = slope;
936 }
937 #else
938 params.scalar.slope = slope;
939 #endif
940 return params;
941}
942
943static inline union xnn_f32_lrelu_params xnn_init_scalar_f32_lrelu_params(float slope)
944{
945 union xnn_f32_lrelu_params params;
946 params.scalar.slope = slope;
947 return params;
948}
949
Marat Dukhanf4db2f32020-06-30 10:55:30 -0700950static inline union xnn_f32_sqrt_params xnn_init_f32_sqrt_params(void)
951{
952 union xnn_f32_sqrt_params params = { 0 };
953 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
954 params.fma.half = 0.5f;
955 #endif
956 return params;
957}
958
959static inline union xnn_f32_sqrt_params xnn_init_scalar_f32_sqrt_params(void)
960{
961 union xnn_f32_sqrt_params params = { 0 };
962 return params;
963}
964
Marat Dukhan1f29b802020-05-15 23:46:39 -0700965static inline union xnn_f32_chw_params xnn_init_f32_chw_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700966 uint32_t width,
967 float output_min,
968 float output_max)
969{
Marat Dukhan1f29b802020-05-15 23:46:39 -0700970 union xnn_f32_chw_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700971 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
972 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700973 params.sse.min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700974 params.sse.max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700975 }
976
977 const uint32_t w4 = (width - 1) & 3;
978 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
979 params.sse.mask[1] = -(uint32_t) (w4 >= 1);
980 params.sse.mask[2] = -(uint32_t) (w4 >= 2);
981 params.sse.mask[3] = -(uint32_t) (w4 >= 3);
982
983 const uint32_t w8 = (width - 1) & 7;
984 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
985 params.sse.mask_even[1] = -(uint32_t) (w8 >= 2);
986 params.sse.mask_even[2] = -(uint32_t) (w8 >= 4);
987 params.sse.mask_even[3] = -(uint32_t) (w8 >= 6);
988 params.sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
989 params.sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
990 params.sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
991 params.sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
992 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700993 params.neon.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700994 params.neon.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700995
996 const uint32_t w4 = (width - 1) & 3;
997 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
998 params.neon.mask[1] = -(uint32_t) (w4 >= 1);
999 params.neon.mask[2] = -(uint32_t) (w4 >= 2);
1000 params.neon.mask[3] = -(uint32_t) (w4 >= 3);
1001
1002 const uint32_t w8 = (width - 1) & 7;
1003 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1004 params.neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1005 params.neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1006 params.neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1007 params.neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1008 params.neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1009 params.neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1010 params.neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
1011 #else
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001012 params.scalar.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001013 params.scalar.max = output_max;
Erich Elsene6214af2020-06-10 22:17:22 -07001014
1015 const uint32_t w4 = (width - 1) & 3;
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -07001016 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1017 params.scalar.mask[1] = -(uint32_t) (w4 >= 1);
1018 params.scalar.mask[2] = -(uint32_t) (w4 >= 2);
1019 params.scalar.mask[3] = -(uint32_t) (w4 >= 3);
Erich Elsenfd7a6e32020-06-11 12:04:44 -07001020
1021 const uint32_t w8 = (width - 1) & 7;
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -07001022 params.scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1023 params.scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1024 params.scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1025 params.scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1026 params.scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1027 params.scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1028 params.scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1029 params.scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001030 #endif
1031 return params;
1032}
1033
Marat Dukhan1f29b802020-05-15 23:46:39 -07001034static inline void xnn_update_f32_chw_params(
1035 union xnn_f32_chw_params* params,
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001036 uint32_t width)
1037{
1038 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1039 const uint32_t w4 = (width - 1) & 3;
1040 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1041 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
1042 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
1043 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
1044
1045 const uint32_t w8 = (width - 1) & 7;
1046 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
1047 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
1048 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
1049 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
1050 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
1051 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
1052 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
1053 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
1054 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1055 const uint32_t w4 = (width - 1) & 3;
1056 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1057 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
1058 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
1059 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
1060
1061 const uint32_t w8 = (width - 1) & 7;
1062 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1063 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1064 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1065 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1066 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1067 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1068 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1069 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -07001070 #else
1071 const uint32_t w4 = (width - 1) & 3;
1072 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1073 params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1074 params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1075 params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1076
1077 const uint32_t w8 = (width - 1) & 7;
1078 params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1079 params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1080 params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1081 params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1082 params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1083 params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1084 params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1085 params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001086 #endif
1087}
1088
Marat Dukhan1f29b802020-05-15 23:46:39 -07001089static inline union xnn_f32_chw_params xnn_init_scalar_f32_chw_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001090 uint32_t width,
1091 float output_min,
1092 float output_max)
1093{
Marat Dukhan1f29b802020-05-15 23:46:39 -07001094 union xnn_f32_chw_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001095 params.scalar.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001096 params.scalar.max = output_max;
Erich Elsene6214af2020-06-10 22:17:22 -07001097
1098 const uint32_t w4 = (width - 1) & 3;
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -07001099 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1100 params.scalar.mask[1] = -(uint32_t) (w4 >= 1);
1101 params.scalar.mask[2] = -(uint32_t) (w4 >= 2);
1102 params.scalar.mask[3] = -(uint32_t) (w4 >= 3);
Erich Elsenfd7a6e32020-06-11 12:04:44 -07001103
1104 const uint32_t w8 = (width - 1) & 7;
Marat Dukhanfaa2e5c2020-07-10 17:34:44 -07001105 params.scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1106 params.scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1107 params.scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1108 params.scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1109 params.scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1110 params.scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1111 params.scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1112 params.scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
Erich Elsenfd7a6e32020-06-11 12:04:44 -07001113
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001114 return params;
1115}
1116
Marat Dukhaneb09a6b2020-04-08 17:34:32 -07001117static inline union xnn_u8_minmax_params xnn_init_u8_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001118 uint8_t output_min,
1119 uint8_t output_max)
1120{
1121 assert(output_min < output_max);
1122
Marat Dukhaneb09a6b2020-04-08 17:34:32 -07001123 union xnn_u8_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001124 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1125 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001126 params.sse2.min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001127 params.sse2.max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001128 }
1129 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001130 params.neon.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001131 params.neon.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001132 #else
1133 params.scalar.min = (int32_t) (uint32_t) output_min;
1134 params.scalar.max = (int32_t) (uint32_t) output_max;
1135 #endif
1136 return params;
1137}
1138
Marat Dukhaneb09a6b2020-04-08 17:34:32 -07001139static inline union xnn_u8_minmax_params xnn_init_scalar_u8_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001140 uint8_t output_min,
1141 uint8_t output_max)
1142{
1143 assert(output_min < output_max);
1144
Marat Dukhaneb09a6b2020-04-08 17:34:32 -07001145 union xnn_u8_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001146 params.scalar.min = (int32_t) (uint32_t) output_min;
1147 params.scalar.max = (int32_t) (uint32_t) output_max;
1148 return params;
1149}
1150
Marat Dukhan08b7a972020-07-14 18:17:29 -07001151static inline union xnn_qu8_add_params xnn_init_qu8_add_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001152 uint8_t a_zero_point,
1153 uint8_t b_zero_point,
1154 uint8_t output_zero_point,
1155 float a_output_scale,
1156 float b_output_scale,
1157 uint8_t output_min,
1158 uint8_t output_max)
1159{
1160 assert(a_output_scale >= 0x1.0p-14f);
1161 assert(b_output_scale >= 0x1.0p-14f);
1162 assert(a_output_scale < 0x1.0p+8f);
1163 assert(b_output_scale < 0x1.0p+8f);
1164
1165 // Compute requantization parameters.
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001166 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001167 assert(max_output_scale >= 0x1.0p-14f);
1168 assert(max_output_scale < 0x1.0p+8f);
1169 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1170 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1171 // Shift is in [13, 31] range.
1172 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1173 assert(shift < 32);
1174 assert(shift >= 13);
1175
1176 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1177
1178 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
Marat Dukhanef3e7dc2020-04-13 01:19:56 -07001179 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(a_output_scale * scale_multiplier);
1180 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(b_output_scale * scale_multiplier);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001181 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1182 assert(a_multiplier < UINT32_C(0x00400000));
1183 assert(b_multiplier < UINT32_C(0x00400000));
1184
Marat Dukhan08b7a972020-07-14 18:17:29 -07001185 union xnn_qu8_add_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001186 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1187 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1188 const uint32_t remainder_threshold = remainder_mask >> 1;
1189 const int32_t zero_point_product =
1190 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1191 for (uint32_t i = 0; i < 4; i++) {
1192 params.sse2.zero_point_product[i] = zero_point_product;
1193 }
1194 for (uint32_t i = 0; i < 8; i++) {
1195 params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
1196 }
1197 for (uint32_t i = 0; i < 8; i++) {
1198 params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
1199 params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
1200 params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
1201 params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
1202 }
1203 params.sse2.a_multiplier = a_multiplier;
1204 params.sse2.b_multiplier = b_multiplier;
1205 for (uint32_t i = 0; i < 4; i++) {
1206 params.sse2.remainder_mask[i] = remainder_mask;
1207 params.sse2.remainder_threshold[i] = remainder_threshold;
1208 }
1209 params.sse2.shift = shift;
1210 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001211 params.sse2.y_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001212 params.sse2.y_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001213 }
1214 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1215 params.neon.a_zero_point = a_zero_point;
1216 params.neon.b_zero_point = b_zero_point;
1217 params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
1218 params.neon.a_multiplier = (int32_t) a_multiplier;
1219 params.neon.b_multiplier = (int32_t) b_multiplier;
1220 params.neon.right_shift = (int32_t) -shift;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001221 params.neon.y_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001222 params.neon.y_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001223 #else
1224 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1225 const uint32_t remainder_threshold = remainder_mask >> 1;
1226 params.scalar.zero_point_product =
1227 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1228 params.scalar.a_multiplier = a_multiplier;
1229 params.scalar.b_multiplier = b_multiplier;
1230 params.scalar.remainder_mask = (int32_t) remainder_mask;
1231 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1232 params.scalar.shift = shift;
1233 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001234 params.scalar.y_min = (int32_t) (uint32_t) output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001235 params.scalar.y_max = (int32_t) (uint32_t) output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001236 #endif
1237 return params;
1238}
1239
Marat Dukhan08b7a972020-07-14 18:17:29 -07001240static inline union xnn_qu8_add_params xnn_init_scalar_qu8_add_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001241 uint8_t a_zero_point,
1242 uint8_t b_zero_point,
1243 uint8_t output_zero_point,
1244 float a_output_scale,
1245 float b_output_scale,
1246 uint8_t output_min,
1247 uint8_t output_max)
1248{
1249 assert(a_output_scale >= 0x1.0p-10f);
1250 assert(b_output_scale >= 0x1.0p-10f);
1251 assert(a_output_scale < 0x1.0p+8f);
1252 assert(b_output_scale < 0x1.0p+8f);
1253
1254 // Compute requantization parameters.
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001255 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001256 assert(max_output_scale >= 0x1.0p-10f);
1257 assert(max_output_scale < 0x1.0p+8f);
1258 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1259 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1260 // Shift is in [13, 31] range.
1261 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1262 assert(shift < 32);
1263 assert(shift >= 13);
1264
1265 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
Marat Dukhanef3e7dc2020-04-13 01:19:56 -07001266 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1267 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001268 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1269 assert(a_multiplier < UINT32_C(0x00400000));
1270 assert(b_multiplier < UINT32_C(0x00400000));
1271
Marat Dukhan08b7a972020-07-14 18:17:29 -07001272 union xnn_qu8_add_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001273 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1274 const uint32_t remainder_threshold = remainder_mask >> 1;
1275 params.scalar.zero_point_product =
1276 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1277 params.scalar.a_multiplier = a_multiplier;
1278 params.scalar.b_multiplier = b_multiplier;
1279 params.scalar.remainder_mask = (int32_t) remainder_mask;
1280 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1281 params.scalar.shift = shift;
1282 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001283 params.scalar.y_min = (int32_t) (uint32_t) output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -07001284 params.scalar.y_max = (int32_t) (uint32_t) output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001285 return params;
1286}
1287
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001288static inline union xnn_qs8_add_params xnn_init_qs8_add_params(
1289 int8_t x_zero_point,
1290 int8_t y_zero_point,
1291 int8_t output_zero_point,
1292 float x_output_scale,
1293 float y_output_scale,
1294 int8_t output_min,
1295 int8_t output_max)
1296{
1297 assert(x_output_scale >= 0x1.0p-14f);
1298 assert(y_output_scale >= 0x1.0p-14f);
1299 assert(x_output_scale < 0x1.0p+8f);
1300 assert(y_output_scale < 0x1.0p+8f);
1301
1302 // Compute requantization parameters.
1303 const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1304 assert(max_output_scale >= 0x1.0p-14f);
1305 assert(max_output_scale < 0x1.0p+8f);
1306 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1307 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1308 // Shift is in [13, 31] range.
1309 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1310 assert(shift < 32);
1311 assert(shift >= 13);
1312
1313 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1314
1315 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1316 const int32_t x_multiplier = (int32_t) lrintf(x_output_scale * scale_multiplier);
1317 const int32_t y_multiplier = (int32_t) lrintf(y_output_scale * scale_multiplier);
1318 assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1319 assert(x_multiplier < INT32_C(0x00400000));
1320 assert(y_multiplier < INT32_C(0x00400000));
1321
1322 union xnn_qs8_add_params params;
1323 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1324 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1325 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1326 const int32_t zero_point_product =
1327 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1328 for (uint32_t i = 0; i < 4; i++) {
1329 params.sse2.zero_point_product[i] = zero_point_product;
1330 }
1331 const uint16_t x_multiplier_lo = (uint16_t) x_multiplier;
1332 const uint16_t x_multiplier_hi = (uint16_t) ((uint32_t) x_multiplier >> 16);
1333 const uint16_t y_multiplier_lo = (uint16_t) y_multiplier;
1334 const uint16_t y_multiplier_hi = (uint16_t) ((uint32_t) y_multiplier >> 16);
1335 for (uint32_t i = 0; i < 8; i++) {
1336 params.sse2.x_multiplier_lo[i] = x_multiplier_lo;
1337 params.sse2.x_multiplier_hi[i] = x_multiplier_hi;
1338 params.sse2.y_multiplier_lo[i] = y_multiplier_lo;
1339 params.sse2.y_multiplier_hi[i] = y_multiplier_hi;
1340 }
1341 params.sse2.shift = shift;
1342 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhanbb9225e2020-09-06 22:40:56 -07001343 params.sse2.x_multiplier[i] = x_multiplier;
1344 params.sse2.y_multiplier[i] = y_multiplier;
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001345 params.sse2.remainder_mask[i] = remainder_mask;
1346 params.sse2.remainder_threshold[i] = remainder_threshold;
1347 }
1348 for (uint32_t i = 0; i < 8; i++) {
1349 params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
1350 params.sse2.output_min[i] = (int16_t) output_min;
1351 params.sse2.output_max[i] = (int16_t) output_max;
1352 }
1353 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1354 params.neon.x_zero_point = x_zero_point;
1355 params.neon.y_zero_point = y_zero_point;
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001356 params.neon.x_multiplier = (int32_t) x_multiplier;
1357 params.neon.y_multiplier = (int32_t) y_multiplier;
1358 params.neon.right_shift = (int32_t) -shift;
Marat Dukhanba7b2792020-09-02 14:26:45 -07001359 params.neon.output_zero_point = (int16_t) output_zero_point;
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001360 params.neon.output_min = output_min;
1361 params.neon.output_max = output_max;
Marat Dukhan5df27f82020-09-02 23:59:21 -07001362 #elif XNN_ARCH_WASMSIMD
1363 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1364 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1365 const int32_t zero_point_product =
1366 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1367 for (uint32_t i = 0; i < 4; i++) {
1368 params.wasmsimd.zero_point_product[i] = zero_point_product;
1369 params.wasmsimd.x_multiplier[i] = x_multiplier;
1370 params.wasmsimd.y_multiplier[i] = y_multiplier;
1371 params.wasmsimd.remainder_mask[i] = remainder_mask;
1372 params.wasmsimd.remainder_threshold[i] = remainder_threshold;
1373 }
1374 params.wasmsimd.shift = shift;
1375 for (uint32_t i = 0; i < 8; i++) {
1376 params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
1377 }
1378 for (uint32_t i = 0; i < 16; i++) {
1379 params.wasmsimd.output_min[i] = output_min;
1380 params.wasmsimd.output_max[i] = output_max;
1381 }
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001382 #else
1383 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1384 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1385 params.scalar.zero_point_product =
1386 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1387 params.scalar.x_multiplier = x_multiplier;
1388 params.scalar.y_multiplier = y_multiplier;
1389 params.scalar.remainder_mask = (int32_t) remainder_mask;
1390 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
Marat Dukhan5df27f82020-09-02 23:59:21 -07001391 params.scalar.shift = (int32_t) shift;
Marat Dukhand9f3ad42020-08-10 12:30:58 -07001392 params.scalar.output_zero_point = (int32_t) output_zero_point;
1393 params.scalar.output_min = (int32_t) output_min;
1394 params.scalar.output_max = (int32_t) output_max;
1395 #endif
1396 return params;
1397}
1398
1399static inline union xnn_qs8_add_params xnn_init_scalar_qs8_add_params(
1400 int8_t x_zero_point,
1401 int8_t y_zero_point,
1402 int8_t output_zero_point,
1403 float x_output_scale,
1404 float y_output_scale,
1405 int8_t output_min,
1406 int8_t output_max)
1407{
1408 assert(x_output_scale >= 0x1.0p-10f);
1409 assert(y_output_scale >= 0x1.0p-10f);
1410 assert(x_output_scale < 0x1.0p+8f);
1411 assert(y_output_scale < 0x1.0p+8f);
1412
1413 // Compute requantization parameters.
1414 const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1415 assert(max_output_scale >= 0x1.0p-10f);
1416 assert(max_output_scale < 0x1.0p+8f);
1417 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1418 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1419 // Shift is in [13, 31] range.
1420 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1421 assert(shift < 32);
1422 assert(shift >= 13);
1423
1424 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1425 const int32_t x_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(x_output_scale) + (shift << 23)));
1426 const int32_t y_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(y_output_scale) + (shift << 23)));
1427 assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1428 assert(x_multiplier < INT32_C(0x00400000));
1429 assert(y_multiplier < INT32_C(0x00400000));
1430
1431 union xnn_qs8_add_params params;
1432 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1433 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1434 params.scalar.zero_point_product =
1435 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1436 params.scalar.x_multiplier = x_multiplier;
1437 params.scalar.y_multiplier = y_multiplier;
1438 params.scalar.remainder_mask = (int32_t) remainder_mask;
1439 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1440 params.scalar.shift = shift;
1441 params.scalar.output_zero_point = (int32_t) output_zero_point;
1442 params.scalar.output_min = (int32_t) output_min;
1443 params.scalar.output_max = (int32_t) output_max;
1444 return params;
1445}
1446
Marat Dukhanec88e272020-07-30 15:02:09 -07001447static inline union xnn_qu8_requantization_params xnn_init_scalar_qu8_requantization_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001448 float scale,
1449 uint8_t zero_point,
1450 uint8_t min,
1451 uint8_t max)
1452{
1453 // Compute requantization parameters.
1454 assert(scale < 1.0f);
1455 assert(scale >= 0x1.0p-32f);
1456 const uint32_t scale_bits = fp32_to_bits(scale);
1457
1458 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
1459 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1460 assert(multiplier >= INT32_C(0x40000000));
1461 assert(multiplier <= INT32_C(0x7FFFFF80));
1462
1463 // Shift is in [0, 31] range.
1464 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1465 assert(shift >= 0);
1466 assert(shift < 32);
1467
Marat Dukhanec88e272020-07-30 15:02:09 -07001468 union xnn_qu8_requantization_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001469 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1470 const uint32_t remainder_threshold = remainder_mask >> 1;
Marat Dukhanec88e272020-07-30 15:02:09 -07001471 params.q31.multiplier = multiplier;
1472 params.q31.remainder_mask = (int32_t) remainder_mask;
1473 params.q31.remainder_threshold = (int32_t) remainder_threshold;
1474 params.q31.shift = (uint32_t) shift;
1475 params.q31.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
1476 params.q31.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
1477 params.q31.zero_point = (int32_t) (uint32_t) zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001478 return params;
1479}
1480
Marat Dukhan595e1702020-07-31 10:12:52 -07001481static inline union xnn_qs8_requantization_params xnn_init_scalar_qs8_requantization_params(
1482 float scale,
1483 int8_t zero_point,
1484 int8_t min,
1485 int8_t max)
1486{
1487 // Compute requantization parameters.
1488 assert(scale < 1.0f);
1489 assert(scale >= 0x1.0p-32f);
1490 const uint32_t scale_bits = fp32_to_bits(scale);
1491
1492 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
1493 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1494 assert(multiplier >= INT32_C(0x40000000));
1495 assert(multiplier <= INT32_C(0x7FFFFF80));
1496
1497 // Shift is in [0, 31] range.
1498 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1499 assert(shift >= 0);
1500 assert(shift < 32);
1501
1502 union xnn_qs8_requantization_params params;
1503 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1504 const uint32_t remainder_threshold = remainder_mask >> 1;
1505 params.q31.multiplier = multiplier;
1506 params.q31.remainder_mask = (int32_t) remainder_mask;
1507 params.q31.remainder_threshold = (int32_t) remainder_threshold;
1508 params.q31.shift = (uint32_t) shift;
1509 params.q31.min_less_zero_point = (int32_t) min - (int32_t) zero_point;
1510 params.q31.max_less_zero_point = (int32_t) max - (int32_t) zero_point;
1511 params.q31.zero_point = (int32_t) zero_point;
1512 return params;
1513}
1514