blob: bf3e10003f8149c9af833977f16950b5e2ac5df0 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#if defined(__cplusplus) && (__cplusplus >= 201103L)
12 #include <cstdint>
13 #include <cstddef>
14 #include <cassert>
15 #include <cmath>
16#else
17 #include <stdint.h>
18 #include <stddef.h>
19 #include <assert.h>
20 #include <math.h>
21#endif
22
23#include <fp16.h>
24
25#include <xnnpack/params.h>
26#include <xnnpack/scalar-utils.h>
27
28
29static inline union xnn_q8_gemm_params xnn_compute_scalar_q8_gemm_params(
30 uint8_t input_zero_point,
31 uint8_t kernel_zero_point,
32 float scale,
33 uint8_t output_zero_point,
34 uint8_t output_min,
35 uint8_t output_max)
36{
37 /* Compute requantization parameters */
38 const uint32_t scale_bits = fp32_to_bits(scale);
39
40 /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
41 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
42 assert(multiplier >= INT32_C(0x40000000));
43 assert(multiplier <= INT32_C(0x7FFFFF80));
44
45 /* Shift is in [0, 31] range */
46 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
47 assert(shift >= 0);
48 assert(shift < 32);
49
50 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
51 const uint32_t remainder_threshold = remainder_mask >> 1;
52
53 union xnn_q8_gemm_params params;
54 params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
55 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
56 params.scalar.multiplier = multiplier;
57 params.scalar.remainder_mask = (int32_t) remainder_mask;
58 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
59 params.scalar.shift = (uint32_t) shift;
60 params.scalar.output_min_less_zero_point =
61 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
62 params.scalar.output_max_less_zero_point =
63 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
64 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
65 return params;
66}
67
68static inline union xnn_q8_gemm_params xnn_compute_q8_gemm_params(
69 uint8_t input_zero_point,
70 uint8_t kernel_zero_point,
71 float scale,
72 uint8_t output_zero_point,
73 uint8_t output_min,
74 uint8_t output_max)
75{
76 /* Compute requantization parameters */
77 const uint32_t scale_bits = fp32_to_bits(scale);
78
79 /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
80 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
81 assert(multiplier >= INT32_C(0x40000000));
82 assert(multiplier <= INT32_C(0x7FFFFF80));
83
84 /* Shift is in [0, 31] range */
85 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
86 assert(shift >= 0);
87 assert(shift < 32);
88
89 union xnn_q8_gemm_params params;
90 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
91 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
92 const uint32_t remainder_threshold = remainder_mask >> 1;
93 for (uint32_t i = 0; i < 8; i++) {
94 params.sse2.input_zero_point[i] = (int16_t) (uint16_t) input_zero_point;
95 params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
96 }
97 params.sse2.multiplier[0] = multiplier;
98 params.sse2.multiplier[1] = multiplier;
99 params.sse2.multiplier[2] = multiplier;
100 params.sse2.multiplier[3] = multiplier;
101 params.sse2.rounding[0] = UINT64_C(0x40000000);
102 params.sse2.rounding[1] = UINT64_C(0x40000000);
103 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
104 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
105 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
106 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
107 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
108 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
109 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
110 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
111 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
112 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
113 for (uint32_t i = 0; i < 8; i++) {
114 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
115 }
116 for (uint32_t i = 0; i < 16; i++) {
117 params.sse2.output_max[i] = output_max;
118 params.sse2.output_min[i] = output_min;
119 }
120 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
121 params.neon.input_zero_point = (int16_t) (uint16_t) input_zero_point;
122 params.neon.kernel_zero_point = (int16_t) (uint16_t) kernel_zero_point;
123 params.neon.multiplier = multiplier;
124 params.neon.right_shift = -shift;
125 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
126 params.neon.output_max = output_max;
127 params.neon.output_min = output_min;
128 #else
129 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
130 const uint32_t remainder_threshold = remainder_mask >> 1;
131 params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
132 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
133 params.scalar.multiplier = multiplier;
134 params.scalar.remainder_mask = (int32_t) remainder_mask;
135 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
136 params.scalar.shift = (uint32_t) shift;
137 params.scalar.output_min_less_zero_point =
138 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
139 params.scalar.output_max_less_zero_point =
140 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
141 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
142 #endif
143 return params;
144}
145
146static inline union xnn_q8_avgpool_params xnn_compute_q8_avgpool_params(
147 int32_t bias,
148 float scale,
149 uint8_t output_zero_point,
150 uint8_t output_min,
151 uint8_t output_max)
152{
153 /* Compute requantization parameters */
154 assert(scale >= 0x1.0p-32f);
155 assert(scale < 256.0f);
156 const uint32_t scale_bits = fp32_to_bits(scale);
157
158 /* Multiplier is in [0x00800000, 0x00FFFFFF] range */
159 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
160 assert(multiplier >= INT32_C(0x00800000));
161 assert(multiplier <= INT32_C(0x00FFFFFF));
162
163 /* Shift is in [16, 55] range */
164 const int32_t shift = 127 + 23 - (scale_bits >> 23);
165 assert(shift >= 16);
166 assert(shift < 64);
167
168 union xnn_q8_avgpool_params params;
169 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
170 const uint32_t right_shift = (uint32_t) shift;
171 const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
172 params.sse2.bias[0] = bias;
173 params.sse2.bias[1] = bias;
174 params.sse2.bias[2] = bias;
175 params.sse2.bias[3] = bias;
176 params.sse2.multiplier[0] = (uint32_t) multiplier;
177 params.sse2.multiplier[1] = (uint32_t) multiplier;
178 params.sse2.multiplier[2] = (uint32_t) multiplier;
179 params.sse2.multiplier[3] = (uint32_t) multiplier;
180 params.sse2.rounding[0] = rounding;
181 params.sse2.rounding[1] = rounding;
182 params.sse2.right_shift[0] = (uint64_t) right_shift;
183 params.sse2.right_shift[1] = (uint64_t) right_shift;
184 for (uint32_t i = 0; i < 8; i++) {
185 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
186 }
187 for (uint32_t i = 0; i < 16; i++) {
188 params.sse2.output_max[i] = output_max;
189 params.sse2.output_min[i] = output_min;
190 }
191 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
192 params.neon.bias = bias;
193 params.neon.multiplier = multiplier;
194 params.neon.left_shift = (int64_t) -shift;
195 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
196 params.neon.output_max = output_max;
197 params.neon.output_min = output_min;
198 #else
199 const uint32_t right_shift = (uint32_t) shift;
200 const int64_t rounding = INT64_C(1) << (right_shift - 1);
201 params.scalar.bias = bias;
202 params.scalar.multiplier = multiplier;
203 params.scalar.rounding = rounding;
204 params.scalar.right_shift = right_shift;
205 params.scalar.output_min_less_zero_point =
206 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
207 params.scalar.output_max_less_zero_point =
208 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
209 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
210 #endif
211 return params;
212}
213
214static inline union xnn_q8_avgpool_params xnn_compute_scalar_q8_avgpool_params(
215 int32_t bias,
216 float scale,
217 uint8_t output_zero_point,
218 uint8_t output_min,
219 uint8_t output_max)
220{
221 /* Compute requantization parameters */
222 assert(scale >= 0x1.0p-32f);
223 assert(scale < 256.0f);
224 const uint32_t scale_bits = fp32_to_bits(scale);
225
226 /* Multiplier is in [0x00800000, 0x00FFFFFF] range */
227 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
228 assert(multiplier >= INT32_C(0x00800000));
229 assert(multiplier <= INT32_C(0x00FFFFFF));
230
231 /* Shift is in [16, 55] range */
232 const int32_t shift = 127 + 23 - (scale_bits >> 23);
233 assert(shift >= 16);
234 assert(shift < 64);
235
236 union xnn_q8_avgpool_params params;
237 const uint32_t right_shift = (uint32_t) shift;
238 const int64_t rounding = INT64_C(1) << (right_shift - 1);
239 params.scalar.bias = bias;
240 params.scalar.rounding = rounding;
241 params.scalar.multiplier = multiplier;
242 params.scalar.right_shift = right_shift;
243 params.scalar.output_min_less_zero_point =
244 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
245 params.scalar.output_max_less_zero_point =
246 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
247 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
248 return params;
249}
250
251static inline void xnn_update_f32_avgpool_params(
252 union xnn_f32_avgpool_params* params,
253 float multiplier)
254{
255 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
256 for (uint32_t i = 0; i < 4; i++) {
257 params->sse2.multiplier[i] = multiplier;
258 }
259 #else
260 params->scalar.multiplier = multiplier;
261 #endif
262}
263
264static inline union xnn_f32_avgpool_params xnn_compute_f32_avgpool_params(
265 float multiplier,
266 float output_min,
267 float output_max)
268{
269 union xnn_f32_avgpool_params params;
270#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
271 for (uint32_t i = 0; i < 4; i++) {
272 params.sse2.multiplier[i] = multiplier;
273 params.sse2.output_min[i] = output_min;
274 params.sse2.output_max[i] = output_max;
275 }
276#else
277 params.scalar.multiplier = multiplier;
278 params.scalar.output_min = output_min;
279 params.scalar.output_max = output_max;
280#endif
281return params;
282}
283
284static inline union xnn_f32_gavgpool_params xnn_compute_f32_gavgpool_params(
285 float multiplier,
286 float output_min,
287 float output_max,
288 uint32_t width)
289{
290 union xnn_f32_gavgpool_params params;
291 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
292 for (uint32_t i = 0; i < 4; i++) {
293 params.sse.multiplier[i] = multiplier;
294 params.sse.output_min[i] = output_min;
295 params.sse.output_max[i] = output_max;
296 }
297 switch (width % 4) {
298 case 0:
299 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
300 params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
301 params.sse.mask[2] = UINT32_C(0xFFFFFFFF);
302 params.sse.mask[3] = UINT32_C(0xFFFFFFFF);
303 break;
304 case 1:
305 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
306 params.sse.mask[1] = 0;
307 params.sse.mask[2] = 0;
308 params.sse.mask[3] = 0;
309 break;
310 case 2:
311 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
312 params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
313 params.sse.mask[2] = 0;
314 params.sse.mask[3] = 0;
315 break;
316 case 3:
317 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
318 params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
319 params.sse.mask[2] = UINT32_C(0xFFFFFFFF);
320 params.sse.mask[3] = 0;
321 break;
322 }
323#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
324 switch (width % 4) {
325 case 0:
326 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
327 params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
328 params.neon.mask[2] = UINT32_C(0xFFFFFFFF);
329 params.neon.mask[3] = UINT32_C(0xFFFFFFFF);
330 break;
331 case 1:
332 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
333 params.neon.mask[1] = 0;
334 params.neon.mask[2] = 0;
335 params.neon.mask[3] = 0;
336 break;
337 case 2:
338 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
339 params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
340 params.neon.mask[2] = 0;
341 params.neon.mask[3] = 0;
342 break;
343 case 3:
344 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
345 params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
346 params.neon.mask[2] = UINT32_C(0xFFFFFFFF);
347 params.neon.mask[3] = 0;
348 break;
349 }
350 params.neon.multiplier = multiplier;
351 params.neon.output_min = output_min;
352 params.neon.output_max = output_max;
353 #else
354 params.scalar.multiplier = multiplier;
355 params.scalar.output_min = output_min;
356 params.scalar.output_max = output_max;
357 #endif
358 return params;
359}
360
361static inline void xnn_update_f32_gavgpool_params(
362 union xnn_f32_gavgpool_params* params,
363 float multiplier,
364 uint32_t width)
365{
366 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
367 for (uint32_t i = 0; i < 4; i++) {
368 params->sse.multiplier[i] = multiplier;
369 }
370 switch (width % 4) {
371 case 0:
372 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
373 params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
374 params->sse.mask[2] = UINT32_C(0xFFFFFFFF);
375 params->sse.mask[3] = UINT32_C(0xFFFFFFFF);
376 break;
377 case 1:
378 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
379 params->sse.mask[1] = 0;
380 params->sse.mask[2] = 0;
381 params->sse.mask[3] = 0;
382 break;
383 case 2:
384 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
385 params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
386 params->sse.mask[2] = 0;
387 params->sse.mask[3] = 0;
388 break;
389 case 3:
390 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
391 params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
392 params->sse.mask[2] = UINT32_C(0xFFFFFFFF);
393 params->sse.mask[3] = 0;
394 break;
395 }
396 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
397 params->neon.multiplier = multiplier;
398 switch (width % 4) {
399 case 0:
400 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
401 params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
402 params->neon.mask[2] = UINT32_C(0xFFFFFFFF);
403 params->neon.mask[3] = UINT32_C(0xFFFFFFFF);
404 break;
405 case 1:
406 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
407 params->neon.mask[1] = 0;
408 params->neon.mask[2] = 0;
409 params->neon.mask[3] = 0;
410 break;
411 case 2:
412 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
413 params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
414 params->neon.mask[2] = 0;
415 params->neon.mask[3] = 0;
416 break;
417 case 3:
418 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
419 params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
420 params->neon.mask[2] = UINT32_C(0xFFFFFFFF);
421 params->neon.mask[3] = 0;
422 break;
423 }
424 #endif
425}
426
427static inline union xnn_f32_avgpool_params xnn_compute_scalar_f32_avgpool_params(
428 float multiplier,
429 float output_min,
430 float output_max)
431{
432 union xnn_f32_avgpool_params params;
433 params.scalar.multiplier = multiplier;
434 params.scalar.output_min = output_min;
435 params.scalar.output_max = output_max;
436 return params;
437}
438
439static inline union xnn_f32_gavgpool_params xnn_compute_scalar_f32_gavgpool_params(
440 float multiplier,
441 float output_min,
442 float output_max,
443 uint32_t width)
444{
445 union xnn_f32_gavgpool_params params;
446 params.scalar.multiplier = multiplier;
447 params.scalar.output_min = output_min;
448 params.scalar.output_max = output_max;
449 return params;
450}
451
452static inline union xnn_f32_output_params xnn_compute_f32_output_params(
453 float output_min,
454 float output_max)
455{
456 union xnn_f32_output_params params;
457#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
458 for (uint32_t i = 0; i < 4; i++) {
459 params.sse.min[i] = output_min;
460 params.sse.max[i] = output_max;
461 }
462#else
463 params.scalar.min = output_min;
464 params.scalar.max = output_max;
465#endif
466 return params;
467}
468
469static inline union xnn_f32_output_params xnn_compute_scalar_f32_output_params(
470 float output_min,
471 float output_max)
472{
473 union xnn_f32_output_params params;
474 params.scalar.min = output_min;
475 params.scalar.max = output_max;
476 return params;
477}
478
479static inline union xnn_f32_hswish_params xnn_compute_f32_hswish_params(void)
480{
481 union xnn_f32_hswish_params params;
482#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
483 for (uint32_t i = 0; i < 4; i++) {
484 params.sse.sixth[i] = 0x1.555556p-3f;
485 params.sse.half[i] = 0.5f;
486 params.sse.one[i] = 1.0f;
487 }
488#else
489 params.scalar.sixth = 0x1.555556p-3f;
490 params.scalar.half = 0.5f;
491 params.scalar.one = 1.0f;
492#endif
493 return params;
494}
495
496static inline union xnn_f32_hswish_params xnn_compute_scalar_f32_hswish_params(void)
497{
498 union xnn_f32_hswish_params params;
499 params.scalar.sixth = 0x1.555556p-3f;
500 params.scalar.half = 0.5f;
501 params.scalar.one = 1.0f;
502 return params;
503}
504
505static inline union xnn_f32_spchw_params xnn_compute_f32_spchw_params(
506 uint32_t width,
507 float output_min,
508 float output_max)
509{
510 union xnn_f32_spchw_params params;
511#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
512 switch (width % 4) {
513 case 0:
514 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
515 params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
516 params.sse.mask[2] = UINT32_C(0xFFFFFFFF);
517 params.sse.mask[3] = UINT32_C(0xFFFFFFFF);
518 break;
519 case 1:
520 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
521 params.sse.mask[1] = 0;
522 params.sse.mask[2] = 0;
523 params.sse.mask[3] = 0;
524 break;
525 case 2:
526 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
527 params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
528 params.sse.mask[2] = 0;
529 params.sse.mask[3] = 0;
530 break;
531 case 3:
532 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
533 params.sse.mask[1] = UINT32_C(0xFFFFFFFF);
534 params.sse.mask[2] = UINT32_C(0xFFFFFFFF);
535 params.sse.mask[3] = 0;
536 break;
537 }
538 switch (width % 8) {
539 case 0:
540 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
541 params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
542 params.sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
543 params.sse.mask_even[3] = UINT32_C(0xFFFFFFFF);
544 params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
545 params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
546 params.sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
547 params.sse.mask_odd[3] = UINT32_C(0xFFFFFFFF);
548 break;
549 case 1:
550 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
551 params.sse.mask_even[1] = 0;
552 params.sse.mask_even[2] = 0;
553 params.sse.mask_even[3] = 0;
554 params.sse.mask_odd[0] = 0;
555 params.sse.mask_odd[1] = 0;
556 params.sse.mask_odd[2] = 0;
557 params.sse.mask_odd[3] = 0;
558 break;
559 case 2:
560 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
561 params.sse.mask_even[1] = 0;
562 params.sse.mask_even[2] = 0;
563 params.sse.mask_even[3] = 0;
564 params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
565 params.sse.mask_odd[1] = 0;
566 params.sse.mask_odd[2] = 0;
567 params.sse.mask_odd[3] = 0;
568 break;
569 case 3:
570 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
571 params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
572 params.sse.mask_even[2] = 0;
573 params.sse.mask_even[3] = 0;
574 params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
575 params.sse.mask_odd[1] = 0;
576 params.sse.mask_odd[2] = 0;
577 params.sse.mask_odd[3] = 0;
578 break;
579 case 4:
580 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
581 params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
582 params.sse.mask_even[2] = 0;
583 params.sse.mask_even[3] = 0;
584 params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
585 params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
586 params.sse.mask_odd[2] = 0;
587 params.sse.mask_odd[3] = 0;
588 break;
589 case 5:
590 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
591 params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
592 params.sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
593 params.sse.mask_even[3] = 0;
594 params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
595 params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
596 params.sse.mask_odd[2] = 0;
597 params.sse.mask_odd[3] = 0;
598 break;
599 case 6:
600 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
601 params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
602 params.sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
603 params.sse.mask_even[3] = 0;
604 params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
605 params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
606 params.sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
607 params.sse.mask_odd[3] = 0;
608 break;
609 case 7:
610 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
611 params.sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
612 params.sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
613 params.sse.mask_even[3] = UINT32_C(0xFFFFFFFF);
614 params.sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
615 params.sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
616 params.sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
617 params.sse.mask_odd[3] = 0;
618 break;
619 }
620 for (uint32_t i = 0; i < 4; i++) {
621 params.sse.max[i] = output_max;
622 params.sse.min[i] = output_min;
623 }
624#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
625 switch (width % 4) {
626 case 0:
627 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
628 params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
629 params.neon.mask[2] = UINT32_C(0xFFFFFFFF);
630 params.neon.mask[3] = UINT32_C(0xFFFFFFFF);
631 break;
632 case 1:
633 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
634 params.neon.mask[1] = 0;
635 params.neon.mask[2] = 0;
636 params.neon.mask[3] = 0;
637 break;
638 case 2:
639 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
640 params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
641 params.neon.mask[2] = 0;
642 params.neon.mask[3] = 0;
643 break;
644 case 3:
645 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
646 params.neon.mask[1] = UINT32_C(0xFFFFFFFF);
647 params.neon.mask[2] = UINT32_C(0xFFFFFFFF);
648 params.neon.mask[3] = 0;
649 break;
650 }
651 switch (width % 8) {
652 case 0:
653 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
654 params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
655 params.neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
656 params.neon.mask_even[3] = UINT32_C(0xFFFFFFFF);
657 params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
658 params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
659 params.neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
660 params.neon.mask_odd[3] = UINT32_C(0xFFFFFFFF);
661 break;
662 case 1:
663 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
664 params.neon.mask_even[1] = 0;
665 params.neon.mask_even[2] = 0;
666 params.neon.mask_even[3] = 0;
667 params.neon.mask_odd[0] = 0;
668 params.neon.mask_odd[1] = 0;
669 params.neon.mask_odd[2] = 0;
670 params.neon.mask_odd[3] = 0;
671 break;
672 case 2:
673 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
674 params.neon.mask_even[1] = 0;
675 params.neon.mask_even[2] = 0;
676 params.neon.mask_even[3] = 0;
677 params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
678 params.neon.mask_odd[1] = 0;
679 params.neon.mask_odd[2] = 0;
680 params.neon.mask_odd[3] = 0;
681 break;
682 case 3:
683 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
684 params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
685 params.neon.mask_even[2] = 0;
686 params.neon.mask_even[3] = 0;
687 params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
688 params.neon.mask_odd[1] = 0;
689 params.neon.mask_odd[2] = 0;
690 params.neon.mask_odd[3] = 0;
691 break;
692 case 4:
693 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
694 params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
695 params.neon.mask_even[2] = 0;
696 params.neon.mask_even[3] = 0;
697 params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
698 params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
699 params.neon.mask_odd[2] = 0;
700 params.neon.mask_odd[3] = 0;
701 break;
702 case 5:
703 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
704 params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
705 params.neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
706 params.neon.mask_even[3] = 0;
707 params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
708 params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
709 params.neon.mask_odd[2] = 0;
710 params.neon.mask_odd[3] = 0;
711 break;
712 case 6:
713 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
714 params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
715 params.neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
716 params.neon.mask_even[3] = 0;
717 params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
718 params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
719 params.neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
720 params.neon.mask_odd[3] = 0;
721 break;
722 case 7:
723 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
724 params.neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
725 params.neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
726 params.neon.mask_even[3] = UINT32_C(0xFFFFFFFF);
727 params.neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
728 params.neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
729 params.neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
730 params.neon.mask_odd[3] = 0;
731 break;
732 }
733 params.neon.max = output_max;
734 params.neon.min = output_min;
735#else
736 params.scalar.max = output_max;
737 params.scalar.min = output_min;
738#endif
739 return params;
740}
741
742static inline void xnn_update_f32_spchw_params(
743 union xnn_f32_spchw_params* params,
744 uint32_t width)
745{
746 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
747 switch (width % 4) {
748 case 0:
749 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
750 params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
751 params->sse.mask[2] = UINT32_C(0xFFFFFFFF);
752 params->sse.mask[3] = UINT32_C(0xFFFFFFFF);
753 break;
754 case 1:
755 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
756 params->sse.mask[1] = 0;
757 params->sse.mask[2] = 0;
758 params->sse.mask[3] = 0;
759 break;
760 case 2:
761 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
762 params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
763 params->sse.mask[2] = 0;
764 params->sse.mask[3] = 0;
765 break;
766 case 3:
767 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
768 params->sse.mask[1] = UINT32_C(0xFFFFFFFF);
769 params->sse.mask[2] = UINT32_C(0xFFFFFFFF);
770 params->sse.mask[3] = 0;
771 break;
772 }
773 switch (width % 8) {
774 case 0:
775 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
776 params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
777 params->sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
778 params->sse.mask_even[3] = UINT32_C(0xFFFFFFFF);
779 params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
780 params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
781 params->sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
782 params->sse.mask_odd[3] = UINT32_C(0xFFFFFFFF);
783 break;
784 case 1:
785 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
786 params->sse.mask_even[1] = 0;
787 params->sse.mask_even[2] = 0;
788 params->sse.mask_even[3] = 0;
789 params->sse.mask_odd[0] = 0;
790 params->sse.mask_odd[1] = 0;
791 params->sse.mask_odd[2] = 0;
792 params->sse.mask_odd[3] = 0;
793 break;
794 case 2:
795 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
796 params->sse.mask_even[1] = 0;
797 params->sse.mask_even[2] = 0;
798 params->sse.mask_even[3] = 0;
799 params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
800 params->sse.mask_odd[1] = 0;
801 params->sse.mask_odd[2] = 0;
802 params->sse.mask_odd[3] = 0;
803 break;
804 case 3:
805 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
806 params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
807 params->sse.mask_even[2] = 0;
808 params->sse.mask_even[3] = 0;
809 params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
810 params->sse.mask_odd[1] = 0;
811 params->sse.mask_odd[2] = 0;
812 params->sse.mask_odd[3] = 0;
813 break;
814 case 4:
815 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
816 params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
817 params->sse.mask_even[2] = 0;
818 params->sse.mask_even[3] = 0;
819 params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
820 params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
821 params->sse.mask_odd[2] = 0;
822 params->sse.mask_odd[3] = 0;
823 break;
824 case 5:
825 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
826 params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
827 params->sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
828 params->sse.mask_even[3] = 0;
829 params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
830 params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
831 params->sse.mask_odd[2] = 0;
832 params->sse.mask_odd[3] = 0;
833 break;
834 case 6:
835 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
836 params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
837 params->sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
838 params->sse.mask_even[3] = 0;
839 params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
840 params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
841 params->sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
842 params->sse.mask_odd[3] = 0;
843 break;
844 case 7:
845 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
846 params->sse.mask_even[1] = UINT32_C(0xFFFFFFFF);
847 params->sse.mask_even[2] = UINT32_C(0xFFFFFFFF);
848 params->sse.mask_even[3] = UINT32_C(0xFFFFFFFF);
849 params->sse.mask_odd[0] = UINT32_C(0xFFFFFFFF);
850 params->sse.mask_odd[1] = UINT32_C(0xFFFFFFFF);
851 params->sse.mask_odd[2] = UINT32_C(0xFFFFFFFF);
852 params->sse.mask_odd[3] = 0;
853 break;
854 }
855 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
856 switch (width % 4) {
857 case 0:
858 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
859 params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
860 params->neon.mask[2] = UINT32_C(0xFFFFFFFF);
861 params->neon.mask[3] = UINT32_C(0xFFFFFFFF);
862 break;
863 case 1:
864 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
865 params->neon.mask[1] = 0;
866 params->neon.mask[2] = 0;
867 params->neon.mask[3] = 0;
868 break;
869 case 2:
870 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
871 params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
872 params->neon.mask[2] = 0;
873 params->neon.mask[3] = 0;
874 break;
875 case 3:
876 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
877 params->neon.mask[1] = UINT32_C(0xFFFFFFFF);
878 params->neon.mask[2] = UINT32_C(0xFFFFFFFF);
879 params->neon.mask[3] = 0;
880 break;
881 }
882 switch (width % 8) {
883 case 0:
884 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
885 params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
886 params->neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
887 params->neon.mask_even[3] = UINT32_C(0xFFFFFFFF);
888 params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
889 params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
890 params->neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
891 params->neon.mask_odd[3] = UINT32_C(0xFFFFFFFF);
892 break;
893 case 1:
894 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
895 params->neon.mask_even[1] = 0;
896 params->neon.mask_even[2] = 0;
897 params->neon.mask_even[3] = 0;
898 params->neon.mask_odd[0] = 0;
899 params->neon.mask_odd[1] = 0;
900 params->neon.mask_odd[2] = 0;
901 params->neon.mask_odd[3] = 0;
902 break;
903 case 2:
904 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
905 params->neon.mask_even[1] = 0;
906 params->neon.mask_even[2] = 0;
907 params->neon.mask_even[3] = 0;
908 params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
909 params->neon.mask_odd[1] = 0;
910 params->neon.mask_odd[2] = 0;
911 params->neon.mask_odd[3] = 0;
912 break;
913 case 3:
914 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
915 params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
916 params->neon.mask_even[2] = 0;
917 params->neon.mask_even[3] = 0;
918 params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
919 params->neon.mask_odd[1] = 0;
920 params->neon.mask_odd[2] = 0;
921 params->neon.mask_odd[3] = 0;
922 break;
923 case 4:
924 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
925 params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
926 params->neon.mask_even[2] = 0;
927 params->neon.mask_even[3] = 0;
928 params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
929 params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
930 params->neon.mask_odd[2] = 0;
931 params->neon.mask_odd[3] = 0;
932 break;
933 case 5:
934 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
935 params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
936 params->neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
937 params->neon.mask_even[3] = 0;
938 params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
939 params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
940 params->neon.mask_odd[2] = 0;
941 params->neon.mask_odd[3] = 0;
942 break;
943 case 6:
944 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
945 params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
946 params->neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
947 params->neon.mask_even[3] = 0;
948 params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
949 params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
950 params->neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
951 params->neon.mask_odd[3] = 0;
952 break;
953 case 7:
954 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
955 params->neon.mask_even[1] = UINT32_C(0xFFFFFFFF);
956 params->neon.mask_even[2] = UINT32_C(0xFFFFFFFF);
957 params->neon.mask_even[3] = UINT32_C(0xFFFFFFFF);
958 params->neon.mask_odd[0] = UINT32_C(0xFFFFFFFF);
959 params->neon.mask_odd[1] = UINT32_C(0xFFFFFFFF);
960 params->neon.mask_odd[2] = UINT32_C(0xFFFFFFFF);
961 params->neon.mask_odd[3] = 0;
962 break;
963 }
964 #endif
965}
966
967static inline union xnn_f32_spchw_params xnn_compute_scalar_f32_spchw_params(
968 uint32_t width,
969 float output_min,
970 float output_max)
971{
972 union xnn_f32_spchw_params params;
973 params.scalar.max = output_max;
974 params.scalar.min = output_min;
975 return params;
976}
977
978static inline union xnn_u8_output_params xnn_compute_u8_output_params(
979 uint8_t output_min,
980 uint8_t output_max)
981{
982 assert(output_min < output_max);
983
984 union xnn_u8_output_params params;
985 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
986 for (uint32_t i = 0; i < 16; i++) {
987 params.sse2.max[i] = output_max;
988 params.sse2.min[i] = output_min;
989 }
990 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
991 params.neon.max = output_max;
992 params.neon.min = output_min;
993 #else
994 params.scalar.min = (int32_t) (uint32_t) output_min;
995 params.scalar.max = (int32_t) (uint32_t) output_max;
996 #endif
997 return params;
998}
999
1000static inline union xnn_u8_output_params xnn_compute_scalar_u8_output_params(
1001 uint8_t output_min,
1002 uint8_t output_max)
1003{
1004 assert(output_min < output_max);
1005
1006 union xnn_u8_output_params params;
1007 params.scalar.min = (int32_t) (uint32_t) output_min;
1008 params.scalar.max = (int32_t) (uint32_t) output_max;
1009 return params;
1010}
1011
1012static inline union xnn_q8_add_params xnn_compute_q8_add_params(
1013 uint8_t a_zero_point,
1014 uint8_t b_zero_point,
1015 uint8_t output_zero_point,
1016 float a_output_scale,
1017 float b_output_scale,
1018 uint8_t output_min,
1019 uint8_t output_max)
1020{
1021 assert(a_output_scale >= 0x1.0p-14f);
1022 assert(b_output_scale >= 0x1.0p-14f);
1023 assert(a_output_scale < 0x1.0p+8f);
1024 assert(b_output_scale < 0x1.0p+8f);
1025
1026 /* Compute requantization parameters */
1027 const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
1028 assert(max_output_scale >= 0x1.0p-14f);
1029 assert(max_output_scale < 0x1.0p+8f);
1030 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1031 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1032 /* Shift is in [13, 31] range */
1033 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1034 assert(shift < 32);
1035 assert(shift >= 13);
1036
1037 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1038
1039 /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range */
1040 const uint32_t a_multiplier = (uint32_t) (int32_t) __builtin_lrintf(a_output_scale * scale_multiplier);
1041 const uint32_t b_multiplier = (uint32_t) (int32_t) __builtin_lrintf(b_output_scale * scale_multiplier);
1042 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1043 assert(a_multiplier < UINT32_C(0x00400000));
1044 assert(b_multiplier < UINT32_C(0x00400000));
1045
1046 union xnn_q8_add_params params;
1047 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
1048 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1049 const uint32_t remainder_threshold = remainder_mask >> 1;
1050 const int32_t zero_point_product =
1051 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1052 for (uint32_t i = 0; i < 4; i++) {
1053 params.sse2.zero_point_product[i] = zero_point_product;
1054 }
1055 for (uint32_t i = 0; i < 8; i++) {
1056 params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
1057 }
1058 for (uint32_t i = 0; i < 8; i++) {
1059 params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
1060 params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
1061 params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
1062 params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
1063 }
1064 params.sse2.a_multiplier = a_multiplier;
1065 params.sse2.b_multiplier = b_multiplier;
1066 for (uint32_t i = 0; i < 4; i++) {
1067 params.sse2.remainder_mask[i] = remainder_mask;
1068 params.sse2.remainder_threshold[i] = remainder_threshold;
1069 }
1070 params.sse2.shift = shift;
1071 for (uint32_t i = 0; i < 16; i++) {
1072 params.sse2.y_max[i] = output_max;
1073 params.sse2.y_min[i] = output_min;
1074 }
1075 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
1076 params.neon.a_zero_point = a_zero_point;
1077 params.neon.b_zero_point = b_zero_point;
1078 params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
1079 params.neon.a_multiplier = (int32_t) a_multiplier;
1080 params.neon.b_multiplier = (int32_t) b_multiplier;
1081 params.neon.right_shift = (int32_t) -shift;
1082 params.neon.y_max = output_max;
1083 params.neon.y_min = output_min;
1084 #else
1085 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1086 const uint32_t remainder_threshold = remainder_mask >> 1;
1087 params.scalar.zero_point_product =
1088 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1089 params.scalar.a_multiplier = a_multiplier;
1090 params.scalar.b_multiplier = b_multiplier;
1091 params.scalar.remainder_mask = (int32_t) remainder_mask;
1092 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1093 params.scalar.shift = shift;
1094 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1095 params.scalar.y_max = (int32_t) (uint32_t) output_max;
1096 params.scalar.y_min = (int32_t) (uint32_t) output_min;
1097 #endif
1098 return params;
1099}
1100
1101static inline union xnn_q8_add_params xnn_compute_scalar_q8_add_params(
1102 uint8_t a_zero_point,
1103 uint8_t b_zero_point,
1104 uint8_t output_zero_point,
1105 float a_output_scale,
1106 float b_output_scale,
1107 uint8_t output_min,
1108 uint8_t output_max)
1109{
1110 assert(a_output_scale >= 0x1.0p-10f);
1111 assert(b_output_scale >= 0x1.0p-10f);
1112 assert(a_output_scale < 0x1.0p+8f);
1113 assert(b_output_scale < 0x1.0p+8f);
1114
1115 /* Compute requantization parameters */
1116 const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
1117 assert(max_output_scale >= 0x1.0p-10f);
1118 assert(max_output_scale < 0x1.0p+8f);
1119 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1120 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1121 /* Shift is in [13, 31] range */
1122 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1123 assert(shift < 32);
1124 assert(shift >= 13);
1125
1126 /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range */
1127 const uint32_t a_multiplier = (uint32_t) (int32_t) __builtin_lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1128 const uint32_t b_multiplier = (uint32_t) (int32_t) __builtin_lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
1129 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1130 assert(a_multiplier < UINT32_C(0x00400000));
1131 assert(b_multiplier < UINT32_C(0x00400000));
1132
1133 union xnn_q8_add_params params;
1134 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1135 const uint32_t remainder_threshold = remainder_mask >> 1;
1136 params.scalar.zero_point_product =
1137 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1138 params.scalar.a_multiplier = a_multiplier;
1139 params.scalar.b_multiplier = b_multiplier;
1140 params.scalar.remainder_mask = (int32_t) remainder_mask;
1141 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1142 params.scalar.shift = shift;
1143 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1144 params.scalar.y_max = (int32_t) (uint32_t) output_max;
1145 params.scalar.y_min = (int32_t) (uint32_t) output_min;
1146 return params;
1147}
1148
1149static inline union xnn_q31_requantization_params xnn_compute_scalar_requantization_params(
1150 float scale,
1151 uint8_t zero_point,
1152 uint8_t min,
1153 uint8_t max)
1154{
1155 /* Compute requantization parameters */
1156 assert(scale < 1.0f);
1157 assert(scale >= 0x1.0p-32f);
1158 const uint32_t scale_bits = fp32_to_bits(scale);
1159
1160 /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
1161 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1162 assert(multiplier >= INT32_C(0x40000000));
1163 assert(multiplier <= INT32_C(0x7FFFFF80));
1164
1165 /* Shift is in [0, 31] range */
1166 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1167 assert(shift >= 0);
1168 assert(shift < 32);
1169
1170 union xnn_q31_requantization_params params;
1171 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1172 const uint32_t remainder_threshold = remainder_mask >> 1;
1173 params.scalar.multiplier = multiplier;
1174 params.scalar.remainder_mask = (int32_t) remainder_mask;
1175 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1176 params.scalar.shift = (uint32_t) shift;
1177 params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
1178 params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
1179 params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
1180 return params;
1181}
1182
1183static inline union xnn_q31_requantization_params xnn_compute_requantization_params(
1184 float scale,
1185 uint8_t zero_point,
1186 uint8_t min,
1187 uint8_t max)
1188{
1189 /* Compute requantization parameters */
1190 const uint32_t scale_bits = fp32_to_bits(scale);
1191
1192 /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
1193 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1194 assert(multiplier >= INT32_C(0x40000000));
1195 assert(multiplier <= INT32_C(0x7FFFFF80));
1196
1197 /* Shift is in [0, 31] range */
1198 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1199 assert(shift >= 0);
1200 assert(shift < 32);
1201
1202 union xnn_q31_requantization_params params;
1203 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
1204 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1205 const uint32_t remainder_threshold = remainder_mask >> 1;
1206 params.sse2.multiplier[0] = multiplier;
1207 params.sse2.multiplier[1] = multiplier;
1208 params.sse2.multiplier[2] = multiplier;
1209 params.sse2.multiplier[3] = multiplier;
1210 params.sse2.rounding[0] = UINT64_C(0x40000000);
1211 params.sse2.rounding[1] = UINT64_C(0x40000000);
1212 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
1213 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
1214 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
1215 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
1216 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
1217 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
1218 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
1219 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
1220 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
1221 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
1222 for (uint32_t i = 0; i < 8; i++) {
1223 params.sse2.zero_point[i] = (int16_t) (uint16_t) zero_point;
1224 }
1225 for (uint32_t i = 0; i < 16; i++) {
1226 params.sse2.max[i] = max;
1227 params.sse2.min[i] = min;
1228 }
1229 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
1230 params.neon.multiplier = multiplier;
1231 params.neon.right_shift = -shift;
1232 params.neon.zero_point = (int16_t) (uint16_t) zero_point;
1233 params.neon.max = max;
1234 params.neon.min = min;
1235 #else
1236 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1237 const uint32_t remainder_threshold = remainder_mask >> 1;
1238 params.scalar.multiplier = multiplier;
1239 params.scalar.remainder_mask = (int32_t) remainder_mask;
1240 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1241 params.scalar.shift = (uint32_t) shift;
1242 params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
1243 params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
1244 params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
1245 #endif
1246 return params;
1247}
1248
1249static inline uint8_t xnn_q31_requantize(
1250 int32_t n,
1251 union xnn_q31_requantization_params params)
1252{
1253 const int64_t product = (int64_t) n * (int64_t) params.scalar.multiplier;
1254 const int32_t q31product = (int32_t) (uint32_t) ((uint64_t) (product + INT64_C(0x40000000)) >> 31);
1255 const int32_t remainder = (q31product & params.scalar.remainder_mask) - (int32_t) (n < 0);
1256 n = asr_s32(q31product, params.scalar.shift) + (int32_t) (remainder > params.scalar.remainder_threshold);
1257 if (n < params.scalar.min_less_zero_point) {
1258 n = params.scalar.min_less_zero_point;
1259 }
1260 if (n > params.scalar.max_less_zero_point) {
1261 n = params.scalar.max_less_zero_point;
1262 }
1263
1264 return (uint8_t) (n + params.scalar.zero_point);
1265}
1266
1267static inline uint8_t xnn_avgpool_quantize(
1268 int32_t n,
1269 union xnn_q8_avgpool_params params)
1270{
1271 const int64_t product = (int64_t) n * (int64_t) params.scalar.multiplier;
1272 const int64_t adjusted_product = product - (int64_t) (n < 0);
1273
1274 n = (int32_t) asr_s64(adjusted_product + params.scalar.rounding, params.scalar.right_shift);
1275 if (n < params.scalar.output_min_less_zero_point) {
1276 n = params.scalar.output_min_less_zero_point;
1277 }
1278 if (n > params.scalar.output_max_less_zero_point) {
1279 n = params.scalar.output_max_less_zero_point;
1280 }
1281
1282 return (uint8_t) (n + params.scalar.output_zero_point);
1283}
1284
1285static inline uint8_t xnn_add_quantize(
1286 uint8_t a, uint8_t b,
1287 union xnn_q8_add_params params)
1288{
1289 /* Multiply by factors and accumulate products */
1290 int32_t acc = params.scalar.zero_point_product +
1291 (int32_t) ((uint32_t) a * params.scalar.a_multiplier) +
1292 (int32_t) ((uint32_t) b * params.scalar.b_multiplier);
1293
1294 /* Shift right and round */
1295 const int32_t rem = (acc & params.scalar.remainder_mask) - (int32_t) (acc < 0);
1296 acc = asr_s32(acc, params.scalar.shift) + (int32_t) (rem > params.scalar.remainder_threshold);
1297
1298 /* Clamp and add output zero point */
1299 int32_t y = acc + params.scalar.y_zero_point;
1300 if (y >= params.scalar.y_max) {
1301 y = params.scalar.y_max;
1302 }
1303 if (y <= params.scalar.y_min) {
1304 y = params.scalar.y_min;
1305 }
1306 return (uint8_t) y;
1307}