QS8 C16 GEMM microkernel source renamed from mull to mlal
PiperOrigin-RevId: 359842415
diff --git a/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
index 158f244..079d6d2 100644
--- a/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
@@ -77,112 +77,97 @@
int32x4_t vacc2x6 = vacc0x6;
int32x4_t vacc2x7 = vacc0x7;
- int k = (int) kc;
- while (k > 0) {
- const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
- const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
- const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
- const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
- const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
- const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+ // KC loop of 16 with up to 15 remainder
+ size_t k = 0;
+ while (k < kc) {
+ const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
+ const int8x16_t va1 = vld1q_s8(a1); a1 += 16;
+ const int8x16_t va2 = vld1q_s8(a2); a2 += 16;
- const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
- const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+ const int8x16_t vb0 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+ const int8x16_t vb1 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+ const int8x16_t vb2 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+ const int8x16_t vb3 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+ const int8x16_t vb4 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+ const int8x16_t vb5 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+ const int8x16_t vb6 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+ const int8x16_t vb7 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
- int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
- vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+ int16x8_t vprod0x0 = vmull_s8(vget_low_s8(vb0), vget_low_s8(va0));
+ int16x8_t vprod1x0 = vmull_s8(vget_low_s8(vb0), vget_low_s8(va1));
+ int16x8_t vprod2x0 = vmull_s8(vget_low_s8(vb0), vget_low_s8(va2));
+ vprod0x0 = vmlal_s8(vprod0x0, vget_high_s8(vb0), vget_high_s8(va0));
+ vprod1x0 = vmlal_s8(vprod1x0, vget_high_s8(vb0), vget_high_s8(va1));
+ vprod2x0 = vmlal_s8(vprod2x0, vget_high_s8(vb0), vget_high_s8(va2));
vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
- int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
- vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
- vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
- int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
- vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
- vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
- int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
- vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
- vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
- int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
- vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
- vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
- int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
- vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
- vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
- int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
- vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
- vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
- int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
- vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
- vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
- int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
- vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
- int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
- vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
- vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
- int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
- vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
- vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
- int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
- vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
- vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
- int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
- vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
- vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
- int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
- vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
- vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
- int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
- vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
- vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
- int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
- vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
- vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
- int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
- vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
- int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
- vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+ int16x8_t vprod0x1 = vmull_s8(vget_low_s8(vb1), vget_low_s8(va0));
+ int16x8_t vprod1x1 = vmull_s8(vget_low_s8(vb1), vget_low_s8(va1));
+ int16x8_t vprod2x1 = vmull_s8(vget_low_s8(vb1), vget_low_s8(va2));
+ vprod0x1 = vmlal_s8(vprod0x1, vget_high_s8(vb1), vget_high_s8(va0));
+ vprod1x1 = vmlal_s8(vprod1x1, vget_high_s8(vb1), vget_high_s8(va1));
+ vprod2x1 = vmlal_s8(vprod2x1, vget_high_s8(vb1), vget_high_s8(va2));
+ vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+ vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
- int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
- vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+ int16x8_t vprod0x2 = vmull_s8(vget_low_s8(vb2), vget_low_s8(va0));
+ int16x8_t vprod1x2 = vmull_s8(vget_low_s8(vb2), vget_low_s8(va1));
+ int16x8_t vprod2x2 = vmull_s8(vget_low_s8(vb2), vget_low_s8(va2));
+ vprod0x2 = vmlal_s8(vprod0x2, vget_high_s8(vb2), vget_high_s8(va0));
+ vprod1x2 = vmlal_s8(vprod1x2, vget_high_s8(vb2), vget_high_s8(va1));
+ vprod2x2 = vmlal_s8(vprod2x2, vget_high_s8(vb2), vget_high_s8(va2));
+ vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+ vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
- int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
- vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+ int16x8_t vprod0x3 = vmull_s8(vget_low_s8(vb3), vget_low_s8(va0));
+ int16x8_t vprod1x3 = vmull_s8(vget_low_s8(vb3), vget_low_s8(va1));
+ int16x8_t vprod2x3 = vmull_s8(vget_low_s8(vb3), vget_low_s8(va2));
+ vprod0x3 = vmlal_s8(vprod0x3, vget_high_s8(vb3), vget_high_s8(va0));
+ vprod1x3 = vmlal_s8(vprod1x3, vget_high_s8(vb3), vget_high_s8(va1));
+ vprod2x3 = vmlal_s8(vprod2x3, vget_high_s8(vb3), vget_high_s8(va2));
+ vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+ vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
- int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
- vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+ int16x8_t vprod0x4 = vmull_s8(vget_low_s8(vb4), vget_low_s8(va0));
+ int16x8_t vprod1x4 = vmull_s8(vget_low_s8(vb4), vget_low_s8(va1));
+ int16x8_t vprod2x4 = vmull_s8(vget_low_s8(vb4), vget_low_s8(va2));
+ vprod0x4 = vmlal_s8(vprod0x4, vget_high_s8(vb4), vget_high_s8(va0));
+ vprod1x4 = vmlal_s8(vprod1x4, vget_high_s8(vb4), vget_high_s8(va1));
+ vprod2x4 = vmlal_s8(vprod2x4, vget_high_s8(vb4), vget_high_s8(va2));
+ vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+ vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
- int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
- vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+ int16x8_t vprod0x5 = vmull_s8(vget_low_s8(vb5), vget_low_s8(va0));
+ int16x8_t vprod1x5 = vmull_s8(vget_low_s8(vb5), vget_low_s8(va1));
+ int16x8_t vprod2x5 = vmull_s8(vget_low_s8(vb5), vget_low_s8(va2));
+ vprod0x5 = vmlal_s8(vprod0x5, vget_high_s8(vb5), vget_high_s8(va0));
+ vprod1x5 = vmlal_s8(vprod1x5, vget_high_s8(vb5), vget_high_s8(va1));
+ vprod2x5 = vmlal_s8(vprod2x5, vget_high_s8(vb5), vget_high_s8(va2));
+ vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+ vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
- int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
- vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+ int16x8_t vprod0x6 = vmull_s8(vget_low_s8(vb6), vget_low_s8(va0));
+ int16x8_t vprod1x6 = vmull_s8(vget_low_s8(vb6), vget_low_s8(va1));
+ int16x8_t vprod2x6 = vmull_s8(vget_low_s8(vb6), vget_low_s8(va2));
+ vprod0x6 = vmlal_s8(vprod0x6, vget_high_s8(vb6), vget_high_s8(va0));
+ vprod1x6 = vmlal_s8(vprod1x6, vget_high_s8(vb6), vget_high_s8(va1));
+ vprod2x6 = vmlal_s8(vprod2x6, vget_high_s8(vb6), vget_high_s8(va2));
+ vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+ vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
- int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
- vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+ int16x8_t vprod0x7 = vmull_s8(vget_low_s8(vb7), vget_low_s8(va0));
+ int16x8_t vprod1x7 = vmull_s8(vget_low_s8(vb7), vget_low_s8(va1));
+ int16x8_t vprod2x7 = vmull_s8(vget_low_s8(vb7), vget_low_s8(va2));
+ vprod0x7 = vmlal_s8(vprod0x7, vget_high_s8(vb7), vget_high_s8(va0));
+ vprod1x7 = vmlal_s8(vprod1x7, vget_high_s8(vb7), vget_high_s8(va1));
+ vprod2x7 = vmlal_s8(vprod2x7, vget_high_s8(vb7), vget_high_s8(va2));
+ vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+ vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
- k -= 16 * sizeof(int8_t);
+ k += 16 * sizeof(int8_t);
}
- // End of accumulation loop. The variable `k` contains the amount by which
- // we advanced the `va` pointers, so we rewind by this amount now.
- a0 = (const int8_t*)((uintptr_t)a0 + k);
- a1 = (const int8_t*)((uintptr_t)a1 + k);
- a2 = (const int8_t*)((uintptr_t)a2 + k);
#if XNN_ARCH_ARM64
const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
@@ -305,9 +290,9 @@
c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
- a0 = (const int8_t*) ((uintptr_t) a0 - kc);
- a1 = (const int8_t*) ((uintptr_t) a1 - kc);
- a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+ a0 = (const int8_t*) ((uintptr_t) a0 - k);
+ a1 = (const int8_t*) ((uintptr_t) a1 - k);
+ a2 = (const int8_t*) ((uintptr_t) a2 - k);
nc -= 8;
} else {