blob: a2a13070b862a474f8083e6ae21abc467e2290eb [file] [log] [blame]
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#if defined(__cplusplus) && (__cplusplus >= 201103L)
12 #include <cstdint>
13 #include <cstddef>
14 #include <cassert>
15 #include <cmath>
16#else
17 #include <stdint.h>
18 #include <stddef.h>
19 #include <assert.h>
20 #include <math.h>
21#endif
22
23#include <fp16.h>
24
25#include <xnnpack/common.h>
Marat Dukhan2b9efd82020-06-08 01:09:31 -070026#include <xnnpack/math.h>
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070027#include <xnnpack/params.h>
28
29
30static inline union xnn_q8_gemm_params xnn_init_scalar_q8_gemm_params(
31 uint8_t input_zero_point,
32 uint8_t kernel_zero_point,
33 float scale,
34 uint8_t output_zero_point,
35 uint8_t output_min,
36 uint8_t output_max)
37{
38 // Compute requantization parameters
39 const uint32_t scale_bits = fp32_to_bits(scale);
40
41 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
42 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
43 assert(multiplier >= INT32_C(0x40000000));
44 assert(multiplier <= INT32_C(0x7FFFFF80));
45
46 // Shift is in [0, 31] range.
47 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
48 assert(shift >= 0);
49 assert(shift < 32);
50
51 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
52 const uint32_t remainder_threshold = remainder_mask >> 1;
53
54 union xnn_q8_gemm_params params;
55 params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
56 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
57 params.scalar.multiplier = multiplier;
58 params.scalar.remainder_mask = (int32_t) remainder_mask;
59 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
60 params.scalar.shift = (uint32_t) shift;
61 params.scalar.output_min_less_zero_point =
62 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
63 params.scalar.output_max_less_zero_point =
64 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
65 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
66 return params;
67}
68
69static inline union xnn_q8_gemm_params xnn_init_q8_gemm_params(
70 uint8_t input_zero_point,
71 uint8_t kernel_zero_point,
72 float scale,
73 uint8_t output_zero_point,
74 uint8_t output_min,
75 uint8_t output_max)
76{
77 // Compute requantization parameters.
78 const uint32_t scale_bits = fp32_to_bits(scale);
79
80 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
81 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
82 assert(multiplier >= INT32_C(0x40000000));
83 assert(multiplier <= INT32_C(0x7FFFFF80));
84
85 // Shift is in [0, 31] range.
86 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
87 assert(shift >= 0);
88 assert(shift < 32);
89
90 union xnn_q8_gemm_params params;
91 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
92 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
93 const uint32_t remainder_threshold = remainder_mask >> 1;
94 for (uint32_t i = 0; i < 8; i++) {
95 params.sse2.input_zero_point[i] = (int16_t) (uint16_t) input_zero_point;
96 params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
97 }
98 params.sse2.multiplier[0] = multiplier;
99 params.sse2.multiplier[1] = multiplier;
100 params.sse2.multiplier[2] = multiplier;
101 params.sse2.multiplier[3] = multiplier;
102 params.sse2.rounding[0] = UINT64_C(0x40000000);
103 params.sse2.rounding[1] = UINT64_C(0x40000000);
104 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
105 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
106 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
107 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
108 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
109 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
110 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
111 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
112 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
113 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
114 for (uint32_t i = 0; i < 8; i++) {
115 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
116 }
117 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700118 params.sse2.output_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700119 params.sse2.output_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700120 }
121 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
122 params.neon.input_zero_point = (int16_t) (uint16_t) input_zero_point;
123 params.neon.kernel_zero_point = (int16_t) (uint16_t) kernel_zero_point;
124 params.neon.multiplier = multiplier;
125 params.neon.right_shift = -shift;
126 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700127 params.neon.output_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700128 params.neon.output_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700129 #else
130 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
131 const uint32_t remainder_threshold = remainder_mask >> 1;
132 params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
133 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
134 params.scalar.multiplier = multiplier;
135 params.scalar.remainder_mask = (int32_t) remainder_mask;
136 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
137 params.scalar.shift = (uint32_t) shift;
138 params.scalar.output_min_less_zero_point =
139 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
140 params.scalar.output_max_less_zero_point =
141 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
142 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
143 #endif
144 return params;
145}
146
147static inline union xnn_q8_avgpool_params xnn_init_q8_avgpool_params(
148 int32_t bias,
149 float scale,
150 uint8_t output_zero_point,
151 uint8_t output_min,
152 uint8_t output_max)
153{
154 // Compute requantization parameters.
155 assert(scale >= 0x1.0p-32f);
156 assert(scale < 256.0f);
157 const uint32_t scale_bits = fp32_to_bits(scale);
158
159 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
160 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
161 assert(multiplier >= INT32_C(0x00800000));
162 assert(multiplier <= INT32_C(0x00FFFFFF));
163
164 // Shift is in [16, 55] range.
165 const int32_t shift = 127 + 23 - (scale_bits >> 23);
166 assert(shift >= 16);
167 assert(shift < 64);
168
169 union xnn_q8_avgpool_params params;
170 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
171 const uint32_t right_shift = (uint32_t) shift;
172 const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
173 params.sse2.bias[0] = bias;
174 params.sse2.bias[1] = bias;
175 params.sse2.bias[2] = bias;
176 params.sse2.bias[3] = bias;
177 params.sse2.multiplier[0] = (uint32_t) multiplier;
178 params.sse2.multiplier[1] = (uint32_t) multiplier;
179 params.sse2.multiplier[2] = (uint32_t) multiplier;
180 params.sse2.multiplier[3] = (uint32_t) multiplier;
181 params.sse2.rounding[0] = rounding;
182 params.sse2.rounding[1] = rounding;
183 params.sse2.right_shift[0] = (uint64_t) right_shift;
184 params.sse2.right_shift[1] = (uint64_t) right_shift;
185 for (uint32_t i = 0; i < 8; i++) {
186 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
187 }
188 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700189 params.sse2.output_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700190 params.sse2.output_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700191 }
192 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
193 params.neon.bias = bias;
194 params.neon.multiplier = multiplier;
195 params.neon.left_shift = (int64_t) -shift;
196 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700197 params.neon.output_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700198 params.neon.output_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700199 #else
200 const uint32_t right_shift = (uint32_t) shift;
201 const int64_t rounding = INT64_C(1) << (right_shift - 1);
202 params.scalar.bias = bias;
203 params.scalar.multiplier = multiplier;
204 params.scalar.rounding = rounding;
205 params.scalar.right_shift = right_shift;
206 params.scalar.output_min_less_zero_point =
207 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
208 params.scalar.output_max_less_zero_point =
209 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
210 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
211 #endif
212 return params;
213}
214
215static inline union xnn_q8_avgpool_params xnn_init_scalar_q8_avgpool_params(
216 int32_t bias,
217 float scale,
218 uint8_t output_zero_point,
219 uint8_t output_min,
220 uint8_t output_max)
221{
222 // Compute requantization parameters.
223 assert(scale >= 0x1.0p-32f);
224 assert(scale < 256.0f);
225 const uint32_t scale_bits = fp32_to_bits(scale);
226
227 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
228 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
229 assert(multiplier >= INT32_C(0x00800000));
230 assert(multiplier <= INT32_C(0x00FFFFFF));
231
232 // Shift is in [16, 55] range.
233 const int32_t shift = 127 + 23 - (scale_bits >> 23);
234 assert(shift >= 16);
235 assert(shift < 64);
236
237 union xnn_q8_avgpool_params params;
238 const uint32_t right_shift = (uint32_t) shift;
239 const int64_t rounding = INT64_C(1) << (right_shift - 1);
240 params.scalar.bias = bias;
241 params.scalar.rounding = rounding;
242 params.scalar.multiplier = multiplier;
243 params.scalar.right_shift = right_shift;
244 params.scalar.output_min_less_zero_point =
245 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
246 params.scalar.output_max_less_zero_point =
247 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
248 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
249 return params;
250}
251
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700252static inline void xnn_update_f16_scaleminmax_params(
253 struct xnn_f16_scaleminmax_params* params,
254 uint16_t scale)
255{
256 params->scale = scale;
257}
258
Marat Dukhan8452ff52020-04-08 20:44:58 -0700259static inline void xnn_update_f32_scaleminmax_params(
260 union xnn_f32_scaleminmax_params* params,
261 float scale)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700262{
263 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
264 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhan8452ff52020-04-08 20:44:58 -0700265 params->sse2.scale[i] = scale;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700266 }
267 #else
Marat Dukhan8452ff52020-04-08 20:44:58 -0700268 params->scalar.scale = scale;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700269 #endif
270}
271
Frank Barchard7e2cbb02020-06-12 01:22:13 -0700272static inline struct xnn_f16_scaleminmax_params xnn_init_f16_scaleminmax_params(
273 uint16_t scale,
274 uint16_t min,
275 uint16_t max)
276{
277 struct xnn_f16_scaleminmax_params params;
278 params.scale = scale;
279 params.min = min;
280 params.max = max;
281 return params;
282}
283
Marat Dukhan8452ff52020-04-08 20:44:58 -0700284static inline union xnn_f32_scaleminmax_params xnn_init_f32_scaleminmax_params(
285 float scale,
286 float min,
287 float max)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700288{
Marat Dukhan8452ff52020-04-08 20:44:58 -0700289 union xnn_f32_scaleminmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700290 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
291 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhan8452ff52020-04-08 20:44:58 -0700292 params.sse2.scale[i] = scale;
293 params.sse2.min[i] = min;
294 params.sse2.max[i] = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700295 }
296 #else
Marat Dukhan8452ff52020-04-08 20:44:58 -0700297 params.scalar.scale = scale;
298 params.scalar.min = min;
299 params.scalar.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700300 #endif
301 return params;
302}
303
304static inline union xnn_f32_gavgpool_params xnn_init_f32_gavgpool_params(
305 float multiplier,
306 float output_min,
307 float output_max,
308 uint32_t width)
309{
310 union xnn_f32_gavgpool_params params;
311 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
312 for (uint32_t i = 0; i < 4; i++) {
313 params.sse.multiplier[i] = multiplier;
314 params.sse.output_min[i] = output_min;
315 params.sse.output_max[i] = output_max;
316 }
317
318 const uint32_t w = (width - 1) & 3;
319 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
320 params.sse.mask[1] = -(uint32_t) (w >= 1);
321 params.sse.mask[2] = -(uint32_t) (w >= 2);
322 params.sse.mask[3] = -(uint32_t) (w >= 3);
323 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
324 params.neon.multiplier = multiplier;
325 params.neon.output_min = output_min;
326 params.neon.output_max = output_max;
327
328 const uint32_t w = (width - 1) & 3;
329 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
330 params.neon.mask[1] = -(uint32_t) (w >= 1);
331 params.neon.mask[2] = -(uint32_t) (w >= 2);
332 params.neon.mask[3] = -(uint32_t) (w >= 3);
333 #else
334 params.scalar.multiplier = multiplier;
335 params.scalar.output_min = output_min;
336 params.scalar.output_max = output_max;
Erich Elsen6f278b52020-06-10 16:13:11 -0700337
338 const uint32_t w = (width - 1) & 3;
339 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
340 params.scalar.mask[1] = -(int32_t) (w >= 1);
341 params.scalar.mask[2] = -(int32_t) (w >= 2);
342 params.scalar.mask[3] = -(int32_t) (w >= 3);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700343 #endif
344 return params;
345}
346
347static inline void xnn_update_f32_gavgpool_params(
348 union xnn_f32_gavgpool_params* params,
349 float multiplier,
350 uint32_t width)
351{
352 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
353 for (uint32_t i = 0; i < 4; i++) {
354 params->sse.multiplier[i] = multiplier;
355 }
356
357 const uint32_t w = (width - 1) & 3;
358 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
359 params->sse.mask[1] = -(uint32_t) (w >= 1);
360 params->sse.mask[2] = -(uint32_t) (w >= 2);
361 params->sse.mask[3] = -(uint32_t) (w >= 3);
362 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
363 params->neon.multiplier = multiplier;
364
365 const uint32_t w = (width - 1) & 3;
366 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
367 params->neon.mask[1] = -(uint32_t) (w >= 1);
368 params->neon.mask[2] = -(uint32_t) (w >= 2);
369 params->neon.mask[3] = -(uint32_t) (w >= 3);
370 #else
371 params->scalar.multiplier = multiplier;
Erich Elsen6f278b52020-06-10 16:13:11 -0700372
373 const uint32_t w = (width - 1) & 3;
374 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
375 params->scalar.mask[1] = (int32_t) (w >= 1);
376 params->scalar.mask[2] = (int32_t) (w >= 2);
377 params->scalar.mask[3] = (int32_t) (w >= 3);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700378 #endif
379}
380
Marat Dukhan8452ff52020-04-08 20:44:58 -0700381static inline union xnn_f32_scaleminmax_params xnn_init_scalar_f32_scaleminmax_params(
382 float scale,
383 float min,
384 float max)
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700385{
Marat Dukhan8452ff52020-04-08 20:44:58 -0700386 union xnn_f32_scaleminmax_params params;
387 params.scalar.scale = scale;
388 params.scalar.min = min;
389 params.scalar.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700390 return params;
391}
392
393static inline union xnn_f32_gavgpool_params xnn_init_scalar_f32_gavgpool_params(
394 float multiplier,
395 float output_min,
396 float output_max,
397 uint32_t width)
398{
399 union xnn_f32_gavgpool_params params;
400 params.scalar.multiplier = multiplier;
401 params.scalar.output_min = output_min;
402 params.scalar.output_max = output_max;
Erich Elsen6f278b52020-06-10 16:13:11 -0700403
404 const uint32_t w = (width - 1) & 3;
405 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
406 params.scalar.mask[1] = -(int32_t) (w >= 1);
407 params.scalar.mask[2] = -(int32_t) (w >= 2);
408 params.scalar.mask[3] = -(int32_t) (w >= 3);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700409 return params;
410}
411
Frank Barchardd793f6c2020-05-08 13:37:43 -0700412static inline struct xnn_f16_minmax_params xnn_init_f16_minmax_params(
413 uint16_t min,
414 uint16_t max)
415{
416 struct xnn_f16_minmax_params params;
417 params.min = min;
418 params.max = max;
419 return params;
420}
421
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700422static inline union xnn_f32_minmax_params xnn_init_f32_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700423 float output_min,
424 float output_max)
425{
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700426 union xnn_f32_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700427 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
428 for (uint32_t i = 0; i < 4; i++) {
429 params.sse.min[i] = output_min;
430 params.sse.max[i] = output_max;
431 }
432 #else
433 params.scalar.min = output_min;
434 params.scalar.max = output_max;
435 #endif
436 return params;
437}
438
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700439static inline union xnn_f32_minmax_params xnn_init_scalar_f32_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700440 float output_min,
441 float output_max)
442{
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700443 union xnn_f32_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700444 params.scalar.min = output_min;
445 params.scalar.max = output_max;
446 return params;
447}
448
Frank Barchardb1966592020-05-12 13:47:06 -0700449static inline struct xnn_f16_hswish_params xnn_init_f16_hswish_params(void)
450{
451 struct xnn_f16_hswish_params params;
452 params.sixth = fp16_ieee_from_fp32_value(0x1.555556p-3f);
453 params.half = fp16_ieee_from_fp32_value(0.5f);
454 params.one = fp16_ieee_from_fp32_value(1.0f);
455 return params;
456}
457
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700458static inline union xnn_f32_hswish_params xnn_init_f32_hswish_params(void)
459{
460 union xnn_f32_hswish_params params;
461 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
462 for (uint32_t i = 0; i < 4; i++) {
463 params.sse.sixth[i] = 0x1.555556p-3f;
464 params.sse.half[i] = 0.5f;
465 params.sse.one[i] = 1.0f;
466 }
467 #else
468 params.scalar.sixth = 0x1.555556p-3f;
469 params.scalar.half = 0.5f;
470 params.scalar.one = 1.0f;
471 #endif
472 return params;
473}
474
475static inline union xnn_f32_hswish_params xnn_init_scalar_f32_hswish_params(void)
476{
477 union xnn_f32_hswish_params params;
478 params.scalar.sixth = 0x1.555556p-3f;
479 params.scalar.half = 0.5f;
480 params.scalar.one = 1.0f;
481 return params;
482}
483
Marat Dukhan2b9efd82020-06-08 01:09:31 -0700484static inline union xnn_f32_abs_params xnn_init_f32_abs_params(void)
485{
486 union xnn_f32_abs_params params = { 0 };
487 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
488 for (uint32_t i = 0; i < 4; i++) {
489 params.sse.nonsign_mask[i] = math_nonsign_mask_f32();
490 }
491 #endif
492 return params;
493}
494
495static inline union xnn_f32_abs_params xnn_init_scalar_f32_abs_params(void)
496{
497 union xnn_f32_abs_params params = { 0 };
498 return params;
499}
500
501static inline union xnn_f32_neg_params xnn_init_f32_neg_params(void)
502{
503 union xnn_f32_neg_params params = { 0 };
504 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
505 for (uint32_t i = 0; i < 4; i++) {
506 params.sse.sign_mask[i] = -0.0f;
507 }
508 #endif
509 return params;
510}
511
512static inline union xnn_f32_neg_params xnn_init_scalar_f32_neg_params(void)
513{
514 union xnn_f32_neg_params params = { 0 };
515 return params;
516}
517
Marat Dukhaneecf8fd2020-06-09 08:59:37 -0700518static inline union xnn_f32_rnd_params xnn_init_f32_rnd_params(void)
519{
520 union xnn_f32_rnd_params params = { 0 };
521 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
522 for (uint32_t i = 0; i < 4; i++) {
523 params.sse2.sign_mask[i] = -0.0f;
524 }
525 for (uint32_t i = 0; i < 4; i++) {
526 params.sse2.one[i] = 1.0f;
527 }
528 #endif
529 return params;
530}
531
532static inline union xnn_f32_rnd_params xnn_init_scalar_f32_rnd_params(void)
533{
534 union xnn_f32_rnd_params params = { 0 };
535 return params;
536}
537
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700538static inline union xnn_f32_lrelu_params xnn_init_f32_lrelu_params(float slope)
539{
540 union xnn_f32_lrelu_params params;
541 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
542 for (uint32_t i = 0; i < 4; i++) {
543 params.sse.slope[i] = slope;
544 }
545 #else
546 params.scalar.slope = slope;
547 #endif
548 return params;
549}
550
551static inline union xnn_f32_lrelu_params xnn_init_scalar_f32_lrelu_params(float slope)
552{
553 union xnn_f32_lrelu_params params;
554 params.scalar.slope = slope;
555 return params;
556}
557
Marat Dukhan1f29b802020-05-15 23:46:39 -0700558static inline union xnn_f32_chw_params xnn_init_f32_chw_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700559 uint32_t width,
560 float output_min,
561 float output_max)
562{
Marat Dukhan1f29b802020-05-15 23:46:39 -0700563 union xnn_f32_chw_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700564 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
565 for (uint32_t i = 0; i < 4; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700566 params.sse.min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700567 params.sse.max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700568 }
569
570 const uint32_t w4 = (width - 1) & 3;
571 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
572 params.sse.mask[1] = -(uint32_t) (w4 >= 1);
573 params.sse.mask[2] = -(uint32_t) (w4 >= 2);
574 params.sse.mask[3] = -(uint32_t) (w4 >= 3);
575
576 const uint32_t w8 = (width - 1) & 7;
577 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
578 params.sse.mask_even[1] = -(uint32_t) (w8 >= 2);
579 params.sse.mask_even[2] = -(uint32_t) (w8 >= 4);
580 params.sse.mask_even[3] = -(uint32_t) (w8 >= 6);
581 params.sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
582 params.sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
583 params.sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
584 params.sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
585 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700586 params.neon.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700587 params.neon.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700588
589 const uint32_t w4 = (width - 1) & 3;
590 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
591 params.neon.mask[1] = -(uint32_t) (w4 >= 1);
592 params.neon.mask[2] = -(uint32_t) (w4 >= 2);
593 params.neon.mask[3] = -(uint32_t) (w4 >= 3);
594
595 const uint32_t w8 = (width - 1) & 7;
596 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
597 params.neon.mask_even[1] = -(uint32_t) (w8 >= 2);
598 params.neon.mask_even[2] = -(uint32_t) (w8 >= 4);
599 params.neon.mask_even[3] = -(uint32_t) (w8 >= 6);
600 params.neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
601 params.neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
602 params.neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
603 params.neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
604 #else
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700605 params.scalar.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700606 params.scalar.max = output_max;
Erich Elsene6214af2020-06-10 22:17:22 -0700607
608 const uint32_t w4 = (width - 1) & 3;
609 params.scalar.mask[0] = INT32_C(0xFFFFFFFF);
610 params.scalar.mask[1] = -(int32_t) (w4 >= 1);
611 params.scalar.mask[2] = -(int32_t) (w4 >= 2);
612 params.scalar.mask[3] = -(int32_t) (w4 >= 3);
Erich Elsenfd7a6e32020-06-11 12:04:44 -0700613
614 const uint32_t w8 = (width - 1) & 7;
615 params.scalar.mask_even[0] = INT32_C(0xFFFFFFFF);
616 params.scalar.mask_even[1] = -(int32_t) (w8 >= 2);
617 params.scalar.mask_even[2] = -(int32_t) (w8 >= 4);
618 params.scalar.mask_even[3] = -(int32_t) (w8 >= 6);
619 params.scalar.mask_odd[0] = -(int32_t) (w8 >= 1);
620 params.scalar.mask_odd[1] = -(int32_t) (w8 >= 3);
621 params.scalar.mask_odd[2] = -(int32_t) (w8 >= 5);
622 params.scalar.mask_odd[3] = -(int32_t) (w8 >= 7);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700623 #endif
624 return params;
625}
626
Marat Dukhan1f29b802020-05-15 23:46:39 -0700627static inline void xnn_update_f32_chw_params(
628 union xnn_f32_chw_params* params,
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700629 uint32_t width)
630{
631 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
632 const uint32_t w4 = (width - 1) & 3;
633 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
634 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
635 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
636 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
637
638 const uint32_t w8 = (width - 1) & 7;
639 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
640 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
641 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
642 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
643 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
644 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
645 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
646 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
647 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
648 const uint32_t w4 = (width - 1) & 3;
649 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
650 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
651 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
652 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
653
654 const uint32_t w8 = (width - 1) & 7;
655 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
656 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
657 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
658 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
659 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
660 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
661 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
662 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
663 #endif
664}
665
Marat Dukhan1f29b802020-05-15 23:46:39 -0700666static inline union xnn_f32_chw_params xnn_init_scalar_f32_chw_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700667 uint32_t width,
668 float output_min,
669 float output_max)
670{
Marat Dukhan1f29b802020-05-15 23:46:39 -0700671 union xnn_f32_chw_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700672 params.scalar.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700673 params.scalar.max = output_max;
Erich Elsene6214af2020-06-10 22:17:22 -0700674
675 const uint32_t w4 = (width - 1) & 3;
676 params.scalar.mask[0] = INT32_C(0xFFFFFFFF);
677 params.scalar.mask[1] = -(int32_t) (w4 >= 1);
678 params.scalar.mask[2] = -(int32_t) (w4 >= 2);
679 params.scalar.mask[3] = -(int32_t) (w4 >= 3);
Erich Elsenfd7a6e32020-06-11 12:04:44 -0700680
681 const uint32_t w8 = (width - 1) & 7;
682 params.scalar.mask_even[0] = INT32_C(0xFFFFFFFF);
683 params.scalar.mask_even[1] = -(int32_t) (w8 >= 2);
684 params.scalar.mask_even[2] = -(int32_t) (w8 >= 4);
685 params.scalar.mask_even[3] = -(int32_t) (w8 >= 6);
686 params.scalar.mask_odd[0] = -(int32_t) (w8 >= 1);
687 params.scalar.mask_odd[1] = -(int32_t) (w8 >= 3);
688 params.scalar.mask_odd[2] = -(int32_t) (w8 >= 5);
689 params.scalar.mask_odd[3] = -(int32_t) (w8 >= 7);
690
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700691 return params;
692}
693
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700694static inline union xnn_u8_minmax_params xnn_init_u8_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700695 uint8_t output_min,
696 uint8_t output_max)
697{
698 assert(output_min < output_max);
699
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700700 union xnn_u8_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700701 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
702 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700703 params.sse2.min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700704 params.sse2.max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700705 }
706 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700707 params.neon.min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700708 params.neon.max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700709 #else
710 params.scalar.min = (int32_t) (uint32_t) output_min;
711 params.scalar.max = (int32_t) (uint32_t) output_max;
712 #endif
713 return params;
714}
715
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700716static inline union xnn_u8_minmax_params xnn_init_scalar_u8_minmax_params(
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700717 uint8_t output_min,
718 uint8_t output_max)
719{
720 assert(output_min < output_max);
721
Marat Dukhaneb09a6b2020-04-08 17:34:32 -0700722 union xnn_u8_minmax_params params;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700723 params.scalar.min = (int32_t) (uint32_t) output_min;
724 params.scalar.max = (int32_t) (uint32_t) output_max;
725 return params;
726}
727
728static inline union xnn_q8_add_params xnn_init_q8_add_params(
729 uint8_t a_zero_point,
730 uint8_t b_zero_point,
731 uint8_t output_zero_point,
732 float a_output_scale,
733 float b_output_scale,
734 uint8_t output_min,
735 uint8_t output_max)
736{
737 assert(a_output_scale >= 0x1.0p-14f);
738 assert(b_output_scale >= 0x1.0p-14f);
739 assert(a_output_scale < 0x1.0p+8f);
740 assert(b_output_scale < 0x1.0p+8f);
741
742 // Compute requantization parameters.
743 const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
744 assert(max_output_scale >= 0x1.0p-14f);
745 assert(max_output_scale < 0x1.0p+8f);
746 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
747 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
748 // Shift is in [13, 31] range.
749 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
750 assert(shift < 32);
751 assert(shift >= 13);
752
753 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
754
755 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
Marat Dukhanef3e7dc2020-04-13 01:19:56 -0700756 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(a_output_scale * scale_multiplier);
757 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(b_output_scale * scale_multiplier);
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700758 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
759 assert(a_multiplier < UINT32_C(0x00400000));
760 assert(b_multiplier < UINT32_C(0x00400000));
761
762 union xnn_q8_add_params params;
763 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
764 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
765 const uint32_t remainder_threshold = remainder_mask >> 1;
766 const int32_t zero_point_product =
767 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
768 for (uint32_t i = 0; i < 4; i++) {
769 params.sse2.zero_point_product[i] = zero_point_product;
770 }
771 for (uint32_t i = 0; i < 8; i++) {
772 params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
773 }
774 for (uint32_t i = 0; i < 8; i++) {
775 params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
776 params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
777 params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
778 params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
779 }
780 params.sse2.a_multiplier = a_multiplier;
781 params.sse2.b_multiplier = b_multiplier;
782 for (uint32_t i = 0; i < 4; i++) {
783 params.sse2.remainder_mask[i] = remainder_mask;
784 params.sse2.remainder_threshold[i] = remainder_threshold;
785 }
786 params.sse2.shift = shift;
787 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700788 params.sse2.y_min[i] = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700789 params.sse2.y_max[i] = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700790 }
791 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
792 params.neon.a_zero_point = a_zero_point;
793 params.neon.b_zero_point = b_zero_point;
794 params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
795 params.neon.a_multiplier = (int32_t) a_multiplier;
796 params.neon.b_multiplier = (int32_t) b_multiplier;
797 params.neon.right_shift = (int32_t) -shift;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700798 params.neon.y_min = output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700799 params.neon.y_max = output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700800 #else
801 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
802 const uint32_t remainder_threshold = remainder_mask >> 1;
803 params.scalar.zero_point_product =
804 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
805 params.scalar.a_multiplier = a_multiplier;
806 params.scalar.b_multiplier = b_multiplier;
807 params.scalar.remainder_mask = (int32_t) remainder_mask;
808 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
809 params.scalar.shift = shift;
810 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700811 params.scalar.y_min = (int32_t) (uint32_t) output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700812 params.scalar.y_max = (int32_t) (uint32_t) output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700813 #endif
814 return params;
815}
816
817static inline union xnn_q8_add_params xnn_init_scalar_q8_add_params(
818 uint8_t a_zero_point,
819 uint8_t b_zero_point,
820 uint8_t output_zero_point,
821 float a_output_scale,
822 float b_output_scale,
823 uint8_t output_min,
824 uint8_t output_max)
825{
826 assert(a_output_scale >= 0x1.0p-10f);
827 assert(b_output_scale >= 0x1.0p-10f);
828 assert(a_output_scale < 0x1.0p+8f);
829 assert(b_output_scale < 0x1.0p+8f);
830
831 // Compute requantization parameters.
832 const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
833 assert(max_output_scale >= 0x1.0p-10f);
834 assert(max_output_scale < 0x1.0p+8f);
835 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
836 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
837 // Shift is in [13, 31] range.
838 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
839 assert(shift < 32);
840 assert(shift >= 13);
841
842 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
Marat Dukhanef3e7dc2020-04-13 01:19:56 -0700843 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
844 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700845 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
846 assert(a_multiplier < UINT32_C(0x00400000));
847 assert(b_multiplier < UINT32_C(0x00400000));
848
849 union xnn_q8_add_params params;
850 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
851 const uint32_t remainder_threshold = remainder_mask >> 1;
852 params.scalar.zero_point_product =
853 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
854 params.scalar.a_multiplier = a_multiplier;
855 params.scalar.b_multiplier = b_multiplier;
856 params.scalar.remainder_mask = (int32_t) remainder_mask;
857 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
858 params.scalar.shift = shift;
859 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700860 params.scalar.y_min = (int32_t) (uint32_t) output_min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700861 params.scalar.y_max = (int32_t) (uint32_t) output_max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700862 return params;
863}
864
865static inline union xnn_q31_requantization_params xnn_init_scalar_requantization_params(
866 float scale,
867 uint8_t zero_point,
868 uint8_t min,
869 uint8_t max)
870{
871 // Compute requantization parameters.
872 assert(scale < 1.0f);
873 assert(scale >= 0x1.0p-32f);
874 const uint32_t scale_bits = fp32_to_bits(scale);
875
876 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
877 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
878 assert(multiplier >= INT32_C(0x40000000));
879 assert(multiplier <= INT32_C(0x7FFFFF80));
880
881 // Shift is in [0, 31] range.
882 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
883 assert(shift >= 0);
884 assert(shift < 32);
885
886 union xnn_q31_requantization_params params;
887 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
888 const uint32_t remainder_threshold = remainder_mask >> 1;
889 params.scalar.multiplier = multiplier;
890 params.scalar.remainder_mask = (int32_t) remainder_mask;
891 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
892 params.scalar.shift = (uint32_t) shift;
893 params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
894 params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
895 params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
896 return params;
897}
898
899static inline union xnn_q31_requantization_params xnn_init_requantization_params(
900 float scale,
901 uint8_t zero_point,
902 uint8_t min,
903 uint8_t max)
904{
905 // Compute requantization parameters.
906 const uint32_t scale_bits = fp32_to_bits(scale);
907
908 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
909 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
910 assert(multiplier >= INT32_C(0x40000000));
911 assert(multiplier <= INT32_C(0x7FFFFF80));
912
913 // Shift is in [0, 31] range.
914 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
915 assert(shift >= 0);
916 assert(shift < 32);
917
918 union xnn_q31_requantization_params params;
919 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
920 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
921 const uint32_t remainder_threshold = remainder_mask >> 1;
922 params.sse2.multiplier[0] = multiplier;
923 params.sse2.multiplier[1] = multiplier;
924 params.sse2.multiplier[2] = multiplier;
925 params.sse2.multiplier[3] = multiplier;
926 params.sse2.rounding[0] = UINT64_C(0x40000000);
927 params.sse2.rounding[1] = UINT64_C(0x40000000);
928 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
929 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
930 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
931 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
932 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
933 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
934 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
935 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
936 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
937 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
938 for (uint32_t i = 0; i < 8; i++) {
939 params.sse2.zero_point[i] = (int16_t) (uint16_t) zero_point;
940 }
941 for (uint32_t i = 0; i < 16; i++) {
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700942 params.sse2.min[i] = min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700943 params.sse2.max[i] = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700944 }
945 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
946 params.neon.multiplier = multiplier;
947 params.neon.right_shift = -shift;
948 params.neon.zero_point = (int16_t) (uint16_t) zero_point;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700949 params.neon.min = min;
Marat Dukhana51cf482020-04-08 16:16:19 -0700950 params.neon.max = max;
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700951 #else
952 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
953 const uint32_t remainder_threshold = remainder_mask >> 1;
954 params.scalar.multiplier = multiplier;
955 params.scalar.remainder_mask = (int32_t) remainder_mask;
956 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
957 params.scalar.shift = (uint32_t) shift;
958 params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
959 params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
960 params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
961 #endif
962 return params;
963}