Refactor pre-SSE4 versions of QS8/QC8 GEMM/IGEMM microkernels
- Replace sign-extension from 8 to 16 bits with a more efficient sequence
- Replace casts to uintptr_t to casts to typed pointer types where possible
PiperOrigin-RevId: 382654408
diff --git a/src/qc8-igemm/gen/3x4c8-minmax-fp32-xop-ld64.c b/src/qc8-igemm/gen/3x4c8-minmax-fp32-xop-ld64.c
index 0433e8f..00c7c77 100644
--- a/src/qc8-igemm/gen/3x4c8-minmax-fp32-xop-ld64.c
+++ b/src/qc8-igemm/gen/3x4c8-minmax-fp32-xop-ld64.c
@@ -69,7 +69,7 @@
__m128i vacc2x1 = vacc0x1;
__m128i vacc2x2 = vacc0x2;
__m128i vacc2x3 = vacc0x3;
- w = (const void*) ((uintptr_t) w + 4 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 4);
size_t p = ks;
do {
@@ -105,26 +105,26 @@
vacc0x0 = _mm_maddd_epi16(vxa0, vxb0, vacc0x0);
vacc1x0 = _mm_maddd_epi16(vxa1, vxb0, vacc1x0);
vacc2x0 = _mm_maddd_epi16(vxa2, vxb0, vacc2x0);
- const __m128i vb1 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 8));
+ const __m128i vb1 = _mm_loadl_epi64((const __m128i*) ((const int8_t*) w + 8));
const __m128i vxb1 = _mm_cvtepi8_epi16(vb1);
vacc0x1 = _mm_maddd_epi16(vxa0, vxb1, vacc0x1);
vacc1x1 = _mm_maddd_epi16(vxa1, vxb1, vacc1x1);
vacc2x1 = _mm_maddd_epi16(vxa2, vxb1, vacc2x1);
- const __m128i vb2 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16));
+ const __m128i vb2 = _mm_loadl_epi64((const __m128i*) ((const int8_t*) w + 16));
const __m128i vxb2 = _mm_cvtepi8_epi16(vb2);
vacc0x2 = _mm_maddd_epi16(vxa0, vxb2, vacc0x2);
vacc1x2 = _mm_maddd_epi16(vxa1, vxb2, vacc1x2);
vacc2x2 = _mm_maddd_epi16(vxa2, vxb2, vacc2x2);
- const __m128i vb3 = _mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 24));
+ const __m128i vb3 = _mm_loadl_epi64((const __m128i*) ((const int8_t*) w + 24));
const __m128i vxb3 = _mm_cvtepi8_epi16(vb3);
vacc0x3 = _mm_maddd_epi16(vxa0, vxb3, vacc0x3);
vacc1x3 = _mm_maddd_epi16(vxa1, vxb3, vacc1x3);
vacc2x3 = _mm_maddd_epi16(vxa2, vxb3, vacc2x3);
- w = (const void*) ((uintptr_t) w + 32);
+ w = (const void*) ((const int8_t*) w + 32);
k += 8 * sizeof(int8_t);
}
p -= 3 * sizeof(void*);
@@ -146,7 +146,7 @@
__m128 vscaled2x0123 = _mm_cvtepi32_ps(vacc2x0123);
const __m128 vscale0123 = _mm_load_ps((const float*) w);
- w = (const void*) ((uintptr_t) w + 4 * sizeof(float));
+ w = (const void*) ((const float*) w + 4);
vscaled0x0123 = _mm_mul_ps(vscaled0x0123, vscale0123);
vscaled1x0123 = _mm_mul_ps(vscaled1x0123, vscale0123);
vscaled2x0123 = _mm_mul_ps(vscaled2x0123, vscale0123);