Refactor DWCONV micro-kernels
- Fix bugs in generation of micro-kernels with large channel tiles
- Add missing unit tests
- Generate, test, and benchmark a microkernels with 2 accumulators, with 2X
channel tile, and their combinations
PiperOrigin-RevId: 279137161
diff --git a/src/f32-dwconv/up4x9-psimd.c b/src/f32-dwconv/up4x9-psimd.c
index 8f820b7..6b41fe8 100644
--- a/src/f32-dwconv/up4x9-psimd.c
+++ b/src/f32-dwconv/up4x9-psimd.c
@@ -44,114 +44,122 @@
size_t c = channels;
const float* w = weights;
for (; c >= 4; c -= 4) {
- psimd_f32 vacc0 = psimd_load_f32(w);
+ psimd_f32 vacc0123p0 = psimd_load_f32(w);
- const psimd_f32 vi0 = psimd_load_f32(i0);
- const psimd_f32 vk0 = psimd_load_f32(w + 4);
- vacc0 = psimd_qfma_f32(vacc0, vi0, vk0);
+
+ const psimd_f32 vi0x0123 = psimd_load_f32(i0);
i0 += 4;
- const psimd_f32 vi1 = psimd_load_f32(i1);
- const psimd_f32 vk1 = psimd_load_f32(w + 8);
- psimd_f32 vacc1 = psimd_mul_f32(vi1, vk1);
+ const psimd_f32 vk0x0123 = psimd_load_f32(w + 4);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi0x0123, vk0x0123);
+
+ const psimd_f32 vi1x0123 = psimd_load_f32(i1);
i1 += 4;
- const psimd_f32 vi2 = psimd_load_f32(i2);
- const psimd_f32 vk2 = psimd_load_f32(w + 12);
- vacc0 = psimd_qfma_f32(vacc0, vi2, vk2);
+ const psimd_f32 vk1x0123 = psimd_load_f32(w + 8);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi1x0123, vk1x0123);
+
+ const psimd_f32 vi2x0123 = psimd_load_f32(i2);
i2 += 4;
- const psimd_f32 vi3 = psimd_load_f32(i3);
- const psimd_f32 vk3 = psimd_load_f32(w + 16);
- vacc1 = psimd_qfma_f32(vacc1, vi3, vk3);
+ const psimd_f32 vk2x0123 = psimd_load_f32(w + 12);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi2x0123, vk2x0123);
+
+ const psimd_f32 vi3x0123 = psimd_load_f32(i3);
i3 += 4;
- const psimd_f32 vi4 = psimd_load_f32(i4);
- const psimd_f32 vk4 = psimd_load_f32(w + 20);
- vacc0 = psimd_qfma_f32(vacc0, vi4, vk4);
+ const psimd_f32 vk3x0123 = psimd_load_f32(w + 16);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi3x0123, vk3x0123);
+
+ const psimd_f32 vi4x0123 = psimd_load_f32(i4);
i4 += 4;
- const psimd_f32 vi5 = psimd_load_f32(i5);
- const psimd_f32 vk5 = psimd_load_f32(w + 24);
- vacc1 = psimd_qfma_f32(vacc1, vi5, vk5);
+ const psimd_f32 vk4x0123 = psimd_load_f32(w + 20);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi4x0123, vk4x0123);
+
+ const psimd_f32 vi5x0123 = psimd_load_f32(i5);
i5 += 4;
- const psimd_f32 vi6 = psimd_load_f32(i6);
- const psimd_f32 vk6 = psimd_load_f32(w + 28);
- vacc0 = psimd_qfma_f32(vacc0, vi6, vk6);
+ const psimd_f32 vk5x0123 = psimd_load_f32(w + 24);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi5x0123, vk5x0123);
+
+ const psimd_f32 vi6x0123 = psimd_load_f32(i6);
i6 += 4;
- const psimd_f32 vi7 = psimd_load_f32(i7);
- const psimd_f32 vk7 = psimd_load_f32(w + 32);
- vacc1 = psimd_qfma_f32(vacc1, vi7, vk7);
+ const psimd_f32 vk6x0123 = psimd_load_f32(w + 28);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi6x0123, vk6x0123);
+
+ const psimd_f32 vi7x0123 = psimd_load_f32(i7);
i7 += 4;
- const psimd_f32 vi8 = psimd_load_f32(i8);
- const psimd_f32 vk8 = psimd_load_f32(w + 36);
- vacc0 = psimd_qfma_f32(vacc0, vi8, vk8);
+ const psimd_f32 vk7x0123 = psimd_load_f32(w + 32);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi7x0123, vk7x0123);
+
+ const psimd_f32 vi8x0123 = psimd_load_f32(i8);
i8 += 4;
+ const psimd_f32 vk8x0123 = psimd_load_f32(w + 36);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi8x0123, vk8x0123);
+
w += 40;
- vacc0 = psimd_add_f32(vacc0, vacc1);
- vacc0 = psimd_max_f32(vacc0, vmin);
- vacc0 = psimd_min_f32(vacc0, vmax);
+ psimd_f32 vacc0123 = psimd_max_f32(vacc0123p0, vmin);
+ vacc0123 = psimd_min_f32(vacc0123, vmax);
- psimd_store_f32(output, vacc0);
+ psimd_store_f32(output, vacc0123);
output += 4;
}
if XNN_UNLIKELY(c != 0) {
- psimd_f32 vacc = psimd_load_f32(w);
+ psimd_f32 vacc0123p0 = psimd_load_f32(w);
- const psimd_f32 vi0 = psimd_load_f32(i0);
- const psimd_f32 vk0 = psimd_load_f32(w + 4);
- vacc = psimd_qfma_f32(vacc, vi0, vk0);
+ const psimd_f32 vi0x0123 = psimd_load_f32(i0);
+ const psimd_f32 vk0x0123 = psimd_load_f32(w + 4);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi0x0123, vk0x0123);
- const psimd_f32 vi1 = psimd_load_f32(i1);
- const psimd_f32 vk1 = psimd_load_f32(w + 8);
- vacc = psimd_qfma_f32(vacc, vi1, vk1);
+ const psimd_f32 vi1x0123 = psimd_load_f32(i1);
+ const psimd_f32 vk1x0123 = psimd_load_f32(w + 8);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi1x0123, vk1x0123);
- const psimd_f32 vi2 = psimd_load_f32(i2);
- const psimd_f32 vk2 = psimd_load_f32(w + 12);
- vacc = psimd_qfma_f32(vacc, vi2, vk2);
+ const psimd_f32 vi2x0123 = psimd_load_f32(i2);
+ const psimd_f32 vk2x0123 = psimd_load_f32(w + 12);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi2x0123, vk2x0123);
- const psimd_f32 vi3 = psimd_load_f32(i3);
- const psimd_f32 vk3 = psimd_load_f32(w + 16);
- vacc = psimd_qfma_f32(vacc, vi3, vk3);
+ const psimd_f32 vi3x0123 = psimd_load_f32(i3);
+ const psimd_f32 vk3x0123 = psimd_load_f32(w + 16);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi3x0123, vk3x0123);
- const psimd_f32 vi4 = psimd_load_f32(i4);
- const psimd_f32 vk4 = psimd_load_f32(w + 20);
- vacc = psimd_qfma_f32(vacc, vi4, vk4);
+ const psimd_f32 vi4x0123 = psimd_load_f32(i4);
+ const psimd_f32 vk4x0123 = psimd_load_f32(w + 20);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi4x0123, vk4x0123);
- const psimd_f32 vi5 = psimd_load_f32(i5);
- const psimd_f32 vk5 = psimd_load_f32(w + 24);
- vacc = psimd_qfma_f32(vacc, vi5, vk5);
+ const psimd_f32 vi5x0123 = psimd_load_f32(i5);
+ const psimd_f32 vk5x0123 = psimd_load_f32(w + 24);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi5x0123, vk5x0123);
- const psimd_f32 vi6 = psimd_load_f32(i6);
- const psimd_f32 vk6 = psimd_load_f32(w + 28);
- vacc = psimd_qfma_f32(vacc, vi6, vk6);
+ const psimd_f32 vi6x0123 = psimd_load_f32(i6);
+ const psimd_f32 vk6x0123 = psimd_load_f32(w + 28);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi6x0123, vk6x0123);
- const psimd_f32 vi7 = psimd_load_f32(i7);
- const psimd_f32 vk7 = psimd_load_f32(w + 32);
- vacc = psimd_qfma_f32(vacc, vi7, vk7);
+ const psimd_f32 vi7x0123 = psimd_load_f32(i7);
+ const psimd_f32 vk7x0123 = psimd_load_f32(w + 32);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi7x0123, vk7x0123);
- const psimd_f32 vi8 = psimd_load_f32(i8);
- const psimd_f32 vk8 = psimd_load_f32(w + 36);
- vacc = psimd_qfma_f32(vacc, vi8, vk8);
+ const psimd_f32 vi8x0123 = psimd_load_f32(i8);
+ const psimd_f32 vk8x0123 = psimd_load_f32(w + 36);
+ vacc0123p0 = psimd_qfma_f32(vacc0123p0, vi8x0123, vk8x0123);
- w += 40;
- vacc = psimd_max_f32(vacc, vmin);
- vacc = psimd_min_f32(vacc, vmax);
+ psimd_f32 vacc0123 = psimd_max_f32(vacc0123p0, vmin);
+ vacc0123 = psimd_min_f32(vacc0123, vmax);
if (c & 2) {
- psimd_store2_f32(output, vacc);
- vacc = psimd_concat_hi_f32(vacc, vacc);
+ psimd_store2_f32(output, vacc0123);
+ vacc0123 = psimd_concat_hi_f32(vacc0123, vacc0123);
output += 2;
}
if (c & 1) {
- psimd_store1_f32(output, vacc);
+ psimd_store1_f32(output, vacc0123);
output += 1;
}
}