Fuse rounding term into bias in QS8 & QU8 VADD[C] microkernels
PiperOrigin-RevId: 395397695
diff --git a/src/qs8-vaddc/avx512skx-mul32-ld128.c.in b/src/qs8-vaddc/avx512skx-mul32-ld128.c.in
index 6095515..0726468 100644
--- a/src/qs8-vaddc/avx512skx-mul32-ld128.c.in
+++ b/src/qs8-vaddc/avx512skx-mul32-ld128.c.in
@@ -31,7 +31,6 @@
const union xnn_${DATATYPE.lower()}_addsub_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
{
const __m512i va_multiplier = _mm512_load_si512(params->avx512.a_multiplier);
- const __m512i vrounding = _mm512_load_si512(params->avx512.rounding);
const __m128i vshift = _mm_loadu_si32(params->avx512.shift);
$if BATCH_TILE > 16:
const __m512i voutput_zero_point = _mm512_load_si512(params->avx512.output_zero_point);
@@ -55,7 +54,7 @@
__m512i vacc${ABC[N:N+16]} = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va${ABC[N:N+16]}, va_multiplier));
$for N in range(0, BATCH_TILE, 16):
- vacc${ABC[N:N+16]} = _mm512_sra_epi32(_mm512_add_epi32(vacc${ABC[N:N+16]}, vrounding), vshift);
+ vacc${ABC[N:N+16]} = _mm512_sra_epi32(vacc${ABC[N:N+16]}, vshift);
$for N in range(0, BATCH_TILE, 32):
$if N + 16 < BATCH_TILE:
@@ -109,7 +108,7 @@
__m512i vacc${ABC[0:16]} = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va${ABC[0:16]}, va_multiplier));
- vacc${ABC[0:16]} = _mm512_sra_epi32(_mm512_add_epi32(vacc${ABC[0:16]}, vrounding), vshift);
+ vacc${ABC[0:16]} = _mm512_sra_epi32(vacc${ABC[0:16]}, vshift);
$if BATCH_TILE > 16:
__m256i vout${ABC[0:4]}${ABC[8:12]}${ABC[4:8]}${ABC[12:16]} = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc${ABC[0:16]}), _mm512_extracti32x8_epi32(vacc${ABC[0:16]}, 1)), _mm512_castsi512_si256(voutput_zero_point));