QS8 C16 GEMM microkernel source renamed from mull to mlal

PiperOrigin-RevId: 359842415
diff --git a/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c b/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
index 158f244..079d6d2 100644
--- a/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
+++ b/src/qs8-gemm/gen/3x8c16-minmax-neon-mlal-padal.c
@@ -77,112 +77,97 @@
     int32x4_t vacc2x6 = vacc0x6;
     int32x4_t vacc2x7 = vacc0x7;
 
-    int k = (int) kc;
-    while (k > 0) {
-      const int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
-      const int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
-      const int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
-      const int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
-      const int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
-      const int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
+    // KC loop of 16 with up to 15 remainder
+    size_t k = 0;
+    while (k < kc) {
+      const int8x16_t va0 = vld1q_s8(a0); a0 += 16;
+      const int8x16_t va1 = vld1q_s8(a1); a1 += 16;
+      const int8x16_t va2 = vld1q_s8(a2); a2 += 16;
 
-      const int8x8_t vb0x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb0x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb1x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb1x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb2x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb2x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb3x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb3x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb4x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb4x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb5x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb5x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb6x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb6x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb7x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
-      const int8x8_t vb7x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
+      const int8x16_t vb0 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+      const int8x16_t vb1 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+      const int8x16_t vb2 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+      const int8x16_t vb3 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+      const int8x16_t vb4 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+      const int8x16_t vb5 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+      const int8x16_t vb6 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
+      const int8x16_t vb7 = vld1q_s8(w); w = (const void*) ((uintptr_t) w + 16 * sizeof(int8_t));
 
-      int16x8_t vprod0x0 = vmull_s8(vb0x0, va0x0);
-      vprod0x0 = vmlal_s8(vprod0x0, vb0x1, va0x1);
+      int16x8_t vprod0x0 = vmull_s8(vget_low_s8(vb0), vget_low_s8(va0));
+      int16x8_t vprod1x0 = vmull_s8(vget_low_s8(vb0), vget_low_s8(va1));
+      int16x8_t vprod2x0 = vmull_s8(vget_low_s8(vb0), vget_low_s8(va2));
+      vprod0x0 = vmlal_s8(vprod0x0, vget_high_s8(vb0), vget_high_s8(va0));
+      vprod1x0 = vmlal_s8(vprod1x0, vget_high_s8(vb0), vget_high_s8(va1));
+      vprod2x0 = vmlal_s8(vprod2x0, vget_high_s8(vb0), vget_high_s8(va2));
       vacc0x0 = vpadalq_s16(vacc0x0, vprod0x0);
-      int16x8_t vprod0x1 = vmull_s8(vb1x0, va0x0);
-      vprod0x1 = vmlal_s8(vprod0x1, vb1x1, va0x1);
-      vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
-      int16x8_t vprod0x2 = vmull_s8(vb2x0, va0x0);
-      vprod0x2 = vmlal_s8(vprod0x2, vb2x1, va0x1);
-      vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
-      int16x8_t vprod0x3 = vmull_s8(vb3x0, va0x0);
-      vprod0x3 = vmlal_s8(vprod0x3, vb3x1, va0x1);
-      vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
-      int16x8_t vprod0x4 = vmull_s8(vb4x0, va0x0);
-      vprod0x4 = vmlal_s8(vprod0x4, vb4x1, va0x1);
-      vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
-      int16x8_t vprod0x5 = vmull_s8(vb5x0, va0x0);
-      vprod0x5 = vmlal_s8(vprod0x5, vb5x1, va0x1);
-      vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
-      int16x8_t vprod0x6 = vmull_s8(vb6x0, va0x0);
-      vprod0x6 = vmlal_s8(vprod0x6, vb6x1, va0x1);
-      vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
-      int16x8_t vprod0x7 = vmull_s8(vb7x0, va0x0);
-      vprod0x7 = vmlal_s8(vprod0x7, vb7x1, va0x1);
-      vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
-      int16x8_t vprod1x0 = vmull_s8(vb0x0, va1x0);
-      vprod1x0 = vmlal_s8(vprod1x0, vb0x1, va1x1);
       vacc1x0 = vpadalq_s16(vacc1x0, vprod1x0);
-      int16x8_t vprod1x1 = vmull_s8(vb1x0, va1x0);
-      vprod1x1 = vmlal_s8(vprod1x1, vb1x1, va1x1);
-      vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
-      int16x8_t vprod1x2 = vmull_s8(vb2x0, va1x0);
-      vprod1x2 = vmlal_s8(vprod1x2, vb2x1, va1x1);
-      vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
-      int16x8_t vprod1x3 = vmull_s8(vb3x0, va1x0);
-      vprod1x3 = vmlal_s8(vprod1x3, vb3x1, va1x1);
-      vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
-      int16x8_t vprod1x4 = vmull_s8(vb4x0, va1x0);
-      vprod1x4 = vmlal_s8(vprod1x4, vb4x1, va1x1);
-      vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
-      int16x8_t vprod1x5 = vmull_s8(vb5x0, va1x0);
-      vprod1x5 = vmlal_s8(vprod1x5, vb5x1, va1x1);
-      vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
-      int16x8_t vprod1x6 = vmull_s8(vb6x0, va1x0);
-      vprod1x6 = vmlal_s8(vprod1x6, vb6x1, va1x1);
-      vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
-      int16x8_t vprod1x7 = vmull_s8(vb7x0, va1x0);
-      vprod1x7 = vmlal_s8(vprod1x7, vb7x1, va1x1);
-      vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
-      int16x8_t vprod2x0 = vmull_s8(vb0x0, va2x0);
-      vprod2x0 = vmlal_s8(vprod2x0, vb0x1, va2x1);
       vacc2x0 = vpadalq_s16(vacc2x0, vprod2x0);
-      int16x8_t vprod2x1 = vmull_s8(vb1x0, va2x0);
-      vprod2x1 = vmlal_s8(vprod2x1, vb1x1, va2x1);
+      int16x8_t vprod0x1 = vmull_s8(vget_low_s8(vb1), vget_low_s8(va0));
+      int16x8_t vprod1x1 = vmull_s8(vget_low_s8(vb1), vget_low_s8(va1));
+      int16x8_t vprod2x1 = vmull_s8(vget_low_s8(vb1), vget_low_s8(va2));
+      vprod0x1 = vmlal_s8(vprod0x1, vget_high_s8(vb1), vget_high_s8(va0));
+      vprod1x1 = vmlal_s8(vprod1x1, vget_high_s8(vb1), vget_high_s8(va1));
+      vprod2x1 = vmlal_s8(vprod2x1, vget_high_s8(vb1), vget_high_s8(va2));
+      vacc0x1 = vpadalq_s16(vacc0x1, vprod0x1);
+      vacc1x1 = vpadalq_s16(vacc1x1, vprod1x1);
       vacc2x1 = vpadalq_s16(vacc2x1, vprod2x1);
-      int16x8_t vprod2x2 = vmull_s8(vb2x0, va2x0);
-      vprod2x2 = vmlal_s8(vprod2x2, vb2x1, va2x1);
+      int16x8_t vprod0x2 = vmull_s8(vget_low_s8(vb2), vget_low_s8(va0));
+      int16x8_t vprod1x2 = vmull_s8(vget_low_s8(vb2), vget_low_s8(va1));
+      int16x8_t vprod2x2 = vmull_s8(vget_low_s8(vb2), vget_low_s8(va2));
+      vprod0x2 = vmlal_s8(vprod0x2, vget_high_s8(vb2), vget_high_s8(va0));
+      vprod1x2 = vmlal_s8(vprod1x2, vget_high_s8(vb2), vget_high_s8(va1));
+      vprod2x2 = vmlal_s8(vprod2x2, vget_high_s8(vb2), vget_high_s8(va2));
+      vacc0x2 = vpadalq_s16(vacc0x2, vprod0x2);
+      vacc1x2 = vpadalq_s16(vacc1x2, vprod1x2);
       vacc2x2 = vpadalq_s16(vacc2x2, vprod2x2);
-      int16x8_t vprod2x3 = vmull_s8(vb3x0, va2x0);
-      vprod2x3 = vmlal_s8(vprod2x3, vb3x1, va2x1);
+      int16x8_t vprod0x3 = vmull_s8(vget_low_s8(vb3), vget_low_s8(va0));
+      int16x8_t vprod1x3 = vmull_s8(vget_low_s8(vb3), vget_low_s8(va1));
+      int16x8_t vprod2x3 = vmull_s8(vget_low_s8(vb3), vget_low_s8(va2));
+      vprod0x3 = vmlal_s8(vprod0x3, vget_high_s8(vb3), vget_high_s8(va0));
+      vprod1x3 = vmlal_s8(vprod1x3, vget_high_s8(vb3), vget_high_s8(va1));
+      vprod2x3 = vmlal_s8(vprod2x3, vget_high_s8(vb3), vget_high_s8(va2));
+      vacc0x3 = vpadalq_s16(vacc0x3, vprod0x3);
+      vacc1x3 = vpadalq_s16(vacc1x3, vprod1x3);
       vacc2x3 = vpadalq_s16(vacc2x3, vprod2x3);
-      int16x8_t vprod2x4 = vmull_s8(vb4x0, va2x0);
-      vprod2x4 = vmlal_s8(vprod2x4, vb4x1, va2x1);
+      int16x8_t vprod0x4 = vmull_s8(vget_low_s8(vb4), vget_low_s8(va0));
+      int16x8_t vprod1x4 = vmull_s8(vget_low_s8(vb4), vget_low_s8(va1));
+      int16x8_t vprod2x4 = vmull_s8(vget_low_s8(vb4), vget_low_s8(va2));
+      vprod0x4 = vmlal_s8(vprod0x4, vget_high_s8(vb4), vget_high_s8(va0));
+      vprod1x4 = vmlal_s8(vprod1x4, vget_high_s8(vb4), vget_high_s8(va1));
+      vprod2x4 = vmlal_s8(vprod2x4, vget_high_s8(vb4), vget_high_s8(va2));
+      vacc0x4 = vpadalq_s16(vacc0x4, vprod0x4);
+      vacc1x4 = vpadalq_s16(vacc1x4, vprod1x4);
       vacc2x4 = vpadalq_s16(vacc2x4, vprod2x4);
-      int16x8_t vprod2x5 = vmull_s8(vb5x0, va2x0);
-      vprod2x5 = vmlal_s8(vprod2x5, vb5x1, va2x1);
+      int16x8_t vprod0x5 = vmull_s8(vget_low_s8(vb5), vget_low_s8(va0));
+      int16x8_t vprod1x5 = vmull_s8(vget_low_s8(vb5), vget_low_s8(va1));
+      int16x8_t vprod2x5 = vmull_s8(vget_low_s8(vb5), vget_low_s8(va2));
+      vprod0x5 = vmlal_s8(vprod0x5, vget_high_s8(vb5), vget_high_s8(va0));
+      vprod1x5 = vmlal_s8(vprod1x5, vget_high_s8(vb5), vget_high_s8(va1));
+      vprod2x5 = vmlal_s8(vprod2x5, vget_high_s8(vb5), vget_high_s8(va2));
+      vacc0x5 = vpadalq_s16(vacc0x5, vprod0x5);
+      vacc1x5 = vpadalq_s16(vacc1x5, vprod1x5);
       vacc2x5 = vpadalq_s16(vacc2x5, vprod2x5);
-      int16x8_t vprod2x6 = vmull_s8(vb6x0, va2x0);
-      vprod2x6 = vmlal_s8(vprod2x6, vb6x1, va2x1);
+      int16x8_t vprod0x6 = vmull_s8(vget_low_s8(vb6), vget_low_s8(va0));
+      int16x8_t vprod1x6 = vmull_s8(vget_low_s8(vb6), vget_low_s8(va1));
+      int16x8_t vprod2x6 = vmull_s8(vget_low_s8(vb6), vget_low_s8(va2));
+      vprod0x6 = vmlal_s8(vprod0x6, vget_high_s8(vb6), vget_high_s8(va0));
+      vprod1x6 = vmlal_s8(vprod1x6, vget_high_s8(vb6), vget_high_s8(va1));
+      vprod2x6 = vmlal_s8(vprod2x6, vget_high_s8(vb6), vget_high_s8(va2));
+      vacc0x6 = vpadalq_s16(vacc0x6, vprod0x6);
+      vacc1x6 = vpadalq_s16(vacc1x6, vprod1x6);
       vacc2x6 = vpadalq_s16(vacc2x6, vprod2x6);
-      int16x8_t vprod2x7 = vmull_s8(vb7x0, va2x0);
-      vprod2x7 = vmlal_s8(vprod2x7, vb7x1, va2x1);
+      int16x8_t vprod0x7 = vmull_s8(vget_low_s8(vb7), vget_low_s8(va0));
+      int16x8_t vprod1x7 = vmull_s8(vget_low_s8(vb7), vget_low_s8(va1));
+      int16x8_t vprod2x7 = vmull_s8(vget_low_s8(vb7), vget_low_s8(va2));
+      vprod0x7 = vmlal_s8(vprod0x7, vget_high_s8(vb7), vget_high_s8(va0));
+      vprod1x7 = vmlal_s8(vprod1x7, vget_high_s8(vb7), vget_high_s8(va1));
+      vprod2x7 = vmlal_s8(vprod2x7, vget_high_s8(vb7), vget_high_s8(va2));
+      vacc0x7 = vpadalq_s16(vacc0x7, vprod0x7);
+      vacc1x7 = vpadalq_s16(vacc1x7, vprod1x7);
       vacc2x7 = vpadalq_s16(vacc2x7, vprod2x7);
 
-      k -= 16 * sizeof(int8_t);
+      k += 16 * sizeof(int8_t);
     }
-    // End of accumulation loop. The variable `k` contains the amount by which
-    // we advanced the `va` pointers, so we rewind by this amount now.
-    a0 = (const int8_t*)((uintptr_t)a0 + k);
-    a1 = (const int8_t*)((uintptr_t)a1 + k);
-    a2 = (const int8_t*)((uintptr_t)a2 + k);
 
 #if XNN_ARCH_ARM64
     const int32x4_t vsum0x01 = vpaddq_s32(vacc0x0, vacc0x1);
@@ -305,9 +290,9 @@
       c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
       c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
 
-      a0 = (const int8_t*) ((uintptr_t) a0 - kc);
-      a1 = (const int8_t*) ((uintptr_t) a1 - kc);
-      a2 = (const int8_t*) ((uintptr_t) a2 - kc);
+      a0 = (const int8_t*) ((uintptr_t) a0 - k);
+      a1 = (const int8_t*) ((uintptr_t) a1 - k);
+      a2 = (const int8_t*) ((uintptr_t) a2 - k);
 
       nc -= 8;
     } else {